Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/sagemaker/02_getting_started_tensorflow/scripts/train.py
Views: 2555
import argparse1import logging2import os3import sys45import tensorflow as tf6from datasets import load_dataset7from transformers import AutoTokenizer, TFAutoModelForSequenceClassification, DataCollatorWithPadding, create_optimizer8910if __name__ == "__main__":1112parser = argparse.ArgumentParser()1314# Hyperparameters sent by the client are passed as command-line arguments to the script.15parser.add_argument("--epochs", type=int, default=3)16parser.add_argument("--train_batch_size", type=int, default=16)17parser.add_argument("--eval_batch_size", type=int, default=8)18parser.add_argument("--model_id", type=str)19parser.add_argument("--learning_rate", type=str, default=3e-5)2021# Data, model, and output directories22parser.add_argument("--output_data_dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])23parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])24parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"])2526args, _ = parser.parse_known_args()2728# Set up logging29logger = logging.getLogger(__name__)3031logging.basicConfig(32level=logging.getLevelName("INFO"),33handlers=[logging.StreamHandler(sys.stdout)],34format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",35)3637# Load tokenizer38tokenizer = AutoTokenizer.from_pretrained(args.model_id)3940# Load DatasetDict41dataset = load_dataset("imdb")4243# Preprocess train dataset44def preprocess_function(examples):45return tokenizer(examples["text"], truncation=True)4647encoded_dataset = dataset.map(preprocess_function, batched=True)4849# define tokenizer_columns50# tokenizer_columns is the list of keys from the dataset that get passed to the TensorFlow model51tokenizer_columns = ["attention_mask", "input_ids"]5253# convert to TF datasets54data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")55encoded_dataset["train"] = encoded_dataset["train"].rename_column("label", "labels")56tf_train_dataset = encoded_dataset["train"].to_tf_dataset(57columns=tokenizer_columns,58label_cols=["labels"],59shuffle=True,60batch_size=8,61collate_fn=data_collator,62)63encoded_dataset["test"] = encoded_dataset["test"].rename_column("label", "labels")64tf_validation_dataset = encoded_dataset["test"].to_tf_dataset(65columns=tokenizer_columns,66label_cols=["labels"],67shuffle=False,68batch_size=8,69collate_fn=data_collator,70)7172# Prepare model labels - useful in inference API73labels = encoded_dataset["train"].features["labels"].names74num_labels = len(labels)75label2id, id2label = dict(), dict()76for i, label in enumerate(labels):77label2id[label] = str(i)78id2label[str(i)] = label7980# download model from model hub81model = TFAutoModelForSequenceClassification.from_pretrained(82args.model_id, num_labels=num_labels, label2id=label2id, id2label=id2label83)8485# create Adam optimizer with learning rate scheduling86batches_per_epoch = len(encoded_dataset["train"]) // args.train_batch_size87total_train_steps = int(batches_per_epoch * args.epochs)8889optimizer, _ = create_optimizer(init_lr=args.learning_rate, num_warmup_steps=0, num_train_steps=total_train_steps)90loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)9192# define metric and compile model93metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]94model.compile(optimizer=optimizer, loss=loss, metrics=metrics)9596# Training97logger.info("*** Train ***")98train_results = model.fit(99tf_train_dataset,100epochs=args.epochs,101validation_data=tf_validation_dataset,102)103104output_eval_file = os.path.join(args.output_data_dir, "train_results.txt")105106with open(output_eval_file, "w") as writer:107logger.info("***** Train results *****")108logger.info(train_results)109for key, value in train_results.history.items():110logger.info(" %s = %s", key, value)111writer.write("%s = %s\n" % (key, value))112113# Save result114model.save_pretrained(args.model_dir)115tokenizer.save_pretrained(args.model_dir)116117118