CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/sagemaker/06_sagemaker_metrics/scripts/train.py
Views: 2554
1
import argparse
2
import logging
3
import os
4
import random
5
import sys
6
7
from datasets import load_from_disk
8
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
9
import torch
10
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer
11
12
13
if __name__ == "__main__":
14
15
parser = argparse.ArgumentParser()
16
17
# hyperparameters sent by the client are passed as command-line arguments to the script.
18
parser.add_argument("--epochs", type=int, default=3)
19
parser.add_argument("--train_batch_size", type=int, default=32)
20
parser.add_argument("--eval_batch_size", type=int, default=64)
21
parser.add_argument("--warmup_steps", type=int, default=500)
22
parser.add_argument("--model_name", type=str)
23
parser.add_argument("--learning_rate", type=float, default=5e-5)
24
25
# Data, model, and output directories
26
parser.add_argument("--checkpoints", type=str, default="/opt/ml/checkpoints/")
27
parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])
28
parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"])
29
parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])
30
parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"])
31
32
args, _ = parser.parse_known_args()
33
34
# Set up logging
35
logger = logging.getLogger(__name__)
36
37
logging.basicConfig(
38
level=logging.getLevelName("INFO"),
39
handlers=[logging.StreamHandler(sys.stdout)],
40
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
41
)
42
43
# load datasets
44
train_dataset = load_from_disk(args.training_dir)
45
test_dataset = load_from_disk(args.test_dir)
46
47
logger.info(f" loaded train_dataset length is: {len(train_dataset)}")
48
logger.info(f" loaded test_dataset length is: {len(test_dataset)}")
49
50
# compute metrics function for binary classification
51
def compute_metrics(pred):
52
labels = pred.label_ids
53
preds = pred.predictions.argmax(-1)
54
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
55
acc = accuracy_score(labels, preds)
56
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
57
58
# download model from model hub
59
model = AutoModelForSequenceClassification.from_pretrained(args.model_name)
60
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
61
62
# define training args
63
training_args = TrainingArguments(
64
output_dir=args.checkpoints,
65
num_train_epochs=args.epochs,
66
per_device_train_batch_size=args.train_batch_size,
67
per_device_eval_batch_size=args.eval_batch_size,
68
warmup_steps=args.warmup_steps,
69
evaluation_strategy="epoch",
70
logging_dir=f"{args.checkpoints}/logs",
71
learning_rate=args.learning_rate,
72
)
73
74
# create Trainer instance
75
trainer = Trainer(
76
model=model,
77
args=training_args,
78
compute_metrics=compute_metrics,
79
train_dataset=train_dataset,
80
eval_dataset=test_dataset,
81
tokenizer=tokenizer,
82
)
83
84
# train model
85
trainer.train()
86
87
# evaluate model
88
eval_result = trainer.evaluate(eval_dataset=test_dataset)
89
90
# writes eval result to file which can be accessed later in s3 ouput
91
with open(os.path.join(args.checkpoints, "eval_results.txt"), "w") as writer:
92
print(f"***** Eval results *****")
93
for key, value in sorted(eval_result.items()):
94
writer.write(f"{key} = {value}\n")
95
96
# Saves the model locally. In SageMaker, writing in /opt/ml/model sends it to S3
97
trainer.save_model(args.model_dir)
98
99