SQL Agent with Cohere and LangChain (i-5O Case Study)

This notebook was created in collaboration with i-5O

This tutorial demonstrates how to create a SQL agent using Cohere and LangChain. The agent can translate natural language queries coming from users into SQL, and execute them against a database. This powerful combination allows for intuitive interaction with databases without requiring direct SQL knowledge.

Key topics covered:

  1. Setting up the necessary libraries and environment
  2. Connecting to a SQLite database
  3. Configuring the LangChain SQL Toolkit
  4. Creating a custom prompt template with few-shot examples
  5. Building and running the SQL agent

By the end of this tutorial, you’ll have a functional SQL agent that can answer questions about your data using natural language.

This tutorial uses a mocked up data of a manufacturing environment where a product item’s production is tracked across multiple stations, allowing for analysis of production efficiency, station performance, and individual item progress through the manufacturing process. This is modelled after a real customer use case.

The database contains two tables:

  • The product_tracking table records the movement of items through different zones in manufacturing stations, including start and end times, station names, and product IDs.
  • The status table logs the operational status of stations, including timestamps, station names, and whether they are productive or in downtime.

Table of contents

Import the required libraries

First, let’s import the necessary libraries for creating a SQL agent using Cohere and LangChain. These libraries enable natural language interaction with databases and provide tools for building AI-powered agents.

PYTHON
1import os
2
3os.environ["COHERE_API_KEY"] = ""
PYTHON
1! pip install faiss-gpu -qq
PYTHON
1! pip install langchain-core langchain-cohere langchain-community
PYTHON
1from langchain.agents import AgentExecutor
2from langchain_cohere import create_sql_agent
3from langchain_core.prompts import ChatPromptTemplate
4from langchain_cohere.chat_models import ChatCohere
5from langchain_community.utilities.sql_database import SQLDatabase
6from langchain_community.agent_toolkits import SQLDatabaseToolkit
7from langchain_community.vectorstores import FAISS
8from langchain_core.example_selectors import SemanticSimilarityExampleSelector
9from langchain_cohere import CohereEmbeddings
10from datetime import datetime, timedelta
11import os
12import json

Load the database

Next, we load the database for our manufacturing data.

We create an in-memory SQLite database using SQL scripts for the product_tracking and status tables. You can get the SQL tables here.

We then create a SQLDatabase instance, which will be used by our LangChain tools and agents to interact with the data.

PYTHON
1import sqlite3
2import os
3
4from langchain_community.utilities.sql_database import SQLDatabase
5from sqlalchemy import create_engine
6from sqlalchemy.pool import StaticPool
7
8def get_engine_for_manufacturing_db():
9 """Create an in-memory database with the manufacturing data tables."""
10 connection = sqlite3.connect(":memory:", check_same_thread=False)
11
12 # Read and execute the SQL scripts
13 for sql_file in ['product_tracking.sql', 'status.sql']:
14 with open(sql_file, 'r') as file:
15 sql_script = file.read()
16 connection.executescript(sql_script)
17
18 return create_engine(
19 "sqlite://",
20 creator=lambda: connection,
21 poolclass=StaticPool,
22 connect_args={"check_same_thread": False},
23 )
24
25# Create the engine
26engine = get_engine_for_manufacturing_db()
27
28# Create the SQLDatabase instance
29db = SQLDatabase(engine)
30
31# Now you can use this db instance with your LangChain tools and agents
PYTHON
1# Test the connection
2db.run("SELECT * FROM status LIMIT 5;")
"[('2024-05-09 19:28:00', 'Canada/Toronto', '2024-05-09', '19', '28', 'stn3', 'downtime'), ('2024-04-21 06:57:00', 'Canada/Toronto', '2024-04-21', '6', '57', 'stn3', 'productive'), ('2024-04-11 23:52:00', 'Canada/Toronto', '2024-04-11', '23', '52', 'stn4', 'productive'), ('2024-04-03 21:52:00', 'Canada/Toronto', '2024-04-03', '21', '52', 'stn2', 'downtime'), ('2024-04-30 05:01:00', 'Canada/Toronto', '2024-04-30', '5', '1', 'stn4', 'productive')]"
PYTHON
1# Test the connection
2db.run("SELECT * FROM product_tracking LIMIT 5;")
"[('2024-05-27 17:22:00', '2024-05-27 17:57:00', 'Canada/Toronto', '2024-05-27', '17', 'stn2', 'wip', '187', '35'), ('2024-04-26 15:56:00', '2024-04-26 17:56:00', 'Canada/Toronto', '2024-04-26', '15', 'stn4', 'wip', '299', '120'), ('2024-04-12 04:36:00', '2024-04-12 05:12:00', 'Canada/Toronto', '2024-04-12', '4', 'stn3', 'wip', '60', '36'), ('2024-04-19 15:15:00', '2024-04-19 15:22:00', 'Canada/Toronto', '2024-04-19', '15', 'stn4', 'wait', '227', '7'), ('2024-04-24 19:10:00', '2024-04-24 21:07:00', 'Canada/Toronto', '2024-04-24', '19', 'stn4', 'wait', '169', '117')]"

Setup the LangChain SQL Toolkit

Next, we initialize the LangChain SQL Toolkit and sets up the language model to use Cohere’s model. This prepares the necessary components for querying the SQL database using natural language.

PYTHON
1## Define model to use
2import os
3
4MODEL="command-r-plus-08-2024"
5llm = ChatCohere(model=MODEL,
6 temperature=0.1,
7 verbose=True,
8 cohere_api_key=os.getenv("COHERE_API_KEY"))
9
10
11toolkit = SQLDatabaseToolkit(db=db, llm=llm)
12context = toolkit.get_context()
13tools = toolkit.get_tools()
14
15print('**List of pre-defined Langchain Tools**')
16print([tool.name for tool in tools])
**List of pre-defined Langchain Tools**
['sql_db_query', 'sql_db_schema', 'sql_db_list_tables', 'sql_db_query_checker']

Create a prompt template with few shot examples

Next, we create a prompt template with few-shot examples. Few-shot examples are used to provide the model with context and improve its performance on specific tasks. In this case, we’ll prepare examples of natural language queries and their corresponding SQL queries to help the model generate accurate SQL statements for our database.

In this example, we use SemanticSimilarityExampleSelector to select the top k examples that are most similar to an input query out of all the examples available.

PYTHON
1examples = [
2 {
3 "input": "What was the average processing time for all stations on April 3rd 2024?",
4 "query": "SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND zone = 'wip' GROUP BY station_name ORDER BY station_name;",
5 },
6 {
7 "input": "What was the average processing time for all stations on April 3rd 2024 between 4pm and 6pm?",
8 "query": "SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND CAST(hour AS INTEGER) BETWEEN 16 AND 18 AND zone = 'wip' GROUP BY station_name ORDER BY station_name;",
9 },
10 {
11 "input": "What was the average processing time for stn4 on April 3rd 2024?",
12 "query": "SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND station_name = 'stn4' AND zone = 'wip';",
13 },
14 {
15 "input": "How much downtime did stn2 have on April 3rd 2024?",
16 "query": "SELECT COUNT(*) AS downtime_count FROM status WHERE date = '2024-04-03' AND station_name = 'stn2' AND station_status = 'downtime';",
17 },
18 {
19 "input": "What were the productive time and downtime numbers for all stations on April 3rd 2024?",
20 "query": "SELECT station_name, station_status, COUNT(*) as total_time FROM status WHERE date = '2024-04-03' GROUP BY station_name, station_status;",
21 },
22 {
23 "input": "What was the bottleneck station on April 3rd 2024?",
24 "query": "SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND zone = 'wip' GROUP BY station_name ORDER BY avg_processing_time DESC LIMIT 1;",
25 },
26 {
27 "input": "Which percentage of the time was stn5 down in the last week of May?",
28 "query": "SELECT SUM(CASE WHEN station_status = 'downtime' THEN 1 ELSE 0 END) * 100.0 / COUNT(*) AS percentage_downtime FROM status WHERE station_name = 'stn5' AND date >= '2024-05-25' AND date <= '2024-05-31';",
29 },
30]
PYTHON
1example_selector = SemanticSimilarityExampleSelector.from_examples(
2 examples,
3 CohereEmbeddings(cohere_api_key=os.getenv("COHERE_API_KEY"),
4 model="embed-english-v3.0"),
5 FAISS,
6 k=5,
7 input_keys=["input"],
8)
PYTHON
1from langchain_core.prompts import (
2 ChatPromptTemplate,
3 FewShotPromptTemplate,
4 MessagesPlaceholder,
5 PromptTemplate,
6 SystemMessagePromptTemplate,
7)
8
9system_prefix = """You are an agent designed to interact with a SQL database.
10You are an expert at answering questions about manufacturing data.
11Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
12Always start with checking the schema of the available tables.
13Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
14You can order the results by a relevant column to return the most interesting examples in the database.
15Never query for all the columns from a specific table, only ask for the relevant columns given the question.
16You have access to tools for interacting with the database.
17Only use the given tools. Only use the information returned by the tools to construct your final answer.
18You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
19
20DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
21
22The current date is {date}.
23
24For questions regarding productive time, downtime, productive or productivity, use minutes as units.
25
26For questions regarding productive time, downtime, productive or productivity use the status table.
27
28For questions regarding processing time and average processing time, use minutes as units.
29
30For questions regarding bottlenecks, processing time and average processing time use the product_tracking table.
31
32If the question does not seem related to the database, just return "I don't know" as the answer.
33
34Here are some examples of user inputs and their corresponding SQL queries:"""
35
36few_shot_prompt = FewShotPromptTemplate(
37 example_selector=example_selector,
38 example_prompt=PromptTemplate.from_template(
39 "User input: {input}\nSQL query: {query}"
40 ),
41 input_variables=["input", "dialect", "top_k","date"],
42 prefix=system_prefix,
43 suffix="",
44)
PYTHON
1full_prompt = ChatPromptTemplate.from_messages(
2 [
3 SystemMessagePromptTemplate(prompt=few_shot_prompt),
4 ("human", "{input}"),
5 MessagesPlaceholder("agent_scratchpad"),
6 ]
7)
PYTHON
1# Example formatted prompt
2prompt_val = full_prompt.invoke(
3 {
4 "input": "What was the productive time for all stations today?",
5 "top_k": 5,
6 "dialect": "SQLite",
7 "date":datetime.now(),
8 "agent_scratchpad": [],
9 }
10)
11print(prompt_val.to_string())
System: You are an agent designed to interact with a SQL database.
You are an expert at answering questions about manufacturing data.
Given an input question, create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Always start with checking the schema of the available tables.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the given tools. Only use the information returned by the tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
The current date is 2024-10-10 09:40:22.052973.
For questions regarding productive time, downtime, productive or productivity, use minutes as units.
For questions regarding productive time, downtime, productive or productivity use the status table.
For questions regarding processing time and average processing time, use minutes as units.
For questions regarding bottlenecks, processing time and average processing time use the product_tracking table.
If the question does not seem related to the database, just return "I don't know" as the answer.
Here are some examples of user inputs and their corresponding SQL queries:
User input: What were the productive time and downtime numbers for all stations on April 3rd 2024?
SQL query: SELECT station_name, station_status, COUNT(*) as total_time FROM status WHERE date = '2024-04-03' GROUP BY station_name, station_status;
User input: What was the average processing time for all stations on April 3rd 2024?
SQL query: SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND zone = 'wip' GROUP BY station_name ORDER BY station_name;
User input: What was the average processing time for all stations on April 3rd 2024 between 4pm and 6pm?
SQL query: SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND CAST(hour AS INTEGER) BETWEEN 16 AND 18 AND zone = 'wip' GROUP BY station_name ORDER BY station_name;
User input: What was the bottleneck station on April 3rd 2024?
SQL query: SELECT station_name, AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND zone = 'wip' GROUP BY station_name ORDER BY avg_processing_time DESC LIMIT 1;
User input: What was the average processing time for stn4 on April 3rd 2024?
SQL query: SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_time FROM product_tracking WHERE date = '2024-04-03' AND station_name = 'stn4' AND zone = 'wip';
Human: What was the productive time for all stations today?

Create the agent

Next, we create an instance of the SQL agent using the LangChain framework, specifically using create_sql_agent.

This agent will be capable of interpreting natural language queries, converting them into SQL queries, and executing them against our database. The agent uses the LLM we defined earlier, along with the SQL toolkit and the custom prompt we created.

PYTHON
1agent = create_sql_agent(
2 llm=llm,
3 toolkit=toolkit,
4 prompt=full_prompt,
5 verbose=True
6)

Run the agent

Now, we can run the agent and test it with a few different queries.

PYTHON
1# %%time
2output=agent.invoke({
3 "input": "What was the total downtime for each station during the month of May 2024?",
4 "date": datetime.now()
5})
6print(output['output'])
7
8# Answer: stn2: 1 minute(s), stn3: 2 minute(s)stn5: 4 minute(s)
> Entering new Cohere SQL Agent Executor chain...

Invoking: `sql_db_list_tables` with `{}`
responded: I will first check the schema of the available tables. Then, I will write and execute a query to find the total downtime for each station during the month of May 2024.
product_tracking, status
Invoking: `sql_db_schema` with `{'table_names': 'product_tracking, status'}`
responded: I have found that the relevant tables are 'product_tracking' and 'status'. I will now check the schema of these tables.

CREATE TABLE product_tracking (
timestamp_start TEXT,
timestamp_end TEXT,
timezone TEXT,
date TEXT,
hour TEXT,
station_name TEXT,
zone TEXT,
product_id TEXT,
duration TEXT
)
/*
3 rows from product_tracking table:
timestamp_start timestamp_end timezone date hour station_name zone product_id duration
2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35
2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120
2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36
*/
CREATE TABLE status (
timestamp_event TEXT,
timezone TEXT,
date TEXT,
hour TEXT,
minute TEXT,
station_name TEXT,
station_status TEXT
)
/*
3 rows from status table:
timestamp_event timezone date hour minute station_name station_status
2024-05-09 19:28:00 Canada/Toronto 2024-05-09 19 28 stn3 downtime
2024-04-21 06:57:00 Canada/Toronto 2024-04-21 6 57 stn3 productive
2024-04-11 23:52:00 Canada/Toronto 2024-04-11 23 52 stn4 productive
*/
Invoking: `sql_db_query_checker` with `{'query': "SELECT station_name, COUNT(*) as total_downtime FROM status WHERE date LIKE '2024-05-%' AND station_status = 'downtime' GROUP BY station_name;"}`
responded: I have found that the 'status' table contains the 'station_name' and 'station_status' columns, which are relevant to the user's request. I will now write and execute a query to find the total downtime for each station during the month of May 2024.
```sql
SELECT station_name, COUNT(*) as total_downtime FROM status WHERE date LIKE '2024-05-%' AND station_status = 'downtime' GROUP BY station_name;
```
Invoking: `sql_db_query` with `{'query': "SELECT station_name, COUNT(*) as total_downtime FROM status WHERE date LIKE '2024-05-%' AND station_status = 'downtime' GROUP BY station_name;"}`
responded: I have checked the query and it is correct. I will now execute it.
[('stn2', 1), ('stn3', 2), ('stn5', 4)]The total downtime for each station during the month of May 2024 was:
| station_name | total_downtime |
|---|---|
| stn2 | 1 |
| stn3 | 2 |
| stn5 | 4 |
> Finished chain.
The total downtime for each station during the month of May 2024 was:
| station_name | total_downtime |
|---|---|
| stn2 | 1 |
| stn3 | 2 |
| stn5 | 4 |
PYTHON
1output=agent.invoke({
2 "input": "What is the average processing duration at stn5 in the wip zone?",
3 "date": datetime.now()
4})
5print(output['output'])
6
7# Answer: 39.17 minutes
> Entering new Cohere SQL Agent Executor chain...

Invoking: `sql_db_list_tables` with `{}`
responded: I will first check the schema of the available tables. Then, I will write and execute a query to find the average processing duration at stn5 in the wip zone.
product_tracking, status
Invoking: `sql_db_schema` with `{'table_names': 'product_tracking'}`
responded: The available tables are product_tracking and status.
I will now write and execute a query to find the average processing duration at stn5 in the wip zone.

CREATE TABLE product_tracking (
timestamp_start TEXT,
timestamp_end TEXT,
timezone TEXT,
date TEXT,
hour TEXT,
station_name TEXT,
zone TEXT,
product_id TEXT,
duration TEXT
)
/*
3 rows from product_tracking table:
timestamp_start timestamp_end timezone date hour station_name zone product_id duration
2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35
2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120
2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36
*/
Invoking: `sql_db_query_checker` with `{'query': "SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration FROM product_tracking WHERE station_name = 'stn5' AND zone = 'wip';"}`
responded: The available tables are product_tracking and status.
I will now write and execute a query to find the average processing duration at stn5 in the wip zone.
SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration FROM product_tracking WHERE station_name = 'stn5' AND zone = 'wip';
Invoking: `sql_db_query` with `{'query': "SELECT AVG(CAST(duration AS INTEGER)) AS avg_processing_duration FROM product_tracking WHERE station_name = 'stn5' AND zone = 'wip';"}`
responded: The available tables are product_tracking and status.
I will now write and execute a query to find the average processing duration at stn5 in the wip zone.
[(39.166666666666664,)]The average processing duration at stn5 in the wip zone is 39.17 minutes.
> Finished chain.
The average processing duration at stn5 in the wip zone is 39.17 minutes.
PYTHON
1output=agent.invoke({
2 "input": "Which station had the highest total duration in the wait zone?",
3 "date": datetime.now()
4})
5print(output['output'])
6
7# Answer: stn4 - 251 minutes
> Entering new Cohere SQL Agent Executor chain...

Invoking: `sql_db_list_tables` with `{}`
responded: I will first check the schema of the available tables. Then I will write and execute a query to find the station with the highest total duration in the wait zone.
product_tracking, status
Invoking: `sql_db_schema` with `{'table_names': 'product_tracking, status'}`
responded: I have found that the relevant tables are product_tracking and status. I will now write and execute a query to find the station with the highest total duration in the wait zone.

CREATE TABLE product_tracking (
timestamp_start TEXT,
timestamp_end TEXT,
timezone TEXT,
date TEXT,
hour TEXT,
station_name TEXT,
zone TEXT,
product_id TEXT,
duration TEXT
)
/*
3 rows from product_tracking table:
timestamp_start timestamp_end timezone date hour station_name zone product_id duration
2024-05-27 17:22:00 2024-05-27 17:57:00 Canada/Toronto 2024-05-27 17 stn2 wip 187 35
2024-04-26 15:56:00 2024-04-26 17:56:00 Canada/Toronto 2024-04-26 15 stn4 wip 299 120
2024-04-12 04:36:00 2024-04-12 05:12:00 Canada/Toronto 2024-04-12 4 stn3 wip 60 36
*/
CREATE TABLE status (
timestamp_event TEXT,
timezone TEXT,
date TEXT,
hour TEXT,
minute TEXT,
station_name TEXT,
station_status TEXT
)
/*
3 rows from status table:
timestamp_event timezone date hour minute station_name station_status
2024-05-09 19:28:00 Canada/Toronto 2024-05-09 19 28 stn3 downtime
2024-04-21 06:57:00 Canada/Toronto 2024-04-21 6 57 stn3 productive
2024-04-11 23:52:00 Canada/Toronto 2024-04-11 23 52 stn4 productive
*/
Invoking: `sql_db_query_checker` with `{'query': "SELECT station_name, SUM(CAST(duration AS INTEGER)) AS total_duration FROM product_tracking WHERE zone = 'wait' GROUP BY station_name ORDER BY total_duration DESC LIMIT 1;"}`
responded: I have found that the relevant columns are zone and duration in the product_tracking table. I will now write and execute a query to find the station with the highest total duration in the wait zone.
```sql
SELECT station_name, SUM(CAST(duration AS INTEGER)) AS total_duration FROM product_tracking WHERE zone = 'wait' GROUP BY station_name ORDER BY total_duration DESC LIMIT 1;
```
Invoking: `sql_db_query` with `{'query': "SELECT station_name, SUM(CAST(duration AS INTEGER)) AS total_duration FROM product_tracking WHERE zone = 'wait' GROUP BY station_name ORDER BY total_duration DESC LIMIT 1;"}`
responded: I have checked the query and it is correct. I will now execute it.
[('stn4', 251)]The station with the highest total duration in the wait zone is stn4, with a total duration of 251 minutes.
> Finished chain.
The station with the highest total duration in the wait zone is stn4, with a total duration of 251 minutes.

Conclusion

This tutorial demonstrated how to create a SQL agent using Cohere and LangChain. The agent can translate natural language queries coming from users into SQL, and execute them against a database. This powerful combination allows for intuitive interaction with databases without requiring direct SQL knowledge.