1 | from langchain.retrievers import ContextualCompressionRetriever |
2 | from langchain_cohere import CohereEmbeddings |
3 | from langchain_cohere import ChatCohere |
4 | from langchain_cohere import CohereRerank, CohereRagRetriever |
5 | from langchain.text_splitter import CharacterTextSplitter |
6 | from langchain_community.document_loaders import TextLoader |
7 | from langchain_community.vectorstores import Chroma |
8 | from langchain_community.document_loaders import WebBaseLoader |
9 | |
10 | user_query = "what is Cohere Toolkit?" |
11 | |
12 | # Define the Cohere LLM |
13 | llm = ChatCohere( |
14 | cohere_api_key="COHERE_API_KEY", model="command-r-plus-08-2024" |
15 | ) |
16 | |
17 | # Define the Cohere embedding model |
18 | embeddings = CohereEmbeddings( |
19 | cohere_api_key="COHERE_API_KEY", model="embed-english-light-v3.0" |
20 | ) |
21 | |
22 | # Load text files and split into chunks, you can also use data gathered elsewhere in your application |
23 | raw_documents = WebBaseLoader( |
24 | "https://docs.cohere.com/docs/cohere-toolkit" |
25 | ).load() |
26 | text_splitter = CharacterTextSplitter( |
27 | chunk_size=1000, chunk_overlap=0 |
28 | ) |
29 | documents = text_splitter.split_documents(raw_documents) |
30 | |
31 | # Create a vector store from the documents |
32 | db = Chroma.from_documents(documents, embeddings) |
33 | |
34 | # Create Cohere's reranker with the vector DB using Cohere's embeddings as the base retriever |
35 | reranker = CohereRerank( |
36 | cohere_api_key="COHERE_API_KEY", model="rerank-english-v3.0" |
37 | ) |
38 | |
39 | compression_retriever = ContextualCompressionRetriever( |
40 | base_compressor=reranker, base_retriever=db.as_retriever() |
41 | ) |
42 | compressed_docs = compression_retriever.get_relevant_documents( |
43 | user_query |
44 | ) |
45 | # Print the relevant documents from using the embeddings and reranker |
46 | print(compressed_docs) |
47 | |
48 | # Create the cohere rag retriever using the chat model |
49 | rag = CohereRagRetriever(llm=llm, connectors=[]) |
50 | docs = rag.get_relevant_documents( |
51 | user_query, |
52 | documents=compressed_docs, |
53 | ) |
54 | # Print the documents |
55 | print("Documents:") |
56 | for doc in docs[:-1]: |
57 | print(doc.metadata) |
58 | print("\n\n" + doc.page_content) |
59 | print("\n\n" + "-" * 30 + "\n\n") |
60 | # Print the final generation |
61 | answer = docs[-1].page_content |
62 | print("Answer:") |
63 | print(answer) |
64 | # Print the final citations |
65 | citations = docs[-1].metadata["citations"] |
66 | print("Citations:") |
67 | print(citations) |