1 | from langchain.retrievers import ContextualCompressionRetriever |
2 | from langchain_cohere import CohereEmbeddings |
3 | from langchain_cohere import ChatCohere |
4 | from langchain_cohere import CohereRerank, CohereRagRetriever |
5 | from langchain.text_splitter import CharacterTextSplitter |
6 | from langchain_community.document_loaders import TextLoader |
7 | from langchain_community.vectorstores import Chroma |
8 | from langchain_community.document_loaders import WebBaseLoader |
9 | |
10 | user_query = "what is Cohere Toolkit?" |
11 | |
12 | # Define the Cohere LLM |
13 | llm = ChatCohere(cohere_api_key="COHERE_API_KEY", |
14 | model="command-r-plus-08-2024") |
15 | |
16 | # Define the Cohere embedding model |
17 | embeddings = CohereEmbeddings(cohere_api_key="COHERE_API_KEY", |
18 | model="embed-english-light-v3.0") |
19 | |
20 | # Load text files and split into chunks, you can also use data gathered elsewhere in your application |
21 | raw_documents = WebBaseLoader("https://docs.cohere.com/docs/cohere-toolkit").load() |
22 | text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
23 | documents = text_splitter.split_documents(raw_documents) |
24 | |
25 | # Create a vector store from the documents |
26 | db = Chroma.from_documents(documents, embeddings) |
27 | |
28 | # Create Cohere's reranker with the vector DB using Cohere's embeddings as the base retriever |
29 | reranker = CohereRerank(cohere_api_key="COHERE_API_KEY", |
30 | model="rerank-english-v3.0") |
31 | |
32 | compression_retriever = ContextualCompressionRetriever( |
33 | base_compressor=reranker, |
34 | base_retriever=db.as_retriever() |
35 | ) |
36 | compressed_docs = compression_retriever.get_relevant_documents(user_query) |
37 | # Print the relevant documents from using the embeddings and reranker |
38 | print(compressed_docs) |
39 | |
40 | # Create the cohere rag retriever using the chat model |
41 | rag = CohereRagRetriever(llm=llm, connectors=[]) |
42 | docs = rag.get_relevant_documents( |
43 | user_query, |
44 | documents=compressed_docs, |
45 | ) |
46 | # Print the documents |
47 | print("Documents:") |
48 | for doc in docs[:-1]: |
49 | print(doc.metadata) |
50 | print("\n\n" + doc.page_content) |
51 | print("\n\n" + "-" * 30 + "\n\n") |
52 | # Print the final generation |
53 | answer = docs[-1].page_content |
54 | print("Answer:") |
55 | print(answer) |
56 | # Print the final citations |
57 | citations = docs[-1].metadata['citations'] |
58 | print("Citations:") |
59 | print(citations) |