Text Classification with Cohere’s Classify Endpoint

Among the most popular use cases for language embeddings is ‘text classification,’ in which different pieces of text — blog posts, lyrics, poems, headlines, etc. — are grouped based on their similarity, their sentiment, or some other property.

Here, we’ll discuss how to perform simple text classification tasks with Cohere’s classify endpoint, and provide links to more information on how to fine-tune this endpoint for more specialized work.

Few-Shot Classification with Cohere’s classify Endpoint

Generally, training a text classifier requires a tremendous amount of data. But with large language models, it’s now possible to create so-called ‘few shot’ classification models able to perform well after seeing a far smaller number of samples.

In the next few sections, we’ll create a sentiment analysis classifier to sort text into “positive,” “negative,” and “neutral” categories.

Setting up the SDK

First, let’s import the required tools and set up a Cohere client.

PYTHON
1import cohere
2from cohere import ClassifyExample
PYTHON
1co = cohere.ClientV2("COHERE_API_KEY") # Your Cohere API key

Preparing the Data and Inputs

With the classify endpoint, you can create a text classifier with as few as two examples per class, and each example must contain the text itself and the corresponding label (i.e. class). So, if you have two classes you need a minimum of four examples, if you have three classes you need a minimum of six examples, and so on.

Here are examples, created as ClassifyExample objects:

PYTHON
1examples = [
2 ClassifyExample(text="I’m so proud of you", label="positive"),
3 ClassifyExample(
4 text="What a great time to be alive", label="positive"
5 ),
6 ClassifyExample(text="That’s awesome work", label="positive"),
7 ClassifyExample(text="The service was amazing", label="positive"),
8 ClassifyExample(text="I love my family", label="positive"),
9 ClassifyExample(
10 text="They don't care about me", label="negative"
11 ),
12 ClassifyExample(text="I hate this place", label="negative"),
13 ClassifyExample(
14 text="The most ridiculous thing I've ever heard",
15 label="negative",
16 ),
17 ClassifyExample(text="I am really frustrated", label="negative"),
18 ClassifyExample(text="This is so unfair", label="negative"),
19 ClassifyExample(text="This made me think", label="neutral"),
20 ClassifyExample(text="The good old days", label="neutral"),
21 ClassifyExample(text="What's the difference", label="neutral"),
22 ClassifyExample(text="You can't ignore this", label="neutral"),
23 ClassifyExample(text="That's how I see it", label="neutral"),
24]

Besides the examples, you’ll also need the ‘inputs,’ which are the strings of text you want the classifier to sort. Here are the ones we’ll be using:

PYTHON
1inputs = [
2 "Hello, world! What a beautiful day",
3 "It was a great time with great people",
4 "Great place to work",
5 "That was a wonderful evening",
6 "Maybe this is why",
7 "Let's start again",
8 "That's how I see it",
9 "These are all facts",
10 "This is the worst thing",
11 "I cannot stand this any longer",
12 "This is really annoying",
13 "I am just plain fed up",
14]

Generate Predictions

Setting up the model is quite straightforward with the classify endpoint. We’ll use Cohere’s embed-english-v3.0 model, here’s what that looks like:

PYTHON
1def classify_text(inputs, examples):
2 """
3 Classifies a list of input texts given the examples
4 Arguments:
5 model (str): identifier of the model
6 inputs (list[str]): a list of input texts to be classified
7 examples (list[Example]): a list of example texts and class labels
8 Returns:
9 classifications (list): each result contains the text, labels, and conf values
10 """
11
12 # Classify text by calling the Classify endpoint
13 response = co.classify(
14 model="embed-english-v3.0", inputs=inputs, examples=examples
15 )
16
17 classifications = response.classifications
18
19 return classifications
20
21
22# Classify the inputs
23predictions = classify_text(inputs, examples)
24
25print(predictions)

Here’s a sample output returned (note that this output has been truncated to make it easier to read, you’ll get much more in return if you run the code yourself):

[ClassifyResponseClassificationsItem(id='9df6628d-57b2-414c-837e-c8a22f00d3db',
input='hello, world! what a beautiful day',
prediction='positive',
predictions=['positive'],
confidence=0.40137812,
confidences=[0.40137812],
labels={'negative': ClassifyResponseClassificationsItemLabelsValue(confidence=0.23582731),
'neutral': ClassifyResponseClassificationsItemLabelsValue(confidence=0.36279458),
'positive': ClassifyResponseClassificationsItemLabelsValue(confidence=0.40137812)},
classification_type='single-label'),
ClassifyResponseClassificationsItem(id='ce2c3b0b-ce98-4905-9ef5-fc83c6848fc5',
input='it was a great time with great people',
prediction='positive',
predictions=['positive'],
confidence=0.49054274,
confidences=[0.49054274],
labels={'negative': ClassifyResponseClassificationsItemLabelsValue(confidence=0.19989403),
'neutral': ClassifyResponseClassificationsItemLabelsValue(confidence=0.30956325),
'positive': ClassifyResponseClassificationsItemLabelsValue(confidence=0.49054274)},
classification_type='single-label')
....]

Most of this is pretty easy to understand, but there are a few things worth drawing attention to.

Besides returning the predicted class in the prediction field, the endpoint also returns the confidence value of the prediction, which varies between 0 (unconfident) and 1 (completely confident).

Also, these confidence values are split among the classes; since we’re using three, the confidence values for the “positive,” “negative,” and “neutral” classes must add up to a total of 1.

Under the hood, the classifier selects the class with the highest confidence value as the “predicted class.” A high confidence value for the predicted class therefore indicates that the model is very confident of its prediction, and vice versa.

What If I Need to Fine-Tune the classify endpoint?

Cohere has dedicated documentation on fine-tuning the classify endpoint for bespoke tasks. You can also read this blog post, which works out a detailed example.

Built with