Monday, September 2, 2024

How To Fine-Tune AI Model with Online Direct Preference Optimization Locally on Own Dataset

 This video is a step-by-step tutorial to use Online DPO to fine-tune a model locally on custom dataset. ODPO is a new alignment method from DeepMind to boost the performance of LLMs.



Code:

conda create -n dpo python=3.11 -y && conda activate dpo

pip install torch
pip install datasets dataclasses
pip install git+https://github.com/huggingface/transformers
pip install git+https://github.com/huggingface/accelerate

git clone https://github.com/huggingface/trl.git && cd trl

git checkout d57e4b726561e5ae58fdc335f34029052944a4a3

pip install -e .

conda install jupyter -y
pip uninstall charset_normalizer -y
pip install charset_normalizer
jupyter notebook

from datasets import Dataset
from trl import OnlineDPOConfig, OnlineDPOTrainer
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
NUM_DUMMY_SAMPLES = 100

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
# The reference model to calculate the KL divergence against
ref_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
# The model to score completions with. In practice, you will need a reward model.
reward_model = AutoModelForSequenceClassification.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct", num_labels=1)

train_dataset = Dataset.from_dict(
    {"prompt": ["Q: Hi how are you? A:"] * NUM_DUMMY_SAMPLES})
eval_dataset = Dataset.from_dict(
    {"prompt": ["Q: What do you like to eat A:"] * NUM_DUMMY_SAMPLES})

args = OnlineDPOConfig(output_dir="online-dpo-model")
trainer = OnlineDPOTrainer(
    model=model,
    ref_model=ref_model,
    reward_model=reward_model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

No comments: