Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/sagemaker/06_sagemaker_metrics/scripts/train.py
Views: 2554
import argparse1import logging2import os3import random4import sys56from datasets import load_from_disk7from sklearn.metrics import accuracy_score, precision_recall_fscore_support8import torch9from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer101112if __name__ == "__main__":1314parser = argparse.ArgumentParser()1516# hyperparameters sent by the client are passed as command-line arguments to the script.17parser.add_argument("--epochs", type=int, default=3)18parser.add_argument("--train_batch_size", type=int, default=32)19parser.add_argument("--eval_batch_size", type=int, default=64)20parser.add_argument("--warmup_steps", type=int, default=500)21parser.add_argument("--model_name", type=str)22parser.add_argument("--learning_rate", type=float, default=5e-5)2324# Data, model, and output directories25parser.add_argument("--checkpoints", type=str, default="/opt/ml/checkpoints/")26parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])27parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"])28parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])29parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"])3031args, _ = parser.parse_known_args()3233# Set up logging34logger = logging.getLogger(__name__)3536logging.basicConfig(37level=logging.getLevelName("INFO"),38handlers=[logging.StreamHandler(sys.stdout)],39format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",40)4142# load datasets43train_dataset = load_from_disk(args.training_dir)44test_dataset = load_from_disk(args.test_dir)4546logger.info(f" loaded train_dataset length is: {len(train_dataset)}")47logger.info(f" loaded test_dataset length is: {len(test_dataset)}")4849# compute metrics function for binary classification50def compute_metrics(pred):51labels = pred.label_ids52preds = pred.predictions.argmax(-1)53precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")54acc = accuracy_score(labels, preds)55return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}5657# download model from model hub58model = AutoModelForSequenceClassification.from_pretrained(args.model_name)59tokenizer = AutoTokenizer.from_pretrained(args.model_name)6061# define training args62training_args = TrainingArguments(63output_dir=args.checkpoints,64num_train_epochs=args.epochs,65per_device_train_batch_size=args.train_batch_size,66per_device_eval_batch_size=args.eval_batch_size,67warmup_steps=args.warmup_steps,68evaluation_strategy="epoch",69logging_dir=f"{args.checkpoints}/logs",70learning_rate=args.learning_rate,71)7273# create Trainer instance74trainer = Trainer(75model=model,76args=training_args,77compute_metrics=compute_metrics,78train_dataset=train_dataset,79eval_dataset=test_dataset,80tokenizer=tokenizer,81)8283# train model84trainer.train()8586# evaluate model87eval_result = trainer.evaluate(eval_dataset=test_dataset)8889# writes eval result to file which can be accessed later in s3 ouput90with open(os.path.join(args.checkpoints, "eval_results.txt"), "w") as writer:91print(f"***** Eval results *****")92for key, value in sorted(eval_result.items()):93writer.write(f"{key} = {value}\n")9495# Saves the model locally. In SageMaker, writing in /opt/ml/model sends it to S396trainer.save_model(args.model_dir)979899