CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/examples/idefics/idefics_zero3_finetuning/idefics_zero3_finetuning.py
Views: 2555
1
"""
2
On one node, launch with `deepspeed --num_gpus N idefics_zero3_finetuning.py`
3
by replacing N with the number of your GPUs
4
5
For several nodes, using Slurm, a template script is provided at
6
`examples/idefics/idefics_zero3_finetuning/slurm_script_idefics_zero3_finetuning_multinode.slurm`
7
8
For more information, follow the tutorial on using DeepSpeed with Transformers at
9
https://huggingface.co/docs/transformers/main_classes/deepspeed
10
"""
11
12
import torch
13
import torchvision.transforms as transforms
14
from datasets import load_dataset
15
from PIL import Image
16
from transformers import AutoProcessor, IdeficsForVisionText2Text, Trainer, TrainingArguments
17
18
19
device = "cuda" if torch.cuda.is_available() else "cpu"
20
21
checkpoint = "HuggingFaceM4/idefics-9b"
22
23
processor = AutoProcessor.from_pretrained(checkpoint, use_auth_token=True)
24
25
26
def convert_to_rgb(image):
27
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
28
# for transparent images. The call to `alpha_composite` handles this case
29
if image.mode == "RGB":
30
return image
31
image_rgba = image.convert("RGBA")
32
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
33
alpha_composite = Image.alpha_composite(background, image_rgba)
34
alpha_composite = alpha_composite.convert("RGB")
35
return alpha_composite
36
37
38
def ds_transforms(example_batch):
39
image_size = processor.image_processor.image_size
40
image_mean = processor.image_processor.image_mean
41
image_std = processor.image_processor.image_std
42
image_transform = transforms.Compose(
43
[
44
convert_to_rgb,
45
transforms.RandomResizedCrop(
46
(image_size, image_size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC
47
),
48
transforms.ToTensor(),
49
transforms.Normalize(mean=image_mean, std=image_std),
50
]
51
)
52
prompts = []
53
for i in range(len(example_batch["caption"])):
54
# We split the captions to avoid having very long examples, which would require more GPU ram during training
55
caption = example_batch["caption"][i].split(".")[0]
56
try:
57
# There are a handful of images that are not hosted anymore. This is a small (dummy) hack to skip these
58
processor.image_processor.fetch_images(example_batch["image_url"][i])
59
except Exception:
60
print(
61
"Warning: at least one image couldn't be retrieved from the internet in an example. Skipping the"
62
" batch."
63
)
64
prompts.append(
65
[
66
example_batch["image_url"][i],
67
f"Question: What's on the picture? Answer: This is {example_batch['name'][i]}. {caption}</s>",
68
],
69
)
70
inputs = processor(prompts, transform=image_transform, return_tensors="pt").to(device)
71
inputs["labels"] = inputs["input_ids"]
72
return inputs
73
74
75
# load and prepare dataset
76
ds = load_dataset("TheFusion21/PokemonCards")
77
ds = ds["train"].train_test_split(test_size=0.002)
78
train_ds = ds["train"]
79
eval_ds = ds["test"]
80
train_ds.set_transform(ds_transforms)
81
eval_ds.set_transform(ds_transforms)
82
83
84
# Important, define the training_args before the model
85
ds_config = {
86
"communication_data_type": "fp32",
87
"bf16": {"enabled": True},
88
"zero_optimization": {
89
"stage": 3,
90
"overlap_comm": False,
91
"reduce_bucket_size": "auto",
92
"contiguous_gradients": True,
93
"stage3_gather_16bit_weights_on_model_save": False,
94
"stage3_prefetch_bucket_size": "auto",
95
"stage3_param_persistence_threshold": "auto",
96
"stage3_max_live_parameters": 2e9,
97
"stage3_max_reuse_distance": 2e9,
98
"offload_optimizer": {"device": "none"},
99
"offload_param": {"device": "none"},
100
},
101
"gradient_clipping": "auto",
102
"train_batch_size": "auto",
103
"train_micro_batch_size_per_gpu": "auto",
104
"steps_per_print": 2000000,
105
}
106
training_args = TrainingArguments(
107
output_dir="idefics-pokemon",
108
learning_rate=2e-4,
109
bf16=True,
110
per_device_train_batch_size=1,
111
per_device_eval_batch_size=1,
112
gradient_accumulation_steps=1,
113
# gradient_checkpointing=True, # Uncomment if OOM
114
dataloader_pin_memory=False,
115
save_total_limit=3,
116
evaluation_strategy="steps",
117
save_strategy="steps",
118
save_steps=40,
119
eval_steps=20,
120
logging_steps=20,
121
max_steps=40,
122
remove_unused_columns=False,
123
push_to_hub=False,
124
label_names=["labels"],
125
load_best_model_at_end=True,
126
report_to="none",
127
optim="adamw_torch",
128
deepspeed=ds_config,
129
)
130
131
model = IdeficsForVisionText2Text.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
132
133
trainer = Trainer(
134
model=model,
135
args=training_args,
136
train_dataset=train_ds,
137
eval_dataset=eval_ds,
138
)
139
140
result = trainer.train()
141
print(result) # Prints one per process - mostly here for sanity check
142
143