Agentic Multi-Stage RAG with Cohere Tools API

Agentic Multi-Stage RAG with Cohere Tools API

Jason Jung

Jason Jung

Motivation

Retrieval augmented generation (RAG) has been a go-to use case that enterprises have been adopting with large language models (LLMs). Even though it works well in general, there are edge cases where this can fail. Most commonly, when the retrieved document mentions the query but actually refers to another document, the model will fail to generate the correct answer.

We propose an agentic RAG system that leverages tool use to continue to retrieve documents if correct ones were not retrieved at first try. This is ideal for use cases where accuracy is a top priority and latency is not. For example, lawyers trying to find the most accurate answer from their contracts are willing to wait a few more seconds to get the answer instead of getting wrong answers fast.

Objective

This notebook, we will explore how we can build a simple agentic RAG using Cohere's native API. We have prepared a fake dataset to demonstrate the use case.
We ask three questions that require different depths of retrieval. We will see how the model answers the question between simple and agentic RAG.

Disclaimer

One of the challenges in building a RAG system is that it has many moving pieces: vector database, type of embedding model, use of reranker, number of retrieved documents, chunking strategy, and more. These components can make debugging and evaluating RAG systems difficult. Since this notebook focuses on the concept of agentic RAG, it will simplify other parts of the RAG system. For example, we will only retrieve top 1 docuemnt to demonstrate what happens when retrieved document does not contain the answer needed.

Result

TypeQuestionSimple RagAgentic Rag
Single-stage retrievalIs there a state level law for wearing helmets?There is currently no state law requiring the use of helmets when riding a bicycle. However, some cities and counties do require helmet use.There is currently no state law requiring helmet use. However, some cities and counties do require helmet use with bicycles.
Multi-stage retrievalI live in orting, do I need to wear a helmet with a bike?In the state of Washington, there is no law requiring you to wear a helmet when riding a bike. However, some cities and counties do require helmet use, so it is worth checking your local laws.Yes, you do need to wear a helmet with a bike in Orting if you are under 17.

As you will see more below, the multi-stage retrieval is achieved by adding a new function reference_extractor() that extracts other references in the documents and updating the instruction so the agent continues to retrieve more documents.

import os
from pprint import pprint

import cohere
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
# versions
print('cohere version:', cohere.__version__)
cohere version: 5.5.1

Setup

COHERE_API_KEY = os.environ.get("CO_API_KEY")
COHERE_MODEL = 'command-r-plus'
co = cohere.Client(api_key=COHERE_API_KEY)

Data

We leveraged data from Washington Department of Transportation and modified to fit the need of this demo.

documents = [
    {
        "title": "Bicycle law",
        "body": """
        Traffic Infractions and fees - For all information related to bicycle traffic infractions such as not wearing a helmet and fee information, please visit Section 3b for more information.
        Riding on the road - When riding on a roadway, a cyclist has all the rights and responsibilities of a vehicle driver (RCW 46.61.755). Bicyclists who violate traffic laws may be ticketed (RCW 46.61.750).
        Roads closed to bicyclists - Some designated sections of the state's limited access highway system may be closed to bicyclists. See the permanent bike restrictions map for more information. In addition, local governments may adopt ordinances banning cycling on specific roads or on sidewalks within business districts.
        Children bicycling - Parents or guardians may not knowingly permit bicycle traffic violations by their ward (RCW 46.61.700).
        Riding side by side - Bicyclists may ride side by side, but not more than two abreast (RCW 46.61.770).
        Riding at night - For night bicycle riding, a white front light (not a reflector) visible for 500 feet and a red rear reflector are required. A red rear light may be used in addition to the required reflector (RCW 46.61.780).
        Shoulder vs. bike lane - Bicyclists may choose to ride on the path, bike lane, shoulder or travel lane as suits their safety needs (RCW 46.61.770).
        Bicycle helmets - Currently, there is no state law requiring helmet use. However, some cities and counties do require helmets. For specific information along with location for bicycle helmet law please reference to section 21a.
        Bicycle equipment - Bicycles must be equipped with a white front light visible for 500 feet and a red rear reflector (RCW 46.61.780). A red rear light may be used in addition to the required reflector.
""",
    },
    {
        "title": "Bicycle helmet requirement",
        "body": "Currently, there is no state law requiring helmet use. However, some cities and counties do require helmet use with bicycles. Here is a list of those locations and when the laws were enacted. For specific information along with location for bicycle helmet law please reference to section 21a.",
    },
    {
        "title": "Section 21a",
        "body": """helmet rules by location: These are city and county level rules. The following group must wear helmets.
        Location name | Who is affected | Effective date
        Aberdeen | All ages | 2001
        Bainbridge Island | All ages | 2001
        Bellevue | All ages | 2001
        Bremerton | All ages | 2000
        DuPont | All ages | 2008
        Eatonville | All ages | 1996
        Fircrest | All ages | 1995
        Gig Harbor | All ages | 1996
        Kent | All ages | 1999
        Lynnwood | All ages | 2004
        Lakewood | All ages | 1996
        Milton | All ages | 1997
        Orting | Under 17 | 1997

     For fines and rules, you will be charged in according with Section 3b of the law.
     """,
    },
    {
        "title": "Section 3b",
        "body": """Traffic infraction - A person operating a bicycle upon a roadway or highway shall be subject to the provisions of this chapter relating to traffic infractions.
        1. Stop for people in crosswalks. Every intersection is a crosswalk - It’s the law. Drivers must stop for pedestrians at intersections, whether it’s an unmarked or marked crosswalk, and bicyclists in crosswalks are considered pedestrians. Also, it is illegal to pass another vehicle stopped for someone at a crosswalk. In Washington, the leading action motorists take that results in them hitting someone is a failure to yield to pedestrians.
        2. Put the phone down. Hand-held cell phone use and texting is prohibited for all Washington drivers and may result in a $136 fine for first offense, $235 on the second distracted-driving citation.
        3. Helmets are required for all bicyclists according to the state and municipal laws. If you are in a group required to wear a helmet but do not wear it you can be fined $48. # If you are the parent or legal guardian of a child under 17 and knowingly allow them to ride without a helmet, you can be fined $136.
""",
    },
]
db = pd.DataFrame(documents)
# comebine title and body
db["combined"] = "Title: " + db["title"] + "\n" + "Body: " + db["body"]
# generate embedding
embeddings = co.embed(
    texts=db.combined.tolist(), model="embed-english-v3.0", input_type="search_document"
)
db["embeddings"] = embeddings.embeddings

db
title body combined embeddings
0 Bicycle law \n Traffic Infractions and fees - For a... Title: Bicycle law\nBody: \n Traffic In... [-0.024673462, -0.034729004, 0.0418396, 0.0121...
1 Bicycle helmet requirement Currently, there is no state law requiring hel... Title: Bicycle helmet requirement\nBody: Curre... [-0.019180298, -0.037384033, 0.0027389526, -0....
2 Section 21a helmet rules by location: These are city and c... Title: Section 21a\nBody: helmet rules by loca... [0.031097412, 0.0007619858, -0.023010254, -0.0...
3 Section 3b Traffic infraction - A person operating a bicy... Title: Section 3b\nBody: Traffic infraction - ... [0.015602112, -0.016143799, 0.032958984, 0.000...

Tools

Following functions and tools will be used in the subsequent tasks.

def retrieve_documents(query: str, n=1) -> dict:
    """
    Function to retrieve documents a given query.

    Steps:
    1. Embed the query
    2. Calculate cosine similarity between the query embedding and the embeddings of the documents
    3. Return the top n documents with the highest similarity scores
    """
    query_emb = co.embed(
        texts=[query], model="embed-english-v3.0", input_type="search_query"
    )

    similarity_scores = cosine_similarity(
        [query_emb.embeddings[0]], db.embeddings.tolist()
    )
    similarity_scores = similarity_scores[0]

    top_indices = similarity_scores.argsort()[::-1][:n]
    top_matches = db.iloc[top_indices]

    return {"top_matched_document": top_matches.combined}


functions_map = {
    "retrieve_documents": retrieve_documents,
}

tools = [
    {
        "name": "retrieve_documents",
        "description": "given a query, retrieve documents from a database to answer user's question",
        "parameter_definitions": {
            "query": {"description": "query", "type": "str", "required": True}
        },
    }
]

RAG function

def simple_rag(query, db):
    """
    Given user's query, retrieve top documents and generate response using documents parameter.
    """
    top_matched_document = retrieve_documents(query)["top_matched_document"]

    print("top_matched_document", top_matched_document)

    output = co.chat(
        message=query, model=COHERE_MODEL, documents=[top_matched_document]
    )

    return output.text

Agentic RAG - cohere_agent()

def cohere_agent(
    message: str,
    preamble: str,
    tools: list[dict],
    force_single_step=False,
    verbose: bool = False,
    temperature: float = 0.3,
) -> str:
    """
    Function to handle multi-step tool use api.

    Args:
        message (str): The message to send to the Cohere AI model.
        preamble (str): The preamble or context for the conversation.
        tools (list of dict): List of tools to use in the conversation.
        verbose (bool, optional): Whether to print verbose output. Defaults to False.

    Returns:
        str: The final response from the call.
    """

    counter = 1

    response = co.chat(
        model=COHERE_MODEL,
        message=message,
        preamble=preamble,
        tools=tools,
        force_single_step=force_single_step,
        temperature=temperature,
    )

    if verbose:
        print(f"\nrunning 0th step.")
        print(response.text)

    while response.tool_calls:
        tool_results = []

        if verbose:
            print(f"\nrunning {counter}th step.")

        for tool_call in response.tool_calls:
            output = functions_map[tool_call.name](**tool_call.parameters)
            outputs = [output]
            tool_results.append({"call": tool_call, "outputs": outputs})

            if verbose:
                print(
                    f"= running tool {tool_call.name}, with parameters: \n{tool_call.parameters}"
                )
                print(f"== tool results:")
                pprint(output)

        response = co.chat(
            model=COHERE_MODEL,
            message="",
            chat_history=response.chat_history,
            preamble=preamble,
            tools=tools,
            force_single_step=force_single_step,
            tool_results=tool_results,
            temperature=temperature,
        )

        if verbose:
            print(response.text)
            counter += 1

    return response.text

Question 1 - single-stage retrieval

Here we are asking a question that can be answered easily with single-stage retrieval. Both regular and agentic RAG should be able to answer this question easily. Below is the comparsion of the response.

QuestionSimple RagAgentic Rag
Is there a state level law for wearing helmets?There is currently no state law requiring the use of helmets when riding a bicycle. However, some cities and counties do require helmet use.There is currently no state law requiring helmet use. However, some cities and counties do require helmet use with bicycles.
question1 = "Is there a state level law for wearing helmets?"

Simple RAG

output = simple_rag(question1, db)
print(output)
top_matched_document 1    Title: Bicycle helmet requirement\nBody: Curre...
Name: combined, dtype: object
There is currently no state law requiring the use of helmets when riding a bicycle. However, some cities and counties do require helmet use.

Agentic RAG

preamble = """
You are an expert assistant that helps users answers question about legal documents and policies.
Use the provided documents to answer questions about an employee's specific situation.
"""

output = cohere_agent(question1, preamble, tools, verbose=True)
running 0th step.
I will search for 'state level law for wearing helmets' in the documents provided and write an answer based on what I find.

running 1th step.
= running tool retrieve_documents, with parameters: 
{'query': 'state level law for wearing helmets'}
== tool results:
{'top_matched_document': 1    Title: Bicycle helmet requirement\nBody: Curre...
Name: combined, dtype: object}
There is currently no state law requiring helmet use. However, some cities and counties do require helmet use with bicycles.

Question 2 - double-stage retrieval

The second question requires a double-stage retrieval because top matched document references another document. You will see below that the agentic RAG is unable to produce the correct answer initially. But when given proper tools and instructions, it finds the correct answer.

QuestionSimple RagAgentic Rag
I live in orting, do I need to wear a helmet with a bike?In the state of Washington, there is no law requiring you to wear a helmet when riding a bike. However, some cities and counties do require helmet use, so it is worth checking your local laws.Yes, you do need to wear a helmet with a bike in Orting if you are under 17.
question2 = "I live in orting, do I need to wear a helmet with a bike?"

Simple RAG

output = simple_rag(question2, db)
print(output)
top_matched_document 1    Title: Bicycle helmet requirement\nBody: Curre...
Name: combined, dtype: object
In the state of Washington, there is no law requiring you to wear a helmet when riding a bike. However, some cities and counties do require helmet use, so it is worth checking your local laws.

Agentic RAG

Produces same quality answer as the simple rag.

preamble = """
You are an expert assistant that helps users answers question about legal documents and policies.
Use the provided documents to answer questions about an employee's specific situation.
"""

output = cohere_agent(question2, preamble, tools, verbose=True)
running 0th step.
I will search for 'helmet with a bike' and then write an answer.

running 1th step.
= running tool retrieve_documents, with parameters: 
{'query': 'helmet with a bike'}
== tool results:
{'top_matched_document': 1    Title: Bicycle helmet requirement\nBody: Curre...
Name: combined, dtype: object}
There is no state law requiring helmet use, however, some cities and counties do require helmet use with bicycles. I cannot find any information about Orting specifically, but you should check with your local authority.

Agentic RAG - New Tools

In order for the model to retrieve correct documents, we do two things:

  1. New reference_extractor() function is added. This function finds the references to other documents when given query and documents.
  2. We update the instruction that directs the agent to keep retrieving relevant documents.
def reference_extractor(query: str, documents: list[str]) -> str:
    """
    Given a query and document, find references to other documents.
    """
    prompt = f"""
    # instruction
    Does the reference document mention any other documents? If so, list them.
    If not, return empty string.

    # user query
    {query}

    # retrieved documents
    {documents}
    """

    return co.chat(message=prompt, model=COHERE_MODEL, preamble=None).text


def retrieve_documents(query: str, n=1) -> dict:
    """
    Function to retrieve most relevant documents a given query.
    It also returns other references mentioned in the top matched documents.
    """
    query_emb = co.embed(
        texts=[query], model="embed-english-v3.0", input_type="search_query"
    )

    similarity_scores = cosine_similarity(
        [query_emb.embeddings[0]], db.embeddings.tolist()
    )
    similarity_scores = similarity_scores[0]

    top_indices = similarity_scores.argsort()[::-1][:n]
    top_matches = db.iloc[top_indices]
    other_references = reference_extractor(query, top_matches.combined.tolist())

    return {
        "top_matched_document": top_matches.combined,
        "other_references_to_query": other_references,
    }


functions_map = {
    "retrieve_documents": retrieve_documents,
}

tools = [
    {
        "name": "retrieve_documents",
        "description": "given a query, retrieve documents from a database to answer user's question. It also finds references to other documents that should be leveraged to retrieve more documents",
        "parameter_definitions": {
            "query": {
                "description": "user's question or question or name of other document sections or references.",
                "type": "str",
                "required": True,
            }
        },
    }
]

preamble2 = """# Instruction
You are an expert assistant that helps users answers question about legal documents and policies.

Please follow these steps:
1. Using user's query, use `retrieve_documents` tool to retrieve the most relevant document from the database.
2. If you see `other_references_to_query` in the tool result, search the mentioned referenced using `retrieve_documents(<other reference>)` tool to retrieve more documents.
3. Keep trying until you find the answer.
4. Answer with yes or no as much as you can to answer the question directly.
"""

output = cohere_agent(question2, preamble2, tools, verbose=True)
running 0th step.
I will search for 'Orting' and 'bike helmet' to find the relevant information.

running 1th step.
= running tool retrieve_documents, with parameters: 
{'query': 'Orting bike helmet'}
== tool results:
{'other_references_to_query': 'Section 21a, Section 3b',
 'top_matched_document': 0    Title: Bicycle law\nBody: \n        Riding on ...
Name: combined, dtype: object}
I have found that there is no state law requiring helmet use, but some cities and counties do require helmets. I will now search for 'Section 21a' to find out if Orting is one of these cities or counties.

running 2th step.
= running tool retrieve_documents, with parameters: 
{'query': 'Section 21a'}
== tool results:
{'other_references_to_query': '- Section 3b',
 'top_matched_document': 2    Title: Section 21a\nBody: helmet rules by loca...
Name: combined, dtype: object}
Yes, you do need to wear a helmet when riding a bike in Orting if you are under 17.