Basic Semantic Search

Language models give computers the ability to search by meaning and go beyond searching by matching keywords. This capability is called semantic search.

Searching an archive using sentence embeddings

In this notebook, we’ll build a simple semantic search engine. The applications of semantic search go beyond building a web search engine. They can empower a private search engine for internal documents or records. It can also be used to power features like StackOverflow’s “similar questions” feature.

  1. Get the archive of questions
  2. Embed the archive
  3. Search using an index and nearest neighbor search
  4. Visualize the archive based on the embeddings

And if you’re running an older version of the SDK, you might need to upgrade it like so:

PYTHON
1#!pip install --upgrade cohere

Get your Cohere API key by signing up here. Paste it in the cell below.

1. Getting Set Up

PYTHON
1#@title Import libraries (Run this cell to execute required code) {display-mode: "form"}
2
3import cohere
4import numpy as np
5import re
6import pandas as pd
7from tqdm import tqdm
8from datasets import load_dataset
9import umap
10import altair as alt
11from sklearn.metrics.pairwise import cosine_similarity
12from annoy import AnnoyIndex
13import warnings
14warnings.filterwarnings('ignore')
15pd.set_option('display.max_colwidth', None)

You’ll need your API key for this next cell. Sign up to Cohere and get one if you haven’t yet.

PYTHON
1model_name = "embed-english-v3.0"
2api_key = ""
3input_type_embed = "search_document"
4
5co = cohere.Client(api_key)

2. Get The Archive of Questions

We’ll use the trec dataset which is made up of questions and their categories.

PYTHON
1dataset = load_dataset("trec", split="train")
2
3df = pd.DataFrame(dataset)[:1000]
4
5df.head(10)
label-coarselabel-finetext
000How did serfdom develop in and then leave Russia ?
111What films featured the character Popeye Doyle ?
200How can I find a list of celebrities ’ real names ?
312

What fowl grabs the spotlight after the Chinese Year of the Monkey ?

423What is the full form of .com ?
534What contemptible scoundrel stole the cork from my lunch ?
635What team did baseball ‘s St. Louis Browns become ?
736What is the oldest profession ?
807What are liver enzymes ?
934Name the scar-faced bounty hunter of The Old West .

2. Embed the archive

The next step is to embed the text of the questions.

embedding archive texts

To get a thousand embeddings of this length should take about fifteen seconds.

PYTHON
1embeds = co.embed(texts=list(df['text']),
2 model=model_name,
3 input_type=input_type_embed).embeddings
PYTHON
1embeds = np.array(embeds)
2embeds.shape
(1000, 4096)
Building the search index from the embeddings

Let’s now use Annoy to build an index that stores the embeddings in a way that is optimized for fast search. This approach scales well to a large number of texts (other options include Faiss, ScaNN, and PyNNDescent).

After building the index, we can use it to retrieve the nearest neighbors either of existing questions (section 3.1), or of new questions that we embed (section 3.2).

PYTHON
1search_index = AnnoyIndex(embeds.shape[1], 'angular')
2for i in range(len(embeds)):
3 search_index.add_item(i, embeds[i])
4
5search_index.build(10) # 10 trees
6search_index.save('test.ann')
True

3.1. Find the neighbors of an example from the dataset

If we’re only interested in measuring the distance between the questions in the dataset (no outside queries), a simple way is to calculate the distance between every pair of embeddings we have.

PYTHON
1example_id = 92
2
3similar_item_ids = search_index.get_nns_by_item(example_id,10,
4 include_distances=True)
5results = pd.DataFrame(data={'texts': df.iloc[similar_item_ids[0]]['text'],
6 'distance': similar_item_ids[1]}).drop(example_id)
7
8print(f"Question:'{df.iloc[example_id]['text']}'\nNearest neighbors:")
9results
Question:'What are bear and bull markets ?'
Nearest neighbors:
textsdistance
614What animals do you find in the stock market ?0.904278
137What are equity securities ?0.992819
513What do economists do ?1.066583
307What does NASDAQ stand for ?1.080738
363What does it mean “ Rupee Depreciates ” ?1.086724
932Why did the world enter a global depression in 1929 ?1.099370
547Where can stocks be traded on-line ?1.105368
922What is the difference between a median and a mean ?1.141870
601What is “ the bear of beers ” ?1.154140

3.2. Find the neighbors of a user query

We’re not limited to searching using existing items. If we get a query, we can embed it and find its nearest neighbors from the dataset.

PYTHON
1query = "What is the tallest mountain in the world?"
2input_type_query = "search_query"
3
4query_embed = co.embed(texts=[query],
5 model=model_name,
6 input_type=input_type_query).embeddings
7
8similar_item_ids = search_index.get_nns_by_vector(query_embed[0],10,
9 include_distances=True)
10query_results = pd.DataFrame(data={'texts': df.iloc[similar_item_ids[0]]['text'],
11 'distance': similar_item_ids[1]})
12
13
14print(f"Query:'{query}'\nNearest neighbors:")
15print(query_results) # NOTE: Your results might look slightly different to ours.
Query:'What is the tallest mountain in the world?'
Nearest neighbors:
textsdistance
236What is the name of the tallest mountain in the world ?0.447309
670What is the highest mountain in the world ?0.552254
412

What was the highest mountain on earth before Mount Everest was discovered ?

0.801252
907

What mountain range is traversed by the highest railroad in the world ?

0.929516
435What is the highest peak in Africa ?0.930806
109Where is the highest point in Japan ?0.977315
901What ‘s the longest river in the world ?1.064209
114What is the largest snake in the world ?1.076390
962What ‘s the second-largest island in the world ?1.088034
27What is the highest waterfall in the United States ?1.091145

4. Visualizing the archive

Finally, let’s plot out all the questions onto a 2D chart so you’re able to visualize the semantic similarities of this dataset!

PYTHON
1#@title Plot the archive {display-mode: "form"}
2
3reducer = umap.UMAP(n_neighbors=20)
4umap_embeds = reducer.fit_transform(embeds)
5df_explore = pd.DataFrame(data={'text': df['text']})
6df_explore['x'] = umap_embeds[:,0]
7df_explore['y'] = umap_embeds[:,1]
8
9chart = alt.Chart(df_explore).mark_circle(size=60).encode(
10 x=#'x',
11 alt.X('x',
12 scale=alt.Scale(zero=False)
13 ),
14 y=
15 alt.Y('y',
16 scale=alt.Scale(zero=False)
17 ),
18 tooltip=['text']
19).properties(
20 width=700,
21 height=400
22)
23chart.interactive()

Hover over the points to read the text. Do you see some of the patterns in clustered points? Similar questions, or questions asking about similar topics?

This concludes this introductory guide to semantic search using sentence embeddings. As you continue the path of building a search product additional considerations arise (like dealing with long texts, or finetuning to better improve the embeddings for a specific use case).

We can’t wait to see what you start building! Share your projects or find support on Discord.