CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/course/en/chapter13/grpo_finetune.ipynb
Views: 2935
Kernel: Python 3

Finetune LLMs with GRPO

This notebook shows how to finetune an LLM with GRPO, using the trl library.

It's by Ben Burtenshaw and Maxime Labonne.

This is a minimal example. For a complete example, refer to the GRPO chapter in the course.

Install dependencies

!pip install -qqq datasets==3.2.0 transformers==4.47.1 trl==0.14.0 peft==0.14.0 accelerate==1.2.1 bitsandbytes==0.45.2 wandb==0.19.7 --progress-bar off !pip install -qqq flash-attn --no-build-isolation --progress-bar off

Load Dataset

import torch import wandb from datasets import load_dataset from peft import LoraConfig, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer from trl import GRPOConfig, GRPOTrainer # Log to Weights & Biases wandb.login() # Load dataset dataset = load_dataset("mlabonne/smoltldr") print(dataset)

Load Model

# Load model model_id = "HuggingFaceTB/SmolLM-135M-Instruct" model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype="auto", device_map="auto", attn_implementation="flash_attention_2", ) tokenizer = AutoTokenizer.from_pretrained(model_id) # Load LoRA lora_config = LoraConfig( task_type="CAUSAL_LM", r=16, lora_alpha=32, target_modules="all-linear", ) model = get_peft_model(model, lora_config) print(model.print_trainable_parameters())

Define Reward Function

# Reward function def reward_len(completions, **kwargs): return [-abs(50 - len(completion)) for completion in completions]

Define Training Arguments

# Training arguments training_args = GRPOConfig( output_dir="GRPO", learning_rate=2e-5, per_device_train_batch_size=8, gradient_accumulation_steps=2, max_prompt_length=512, max_completion_length=96, num_generations=8, optim="adamw_8bit", num_train_epochs=1, bf16=True, report_to=["wandb"], remove_unused_columns=False, logging_steps=1, ) # Trainer trainer = GRPOTrainer( model=model, reward_funcs=[reward_len], args=training_args, train_dataset=dataset["train"], ) # Train model wandb.init(project="GRPO") trainer.train()

Push Model to Hub

# Save model merged_model = trainer.model.merge_and_unload() merged_model.push_to_hub("<your-model-id>", private=False)

Generate Text

prompt = """ # A long document about the Cat The cat (Felis catus), also referred to as the domestic cat or house cat, is a small domesticated carnivorous mammal. It is the only domesticated species of the family Felidae. Advances in archaeology and genetics have shown that the domestication of the cat occurred in the Near East around 7500 BC. It is commonly kept as a pet and farm cat, but also ranges freely as a feral cat avoiding human contact. It is valued by humans for companionship and its ability to kill vermin. Its retractable claws are adapted to killing small prey species such as mice and rats. It has a strong, flexible body, quick reflexes, and sharp teeth, and its night vision and sense of smell are well developed. It is a social species, but a solitary hunter and a crepuscular predator. Cat communication includes vocalizations—including meowing, purring, trilling, hissing, growling, and grunting—as well as body language. It can hear sounds too faint or too high in frequency for human ears, such as those made by small mammals. It secretes and perceives pheromones. """ messages = [ {"role": "user", "content": prompt}, ]
# Generate text from transformers import pipeline generator = pipeline("text-generation", model="<your-model-id>") ## Or use the model and tokenizer we defined earlier # generator = pipeline("text-generation", model=model, tokenizer=tokenizer) generate_kwargs = { "max_new_tokens": 256, "do_sample": True, "temperature": 0.5, "min_p": 0.1, } generated_text = generator(messages, generate_kwargs=generate_kwargs) print(generated_text)