Basic Semantic Search

Basic Semantic Search

Language models give computers the ability to search by meaning and go beyond searching by matching keywords. This capability is called semantic search.

Searching an archive using sentence embeddings

In this notebook, we'll build a simple semantic search engine. The applications of semantic search go beyond building a web search engine. They can empower a private search engine for internal documents or records. It can also be used to power features like StackOverflow's "similar questions" feature.

  1. Get the archive of questions
  2. Embed the archive
  3. Search using an index and nearest neighbor search
  4. Visualize the archive based on the embeddings

And if you're running an older version of the SDK, you might need to upgrade it like so:

#!pip install --upgrade cohere

Get your Cohere API key by signing up here. Paste it in the cell below.

1. Getting Set Up

#@title Import libraries (Run this cell to execute required code) {display-mode: "form"}

import cohere
import numpy as np
import re
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
import umap
import altair as alt
from sklearn.metrics.pairwise import cosine_similarity
from annoy import AnnoyIndex
import warnings
warnings.filterwarnings('ignore')
pd.set_option('display.max_colwidth', None)

You'll need your API key for this next cell. Sign up to Cohere and get one if you haven't yet.

model_name = "embed-english-v3.0"
api_key = ""
input_type_embed = "search_document"

co = cohere.Client(api_key)

2. Get The Archive of Questions

We'll use the trec dataset which is made up of questions and their categories.

dataset = load_dataset("trec", split="train")

df = pd.DataFrame(dataset)[:1000]

df.head(10)
label-coarse label-fine text
0 0 0 How did serfdom develop in and then leave Russia ?
1 1 1 What films featured the character Popeye Doyle ?
2 0 0 How can I find a list of celebrities ' real names ?
3 1 2 What fowl grabs the spotlight after the Chinese Year of the Monkey ?
4 2 3 What is the full form of .com ?
5 3 4 What contemptible scoundrel stole the cork from my lunch ?
6 3 5 What team did baseball 's St. Louis Browns become ?
7 3 6 What is the oldest profession ?
8 0 7 What are liver enzymes ?
9 3 4 Name the scar-faced bounty hunter of The Old West .

2. Embed the archive

The next step is to embed the text of the questions.

embedding archive texts

To get a thousand embeddings of this length should take about fifteen seconds.

embeds = co.embed(texts=list(df['text']),
                  model=model_name,
                  input_type=input_type_embed).embeddings
embeds = np.array(embeds)
embeds.shape
(1000, 4096)

3. Search using an index and nearest neighbor search

Building the search index from the embeddings

Let's now use Annoy to build an index that stores the embeddings in a way that is optimized for fast search. This approach scales well to a large number of texts (other options include Faiss, ScaNN, and PyNNDescent).

After building the index, we can use it to retrieve the nearest neighbors either of existing questions (section 3.1), or of new questions that we embed (section 3.2).

search_index = AnnoyIndex(embeds.shape[1], 'angular')
for i in range(len(embeds)):
    search_index.add_item(i, embeds[i])

search_index.build(10) # 10 trees
search_index.save('test.ann')
True

3.1. Find the neighbors of an example from the dataset

If we're only interested in measuring the distance between the questions in the dataset (no outside queries), a simple way is to calculate the distance between every pair of embeddings we have.

example_id = 92

similar_item_ids = search_index.get_nns_by_item(example_id,10,
                                                include_distances=True)
results = pd.DataFrame(data={'texts': df.iloc[similar_item_ids[0]]['text'], 
                             'distance': similar_item_ids[1]}).drop(example_id)

print(f"Question:'{df.iloc[example_id]['text']}'\nNearest neighbors:")
results
Question:'What are bear and bull markets ?'
Nearest neighbors:
texts distance
614 What animals do you find in the stock market ? 0.904278
137 What are equity securities ? 0.992819
513 What do economists do ? 1.066583
307 What does NASDAQ stand for ? 1.080738
363 What does it mean `` Rupee Depreciates '' ? 1.086724
932 Why did the world enter a global depression in 1929 ? 1.099370
547 Where can stocks be traded on-line ? 1.105368
922 What is the difference between a median and a mean ? 1.141870
601 What is `` the bear of beers '' ? 1.154140

3.2. Find the neighbors of a user query

We're not limited to searching using existing items. If we get a query, we can embed it and find its nearest neighbors from the dataset.

query = "What is the tallest mountain in the world?"
input_type_query = "search_query"

query_embed = co.embed(texts=[query],
                  model=model_name,
                  input_type=input_type_query).embeddings

similar_item_ids = search_index.get_nns_by_vector(query_embed[0],10,
                                                include_distances=True)
query_results = pd.DataFrame(data={'texts': df.iloc[similar_item_ids[0]]['text'], 
                             'distance': similar_item_ids[1]})


print(f"Query:'{query}'\nNearest neighbors:")
print(query_results) # NOTE: Your results might look slightly different to ours.
Query:'What is the tallest mountain in the world?'
Nearest neighbors:
texts distance
236 What is the name of the tallest mountain in the world ? 0.447309
670 What is the highest mountain in the world ? 0.552254
412 What was the highest mountain on earth before Mount Everest was discovered ? 0.801252
907 What mountain range is traversed by the highest railroad in the world ? 0.929516
435 What is the highest peak in Africa ? 0.930806
109 Where is the highest point in Japan ? 0.977315
901 What 's the longest river in the world ? 1.064209
114 What is the largest snake in the world ? 1.076390
962 What 's the second-largest island in the world ? 1.088034
27 What is the highest waterfall in the United States ? 1.091145

4. Visualizing the archive

Finally, let's plot out all the questions onto a 2D chart so you're able to visualize the semantic similarities of this dataset!

#@title Plot the archive {display-mode: "form"}

reducer = umap.UMAP(n_neighbors=20) 
umap_embeds = reducer.fit_transform(embeds)
df_explore = pd.DataFrame(data={'text': df['text']})
df_explore['x'] = umap_embeds[:,0]
df_explore['y'] = umap_embeds[:,1]

chart = alt.Chart(df_explore).mark_circle(size=60).encode(
    x=#'x',
    alt.X('x',
        scale=alt.Scale(zero=False)
    ),
    y=
    alt.Y('y',
        scale=alt.Scale(zero=False)
    ),
    tooltip=['text']
).properties(
    width=700,
    height=400
)
chart.interactive()

Hover over the points to read the text. Do you see some of the patterns in clustered points? Similar questions, or questions asking about similar topics?

This concludes this introductory guide to semantic search using sentence embeddings. As you continue the path of building a search product additional considerations arise (like dealing with long texts, or finetuning to better improve the embeddings for a specific use case).

We can’t wait to see what you start building! Share your projects or find support on Discord.