1 | from langchain.retrievers import ContextualCompressionRetriever, CohereRagRetriever |
2 | from langchain.retrievers.document_compressors import CohereRerank |
3 | from langchain_community.embeddings import CohereEmbeddings |
4 | from langchain_community.chat_models import ChatCohere |
5 | from langchain.text_splitter import CharacterTextSplitter |
6 | from langchain_community.document_loaders import TextLoader |
7 | from langchain_community.vectorstores import Chroma |
8 | |
9 | user_query = "When was Cohere started?" |
10 | # Create cohere's chat model and embeddings objects |
11 | cohere_chat_model = ChatCohere(cohere_api_key="{API_KEY}") |
12 | cohere_embeddings = CohereEmbeddings(cohere_api_key="{API_KEY}") |
13 | # Load text files and split into chunks, you can also use data gathered elsewhere in your application |
14 | raw_documents = TextLoader('demofile.txt').load() |
15 | text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
16 | documents = text_splitter.split_documents(raw_documents) |
17 | # Create a vector store from the documents |
18 | db = Chroma.from_documents(documents, cohere_embeddings) |
19 | |
20 | # Create Cohere's reranker with the vector DB using Cohere's embeddings as the base retriever |
21 | cohere_rerank = CohereRerank(cohere_api_key="{API_KEY}") |
22 | compression_retriever = ContextualCompressionRetriever( |
23 | base_compressor=cohere_rerank, |
24 | base_retriever=db.as_retriever() |
25 | ) |
26 | compressed_docs = compression_retriever.get_relevant_documents(user_query) |
27 | # Print the relevant documents from using the embeddings and reranker |
28 | print(compressed_docs) |
29 | |
30 | # Create the cohere rag retriever using the chat model |
31 | rag = CohereRagRetriever(llm=cohere_chat_model) |
32 | docs = rag.get_relevant_documents( |
33 | user_query, |
34 | source_documents=compressed_docs, |
35 | ) |
36 | # Print the documents |
37 | for doc in docs[:-1]: |
38 | print(doc.metadata) |
39 | print("\n\n" + doc.page_content) |
40 | print("\n\n" + "-" * 30 + "\n\n") |
41 | # Print the final generation |
42 | answer = docs[-1].page_content |
43 | print(answer) |
44 | # Print the final citations |
45 | citations = docs[-1].metadata['citations'] |
46 | print(citations) |