Rerank on LangChain

Cohere supports various integrations with LangChain, a large language model (LLM) framework which allows you to quickly create applications based on Cohere's models. This doc will guide you through how to leverage Rerank with LangChain.

Prerequisites

Running Cohere Rerank with LangChain doesn't require many prerequisites, consult the top-level document for more information.

Cohere ReRank with LangChain

To use Cohere's rerank functionality with LangChain, start with instantiating a CohereRerank object as follows: cohere_rerank = CohereRerank(cohere_api_key="{API_KEY}").

You can then use it with LangChain retrievers, embeddings, and RAG. The example below uses the vector DB chroma, for which you will need to install pip install chromadb. Other vector DB's from this list can also be used.

from langchain.retrievers import ContextualCompressionRetriever, CohereRagRetriever
from langchain.retrievers.document_compressors import CohereRerank
from langchain_community.embeddings import CohereEmbeddings
from langchain_community.chat_models import ChatCohere
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma

user_query =  "When was Cohere started?"
# Create cohere's chat model and embeddings objects
cohere_chat_model = ChatCohere(cohere_api_key="{API_KEY}")
cohere_embeddings = CohereEmbeddings(cohere_api_key="{API_KEY}")
# Load text files and split into chunks, you can also use data gathered elsewhere in your application
raw_documents = TextLoader('demofile.txt').load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
documents = text_splitter.split_documents(raw_documents)
# Create a vector store from the documents
db = Chroma.from_documents(documents, cohere_embeddings)

# Create Cohere's reranker with the vector DB using Cohere's embeddings as the base retriever
cohere_rerank = CohereRerank(cohere_api_key="{API_KEY}")
compression_retriever = ContextualCompressionRetriever(
    base_compressor=cohere_rerank, 
    base_retriever=db.as_retriever()
)
compressed_docs = compression_retriever.get_relevant_documents(user_query)
# Print the relevant documents from using the embeddings and reranker
print(compressed_docs)

# Create the cohere rag retriever using the chat model 
rag = CohereRagRetriever(llm=cohere_chat_model)
docs = rag.get_relevant_documents(
    user_query,
    source_documents=compressed_docs,
)
# Print the documents
for doc in docs[:-1]:
    print(doc.metadata)
    print("\n\n" + doc.page_content)
    print("\n\n" + "-" * 30 + "\n\n")
# Print the final generation 
answer = docs[-1].page_content
print(answer)
# Print the final citations 
citations = docs[-1].metadata['citations']
print(citations)