Classification Using Embeddings

In a previous chapter, you learned how to classify text using the Classify endpoint. However, there are more ways to classify text, and one of them is using embeddings! In this chapter you’ll learn how.

Colab Notebook

This chapter uses the same notebook as the previous chapter.

For the setup, please refer to the Setting Up chapter at the beginning of this module.

Using Embeddings to Classify Text

In the previous chapter, we looked at clustering, which is a task of grouping documents when the groups are not defined beforehand. Now, what if we already know the kinds of groups, or classes, that we want to group our dataset into?

While clustering is an “unsupervised learning algorithm” where we don’t know the number of classes and what they are, classification is a “supervised learning algorithm” where we do know them.

Text classification enables many possible applications, and one example is helping content moderators automatically flag toxic content on their platforms. Rather than having to manually go through every post and comment, they can have a system take in text, turn them into embeddings, and classify them based on the level of toxicity.

Another example is intent classification for customer support, where we build a system that takes in a customer inquiry and classifies the right intent so the inquiry can be routed to the right places.

And this happens to be what the dataset we’re using is all about. It contains a class feature called intent, and to demonstrate a classification task, we’ll bring it in this time. We’ll keep to the same 9 data points, but our task now is to predict the class of each data point out of three options—Airfare, Airline, and Ground Service (Note: the original dataset has more other classes but here we use just three for simplicity).

We’ll use sklearn to train a classifier with some training data. Implementation-wise, we take a set of training data to train a Support Vector Machine (SVM) model. If you’d like to learn more about this type of model, please check out this SVM video.

PYTHON
1# Train the classifier with Support Vector Machine (SVM) algorithm
2
3# import SVM classifier code
4from sklearn.svm import SVC
5from sklearn.pipeline import make_pipeline
6from sklearn.preprocessing import StandardScaler
7
8
9# Initialize the classifier
10svm_classifier = make_pipeline(StandardScaler(), SVC())
11
12# Prepare the training features and label
13features = df_train["query_embeds"].tolist()
14label = df_train["intent"]
15
16# Fit the support vector machine
17svm_classifier.fit(features, label)

Once that is done, we’ll take the embeddings of the 9 data points, put them through the trained model, and get the class predictions on the other side. And with this small test dataset, we get all predictions correct.

PYTHON
1# Predict with test data
2
3# Prepare the test inputs
4df_test = df_test.copy()
5inputs = df_test["query_embeds"].tolist()
6
7# Predict the labels
8df_test["intent_pred"] = svm_classifier.predict(inputs)
9
10# Compute the score
11score = svm_classifier.score(inputs, df_test["intent"])
12print(f"Prediction accuracy is {100*score}%")

Output: Prediction accuracy is 100.0%

Here we can see that all predictions match the actual classes:

two graphs where all predictions match the actual classes

Conclusion

As you can see, there are different ways to use the endpoints to do classification! As you learned before, you can use the Classify endpoint, but you can also use the Embed endpoint and train a simple classifier on the resulting data. Since embeddings capture context so well, the job of the classifier is vastly simplified, and this is why we could train a simple SVM.

Original Source

This material comes from the post Text Embeddings Visually Explained

Built with