CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/sagemaker/09_image_classification_vision_transformer/scripts/train.py
Views: 2555
1
from transformers import ViTForImageClassification, Trainer, TrainingArguments,default_data_collator,ViTFeatureExtractor
2
from datasets import load_from_disk,load_metric
3
import random
4
import logging
5
import sys
6
import argparse
7
import os
8
import numpy as np
9
import subprocess
10
11
subprocess.run([
12
"git",
13
"config",
14
"--global",
15
"user.email",
16
"[email protected]",
17
], check=True)
18
subprocess.run([
19
"git",
20
"config",
21
"--global",
22
"user.name",
23
"sagemaker",
24
], check=True)
25
26
27
if __name__ == "__main__":
28
29
parser = argparse.ArgumentParser()
30
31
# hyperparameters sent by the client are passed as command-line arguments to the script.
32
parser.add_argument("--model_name", type=str)
33
parser.add_argument("--output_dir", type=str,default="/opt/ml/model")
34
parser.add_argument("--extra_model_name", type=str,default="sagemaker")
35
parser.add_argument("--dataset", type=str,default="cifar10")
36
parser.add_argument("--task", type=str,default="image-classification")
37
parser.add_argument("--use_auth_token", type=str, default="")
38
39
parser.add_argument("--num_train_epochs", type=int, default=3)
40
parser.add_argument("--per_device_train_batch_size", type=int, default=32)
41
parser.add_argument("--per_device_eval_batch_size", type=int, default=64)
42
parser.add_argument("--warmup_steps", type=int, default=500)
43
parser.add_argument("--weight_decay", type=float, default=0.01)
44
parser.add_argument("--learning_rate", type=str, default=2e-5)
45
46
parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])
47
parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"])
48
49
args, _ = parser.parse_known_args()
50
51
# Set up logging
52
logger = logging.getLogger(__name__)
53
54
logging.basicConfig(
55
level=logging.getLevelName("INFO"),
56
handlers=[logging.StreamHandler(sys.stdout)],
57
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
58
)
59
60
# load datasets
61
train_dataset = load_from_disk(args.training_dir)
62
test_dataset = load_from_disk(args.test_dir)
63
num_classes = train_dataset.features["label"].num_classes
64
65
66
logger.info(f" loaded train_dataset length is: {len(train_dataset)}")
67
logger.info(f" loaded test_dataset length is: {len(test_dataset)}")
68
69
metric_name = "accuracy"
70
# compute metrics function for binary classification
71
72
metric = load_metric(metric_name)
73
74
def compute_metrics(eval_pred):
75
predictions, labels = eval_pred
76
predictions = np.argmax(predictions, axis=1)
77
return metric.compute(predictions=predictions, references=labels)
78
79
# download model from model hub
80
model = ViTForImageClassification.from_pretrained(args.model_name,num_labels=num_classes)
81
82
# change labels
83
id2label = {key:train_dataset.features["label"].names[index] for index,key in enumerate(model.config.id2label.keys())}
84
label2id = {train_dataset.features["label"].names[index]:value for index,value in enumerate(model.config.label2id.values())}
85
model.config.id2label = id2label
86
model.config.label2id = label2id
87
88
89
# define training args
90
training_args = TrainingArguments(
91
output_dir=args.output_dir,
92
num_train_epochs=args.num_train_epochs,
93
per_device_train_batch_size=args.per_device_train_batch_size,
94
per_device_eval_batch_size=args.per_device_eval_batch_size,
95
warmup_steps=args.warmup_steps,
96
weight_decay=args.weight_decay,
97
evaluation_strategy="epoch",
98
logging_dir=f"{args.output_dir}/logs",
99
learning_rate=float(args.learning_rate),
100
load_best_model_at_end=True,
101
metric_for_best_model=metric_name,
102
)
103
104
105
# create Trainer instance
106
trainer = Trainer(
107
model=model,
108
args=training_args,
109
compute_metrics=compute_metrics,
110
train_dataset=train_dataset,
111
eval_dataset=test_dataset,
112
data_collator=default_data_collator,
113
)
114
115
# train model
116
trainer.train()
117
118
# evaluate model
119
eval_result = trainer.evaluate(eval_dataset=test_dataset)
120
121
# writes eval result to file which can be accessed later in s3 ouput
122
with open(os.path.join(args.output_dir, "eval_results.txt"), "w") as writer:
123
print(f"***** Eval results *****")
124
for key, value in sorted(eval_result.items()):
125
writer.write(f"{key} = {value}\n")
126
127
# Saves the model to s3
128
trainer.save_model(args.output_dir)
129
130
if args.use_auth_token != "":
131
kwargs = {
132
"finetuned_from": args.model_name.split("/")[1],
133
"tags": "image-classification",
134
"dataset": args.dataset,
135
}
136
repo_name = (
137
f"{args.model_name.split('/')[1]}-{args.task}"
138
if args.extra_model_name == ""
139
else f"{args.model_name.split('/')[1]}-{args.task}-{args.extra_model_name}"
140
)
141
142
trainer.push_to_hub(
143
repo_name=repo_name,
144
use_auth_token=args.use_auth_token,
145
**kwargs,
146
)
147
148