Grounded Summarization Using Command R

Note: we are in the process of updating the links in this notebook. If a link doesn’t work, please open an issue and we’ll rectify it ASAP. Thanks for your understanding!

Links to add:

  • Cell 1: long-form, grounded summarisation blog post
  • Section 4: to text-rank method (context filtering)

This notebook provides the code to produce the outputs described in this blog post.

1. Setup

PYTHON
1%%capture
2
3import cohere
4import networkx as nx
5import nltk
6nltk.download("punkt")
7from nltk.tokenize import sent_tokenize
8import numpy as np
9import spacy
10
11from collections import deque
12from getpass import getpass
13import re
14from typing import List, Tuple
15
16co_api_key = getpass("Enter your Cohere API key: ")
17co_model = "command-r"
18co = cohere.Client(co_api_key)
PYTHON
1from google.colab import drive
2drive.mount("/content/drive", force_remount=True)
3
4fpath = "drive/Shareddrives/FDE/Cookbooks/Long-form summarisation/ai_and_future_of_work.txt"
5with open(fpath, "r") as f:
6 text = f.read()
7
8num_tokens = co.tokenize(text).length
9print(f"Loaded IMF report with {num_tokens} tokens")

Aside: define utils

PYTHON
1def split_text_into_sentences(text: str) -> List[str]:
2 sentences = sent_tokenize(text)
3 return sentences
4
5def group_sentences_into_passages(sentence_list: List[str], n_sentences_per_passage: int = 10):
6 """
7 Group sentences into passages of n_sentences sentences.
8 """
9 passages = []
10 passage = ""
11 for i, sentence in enumerate(sentence_list):
12 passage += sentence + " "
13 if (i + 1) % n_sentences_per_passage == 0:
14 passages.append(passage)
15 passage = ""
16 return passages
17
18def build_simple_chunks(text, n_sentences: int = 10):
19 """
20 Build chunks of text from the input text.
21 """
22 sentences = split_text_into_sentences(text)
23 chunks = group_sentences_into_passages(sentences, n_sentences_per_passage=n_sentences)
24 return chunks
25
26
27
28def insert_citations(text: str, citations: List[dict]):
29 """
30 A helper function to pretty print citations.
31 """
32 offset = 0
33 # Process citations in the order they were provided
34 for citation in citations:
35 # Adjust start/end with offset
36 start, end = citation['start'] + offset, citation['end'] + offset
37 placeholder = "[" + ", ".join(doc[4:] for doc in citation["document_ids"]) + "]"
38 # ^ doc[4:] removes the 'doc_' prefix, and leaves the quoted document
39 modification = f'{text[start:end]} {placeholder}'
40 # Replace the cited text with its bolded version + placeholder
41 text = text[:start] + modification + text[end:]
42 # Update the offset for subsequent replacements
43 offset += len(modification) - (end - start)
44
45 return text
46
47
48
49def textrank(text: str, co, max_tokens: int, n_sentences_per_passage: int) -> str:
50 """
51 Shortens `text` by extracting key units of text from `text` based on their centrality and concatenating them.
52 The output is the concatenation of those key units, in their original order. Centrality is graph-theoretic
53 measure of connectedness of a node; the more connected a node is to surrounding nodes (and the more sparsely
54 those neighbours are connected), the higher centrality.
55
56 Key passages are identified via clustering in a three-step process:
57 1. Break up `long` into chunks (either sentences or passages, based on `unit`)
58 2. Embed each chunk using Cohere's embedding model and construct a similarity matrix
59 3. Compute the centrality of each chunk
60 4. Keep the highest-centrality chunks until `max_tokens` is reached
61 5. Put together shorterned text by reordering chunks in their original order
62
63 This approach is based on summarise.long_doc_summarization.extraction::extract_single_doc with sorting by
64 centrality. Adapted here because installing the `summarise` repo would have added a lot of unused functionalities
65 and dependencies.
66 """
67
68 # 1. Chunk text into units
69 chunks = build_simple_chunks(text, n_sentences_per_passage)
70
71 # 2. Embed and construct similarity matrix
72 embeddings = np.array(
73 co.embed(
74 texts=chunks,
75 model="embed-english-v3.0",
76 input_type="clustering",
77 ).embeddings
78 )
79 similarities = np.dot(embeddings, embeddings.T)
80
81 # 3. Compute centrality and sort sentences by centrality
82 # Easiest to use networkx's `degree` function with similarity as weight
83 g = nx.from_numpy_array(similarities, edge_attr="weight")
84 centralities = g.degree(weight="weight")
85 idcs_sorted_by_centrality = [node for node, degree in sorted(centralities, key=lambda item: item[1], reverse=True)]
86
87 # 4. Add chunks back in order of centrality
88 selected = _add_chunks_by_priority(co, chunks, idcs_sorted_by_centrality, max_tokens)
89
90 # 5. Put condensed text back in original order
91 separator = "\n"
92 short = separator.join([chunk for index, chunk in sorted(selected, key=lambda item: item[0], reverse=False)])
93
94 return short
95
96
97def _add_chunks_by_priority(
98 co, chunks: List[str], idcs_sorted_by_priority: List[int], max_tokens: int
99) -> List[Tuple[int, str]]:
100 """
101 Given chunks of text and their indices sorted by priority (highest priority first), this function
102 fills the model context window with as many highest-priority chunks as possible.
103
104 The output is a list of (index, chunk) pairs, ordered by priority. To stitch back the chunks into
105 a cohesive text that preserves chronological order, sort the output on its index.
106 """
107
108 selected = []
109 num_tokens = 0
110 idcs_queue = deque(idcs_sorted_by_priority)
111
112 while num_tokens < max_tokens and len(idcs_queue) > 0:
113 next_idx = idcs_queue.popleft()
114 num_tokens += co.tokenize(chunks[next_idx]).length - 2
115 # num_tokens += len(tokenizer.encode(chunks[next_idx]).ids) - 2
116 # ^ removing BOS and EOS tokens from count
117 selected.append((next_idx, chunks[next_idx]))
118 # ^ keep index and chunk, to reorder chronologically
119 if num_tokens > max_tokens:
120 selected.pop()
121
122 return selected

2. Out-of-the-box summarization with Command-R

First, let’s see Command-R’s out-of-the-box performance. It’s a 128k-context model, so we can pass the full IMF report in a single call. We replicate the exact instructions from the original tweet (correcting for a minor typo) for enabling fair comparisons.

PYTHON
1prompt_template = """\
2## text
3{text}
4
5## instructions
6Step 1. Read the entire text from the first to the last page.
7Step 2. Create a summary of every chapter from the first to the last page.
8
9## summary
10"""
11
12prompt = prompt_template.format(text=text)
13resp = co.chat(
14 message=prompt,
15 model=co_model,
16 temperature=0.3,
17 return_prompt=True
18)
19
20num_tokens_in = co.tokenize(resp.prompt).length
21num_tokens_out = resp.meta["billed_units"]["output_tokens"]
22print(f"Generated summary with {num_tokens_in} tokens in, {num_tokens_out} tokens out")
23print()
24print("--- Out-of-the-box summary with Command-R ---")
25print()
26print(resp.text)

3. Introduce citations to the summary for grounding

When summarizing long documents, introducing citations is one simple method for checking the factuality of the summary without needing to read the full document.

We’ve trained Command-R to introduce citations whenever prompted by our grounded generations instructions. Triggering this grounded mode is straightforward. Starting from the previous snippet, we only need to make two changes:

  1. Pass our text to the documents argument
  2. Pass our instructions to the message argument

For more information on how to enable grounded generation via our co.chat API, please refer to our documentation.

Finally, note that we chunk the IMF report into multiple documents before passing them to co.chat. This isn’t necessary (co.chat annotates citations at the character level), but allows for more human-readable citations.

PYTHON
1summarize_preamble = """\
2You will receive a series of text fragments from an article that are presented in chronological order. \
3As the assistant, you must generate responses to user's requests based on the information given in the fragments. \
4Ensure that your responses are accurate and truthful, and that you reference your sources where appropriate to answer \
5the queries, regardless of their complexity.\
6"""
7
8instructions = """\
9## instructions
10Step 1. Read the entire text from the first to the last page.
11Step 2. Create a summary of every chapter from the first to the last page.
12"""
13
14chunked = build_simple_chunks(text, n_sentences=30)
15resp = co.chat(
16 preamble=summarize_preamble,
17 message=instructions,
18 documents=[{"text": chunk} for chunk in chunked],
19 model=co_model,
20 temperature=0.3,
21 return_prompt=True
22)
23
24num_tokens_in = co.tokenize(resp.prompt).length
25num_tokens_out = resp.meta["billed_units"]["output_tokens"]
26print(f"Generated summary with {num_tokens_in} tokens in, {num_tokens_out} tokens out")
27print()
28print("--- Summary with citations using grounded generation in Command-R ---")
29print()
30print(resp.text)

Let’s display the citations inside our answer:

PYTHON
1print(insert_citations(resp.text, resp.citations))

We can now visualise which section of the answer is based on which passage in the main text. Verifying factuality is straightforward: pick a section and verify that the relevant information is contained in the cited chunk.

For instance, let’s verify the statement

Around 40% of employment worldwide is exposed to AI [1, 6]

by checking its chunk:

PYTHON
1print(chunked[6])

Seems convincing! By repeating such checks, it’s straightforward to build trust in your summaries.

4. Reduce the cost of summarization calls

Even though Command-R is an efficient, light-weight model, for some applications we may accept trading off some summarization quality for lower costs. To do this, we must reduce the amount of tokens sent to the model — but how do we select the most relevant bits?

We have a whole notebook dedicated to methods for reducing context length. Here, we call our ‘text-rank’ method to select maximally central chunks in a graph based on the chunk-to-chunk similarties. For more detail, please refer to this cookbook.

PYTHON
1num_tokens = 8192
2shortened = textrank(text, co, num_tokens, n_sentences_per_passage=30)
3
4chunked = build_simple_chunks(shortened)
5resp = co.chat(
6 message=instructions,
7 documents=[{"text": chunk} for chunk in chunked],
8 model=co_model,
9 temperature=0.3,
10 return_prompt=True
11)
12
13num_tokens_in = co.tokenize(resp.prompt).length
14num_tokens_out = resp.meta["billed_units"]["output_tokens"]
15print(f"Generated summary with {num_tokens_in} tokens in, {num_tokens_out} tokens out")
16print()
17print("--- Summary with citations using text-rank + grounding in Command-R ---")
18print()
19print(resp.text)

The summary is looking convincing! In practice, the trade-off between cost-efficiency and performance should be considered carefully.