将问题转换为 SQL 查询:模型将用户输入转换为 SQL 查询 执行 SQL 查询:执行查询 回答问题:模型使用查询结果响应用户输入。
curl -s https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql | sqlite3 Chinook.db
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.run("SELECT * FROM Artist LIMIT 10;"))
sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]
将问题转换为 SQL 查询; 执行查询; 利用结果来回答原始问题。
from typing_extensions import TypedDict
class State(TypedDict):
question: str
query: str
result: str
answer: str
model = ChatOllama(base_url=base_conf.base_url,
model=base_conf.model_name,
temperature=0)
from langchain import hub
query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")
assert len(query_prompt_template.messages) == 1
query_prompt_template.messages[0].pretty_print()
================================ System Message ================================
Given an input question, create a syntactically correct {dialect} query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Only use the following tables:
{table_info}
Question: {input}
from typing_extensions import Annotated
# 结构化输出的schema
class QueryOutput(TypedDict):
"""Generated SQL query."""
query: Annotated[str, ..., "Syntactically valid SQL query."]
def write_query(state: State):
"""Generate SQL query to fetch information."""
prompt = query_prompt_template.invoke(
# 填充 sql-query-system-prompt 的模板
{
"dialect": db.dialect,
"top_k": 10,
"table_info": db.get_table_info(),
"input": state["question"],
}
)
# 注册结构化输出
structured_llm = model.with_structured_output(QueryOutput)
result = structured_llm.invoke(prompt)
# 将QueryOutput的query字段返回
return {"query": result["query"]}
write_query({"question": "How many Employees are there?"})
{'query': 'SELECT COUNT(*) FROM Employee'}
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
def execute_query(state: State):
"""Execute SQL query."""
execute_query_tool = QuerySQLDatabaseTool(db=db)
return {"result": execute_query_tool.invoke(state["query"])}
print(execute_query({"query": "SELECT COUNT(*) FROM Employee;"}))
{'result': '[(8,)]'}
def generate_answer(state: State):
"""Answer question using retrieved information as context."""
prompt = (
"Given the following user question, corresponding SQL query, "
"and SQL result, answer the user question.\n\n"
f'Question: {state["question"]}\n'
f'SQL Query: {state["query"]}\n'
f'SQL Result: {state["result"]}'
)
response = model.invoke(prompt)
return {"answer": response.content}
使用 LangGraph 进行编排
from langgraph.graph import START, StateGraph
graph_builder = StateGraph(State).add_sequence(
[write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()
for step in graph.stream(
{"question": "How many employees are there?"}, stream_mode="updates"
):
print(step)
{'write_query': {'query': 'SELECT COUNT(*) FROM Employee'}}
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': "Based on the SQL result provided, the number of employees is 8. Therefore, to directly answer the user's question:\n\nThere are 8 employees."}}
from langgraph.checkpoint.memory import MemorySaver
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory, interrupt_before=["execute_query"])
# 现在我们使用持久性,我们需要指定一个线程ID
# 这样我们就可以在审查后继续运行。
config = {"configurable": {"thread_id": "1"}}
for step in graph.stream(
{"question": "How many employees are there?"},
config,
stream_mode="updates",
):
print(step)
try:
# 接受一个输入,询问用户是否继续
user_approval = input("Do you want to go to execute query? (yes/no): ")
except Exception:
user_approval = "no"
if user_approval.lower() == "yes":
# 如果通过继续graph的流程
for step in graph.stream(None, config, stream_mode="updates"):
print(step)
else:
print("Operation cancelled by user.")
{'write_query': {'query': 'SELECT COUNT(*) FROM Employee'}}
{'__interrupt__': ()}
Do you want to go to execute query? (yes/no): yes
{'execute_query': {'result': '[(8,)]'}}
{'generate_answer': {'answer': "Based on the SQL result provided, the number of employees is 8. Therefore, to directly answer the user's question:\n\nThere are 8 employees."}}
他们可以根据需要多次查询数据库来回答用户的问题。 他们可以通过运行生成的查询、捕获回溯并正确地重新生成来从错误中恢复。 他们可以根据数据库的schame以及数据库的内容(如描述特定的表)回答问题。
创建和执行查询 检查查询语法 检索表描述 ... 等等
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=model)
tools = toolkit.get_tools()
print(tools)
[QuerySQLDatabaseTool(description="此工具的输入是详细且正确的 SQL 查询,输出是数据库的结果。如果查询不正确,将返回错误消息。如果返回错误,请重写查询,检查查询,然后重试。如果您遇到“字段列表”中未知列“xxxx”的问题,请使用 sql_db_schema 查询正确的表字段.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x11f3ea6d0>),
InfoSQLDatabaseTool(description='此工具的输入是逗号分隔的表列表,输出是这些表的架构和示例行。首先调用 sql_db_list_tables 确保表确实存在!示例输入:table1、table2、table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x11f3ea6d0>),
ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x11f3ea6d0>),
QuerySQLCheckerTool(description='使用此工具在执行查询之前仔细检查查询是否正确。在使用 sql_db_query 执行查询之前,请务必使用此工具!', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x11f3ea6d0>, llm=ChatOllama(model='qwen2.5:7b', temperature=0.0, base_url='http://192.168.0.111:11434'), llm_chain=LLMChain(verbose=False, prompt=PromptTemplate(input_variables=['dialect', 'query'], input_types={}, partial_variables={}, template='\n{query}\nDouble check the {dialect} query above for common mistakes, including:\n- Using NOT IN with NULL values\n- Using UNION when UNION ALL should have been used\n- Using BETWEEN for exclusive ranges\n- Data type mismatch in predicates\n- Properly quoting identifiers\n- Using the correct number of arguments for functions\n- Casting to the correct data type\n- Using the proper columns for joins\n\nIf there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n\nOutput the final SQL query only.\n\nSQL Query: '), llm=ChatOllama(model='qwen2.5:7b', temperature=0.0, base_url='http://192.168.0.111:11434'), output_parser=StrOutputParser(), llm_kwargs={}))]
from langchain import hub
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
assert len(prompt_template.messages) == 1
prompt_template.messages[0].pretty_print()
system_message = prompt_template.format(dialect="SQLite", top_k=5)
================================ System Message ================================
You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
To start you should ALWAYS look at the tables in the database to see what you can query.
Do NOT skip this step.
Then you should query the schema of the most relevant tables.
from langchain_core.messages import HumanMessage
from langgraph.prebuilt import create_react_agent
agent_executor = create_react_agent(model, tools, state_modifier=system_message)
question = "哪个国家的顾客花费最多?"
for step in agent_executor.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
================================ Human Message =================================
哪个国家的顾客花费最多?
================================== Ai Message ==================================
Tool Calls:
sql_db_list_tables (2fc3f61b-c577-4a70-926c-cc57e21560d0)
Call ID: 2fc3f61b-c577-4a70-926c-cc57e21560d0
Args:
================================= Tool Message =================================
Name: sql_db_list_tables
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
Tool Calls:
sql_db_schema (9b6946be-2e19-42da-b2b7-101ac487db69)
Call ID: 9b6946be-2e19-42da-b2b7-101ac487db69
Args:
table_names: Customer, Invoice
================================= Tool Message =================================
Name: sql_db_schema
CREATE TABLE "Customer" (
"CustomerId" INTEGER NOT NULL,
"FirstName" NVARCHAR(40) NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"Company" NVARCHAR(80),
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60) NOT NULL,
"SupportRepId" INTEGER,
PRIMARY KEY ("CustomerId"),
FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)
/*
3 rows from Customer table:
CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId
1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 luisg@embraer.com.br 3
2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None leonekohler@surfeu.de 5
3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None ftremblay@gmail.com 3
*/
CREATE TABLE "Invoice" (
"InvoiceId" INTEGER NOT NULL,
"CustomerId" INTEGER NOT NULL,
"InvoiceDate" DATETIME NOT NULL,
"BillingAddress" NVARCHAR(70),
"BillingCity" NVARCHAR(40),
"BillingState" NVARCHAR(40),
"BillingCountry" NVARCHAR(40),
"BillingPostalCode" NVARCHAR(10),
"Total" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("InvoiceId"),
FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)
/*
3 rows from Invoice table:
InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total
1 2 2021-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98
2 4 2021-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96
3 8 2021-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94
*/
================================== Ai Message ==================================
Tool Calls:
sql_db_query_checker (0a223a92-4bc5-44bb-a394-155827558e8c)
Call ID: 0a223a92-4bc5-44bb-a394-155827558e8c
Args:
query: SELECT c.Country, SUM(i.Total) AS TotalSpent FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 5
================================= Tool Message =================================
Name: sql_db_query_checker
The provided SQL query does not contain any of the common mistakes listed. Here is the original query reproduced:
```sql
SELECT c.Country, SUM(i.Total) AS TotalSpent
FROM Invoice i
JOIN Customer c ON i.CustomerId = c.CustomerId
GROUP BY c.Country
ORDER BY TotalSpent DESC
LIMIT 5;
```
If you need any further adjustments or have additional requirements, please let me know!
================================== Ai Message ==================================
Tool Calls:
sql_db_query (088e0ec1-8e06-40ea-8a5b-da030662b172)
Call ID: 088e0ec1-8e06-40ea-8a5b-da030662b172
Args:
query: SELECT c.Country, SUM(i.Total) AS TotalSpent FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSpent DESC LIMIT 5
================================= Tool Message =================================
Name: sql_db_query
[('USA', 523.06), ('Canada', 303.96), ('France', 195.1), ('Brazil', 190.1), ('Germany', 156.48)]
================================== Ai Message ==================================
根据查询结果,花费最多的国家及其总消费金额如下:
1. 美国 (USA):$523.06
2. 加拿大 (Canada):$303.96
3. 法国 (France):$195.10
4. 巴西 (Brazil):$190.10
5. 德国 (Germany):$156.48
因此,花费最多的国家是美国。
列出可用表; 检索三个表的schema; 通过join操作查询多个表。
question = "使用中文描述 playlisttrack 表"
for step in agent_executor.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
================================ Human Message =================================
使用中文描述 playlisttrack 表
================================== Ai Message ==================================
Tool Calls:
sql_db_list_tables (ba3d7d38-551c-4092-982e-36356e0af18c)
Call ID: ba3d7d38-551c-4092-982e-36356e0af18c
Args:
================================= Tool Message =================================
Name: sql_db_list_tables
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
Tool Calls:
sql_db_schema (167c1c9c-3c3e-481f-8843-ff40ef94c03e)
Call ID: 167c1c9c-3c3e-481f-8843-ff40ef94c03e
Args:
table_names: PlaylistTrack
================================= Tool Message =================================
Name: sql_db_schema
CREATE TABLE "PlaylistTrack" (
"PlaylistId" INTEGER NOT NULL,
"TrackId" INTEGER NOT NULL,
PRIMARY KEY ("PlaylistId", "TrackId"),
FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"),
FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)
/*
3 rows from PlaylistTrack table:
PlaylistId TrackId
1 3402
1 3389
1 3390
*/
================================== Ai Message ==================================
根据 `PlaylistTrack` 表的结构,该表用于关联播放列表(Playlist)和歌曲(Track)。查询结果展示了三个示例记录,表示哪些歌曲被添加到了哪个播放列表中。
以下是 `PlaylistTrack` 表的 SQL 查询语句:
```sql
SELECT PlaylistId, TrackId
FROM PlaylistTrack
LIMIT 5;
```
这个查询将返回最多 5 条记录,并按 `PlaylistId` 和 `TrackId` 排序。
import ast
import re
# 这个函数 `query_as_list` 接受两个参数 `db` 和 `query`,并执行以下操作:
#
# 1. 使用 `db.run(query)` 执行 SQL 查询,并将结果存储在 `res` 变量中。
# 2. 使用 `ast.literal_eval(res)` 将查询结果字符串解析为 Python 对象,并通过列表推导式将嵌套列表展平,过滤掉空值。
# 3. 使用正则表达式 `re.sub(r"\b\d+\b", "", string).strip()` 去除结果中的所有数字,并去除字符串两端的空白。
# 4. 使用 `list(set(res))` 去除结果中的重复项,并返回一个唯一值的列表。
#
# 这个函数的目的是从数据库查询结果中提取非数字的唯一字符串。
def query_as_list(db, query):
res = db.run(query)
res = [el for sub in ast.literal_eval(res) for el in sub if el]
res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
return list(set(res))
artists = query_as_list(db, "SELECT Name FROM Artist")
albums = query_as_list(db, "SELECT Title FROM Album")
print(albums[:5])
['Sibelius: Finlandia',
'Demorou...',
'War',
'Pearl Jam',
'The Cream Of Clapton']
from langchain.agents.agent_toolkits import create_retriever_tool
from rag.component import vector_store
_ = vector_store.add_texts(artists + albums)
retriever = vector_store.as_retriever(search_kwargs={"k": 5})
description = (
"Use to look up values to filter on. Input is an approximate spelling "
"of the proper noun, output is valid proper nouns. Use the noun most "
"similar to the search."
)
retriever_tool = create_retriever_tool(
retriever,
name="search_proper_nouns",
description=description,
)
print(retriever_tool.invoke("Alice Chains"))
Alice In Chains
The Singles
The Posies
Holst: The Planets, Op. & Vaughan Williams: Fantasies
Slaves And Masters
# 添加 system message
suffix = (
"you must ALWAYS first look up the filter value using the 'search_proper_nouns' tool! "
"Do not try to guess at the proper name - use this function to find similar ones."
)
system = f"{system_message}\n\n{suffix}"
tools.append(retriever_tool)
agent = create_react_agent(model, tools, state_modifier=system)
# 爱丽丝梦游仙境一共有几张专辑?
question = "How many albums does alis in chain have?"
for step in agent.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()
================================ Human Message =================================
How many albums does alis in chain have?
================================== Ai Message ==================================
Tool Calls:
search_proper_nouns (a901e4a7-4340-4c83-99d5-79bd53b95dce)
Call ID: a901e4a7-4340-4c83-99d5-79bd53b95dce
Args:
query: alis in chain
================================= Tool Message =================================
Name: search_proper_nouns
Alice In Chains
Da Lama Ao Caos
Corinne Bailey Rae
Kid Abelha
Ace Of Spades
================================== Ai Message ==================================
Tool Calls:
sql_db_schema (511a0843-5589-4c63-b841-35c9f255c8e8)
Call ID: 511a0843-5589-4c63-b841-35c9f255c8e8
Args:
table_names: albums
================================= Tool Message =================================
Name: sql_db_schema
Error: table_names {'albums'} not found in database
================================== Ai Message ==================================
Tool Calls:
sql_db_list_tables (fac0656f-7d3d-4d50-8210-850e23fd6c90)
Call ID: fac0656f-7d3d-4d50-8210-850e23fd6c90
Args:
================================= Tool Message =================================
Name: sql_db_list_tables
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
================================== Ai Message ==================================
Tool Calls:
sql_db_schema (390d1cf7-e4cb-4cf1-aab9-c591ceeddc9c)
Call ID: 390d1cf7-e4cb-4cf1-aab9-c591ceeddc9c
Args:
table_names: Artist
================================= Tool Message =================================
Name: sql_db_schema
CREATE TABLE "Artist" (
"ArtistId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("ArtistId")
)
/*
3 rows from Artist table:
ArtistId Name
1 AC/DC
2 Accept
3 Aerosmith
*/
================================== Ai Message ==================================
Tool Calls:
sql_db_schema (9266f48a-228d-47ca-8bde-b905c85d1db9)
Call ID: 9266f48a-228d-47ca-8bde-b905c85d1db9
Args:
table_names: Album
================================= Tool Message =================================
Name: sql_db_schema
CREATE TABLE "Album" (
"AlbumId" INTEGER NOT NULL,
"Title" NVARCHAR(160) NOT NULL,
"ArtistId" INTEGER NOT NULL,
PRIMARY KEY ("AlbumId"),
FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)
/*
3 rows from Album table:
AlbumId Title ArtistId
1 For Those About To Rock We Salute You 1
2 Balls to the Wall 2
3 Restless and Wild 2
*/
================================== Ai Message ==================================
Tool Calls:
sql_db_query_checker (626449ff-f0c9-42aa-92e7-33d6d94fe6cb)
Call ID: 626449ff-f0c9-42aa-92e7-33d6d94fe6cb
Args:
query: SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name LIKE 'Alice In Chains%');
================================= Tool Message =================================
Name: sql_db_query_checker
The provided SQLite query appears to be correctly written and does not contain any of the common mistakes you've listed. Here is the query again for reference:
```sql
SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name LIKE 'Alice In Chains%');
```
This query selects the count of records from the `Album` table where the `ArtistId` matches any `ArtistId` in the `Artist` table whose `Name` starts with 'Alice In Chains'. There are no apparent issues with the query based on the criteria you provided.
If there is a specific context or additional information that might have led to concerns, please provide more details. Otherwise, the original query can be used as-is.
================================== Ai Message ==================================
Tool Calls:
sql_db_query (d3712aaf-ab7c-4532-992f-bf8b8c1c0528)
Call ID: d3712aaf-ab7c-4532-992f-bf8b8c1c0528
Args:
query: SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name LIKE 'Alice In Chains%');
================================= Tool Message =================================
Name: sql_db_query
[(1,)]
================================== Ai Message ==================================
The SQL query executed successfully and returned the result. Alice in Chains has 1 album in the database.
The answer to your question is: **Alice in Chains have 1 album**.
SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name LIKE 'Alice In Chains%')