Grounded Summarization Using Command R

Grounded Summarization Using Command R

Note: we are in the process of updating the links in this notebook. If a link doesn't work, please open an issue and we'll rectify it ASAP. Thanks for your understanding!

Links to add:

  • Cell 1: long-form, grounded summarisation blog post
  • Section 4: to text-rank method (context filtering)

This notebook provides the code to produce the outputs described in this blog post.

Table of contents

  1. Setup
  2. Out-of-the-box summarization with Command-R
  3. Introduce citations to the summary for grounding
  4. Reduce the cost of summarization calls


1. Setup

%%capture

import cohere
import networkx as nx
import nltk
nltk.download("punkt")
from nltk.tokenize import sent_tokenize
import numpy as np
import spacy

from collections import deque
from getpass import getpass
import re
from typing import List, Tuple

co_api_key = getpass("Enter your Cohere API key: ")
co_model = "command-r"
co = cohere.Client(co_api_key)


from google.colab import drive
drive.mount("/content/drive", force_remount=True)

fpath = "drive/Shareddrives/FDE/Cookbooks/Long-form summarisation/ai_and_future_of_work.txt"
with open(fpath, "r") as f:
  text = f.read()

num_tokens = co.tokenize(text).length
print(f"Loaded IMF report with {num_tokens} tokens")


Aside: define utils


def split_text_into_sentences(text: str) -> List[str]:
    sentences =  sent_tokenize(text)
    return sentences

def group_sentences_into_passages(sentence_list: List[str], n_sentences_per_passage: int = 10):
    """
    Group sentences into passages of n_sentences sentences.
    """
    passages = []
    passage = ""
    for i, sentence in enumerate(sentence_list):
        passage += sentence + " "
        if (i + 1) % n_sentences_per_passage == 0:
            passages.append(passage)
            passage = ""
    return passages

def build_simple_chunks(text, n_sentences: int = 10):
    """
    Build chunks of text from the input text.
    """
    sentences = split_text_into_sentences(text)
    chunks = group_sentences_into_passages(sentences, n_sentences_per_passage=n_sentences)
    return chunks



def insert_citations(text: str, citations: List[dict]):
    """
    A helper function to pretty print citations.
    """
    offset = 0
    # Process citations in the order they were provided
    for citation in citations:
        # Adjust start/end with offset
        start, end = citation['start'] + offset, citation['end'] + offset
        placeholder = "[" + ", ".join(doc[4:] for doc in citation["document_ids"]) + "]"
        # ^ doc[4:] removes the 'doc_' prefix, and leaves the quoted document
        modification = f'{text[start:end]} {placeholder}'
        # Replace the cited text with its bolded version + placeholder
        text = text[:start] + modification + text[end:]
        # Update the offset for subsequent replacements
        offset += len(modification) - (end - start)

    return text



def textrank(text: str, co, max_tokens: int, n_sentences_per_passage: int) -> str:
    """
    Shortens `text` by extracting key units of text from `text` based on their centrality and concatenating them.
    The output is the concatenation of those key units, in their original order. Centrality is graph-theoretic
    measure of connectedness of a node; the more connected a node is to surrounding nodes (and the more sparsely
    those neighbours are connected), the higher centrality.

    Key passages are identified via clustering in a three-step process:
    1. Break up `long` into chunks (either sentences or passages, based on `unit`)
    2. Embed each chunk using Cohere's embedding model and construct a similarity matrix
    3. Compute the centrality of each chunk
    4. Keep the highest-centrality chunks until `max_tokens` is reached
    5. Put together shorterned text by reordering chunks in their original order

    This approach is based on summarise.long_doc_summarization.extraction::extract_single_doc with sorting by
    centrality. Adapted here because installing the `summarise` repo would have added a lot of unused functionalities
    and dependencies.
    """

    # 1. Chunk text into units
    chunks = build_simple_chunks(text, n_sentences_per_passage)

    # 2. Embed and construct similarity matrix
    embeddings = np.array(
        co.embed(
            texts=chunks,
            model="embed-english-v3.0",
            input_type="clustering",
        ).embeddings
    )
    similarities = np.dot(embeddings, embeddings.T)

    # 3. Compute centrality and sort sentences by centrality
    # Easiest to use networkx's `degree` function with similarity as weight
    g = nx.from_numpy_array(similarities, edge_attr="weight")
    centralities = g.degree(weight="weight")
    idcs_sorted_by_centrality = [node for node, degree in sorted(centralities, key=lambda item: item[1], reverse=True)]

    # 4. Add chunks back in order of centrality
    selected = _add_chunks_by_priority(co, chunks, idcs_sorted_by_centrality, max_tokens)

    # 5. Put condensed text back in original order
    separator = "\n"
    short = separator.join([chunk for index, chunk in sorted(selected, key=lambda item: item[0], reverse=False)])

    return short


def _add_chunks_by_priority(
    co, chunks: List[str], idcs_sorted_by_priority: List[int], max_tokens: int
) -> List[Tuple[int, str]]:
    """
    Given chunks of text and their indices sorted by priority (highest priority first), this function
    fills the model context window with as many highest-priority chunks as possible.

    The output is a list of (index, chunk) pairs, ordered by priority. To stitch back the chunks into
    a cohesive text that preserves chronological order, sort the output on its index.
    """

    selected = []
    num_tokens = 0
    idcs_queue = deque(idcs_sorted_by_priority)

    while num_tokens < max_tokens and len(idcs_queue) > 0:
        next_idx = idcs_queue.popleft()
        num_tokens += co.tokenize(chunks[next_idx]).length - 2
        # num_tokens += len(tokenizer.encode(chunks[next_idx]).ids) - 2
        # ^ removing BOS and EOS tokens from count
        selected.append((next_idx, chunks[next_idx]))
        # ^ keep index and chunk, to reorder chronologically
    if num_tokens > max_tokens:
        selected.pop()

    return selected


2. Out-of-the-box summarization with Command-R

First, let's see Command-R's out-of-the-box performance. It's a 128k-context model, so we can pass the full IMF report in a single call. We replicate the exact instructions from the original tweet (correcting for a minor typo) for enabling fair comparisons.

prompt_template = """\
## text
{text}

## instructions
Step 1. Read the entire text from the first to the last page.
Step 2. Create a summary of every chapter from the first to the last page.

## summary
"""

prompt = prompt_template.format(text=text)
resp = co.chat(
  message=prompt,
  model=co_model,
  temperature=0.3,
  return_prompt=True
)

num_tokens_in = co.tokenize(resp.prompt).length
num_tokens_out = resp.meta["billed_units"]["output_tokens"]
print(f"Generated summary with {num_tokens_in} tokens in, {num_tokens_out} tokens out")
print()
print("--- Out-of-the-box summary with Command-R ---")
print()
print(resp.text)


3. Introduce citations to the summary for grounding

When summarizing long documents, introducing citations is one simple method for checking the factuality of the summary without needing to read the full document.

We've trained Command-R to introduce citations whenever prompted by our grounded generations instructions. Triggering this grounded mode is straightforward. Starting from the previous snippet, we only need to make two changes:

  1. Pass our text to the documents argument
  2. Pass our instructions to the message argument

For more information on how to enable grounded generation via our co.chat API, please refer to our documentation.

Finally, note that we chunk the IMF report into multiple documents before passing them to co.chat. This isn't necessary (co.chat annotates citations at the character level), but allows for more human-readable citations.

summarize_preamble = """\
You will receive a series of text fragments from an article that are presented in chronological order. \
As the assistant, you must generate responses to user's requests based on the information given in the fragments. \
Ensure that your responses are accurate and truthful, and that you reference your sources where appropriate to answer \
the queries, regardless of their complexity.\
"""

instructions = """\
## instructions
Step 1. Read the entire text from the first to the last page.
Step 2. Create a summary of every chapter from the first to the last page.
"""

chunked = build_simple_chunks(text, n_sentences=30)
resp = co.chat(
  preamble=summarize_preamble,
  message=instructions,
  documents=[{"text": chunk} for chunk in chunked],
  model=co_model,
  temperature=0.3,
  return_prompt=True
)

num_tokens_in = co.tokenize(resp.prompt).length
num_tokens_out = resp.meta["billed_units"]["output_tokens"]
print(f"Generated summary with {num_tokens_in} tokens in, {num_tokens_out} tokens out")
print()
print("--- Summary with citations using grounded generation in Command-R ---")
print()
print(resp.text)

Let's display the citations inside our answer:

print(insert_citations(resp.text, resp.citations))

We can now visualise which section of the answer is based on which passage in the main text. Verifying factuality is straightforward: pick a section and verify that the relevant information is contained in the cited chunk.

For instance, let's verify the statement

Around 40% of employment worldwide is exposed to AI [1, 6]

by checking its chunk:

print(chunked[6])

Seems convincing!
By repeating such checks, it's straightforward to build trust in your summaries.


4. Reduce the cost of summarization calls

Even though Command-R is an efficient, light-weight model, for some applications we may accept trading off some summarization quality for lower costs. To do this, we must reduce the amount of tokens sent to the model -- but how do we select the most relevant bits?

We have a whole notebook dedicated to methods for reducing context length. Here, we call our 'text-rank' method to select maximally central chunks in a graph based on the chunk-to-chunk similarties. For more detail, please refer to this cookbook.

num_tokens = 8192
shortened = textrank(text, co, num_tokens, n_sentences_per_passage=30)

chunked = build_simple_chunks(shortened)
resp = co.chat(
  message=instructions,
  documents=[{"text": chunk} for chunk in chunked],
  model=co_model,
  temperature=0.3,
  return_prompt=True
)

num_tokens_in = co.tokenize(resp.prompt).length
num_tokens_out = resp.meta["billed_units"]["output_tokens"]
print(f"Generated summary with {num_tokens_in} tokens in, {num_tokens_out} tokens out")
print()
print("--- Summary with citations using text-rank + grounding in Command-R ---")
print()
print(resp.text)

The summary is looking convincing! In practice, the trade-off between cost-efficiency and performance should be considered carefully.