CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/sagemaker/01_getting_started_pytorch/scripts/train.py
Views: 2555
1
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer
2
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
3
from datasets import load_from_disk
4
import random
5
import logging
6
import sys
7
import argparse
8
import os
9
import torch
10
11
if __name__ == "__main__":
12
13
parser = argparse.ArgumentParser()
14
15
# hyperparameters sent by the client are passed as command-line arguments to the script.
16
parser.add_argument("--epochs", type=int, default=3)
17
parser.add_argument("--train_batch_size", type=int, default=32)
18
parser.add_argument("--eval_batch_size", type=int, default=64)
19
parser.add_argument("--warmup_steps", type=int, default=500)
20
parser.add_argument("--model_name", type=str)
21
parser.add_argument("--learning_rate", type=str, default=5e-5)
22
23
# Data, model, and output directories
24
parser.add_argument("--output_data_dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])
25
parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])
26
parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"])
27
parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])
28
parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"])
29
30
args, _ = parser.parse_known_args()
31
32
# Set up logging
33
logger = logging.getLogger(__name__)
34
35
logging.basicConfig(
36
level=logging.getLevelName("INFO"),
37
handlers=[logging.StreamHandler(sys.stdout)],
38
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
39
)
40
41
# load datasets
42
train_dataset = load_from_disk(args.training_dir)
43
test_dataset = load_from_disk(args.test_dir)
44
45
logger.info(f" loaded train_dataset length is: {len(train_dataset)}")
46
logger.info(f" loaded test_dataset length is: {len(test_dataset)}")
47
48
# compute metrics function for binary classification
49
def compute_metrics(pred):
50
labels = pred.label_ids
51
preds = pred.predictions.argmax(-1)
52
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
53
acc = accuracy_score(labels, preds)
54
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
55
56
# download model from model hub
57
model = AutoModelForSequenceClassification.from_pretrained(args.model_name)
58
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
59
60
# define training args
61
training_args = TrainingArguments(
62
output_dir=args.model_dir,
63
num_train_epochs=args.epochs,
64
per_device_train_batch_size=args.train_batch_size,
65
per_device_eval_batch_size=args.eval_batch_size,
66
warmup_steps=args.warmup_steps,
67
evaluation_strategy="epoch",
68
logging_dir=f"{args.output_data_dir}/logs",
69
learning_rate=float(args.learning_rate),
70
)
71
72
# create Trainer instance
73
trainer = Trainer(
74
model=model,
75
args=training_args,
76
compute_metrics=compute_metrics,
77
train_dataset=train_dataset,
78
eval_dataset=test_dataset,
79
tokenizer=tokenizer,
80
)
81
82
# train model
83
trainer.train()
84
85
# evaluate model
86
eval_result = trainer.evaluate(eval_dataset=test_dataset)
87
88
# writes eval result to file which can be accessed later in s3 ouput
89
with open(os.path.join(args.output_data_dir, "eval_results.txt"), "w") as writer:
90
print(f"***** Eval results *****")
91
for key, value in sorted(eval_result.items()):
92
writer.write(f"{key} = {value}\n")
93
94
# Saves the model to s3
95
trainer.save_model(args.model_dir)
96
97