CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/idefics/finetune_image_captioning_peft.ipynb
Views: 2545
Kernel: Python 3 (ipykernel)

IDEFICS: A Flamingo-based model, trained at scale for the community

Finetuning Demo Notebook:

Idefics image

Credit: Flamingo blog

This google colab notebook shows how to run predictions with the 4-bit quantized 🤗 Idefics-9B model and finetune it on a specific dataset.

IDEFICS is a multi-modal model based on the Flamingo architecture. It can take images and texts as input and return text outputs but it does not support image generation. \ IDEFICS is built on top of two unimodal open-access pre-trained models to connect the two modalities. Newly initialized parameters in the form of Transformer blocks bridge the gap between the vision encoder and the language model. The model is trained on a mixture of image/text pairs and unstrucutred multimodal web documents. \ The finetuned versions of IDEFICS behave like LLM chatbots while also understanding visual input. \ You can play with the demo here

The code for this notebook was contributed to by Léo Tronchon, Younes Belkada, and Stas Bekman, the IDEFICS model has been contributed to by: Lucile Saulnier, Léo Tronchon, Hugo Laurençon, Stas Bekman, Amanpreet Singh, Siddharth Karamcheti, and Victor Sanh

Install and import necessary libraries

!pip install -q datasets !pip install -q git+https://github.com/huggingface/transformers.git !pip install -q bitsandbytes sentencepiece accelerate loralib !pip install -q -U git+https://github.com/huggingface/peft.git
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 519.3/519.3 kB 6.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.3/115.3 kB 6.7 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 194.1/194.1 kB 9.5 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 8.5 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 268.8/268.8 kB 10.1 MB/s eta 0:00:00 Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 17.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 31.5 MB/s eta 0:00:00 Building wheel for transformers (pyproject.toml) ... done ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 92.6/92.6 MB 11.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 67.7 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 244.2/244.2 kB 25.2 MB/s eta 0:00:00 Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Building wheel for peft (pyproject.toml) ... done
import torch from datasets import load_dataset from peft import LoraConfig, get_peft_model from PIL import Image from transformers import IdeficsForVisionText2Text, AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig import torchvision.transforms as transforms

Load quantized model

First get the quantized version of the model. This will allow us to use the 9B version of Idefics with a single 16GB gpu

device = "cuda" if torch.cuda.is_available() else "cpu" # checkpoint = "HuggingFaceM4/tiny-random-idefics" checkpoint = "HuggingFaceM4/idefics-9b" # Here we skip some special modules that can't be quantized properly bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, llm_int8_skip_modules=["lm_head", "embed_tokens"], ) processor = AutoProcessor.from_pretrained(checkpoint, use_auth_token=True) # Simply take-off the quantization_config arg if you want to load the original model model = IdeficsForVisionText2Text.from_pretrained(checkpoint, quantization_config=bnb_config, device_map="auto")
/usr/local/lib/python3.10/dist-packages/transformers/models/auto/processing_auto.py:203: FutureWarning: The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. warnings.warn(

If you print the model, you will see that all nn.Linear layers are in fact replaced by bnb.nn.Linear4bit layers.

print(model)
IdeficsForVisionText2Text( (model): IdeficsModel( (embed_tokens): IdeficsDecoupledEmbedding( num_embeddings=32000, num_additional_embeddings=2, embedding_dim=4096, partially_freeze=False (additional_embedding): Embedding(2, 4096) ) (vision_model): IdeficsVisionTransformer( (embeddings): IdeficsVisionEmbeddings( (patch_embedding): Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False) (position_embedding): Embedding(257, 1280) ) (pre_layrnorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (encoder): IdeficsVisionEncoder( (layers): ModuleList( (0-31): 32 x IdeficsVisionEncoderLayer( (self_attn): IdeficsVisionAttention( (k_proj): Linear4bit(in_features=1280, out_features=1280, bias=True) (v_proj): Linear4bit(in_features=1280, out_features=1280, bias=True) (q_proj): Linear4bit(in_features=1280, out_features=1280, bias=True) (out_proj): Linear4bit(in_features=1280, out_features=1280, bias=True) ) (layer_norm1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (mlp): IdeficsVisionMLP( (activation_fn): QuickGELUActivation() (fc1): Linear4bit(in_features=1280, out_features=5120, bias=True) (fc2): Linear4bit(in_features=5120, out_features=1280, bias=True) ) (layer_norm2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) ) ) ) (post_layernorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) ) (perceiver_resampler): IdeficsPerceiverResampler( (blocks): ModuleList( (0-5): 6 x ModuleList( (0): IdeficsPerceiverAttention( (context_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (latents_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (q_layer_norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True) (k_layer_norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True) (q_proj): Linear4bit(in_features=1280, out_features=1536, bias=False) (k_proj): Linear4bit(in_features=1280, out_features=1536, bias=False) (v_proj): Linear4bit(in_features=1280, out_features=1536, bias=False) (output_proj): Linear4bit(in_features=1536, out_features=1280, bias=False) ) (1): IdeficsMLP( (ln): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) (fc): Linear4bit(in_features=1280, out_features=5120, bias=False) (act): ReLU() (c_proj): Linear4bit(in_features=5120, out_features=1280, bias=False) ) ) ) (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True) ) (layers): ModuleList( (0-31): 32 x IdeficsDecoderLayer( (self_attn): IdeficsAttention( (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False) (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False) (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False) (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False) (rotary_emb): IdeficsEmbedding() ) (mlp): IdeficsMLP( (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False) (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False) (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False) (act_fn): SiLUActivation() ) (input_layernorm): IdeficsRMSNorm() (post_attention_layernorm): IdeficsRMSNorm() ) ) (gated_cross_attn_layers): ModuleList( (0-7): 8 x IdeficsGatedCrossAttentionLayer( (cross_attn): IdeficsAttention( (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False) (k_proj): Linear4bit(in_features=1280, out_features=4096, bias=False) (v_proj): Linear4bit(in_features=1280, out_features=4096, bias=False) (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False) (rotary_emb): IdeficsEmbedding() (q_layer_norm): IdeficsRMSNorm() (k_layer_norm): IdeficsRMSNorm() ) (mlp): IdeficsMLP( (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False) (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False) (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False) (act_fn): SiLUActivation() ) (input_layernorm): IdeficsRMSNorm() (post_attention_layernorm): IdeficsRMSNorm() (act_cross_attn): Tanh() (act_dense): Tanh() ) ) (norm): IdeficsRMSNorm() ) (lm_head): IdeficsDecoupledLinear( in_features=4096, out_features=32000, out_additional_features=2, bias=False, partially_freeze=False (additional_fc): Linear(in_features=4096, out_features=2, bias=False) ) )

Inference

Let's make a simple method to test the model's inference

def check_inference(model, processor, prompts, max_new_tokens=50): tokenizer = processor.tokenizer bad_words = ["<image>", "<fake_token_around_image>"] if len(bad_words) > 0: bad_words_ids = tokenizer(bad_words, add_special_tokens=False).input_ids eos_token = "</s>" eos_token_id = tokenizer.convert_tokens_to_ids(eos_token) inputs = processor(prompts, return_tensors="pt").to(device) generated_ids = model.generate(**inputs, eos_token_id=[eos_token_id], bad_words_ids=bad_words_ids, max_new_tokens=max_new_tokens, early_stopping=True) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print(generated_text)

Let's run prediction with the quantized model for the image below which pictures two kittens. \

url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg" prompts = [ # "Instruction: provide an answer to the question. Use the image to answer.\n", url, "Question: What's on the picture? Answer:", ] check_inference(model, processor, prompts, max_new_tokens=5)
Question: What's on the picture? Answer: Two kittens.

Now let's see how the model fares on pokemon knowledge before we try to finetune it further. \

# check generation before finetuning url = "https://images.pokemontcg.io/pop6/2_hires.png" prompts = [ url, "Question: What's on the picture? Answer:", ] check_inference(model, processor, prompts, max_new_tokens=100) # It looks like the model is already aware of pokemon - but it could be more specific, and less repetitive
Question: What's on the picture? Answer: Lucario Lucario is a Pokémon that is a combination of a bear and a lion. It is a Pokémon that is a combination of a bear and a lion. It is a Pokémon that is a combination of a bear and a lion. It is a Pokémon that is a combination of a bear and a lion. It is a Pokémon that is a combination of a bear and a lion. It is a Pok

Finetuning dataset

Prepare the dataset that will be used for finetuning

def convert_to_rgb(image): # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background # for transparent images. The call to `alpha_composite` handles this case if image.mode == "RGB": return image image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) alpha_composite = alpha_composite.convert("RGB") return alpha_composite def ds_transforms(example_batch): image_size = processor.image_processor.image_size image_mean = processor.image_processor.image_mean image_std = processor.image_processor.image_std image_transform = transforms.Compose([ convert_to_rgb, transforms.RandomResizedCrop((image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=image_mean, std=image_std), ]) prompts = [] for i in range(len(example_batch['caption'])): # We split the captions to avoid having very long examples, which would require more GPU ram during training caption = example_batch['caption'][i].split(".")[0] prompts.append( [ example_batch['image_url'][i], f"Question: What's on the picture? Answer: This is {example_batch['name'][i]}. {caption}</s>", ], ) inputs = processor(prompts, transform=image_transform, return_tensors="pt").to(device) inputs["labels"] = inputs["input_ids"] return inputs # load and prepare dataset ds = load_dataset("TheFusion21/PokemonCards") ds = ds["train"].train_test_split(test_size=0.002) train_ds = ds["train"] eval_ds = ds["test"] train_ds.set_transform(ds_transforms) eval_ds.set_transform(ds_transforms)

LoRA

After specifying the low-rank adapters (LoRA) config, we load the PeftModel using the get_peft_model utility function

model_name = checkpoint.split("/")[1] config = LoraConfig( r=16, lora_alpha=32, target_modules=["q_proj", "k_proj", "v_proj"], lora_dropout=0.05, bias="none", ) model = get_peft_model(model, config)
model.print_trainable_parameters()
trainable params: 19,750,912 || all params: 8,949,430,544 || trainable%: 0.2206946230030432

Training

Finally, using the Hugging Face Trainer, we can finetune the model!

For the sake of the demo, we have set the max_steps at 40. That's about 0.05 epoch on this dataset, so feel free to tune further!

It has been reported that fine-tuning in mixed precision fp16 can lead to overflows. As such, we recommend training in mixed precision bf16 when possible.

training_args = TrainingArguments( output_dir=f"{model_name}-pokemon", learning_rate=2e-4, fp16=True, per_device_train_batch_size=2, per_device_eval_batch_size=2, gradient_accumulation_steps=8, dataloader_pin_memory=False, save_total_limit=3, evaluation_strategy="steps", save_strategy="steps", save_steps=40, eval_steps=20, logging_steps=20, max_steps=40, remove_unused_columns=False, push_to_hub=False, label_names=["labels"], load_best_model_at_end=True, report_to=None, optim="paged_adamw_8bit", ) trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=eval_ds, ) trainer.train()
TrainOutput(global_step=40, training_loss=1.0759869813919067, metrics={'train_runtime': 403.1999, 'train_samples_per_second': 1.587, 'train_steps_per_second': 0.099, 'total_flos': 1445219210656320.0, 'train_loss': 1.0759869813919067, 'epoch': 0.05})
# check generation again after finetuning check_inference(model, processor, prompts, max_new_tokens=100)
Question: What's on the picture? Answer: This is Lucario. A Stage 2 Pokemon Card of type Fighting with the title Lucario and 90 HP of rarity Rare evolved from Pikachu from the set Neo Destiny and the flavor text: It can use its tail as a whip

Push your new model to the hub!

# Insert your "write" token. You should find it in the settings of your HF profile !huggingface-cli login
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out. Setting a new token will erase the existing one. To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens . Token: Add token as git credential? (Y/n) Y Token is valid (permission: write). Cannot authenticate through git-credential as no helper is defined on your machine. You might have to re-authenticate when pushing to the Hugging Face Hub. Run the following command in your terminal in case you want to set the 'store' credential helper as default. git config --global credential.helper store Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details. Token has not been saved to git credential helper. Your token has been saved to /root/.cache/huggingface/token Login successful
model.push_to_hub(f"{model_name}-pokemon", private=False)
CommitInfo(commit_url='https://huggingface.co/Leyo/idefics-9b-pokemon/commit/6e08354af8529c0c286d16a42a674ca28ce7f3ed', commit_message='Upload model', commit_description='', oid='6e08354af8529c0c286d16a42a674ca28ce7f3ed', pr_url=None, pr_revision=None, pr_num=None)