Finetuning on Cohere’s Platform
Overview
Cohere chat models (Comand R and Command R+) are fantastic generally capable models out of the box. To further adopt our models for specific tasks, there are several strategies like prompt engineering, RAG, tool use, finetuning. In this cookbook, we will focus on finetuning. While the other strategies involve careful and intelligent orchestration of our models out of the box, finetuning involves modifying the weights to specialize the model for a task at hand. This requires careful investment of time and resources from data collection to model training. This is typically employed when all other strategies fall short.
Our finetuning service allows customization of our latest Command R model (command-r-08-2024) with LoRA based finetuning which gives users the the ability to control model flexibility depending on their task. Additionally, we extended the training context length to 16384 tokens giving users the ability to user longer training data points which is typical for RAG, agents, and tool use. In this cookbook, we will showcase model customization via our Finetuning API and also show you how to monitor loss functions for your finetuning jobs using the Weights & Biases integration. Please note that you can do the same via the UI. You can find a detailed guide for that here.
We will finetine our Command R model on the task of conversational financial question answering. Specifically, we finetune our model on ConvFinQA dataset. In this task, the output expected from the model is a domain specific language (DSL) that we will potentially feed into a downstream application. LLMs are known to be bad at arithmetics. Hence, instead of computing the answer, the task here is to extract the right numbers from the context and applying the right sequence of predicates and to strictly follow the DSL to ensure minimal error rates in the downstream application that may consume the DSL output from our model. Prompt engineering proves to be rather brittle for such tasks as it is hard to make sure the model follows the exact syntax of the DSL. Finetuning the model gives that guarantee.
Setup
Dependencies
If you dont already have Cohere Python SDK, you can install it as follows.
Dataset
ConvFinQA dataset is a conversational dataset comprising of multi-turn numerical question and answers based on a given financial report which includes text and tables. We process the original dataset to do a few things:
- We preprocess the financial reports to combine various fields in the original dataset to create a single text blurb from which the questions are to be answered. This involves concatenating various pieces of text, converting the tables to simple text with heuristic regex mappings, among other cosmetic things.
- For finetuning Command R models, the dataset needs to be a
jsonl
file, where eachjson
object is a conversation. Each conversation has a list of messages, and each message has two properties: role and content. The role identifies the sender (Chatbot, System, or User), while the content contains the text content. You can find more detailed guide on preparing the dataset including the data validations we have, train/eval splits we recommend, etc. here. We format the conversations in the original dataset to conform to these requirements.
ConvFinQA data example
Following is an example datapoint from the finetuning data.
As you can see, the financial report based on which we answer the questions is put in as System role. This acts as the ‘system prompt’, which is part of the prompt used as context/instructions for the entire conversation. Since the information in the report is required and relevant to every user question in the conversation, we would want to put it as the overall context of the conversation.
Few things to note in the above example:
- Models trained via our finetuning API do not have any additional/default preamble other than the system prompt provided in the finetuning dataset.
- Each datapoint has multiple turns alternating between
User
andChatbot
; during finetuning, we consume messages from all roles but only theChatbot
messages contribute to the model updtates. - We want the model to learn to strictly follow the domain specific language as represented by the desired Chatbot responses in this example.
Upload the dataset
We use the Datasets API to upload the dataset required for finetuning. Note that we upload both training and evaluation files. The data in evaluation file is used for validation and early stopping as we will elaborate later.
Whenever a dataset is created, the data is validated asynchronously. This validation is kicked off automatically on the backend, and must be completed before we can use this dataset for finetuning. You can find more info on interpreting the errors, if you get any, here.
Start finetuning
Once the dataset is validated, we can start a finetuning job with the Finetuning API.
Hyperparameters
There are several hyperparameters that you can modify to get the most out of your finetuning, including LoRA-specific params. You can find detailed explanation here.
WandB integration
For chat finetuning, we support WandB integration which allows you to monitor the loss curves of finetuning jobs in real-time without having to wait for the job to finish. You can find more info here.
Create the finetuning job
With the dataset, hyperparameters, and the wandb configurations ready, we can create a fientuning job as follows. You can find the details of all params in the Finetuning API documentation.
Check finetuning status
Once the finetuning job finishes and the finetuned model is ready to use, you will get notified via email. Or you can check of the status of your finetuning job as follows.
You may view the fine-tuning job loss curves via the Weights and Biases dashboard. It will be available via the following URL once the training starts: https://wandb.ai/<your-entity>/<your-project>/runs/<finetuned-model-id>
. We log the following to WandB:
- training loss at every training step
- validation loss and accuracy (as described here) at every validation step
For this particular fientuning job, the traning loss, validation loss and validation accuracy should look as follows.
Once the training job finished you can also check the validation metrics as follows.
Run inference with the finetuned model
Once your model completes training, you can call it via co.chat() and pass your custom model id. Please note, the model id is the id returned by the fine-tuned model object + -ft
suffix. co.chat()
uses no preamble by default for fine-tuned models. You can specify a preamble using the preamble parameter, if you like. In this case, we wont specify any preamble and follow the convention set in the training data.
The response object is described in detail here. As you can see, the finetuned model responds in the DSL as expected and matches the ground truth. This DSL response can now be consumed by any downstream application or engine that can compute the final answer. For a comparison we show the base model response to the same inputs.
As you can see, the base model is pretty good in itself. The final answer is correct, in this particular instance (189 + 68 = 257). However, this model response needs further processing to extract the final answer. This post processing can be a noisy process. Also, please note that the LLM’s ability for complex numerical reasoning is not very reliable. For these reasons, finetuning it to output DSL is a much more reliable and interpretable way to arrive at the final answer.