SQL Agent with Cohere and LangChain (i-5O Case Study)
This tutorial demonstrates how to create a SQL agent using Cohere and LangChain. The agent can translate natural language queries coming from users into SQL, and execute them against a database. This powerful combination allows for intuitive interaction with databases without requiring direct SQL knowledge.
Key topics covered:
- Setting up the necessary libraries and environment
- Connecting to a SQLite database
- Configuring the LangChain SQL Toolkit
- Creating a custom prompt template with few-shot examples
- Building and running the SQL agent
- Adding memory to the agent to keep track of historical messages
By the end of this tutorial, you’ll have a functional SQL agent that can answer questions about your data using natural language.
This tutorial uses a mocked up data of a manufacturing environment where a product item’s production is tracked across multiple stations, allowing for analysis of production efficiency, station performance, and individual item progress through the manufacturing process. This is modelled after a real customer use case.
The database contains two tables:
- The
product_tracking
table records the movement of items through different zones in manufacturing stations, including start and end times, station names, and product IDs. - The
status
table logs the operational status of stations, including timestamps, station names, and whether they are productive or in downtime.
Import the required libraries
First, let’s import the necessary libraries for creating a SQL agent using Cohere and LangChain. These libraries enable natural language interaction with databases and provide tools for building AI-powered agents.
1 import os 2 3 os.environ["COHERE_API_KEY"] = "<cohere-api-key>"
1 ! pip install langchain-core langchain-cohere langchain-community faiss-cpu -qq
1 from langchain_cohere import create_sql_agent 2 from langchain_cohere.chat_models import ChatCohere 3 from langchain_community.agent_toolkits import SQLDatabaseToolkit 4 from langchain_community.vectorstores import FAISS 5 from langchain_core.example_selectors import SemanticSimilarityExampleSelector 6 from langchain_cohere import CohereEmbeddings 7 from datetime import datetime
Load the database
Next, we load the database for our manufacturing data.
Download the sql files from the link below to create the database.
We create an in-memory SQLite database using SQL scripts for the product_tracking
and status
tables. You can get the SQL tables here.
We then create a SQLDatabase instance, which will be used by our LangChain tools and agents to interact with the data.
1 import sqlite3 2 3 from langchain_community.utilities.sql_database import SQLDatabase 4 from sqlalchemy import create_engine 5 from sqlalchemy.pool import StaticPool 6 7 def get_engine_for_manufacturing_db(): 8 """Create an in-memory database with the manufacturing data tables.""" 9 connection = sqlite3.connect(":memory:", check_same_thread=False) 10 11 # Read and execute the SQL scripts 12 for sql_file in ['product_tracking.sql', 'status.sql']: 13 with open(sql_file, 'r') as file: 14 sql_script = file.read() 15 connection.executescript(sql_script) 16 17 return create_engine( 18 "sqlite://", 19 creator=lambda: connection, 20 poolclass=StaticPool, 21 connect_args={"check_same_thread": False}, 22 ) 23 24 # Create the engine 25 engine = get_engine_for_manufacturing_db() 26 27 # Create the SQLDatabase instance 28 db = SQLDatabase(engine) 29 30 # Now you can use this db instance with your LangChain tools and agents
1 # Test the connection 2 db.run("SELECT * FROM status LIMIT 5;")
1 "[('2024-05-09 19:28:00', 'Canada/Toronto', '2024-05-09', '19', '28', 'stn3', 'downtime'), ('2024-04-21 06:57:00', 'Canada/Toronto', '2024-04-21', '6', '57', 'stn3', 'productive'), ('2024-04-11 23:52:00', 'Canada/Toronto', '2024-04-11', '23', '52', 'stn4', 'productive'), ('2024-04-03 21:52:00', 'Canada/Toronto', '2024-04-03', '21', '52', 'stn2', 'downtime'), ('2024-04-30 05:01:00', 'Canada/Toronto', '2024-04-30', '5', '1', 'stn4', 'productive')]"
1 # Test the connection 2 db.run("SELECT * FROM product_tracking LIMIT 5;")
1 "[('2024-05-27 17:22:00', '2024-05-27 17:57:00', 'Canada/Toronto', '2024-05-27', '17', 'stn2', 'wip', '187', '35'), ('2024-04-26 15:56:00', '2024-04-26 17:56:00', 'Canada/Toronto', '2024-04-26', '15', 'stn4', 'wip', '299', '120'), ('2024-04-12 04:36:00', '2024-04-12 05:12:00', 'Canada/Toronto', '2024-04-12', '4', 'stn3', 'wip', '60', '36'), ('2024-04-19 15:15:00', '2024-04-19 15:22:00', 'Canada/Toronto', '2024-04-19', '15', 'stn4', 'wait', '227', '7'), ('2024-04-24 19:10:00', '2024-04-24 21:07:00', 'Canada/Toronto', '2024-04-24', '19', 'stn4', 'wait', '169', '117')]"
Setup the LangChain SQL Toolkit
Next, we initialize the LangChain SQL Toolkit and initialize the language model to use Cohere’s LLM. This prepares the necessary components for querying the SQL database using natural language.
1 ## Define model to use 2 import os 3 4 MODEL = "command-a-03-2025" 5 6 llm = ChatCohere( 7 model=MODEL, 8 temperature=0.1, 9 verbose=True 10 ) 11 12 13 toolkit = SQLDatabaseToolkit(db=db, llm=llm) 14 context = toolkit.get_context() 15 tools = toolkit.get_tools() 16 17 print("**List of pre-defined Langchain Tools**") 18 print([tool.name for tool in tools])
1 **List of pre-defined Langchain Tools** 2 ['sql_db_query', 'sql_db_schema', 'sql_db_list_tables', 'sql_db_query_checker']
Create a prompt template
Next, we create a prompt template. In this section, we will introduce a simple system message, and then also show how we can improve the prompt by introducing few shot prompting examples in the later sections. The system message is used to communicate instructions or provide context to the model at the beginning of a conversation.
In this case, we provide the model with context on what sql dialect it should use, how many samples to query among other instructions.
1 from langchain_core.prompts import ( 2 PromptTemplate, 3 ChatPromptTemplate, 4 SystemMessagePromptTemplate, 5 MessagesPlaceholder 6 ) 7 8 system_message = """You are an agent designed to interact with a SQL database. 9 You are an expert at answering questions about manufacturing data. 10 Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. 11 Always start with checking the schema of the available tables. 12 Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. 13 You can order the results by a relevant column to return the most interesting examples in the database. 14 Never query for all the columns from a specific table, only ask for the relevant columns given the question. 15 You have access to tools for interacting with the database. 16 Only use the given tools. Only use the information returned by the tools to construct your final answer. 17 You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. 18 19 DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. 20 21 The current date is {date}. 22 23 For questions regarding productive time, downtime, productive or productivity, use minutes as units. 24 25 For questions regarding productive time, downtime, productive or productivity use the status table. 26 27 For questions regarding processing time and average processing time, use minutes as units. 28 29 For questions regarding bottlenecks, processing time and average processing time use the product_tracking table. 30 31 If the question does not seem related to the database, just return "I don't know" as the answer.""" 32 33 system_prompt = PromptTemplate.from_template(system_message)
1 full_prompt = ChatPromptTemplate.from_messages( 2 [ 3 SystemMessagePromptTemplate(prompt=system_prompt), 4 MessagesPlaceholder(variable_name='chat_history', optional=True), 5 ("human", "{input}"), 6 MessagesPlaceholder("agent_scratchpad"), 7 ] 8 )
1 prompt_val = full_prompt.invoke({ 2 "input": "What was the productive time for all stations today?", 3 "top_k": 5, 4 "dialect": "SQLite", 5 "date":datetime.now(), 6 "agent_scratchpad": [], 7 }) 8 print(prompt_val.to_string())
1 System: You are an agent designed to interact with a SQL database. 2 You are an expert at answering questions about manufacturing data. 3 Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. 4 Always start with checking the schema of the available tables. 5 Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results. 6 You can order the results by a relevant column to return the most interesting examples in the database. 7 Never query for all the columns from a specific table, only ask for the relevant columns given the question. 8 You have access to tools for interacting with the database. 9 Only use the given tools. Only use the information returned by the tools to construct your final answer. 10 You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. 11 12 DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. 13 14 The current date is 2025-03-13 09:21:55.403450. 15 16 For questions regarding productive time, downtime, productive or productivity, use minutes as units. 17 18 For questions regarding productive time, downtime, productive or productivity use the status table. 19 20 For questions regarding processing time and average processing time, use minutes as units. 21 22 For questions regarding bottlenecks, processing time and average processing time use the product_tracking table. 23 24 If the question does not seem related to the database, just return "I don't know" as the answer. 25 Human: What was the productive time for all stations today?
Create a few-shot prompt template
In the above step, we’ve created a simple system prompt. Now, let us see how we can create a better few shot prompt template in this section. Few-shot examples are used to provide the model with context and improve its performance on specific tasks. In this case, we’ll prepare examples of natural language queries and their corresponding SQL queries to help the model generate accurate SQL statements for our database.
In this example, we use SemanticSimilarityExampleSelector
to select the top k examples that are most similar to an input query out of all the examples available.
1 examples = [ 2 { 3 "input": "What was the average processing time for all stations on April 3rd 2024?", 4 "query": "SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND zone = 'wip' GROUP BY station_name ORDER BY station_name;", 5 }, 6 { 7 "input": "What was the average processing time for all stations on April 3rd 2024 between 4pm and 6pm?", 8 "query": "SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND CAST(hour AS INTEGER) BETWEEN 16 AND 18 AND zone = 'wip' GROUP BY station_name ORDER BY station_name;", 9 }, 10 { 11 "input": "What was the average processing time for stn4 on April 3rd 2024?", 12 "query": "SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND station_name = 'stn4' AND zone = 'wip';", 13 }, 14 { 15 "input": "How much downtime did stn2 have on April 3rd 2024?", 16 "query": "SELECT COUNT(*) AS downtime_count FROM status WHERE date = '2024-04-03' AND station_name = 'stn2' AND station_status = 'downtime';", 17 }, 18 { 19 "input": "What were the productive time and downtime numbers for all stations on April 3rd 2024?", 20 "query": "SELECT station_name, station_status, COUNT(*) as total_time FROM status WHERE date = '2024-04-03' GROUP BY station_name, station_status;", 21 }, 22 { 23 "input": "What was the bottleneck station on April 3rd 2024?", 24 "query": "SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND zone = 'wip' GROUP BY station_name ORDER BY avg_processing_time DESC LIMIT 1;", 25 }, 26 { 27 "input": "Which percentage of the time was stn5 down in the last week of May?", 28 "query": "SELECT SUM(CASE WHEN station_status = 'downtime' THEN 1 ELSE 0 END) * 100.0 / COUNT(*) AS percentage_downtime FROM status WHERE station_name = 'stn5' AND date >= '2024-05-25' AND date <= '2024-05-31';", 29 }, 30 ]
1 example_selector = SemanticSimilarityExampleSelector.from_examples( 2 examples, 3 CohereEmbeddings( 4 cohere_api_key=os.getenv("COHERE_API_KEY"), model="embed-english-v3.0" 5 ), 6 FAISS, 7 k=5, 8 input_keys=["input"], 9 )
1 from langchain_core.prompts import ( 2 ChatPromptTemplate, 3 FewShotPromptTemplate, 4 MessagesPlaceholder, 5 PromptTemplate, 6 SystemMessagePromptTemplate, 7 ) 8 9 system_prefix = """You are an agent designed to interact with a SQL database. 10 You are an expert at answering questions about manufacturing data. 11 Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. 12 Always start with checking the schema of the available tables. 13 Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. 14 You can order the results by a relevant column to return the most interesting examples in the database. 15 Never query for all the columns from a specific table, only ask for the relevant columns given the question. 16 You have access to tools for interacting with the database. 17 Only use the given tools. Only use the information returned by the tools to construct your final answer. 18 You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. 19 20 DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. 21 22 The current date is {date}. 23 24 For questions regarding productive time, downtime, productive or productivity, use minutes as units. 25 26 For questions regarding productive time, downtime, productive or productivity use the status table. 27 28 For questions regarding processing time and average processing time, use minutes as units. 29 30 For questions regarding bottlenecks, processing time and average processing time use the product_tracking table. 31 32 If the question does not seem related to the database, just return "I don't know" as the answer. 33 34 Here are some examples of user inputs and their corresponding SQL queries: 35 """ 36 37 few_shot_prompt = FewShotPromptTemplate( 38 example_selector=example_selector, 39 example_prompt=PromptTemplate.from_template( 40 "User input: {input}\nSQL query: {query}" 41 ), 42 input_variables=["input", "dialect", "top_k","date"], 43 prefix=system_prefix, 44 suffix="", 45 )
1 full_prompt = ChatPromptTemplate.from_messages( 2 [ 3 # In the previous section, this was system_prompt instead without the few shot examples. 4 # We can use either prompting style as required 5 SystemMessagePromptTemplate(prompt=few_shot_prompt), 6 ("human", "{input}"), 7 MessagesPlaceholder("agent_scratchpad"), 8 ] 9 )
1 # Example formatted prompt 2 prompt_val = full_prompt.invoke( 3 { 4 "input": "What was the productive time for all stations today?", 5 "top_k": 5, 6 "dialect": "SQLite", 7 "date":datetime.now(), 8 "agent_scratchpad": [], 9 } 10 ) 11 print(prompt_val.to_string())
1 System: You are an agent designed to interact with a SQL database. 2 You are an expert at answering questions about manufacturing data. 3 Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer. 4 Always start with checking the schema of the available tables. 5 Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results. 6 You can order the results by a relevant column to return the most interesting examples in the database. 7 Never query for all the columns from a specific table, only ask for the relevant columns given the question. 8 You have access to tools for interacting with the database. 9 Only use the given tools. Only use the information returned by the tools to construct your final answer. 10 You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. 11 12 DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. 13 14 The current date is 2025-03-13 09:22:22.275098. 15 16 For questions regarding productive time, downtime, productive or productivity, use minutes as units. 17 18 For questions regarding productive time, downtime, productive or productivity use the status table. 19 20 For questions regarding processing time and average processing time, use minutes as units. 21 22 For questions regarding bottlenecks, processing time and average processing time use the product_tracking table. 23 24 If the question does not seem related to the database, just return "I don't know" as the answer. 25 26 Here are some examples of user inputs and their corresponding SQL queries: 27 28 29 User input: What were the productive time and downtime numbers for all stations on April 3rd 2024? 30 SQL query: SELECT station_name, station_status, COUNT(*) as total_time FROM status WHERE date = '2024-04-03' GROUP BY station_name, station_status; 31 32 User input: What was the average processing time for all stations on April 3rd 2024? 33 SQL query: SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND zone = 'wip' GROUP BY station_name ORDER BY station_name; 34 35 User input: What was the average processing time for all stations on April 3rd 2024 between 4pm and 6pm? 36 SQL query: SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND CAST(hour AS INTEGER) BETWEEN 16 AND 18 AND zone = 'wip' GROUP BY station_name ORDER BY station_name; 37 38 User input: What was the bottleneck station on April 3rd 2024? 39 SQL query: SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND zone = 'wip' GROUP BY station_name ORDER BY avg_processing_time DESC LIMIT 1; 40 41 User input: What was the average processing time for stn4 on April 3rd 2024? 42 SQL query: SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND station_name = 'stn4' AND zone = 'wip'; 43 Human: What was the productive time for all stations today?
Create the agent
Next, we create an instance of the SQL agent using the LangChain framework, specifically using create_sql_agent
.
This agent will be capable of interpreting natural language queries, converting them into SQL queries, and executing them against our database. The agent uses the LLM we defined earlier, along with the SQL toolkit and the custom prompt we created.
1 agent = create_sql_agent( 2 llm=llm, 3 toolkit=toolkit, 4 prompt=full_prompt, 5 verbose=True 6 )
Run the agent
Now, we can run the agent and test it with a few different queries.
1 # %%time 2 output=agent.invoke({ 3 "input": "Which stations had some downtime in the month of May 2024?", 4 "date": datetime.now() 5 }) 6 print(output['output']) 7 8 # Answer: stn2, stn3 and stn5 had some downtime in the month of May 2024.
1 [1m> Entering new Cohere SQL Agent Executor chain...[0m 2 [32;1m[1;3m 3 Invoking: `sql_db_list_tablessql_db_list_tables` with `{}` 4 responded: I will first check the schema of the available tables. Then, I will query the connected SQL database to find the stations that had some downtime in the month of May 2024. 5 6 [0msql_db_list_tablessql_db_list_tables is not a valid tool, try one of [sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker].[32;1m[1;3m 7 Invoking: `sql_db_list_tables` with `{}` 8 responded: I will first check the schema of the available tables. Then, I will query the connected SQL database to find the stations that had some downtime in the month of May 2024. 9 10 [0m[38;5;200m[1;3mproduct_tracking, status[0m[32;1m[1;3m 11 Invoking: `sql_db_schema` with `{'table_names': 'product_tracking, status'}` 12 responded: I have found the following tables: product_tracking and status. I will now query the schema of these tables to understand their structure. 13 14 [0m[33;1m[1;3m 15 CREATE TABLE product_tracking ( 16 timestamp_start TEXT, 17 timestamp_end TEXT, 18 timezone TEXT, 19 date TEXT, 20 hour TEXT, 21 station_name TEXT, 22 zone TEXT, 23 product_id TEXT, 24 duration TEXT 25 ) 26 27 /* 28 3 rows from product_tracking table: 29 timestamp_start timestamp_end timezone date hour station_name zone product_id duration 30 2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35 31 2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120 32 2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36 33 */ 34 35 36 CREATE TABLE status ( 37 timestamp_event TEXT, 38 timezone TEXT, 39 date TEXT, 40 hour TEXT, 41 minute TEXT, 42 station_name TEXT, 43 station_status TEXT 44 ) 45 46 /* 47 3 rows from status table: 48 timestamp_event timezone date hour minute station_name station_status 49 2024-05-09 19:28:00 Canada/Toronto 2024-05-09 19 28 stn3 downtime 50 2024-04-21 06:57:00 Canada/Toronto 2024-04-21 6 57 stn3 productive 51 2024-04-11 23:52:00 Canada/Toronto 2024-04-11 23 52 stn4 productive 52 */[0m[32;1m[1;3m 53 Invoking: `sql_db_schema` with `{'table_names': 'product_tracking, status'}` 54 responded: I have found the following tables: product_tracking and status. I will now query the schema of these tables to understand their structure. 55 56 [0m[33;1m[1;3m 57 CREATE TABLE product_tracking ( 58 timestamp_start TEXT, 59 timestamp_end TEXT, 60 timezone TEXT, 61 date TEXT, 62 hour TEXT, 63 station_name TEXT, 64 zone TEXT, 65 product_id TEXT, 66 duration TEXT 67 ) 68 69 /* 70 3 rows from product_tracking table: 71 timestamp_start timestamp_end timezone date hour station_name zone product_id duration 72 2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35 73 2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120 74 2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36 75 */ 76 77 78 CREATE TABLE status ( 79 timestamp_event TEXT, 80 timezone TEXT, 81 date TEXT, 82 hour TEXT, 83 minute TEXT, 84 station_name TEXT, 85 station_status TEXT 86 ) 87 88 /* 89 3 rows from status table: 90 timestamp_event timezone date hour minute station_name station_status 91 2024-05-09 19:28:00 Canada/Toronto 2024-05-09 19 28 stn3 downtime 92 2024-04-21 06:57:00 Canada/Toronto 2024-04-21 6 57 stn3 productive 93 2024-04-11 23:52:00 Canada/Toronto 2024-04-11 23 52 stn4 productive 94 */[0m[32;1m[1;3m 95 Invoking: `sql_db_query_checker` with `{'query': "SELECT DISTINCT station_name\r\nFROM status\r\nWHERE station_status = 'downtime'\r\n AND SUBSTR(date, 1, 7) = '2024-05';"}` 96 responded: I have found that the status table contains the relevant information, including the station_status column which contains the values 'productive' and 'downtime'. I will now query the database to find the stations that had some downtime in the month of May 2024. 97 98 [0m[36;1m[1;3m```sql 99 SELECT DISTINCT station_name 100 FROM status 101 WHERE station_status = 'downtime' 102 AND SUBSTR(date, 1, 7) = '2024-05'; 103 ```[0m[32;1m[1;3m 104 Invoking: `sql_db_query_checker` with `{'query': "SELECT DISTINCT station_name\r\nFROM status\r\nWHERE station_status = 'downtime'\r\n AND SUBSTR(date, 1, 7) = '2024-05';"}` 105 responded: I have found that the status table contains the relevant information, including the station_status column which contains the values 'productive' and 'downtime'. I will now query the database to find the stations that had some downtime in the month of May 2024. 106 107 [0m[36;1m[1;3m```sql 108 SELECT DISTINCT station_name 109 FROM status 110 WHERE station_status = 'downtime' 111 AND SUBSTR(date, 1, 7) = '2024-05'; 112 ```[0m[32;1m[1;3m 113 Invoking: `sql_db_query_checker` with `{'query': "SELECT DISTINCT station_name\r\nFROM status\r\nWHERE station_status = 'downtime'\r\n AND SUBSTR(date, 1, 7) = '2024-05';"}` 114 responded: I have found that the status table contains the relevant information, including the station_status column which contains the values 'productive' and 'downtime'. I will now query the database to find the stations that had some downtime in the month of May 2024. 115 116 [0m[36;1m[1;3m```sql 117 SELECT DISTINCT station_name 118 FROM status 119 WHERE station_status = 'downtime' 120 AND SUBSTR(date, 1, 7) = '2024-05'; 121 ```[0m[32;1m[1;3m 122 Invoking: `sql_db_query_checker` with `{'query': "SELECT DISTINCT station_name\r\nFROM status\r\nWHERE station_status = 'downtime'\r\n AND SUBSTR(date, 1, 7) = '2024-05';"}` 123 responded: I have found that the status table contains the relevant information, including the station_status column which contains the values 'productive' and 'downtime'. I will now query the database to find the stations that had some downtime in the month of May 2024. 124 125 [0m[36;1m[1;3m```sql 126 SELECT DISTINCT station_name 127 FROM status 128 WHERE station_status = 'downtime' 129 AND SUBSTR(date, 1, 7) = '2024-05'; 130 ```[0m[32;1m[1;3m 131 Invoking: `sql_db_query` with `{'query': "SELECT DISTINCT station_name\r\nFROM status\r\nWHERE station_status = 'downtime'\r\n AND SUBSTR(date, 1, 7) = '2024-05';"}` 132 responded: I have checked the query and it is correct. I will now execute the query to find the stations that had some downtime in the month of May 2024. 133 134 [0m[36;1m[1;3m[('stn3',), ('stn5',), ('stn2',)][0m[32;1m[1;3m 135 Invoking: `sql_db_query` with `{'query': "SELECT DISTINCT station_name\r\nFROM status\r\nWHERE station_status = 'downtime'\r\n AND SUBSTR(date, 1, 7) = '2024-05';"}` 136 responded: I have checked the query and it is correct. I will now execute the query to find the stations that had some downtime in the month of May 2024. 137 138 [0m[36;1m[1;3m[('stn3',), ('stn5',), ('stn2',)][0m[32;1m[1;3mThe stations that had some downtime in the month of May 2024 are: 139 - stn3 140 - stn5 141 - stn2[0m 142 143 [1m> Finished chain.[0m 144 The stations that had some downtime in the month of May 2024 are: 145 - stn3 146 - stn5 147 - stn2
1 output=agent.invoke({ 2 "input": "What is the average processing duration at stn5 in the wip zone?", 3 "date": datetime.now() 4 }) 5 print(output['output']) 6 7 # Answer: 39.17 minutes
1 [1m> Entering new Cohere SQL Agent Executor chain...[0m 2 [32;1m[1;3m 3 Invoking: `sql_db_list_tablessql_db_list_tables` with `{}` 4 responded: I will first check the schema of the available tables. Then, I will write a query to find the average processing duration at stn5 in the wip zone. 5 6 [0msql_db_list_tablessql_db_list_tables is not a valid tool, try one of [sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker].[32;1m[1;3m 7 Invoking: `sql_db_list_tables` with `{}` 8 responded: I will first check the schema of the available tables. Then, I will write a query to find the average processing duration at stn5 in the wip zone. 9 10 [0m[38;5;200m[1;3mproduct_tracking, status[0m[32;1m[1;3m 11 Invoking: `sql_db_schema` with `{'table_names': 'product_tracking'}` 12 responded: I will use the product_tracking table to find the average processing duration at stn5 in the wip zone. 13 14 [0m[33;1m[1;3m 15 CREATE TABLE product_tracking ( 16 timestamp_start TEXT, 17 timestamp_end TEXT, 18 timezone TEXT, 19 date TEXT, 20 hour TEXT, 21 station_name TEXT, 22 zone TEXT, 23 product_id TEXT, 24 duration TEXT 25 ) 26 27 /* 28 3 rows from product_tracking table: 29 timestamp_start timestamp_end timezone date hour station_name zone product_id duration 30 2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35 31 2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120 32 2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36 33 */[0m[32;1m[1;3m 34 Invoking: `sql_db_schema` with `{'table_names': 'product_tracking'}` 35 responded: I will use the product_tracking table to find the average processing duration at stn5 in the wip zone. 36 37 [0m[33;1m[1;3m 38 CREATE TABLE product_tracking ( 39 timestamp_start TEXT, 40 timestamp_end TEXT, 41 timezone TEXT, 42 date TEXT, 43 hour TEXT, 44 station_name TEXT, 45 zone TEXT, 46 product_id TEXT, 47 duration TEXT 48 ) 49 50 /* 51 3 rows from product_tracking table: 52 timestamp_start timestamp_end timezone date hour station_name zone product_id duration 53 2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35 54 2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120 55 2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36 56 */[0m[32;1m[1;3m 57 Invoking: `sql_db_query_checker` with `{'query': "SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration\nFROM product_tracking\nWHERE station_name = 'stn5' AND zone = 'wip';"}` 58 responded: I will use the product_tracking table to find the average processing duration at stn5 in the wip zone. The relevant columns are station_name, zone and duration. 59 60 [0m[36;1m[1;3m```sql 61 SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration 62 FROM product_tracking 63 WHERE station_name = 'stn5' AND zone = 'wip'; 64 ```[0m[32;1m[1;3m 65 Invoking: `sql_db_query_checker` with `{'query': "SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration\nFROM product_tracking\nWHERE station_name = 'stn5' AND zone = 'wip';"}` 66 responded: I will use the product_tracking table to find the average processing duration at stn5 in the wip zone. The relevant columns are station_name, zone and duration. 67 68 [0m[36;1m[1;3m```sql 69 SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration 70 FROM product_tracking 71 WHERE station_name = 'stn5' AND zone = 'wip'; 72 ```[0m[32;1m[1;3m 73 Invoking: `sql_db_query` with `{'query': "SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration\nFROM product_tracking\nWHERE station_name = 'stn5' AND zone = 'wip';"}` 74 responded: I will now execute the query to find the average processing duration at stn5 in the wip zone. 75 76 [0m[36;1m[1;3m[(39.166666666666664,)][0m[32;1m[1;3m 77 Invoking: `sql_db_query` with `{'query': "SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration\nFROM product_tracking\nWHERE station_name = 'stn5' AND zone = 'wip';"}` 78 responded: I will now execute the query to find the average processing duration at stn5 in the wip zone. 79 80 [0m[36;1m[1;3m[(39.166666666666664,)][0m[32;1m[1;3mThe average processing duration at stn5 in the wip zone is 39.17 minutes.[0m 81 82 [1m> Finished chain.[0m 83 The average processing duration at stn5 in the wip zone is 39.17 minutes.
1 output=agent.invoke({ 2 "input": "Which station had the highest total duration in the wait zone?", 3 "date": datetime.now() 4 }) 5 print(output['output']) 6 7 # Answer: stn4 - 251 minutes
1 [1m> Entering new Cohere SQL Agent Executor chain...[0m 2 [32;1m[1;3m 3 Invoking: `sql_db_list_tablessql_db_list_tables` with `{}` 4 responded: I will first check the schema of the available tables. Then, I will write a query to find the station with the highest total duration in the wait zone. 5 6 [0msql_db_list_tablessql_db_list_tables is not a valid tool, try one of [sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker].[32;1m[1;3m 7 Invoking: `sql_db_list_tables` with `{}` 8 responded: I will first check the schema of the available tables. Then, I will write a query to find the station with the highest total duration in the wait zone. 9 10 [0m[38;5;200m[1;3mproduct_tracking, status[0m[32;1m[1;3m 11 Invoking: `sql_db_schema` with `{'table_names': 'product_tracking'}` 12 responded: I will use the product_tracking table to find the station with the highest total duration in the wait zone. I will group the results by station_name and zone, and filter for the zone 'wait'. I will then order the results by total duration in descending order and limit the results to 1. 13 14 [0m[33;1m[1;3m 15 CREATE TABLE product_tracking ( 16 timestamp_start TEXT, 17 timestamp_end TEXT, 18 timezone TEXT, 19 date TEXT, 20 hour TEXT, 21 station_name TEXT, 22 zone TEXT, 23 product_id TEXT, 24 duration TEXT 25 ) 26 27 /* 28 3 rows from product_tracking table: 29 timestamp_start timestamp_end timezone date hour station_name zone product_id duration 30 2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35 31 2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120 32 2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36 33 */[0m[32;1m[1;3m 34 Invoking: `sql_db_schema` with `{'table_names': 'product_tracking'}` 35 responded: I will use the product_tracking table to find the station with the highest total duration in the wait zone. I will group the results by station_name and zone, and filter for the zone 'wait'. I will then order the results by total duration in descending order and limit the results to 1. 36 37 [0m[33;1m[1;3m 38 CREATE TABLE product_tracking ( 39 timestamp_start TEXT, 40 timestamp_end TEXT, 41 timezone TEXT, 42 date TEXT, 43 hour TEXT, 44 station_name TEXT, 45 zone TEXT, 46 product_id TEXT, 47 duration TEXT 48 ) 49 50 /* 51 3 rows from product_tracking table: 52 timestamp_start timestamp_end timezone date hour station_name zone product_id duration 53 2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35 54 2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120 55 2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36 56 */[0m[32;1m[1;3m 57 Invoking: `sql_db_query_checker` with `{'query': "SELECT station_name, SUM(CAST(duration AS INTEGER)) AS total_duration\r\nFROM product_tracking\r\nWHERE zone = 'wait'\r\nGROUP BY station_name\r\nORDER BY total_duration DESC\r\nLIMIT 1;"}` 58 responded: I will use the product_tracking table to find the station with the highest total duration in the wait zone. I will group the results by station_name and zone, and filter for the zone 'wait'. I will then order the results by total duration in descending order and limit the results to 1. 59 60 [0m[36;1m[1;3m```sql 61 SELECT station_name, SUM(CAST(duration AS INTEGER)) AS total_duration 62 FROM product_tracking 63 WHERE zone = 'wait' 64 GROUP BY station_name 65 ORDER BY total_duration DESC 66 LIMIT 1; 67 ```[0m[32;1m[1;3m 68 Invoking: `sql_db_query` with `{'query': "SELECT station_name, SUM(CAST(duration AS INTEGER)) AS total_duration\nFROM product_tracking\nWHERE zone = 'wait'\nGROUP BY station_name\nORDER BY total_duration DESC\nLIMIT 1;"}` 69 responded: I will now execute the query to find the station with the highest total duration in the wait zone. 70 71 [0m[36;1m[1;3m[('stn4', 251)][0m[32;1m[1;3mThe station with the highest total duration in the wait zone is stn4, with a total duration of 251 minutes.[0m 72 73 [1m> Finished chain.[0m 74 The station with the highest total duration in the wait zone is stn4, with a total duration of 251 minutes.
Memory in the sql agent
We may want the agent to hold memory of our previous messages so that we’re able to coherently engage with the agent to answer our queries. In this section, let’s take a look at how we can add memory to the agent so that we’re able to achieve this outcome!
1 from langchain_core.runnables.history import RunnableWithMessageHistory 2 from langchain_core.chat_history import BaseChatMessageHistory 3 from langchain_core.messages import BaseMessage 4 from pydantic import BaseModel, Field 5 from typing import List
In the code snippets below, we create a class to store the chat history in memory. This can be customised to store the messages from a database or any other suitable data store.
1 class InMemoryHistory(BaseChatMessageHistory, BaseModel): 2 """In memory implementation of chat message history.""" 3 4 messages: List[BaseMessage] = Field(default_factory=list) 5 6 def add_messages(self, messages: List[BaseMessage]) -> None: 7 """Add a list of messages to the store""" 8 self.messages.extend(messages) 9 10 def clear(self) -> None: 11 self.messages = []
In the below code snippet, we make use of the RunnableWithMessageHistory abstraction to wrap around the agent we’ve created above to provide the message history to the agent that we can now utilize by chatting with the agent_with_chat_history
as shown below.
1 store = {} 2 3 4 def get_by_session_id(session_id: str): 5 if session_id not in store: 6 store[session_id] = InMemoryHistory() 7 return store[session_id] 8 9 10 agent_with_chat_history = RunnableWithMessageHistory( 11 agent, get_by_session_id, history_messages_key="chat_history" 12 ) 13 14 output = agent_with_chat_history.invoke( 15 { 16 "input": "What station had the longest duration on 27th May 2024?", 17 "date": datetime.now(), 18 }, 19 config={"configurable": {"session_id": "foo"}}, 20 ) 21 print(output["output"]) 22 23 # Answer: stn2, with duration of 35 mins.
1 [1m> Entering new Cohere SQL Agent Executor chain...[0m 2 [32;1m[1;3m 3 Invoking: `sql_db_list_tablessql_db_list_tables` with `{}` 4 responded: I will first check the schema of the available tables. Then, I will query the product_tracking table to find the station with the longest duration on 27th May 2024. 5 6 [0msql_db_list_tablessql_db_list_tables is not a valid tool, try one of [sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker].[32;1m[1;3m 7 Invoking: `sql_db_list_tables` with `{}` 8 responded: I will first check the schema of the available tables. Then, I will query the product_tracking table to find the station with the longest duration on 27th May 2024. 9 10 [0m[38;5;200m[1;3mproduct_tracking, status[0m[32;1m[1;3m 11 Invoking: `sql_db_schema` with `{'table_names': 'product_tracking'}` 12 responded: I will query the product_tracking table to find the station with the longest duration on 27th May 2024. 13 14 [0m[33;1m[1;3m 15 CREATE TABLE product_tracking ( 16 timestamp_start TEXT, 17 timestamp_end TEXT, 18 timezone TEXT, 19 date TEXT, 20 hour TEXT, 21 station_name TEXT, 22 zone TEXT, 23 product_id TEXT, 24 duration TEXT 25 ) 26 27 /* 28 3 rows from product_tracking table: 29 timestamp_start timestamp_end timezone date hour station_name zone product_id duration 30 2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35 31 2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120 32 2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36 33 */[0m[32;1m[1;3m 34 Invoking: `sql_db_schema` with `{'table_names': 'product_tracking'}` 35 responded: I will query the product_tracking table to find the station with the longest duration on 27th May 2024. 36 37 [0m[33;1m[1;3m 38 CREATE TABLE product_tracking ( 39 timestamp_start TEXT, 40 timestamp_end TEXT, 41 timezone TEXT, 42 date TEXT, 43 hour TEXT, 44 station_name TEXT, 45 zone TEXT, 46 product_id TEXT, 47 duration TEXT 48 ) 49 50 /* 51 3 rows from product_tracking table: 52 timestamp_start timestamp_end timezone date hour station_name zone product_id duration 53 2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35 54 2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120 55 2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36 56 */[0m[32;1m[1;3m 57 Invoking: `sql_db_query_checker` with `{'query': "SELECT station_name, duration\r\nFROM product_tracking\r\nWHERE date = '2024-05-27'\r\nORDER BY duration DESC\r\nLIMIT 1;"}` 58 responded: I will query the product_tracking table to find the station with the longest duration on 27th May 2024. I will filter the data for the date '2024-05-27' and order the results by duration in descending order to find the station with the longest duration. 59 60 [0m[36;1m[1;3m```sql 61 SELECT station_name, duration 62 FROM product_tracking 63 WHERE date = '2024-05-27' 64 ORDER BY duration DESC 65 LIMIT 1; 66 ```[0m[32;1m[1;3m 67 Invoking: `sql_db_query` with `{'query': "SELECT station_name, duration\nFROM product_tracking\nWHERE date = '2024-05-27'\nORDER BY duration DESC\nLIMIT 1;"}` 68 responded: I will now execute the SQL query to find the station with the longest duration on 27th May 2024. 69 70 [0m[36;1m[1;3m[('stn2', '35')][0m[32;1m[1;3mThe station with the longest duration on 27th May 2024 was stn2 with a duration of 35 minutes.[0m 71 72 [1m> Finished chain.[0m 73 The station with the longest duration on 27th May 2024 was stn2 with a duration of 35 minutes.
1 output = agent_with_chat_history.invoke( 2 { 3 "input": "Can you tell me when this station had downtime on 2024-04-03?", 4 "date": datetime.now(), 5 }, 6 config={"configurable": {"session_id": "foo"}}, 7 ) 8 print(output["output"]) 9 10 # Answer: 21:52:00
1 [1m> Entering new Cohere SQL Agent Executor chain...[0m 2 [32;1m[1;3m 3 Invoking: `sql_db_list_tablessql_db_list_tables` with `{}` 4 responded: I will first check the schema of the available tables. Then, I will query the database to find out when the station had downtime on 2024-04-03. 5 6 [0msql_db_list_tablessql_db_list_tables is not a valid tool, try one of [sql_db_query, sql_db_schema, sql_db_list_tables, sql_db_query_checker].[32;1m[1;3m 7 Invoking: `sql_db_list_tables` with `{}` 8 responded: I will first check the schema of the available tables. Then, I will query the database to find out when the station had downtime on 2024-04-03. 9 10 [0m[38;5;200m[1;3mproduct_tracking, status[0m[32;1m[1;3m 11 Invoking: `sql_db_schema` with `{'table_names': 'product_tracking, status'}` 12 responded: I have found that there are two tables: product_tracking and status. I will now check the schema of these tables. 13 14 [0m[33;1m[1;3m 15 CREATE TABLE product_tracking ( 16 timestamp_start TEXT, 17 timestamp_end TEXT, 18 timezone TEXT, 19 date TEXT, 20 hour TEXT, 21 station_name TEXT, 22 zone TEXT, 23 product_id TEXT, 24 duration TEXT 25 ) 26 27 /* 28 3 rows from product_tracking table: 29 timestamp_start timestamp_end timezone date hour station_name zone product_id duration 30 2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35 31 2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120 32 2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36 33 */ 34 35 36 CREATE TABLE status ( 37 timestamp_event TEXT, 38 timezone TEXT, 39 date TEXT, 40 hour TEXT, 41 minute TEXT, 42 station_name TEXT, 43 station_status TEXT 44 ) 45 46 /* 47 3 rows from status table: 48 timestamp_event timezone date hour minute station_name station_status 49 2024-05-09 19:28:00 Canada/Toronto 2024-05-09 19 28 stn3 downtime 50 2024-04-21 06:57:00 Canada/Toronto 2024-04-21 6 57 stn3 productive 51 2024-04-11 23:52:00 Canada/Toronto 2024-04-11 23 52 stn4 productive 52 */[0m[32;1m[1;3m 53 Invoking: `sql_db_schema` with `{'table_names': 'product_tracking, status'}` 54 responded: I have found that there are two tables: product_tracking and status. I will now check the schema of these tables. 55 56 [0m[33;1m[1;3m 57 CREATE TABLE product_tracking ( 58 timestamp_start TEXT, 59 timestamp_end TEXT, 60 timezone TEXT, 61 date TEXT, 62 hour TEXT, 63 station_name TEXT, 64 zone TEXT, 65 product_id TEXT, 66 duration TEXT 67 ) 68 69 /* 70 3 rows from product_tracking table: 71 timestamp_start timestamp_end timezone date hour station_name zone product_id duration 72 2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35 73 2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120 74 2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36 75 */ 76 77 78 CREATE TABLE status ( 79 timestamp_event TEXT, 80 timezone TEXT, 81 date TEXT, 82 hour TEXT, 83 minute TEXT, 84 station_name TEXT, 85 station_status TEXT 86 ) 87 88 /* 89 3 rows from status table: 90 timestamp_event timezone date hour minute station_name station_status 91 2024-05-09 19:28:00 Canada/Toronto 2024-05-09 19 28 stn3 downtime 92 2024-04-21 06:57:00 Canada/Toronto 2024-04-21 6 57 stn3 productive 93 2024-04-11 23:52:00 Canada/Toronto 2024-04-11 23 52 stn4 productive 94 */[0m[32;1m[1;3m 95 Invoking: `sql_db_query` with `{'query': "SELECT timestamp_event, station_name\r\nFROM status\r\nWHERE date = '2024-04-03'\r\nAND station_status = 'downtime';"}` 96 responded: I have found that the status table contains information about downtime. I will now query the database to find out when the station had downtime on 2024-04-03. 97 98 [0m[36;1m[1;3m[('2024-04-03 21:52:00', 'stn2')][0m[32;1m[1;3m 99 Invoking: `sql_db_query` with `{'query': "SELECT timestamp_event, station_name\r\nFROM status\r\nWHERE date = '2024-04-03'\r\nAND station_status = 'downtime';"}` 100 responded: I have found that the status table contains information about downtime. I will now query the database to find out when the station had downtime on 2024-04-03. 101 102 [0m[36;1m[1;3m[('2024-04-03 21:52:00', 'stn2')][0m[32;1m[1;3mThe station stn2 had downtime at 21:52 on 2024-04-03.[0m 103 104 [1m> Finished chain.[0m 105 The station stn2 had downtime at 21:52 on 2024-04-03.
We can see from the above code snippets that the agent is automatically able to infer and query with respect to ‘stn2’ in the above question without us having to specify it explicitly. This allows us to have more coherent conversations with the agent.
Conclusion
This tutorial demonstrated how to create a SQL agent using Cohere and LangChain. The agent can translate natural language queries coming from users into SQL, and execute them against a database. This powerful combination allows for intuitive interaction with databases without requiring direct SQL knowledge.