Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/sagemaker/26_document_ai_donut/scripts/train.py
Views: 2555
import os1import argparse2from transformers import (3AutoModelForCausalLM,4AutoTokenizer,5set_seed,6default_data_collator,7)8from datasets import load_from_disk9import torch10from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DonutProcessor, VisionEncoderDecoderModel,VisionEncoderDecoderConfig11import shutil12import logging13import sys14import json1516def parse_arge():17"""Parse the arguments."""18parser = argparse.ArgumentParser()19# add model id and dataset path argument20parser.add_argument(21"--model_id",22type=str,23default="naver-clova-ix/donut-base",24help="Model id to use for training.",25)26parser.add_argument("--special_tokens", type=str, default=None, help="JSON string of special tokens to add to tokenizer.")27parser.add_argument("--dataset_path", type=str, default="lm_dataset", help="Path to dataset.")28# add training hyperparameters for epochs, batch size, learning rate, and seed29parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for.")30parser.add_argument(31"--per_device_train_batch_size",32type=int,33default=1,34help="Batch size to use for training.",35)36parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate to use for training.")37parser.add_argument("--seed", type=int, default=42, help="Seed to use for training.")38parser.add_argument(39"--gradient_checkpointing",40type=bool,41default=False,42help="Path to deepspeed config file.",43)44parser.add_argument(45"--bf16",46type=bool,47default=True if torch.cuda.get_device_capability()[0] == 8 else False,48help="Whether to use bf16.",49)50args = parser.parse_known_args()51return args525354def training_function(args):55# set seed56set_seed(args.seed)5758# Set up logging59logger = logging.getLogger(__name__)6061logging.basicConfig(62level=logging.getLevelName("INFO"),63handlers=[logging.StreamHandler(sys.stdout)],64format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",65)6667# load datasets68train_dataset = load_from_disk(args.dataset_path)69image_size = list(torch.tensor(train_dataset[0]["pixel_values"][0]).shape) # height, width70logger.info(f"loaded train_dataset length is: {len(train_dataset)}")7172# Load processor and set up new special tokens73processor = DonutProcessor.from_pretrained(args.model_id)74# add new special tokens to tokenizer and resize feature extractor75special_tokens = args.special_tokens.split(",")76processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})77processor.feature_extractor.size = image_size[::-1] # should be (width, height)78processor.feature_extractor.do_align_long_axis = False7980# Load model from huggingface.co81config = VisionEncoderDecoderConfig.from_pretrained(args.model_id, use_cache=False if args.gradient_checkpointing else True)82model = VisionEncoderDecoderModel.from_pretrained(args.model_id, config=config)8384# Resize embedding layer to match vocabulary size & adjust our image size and output sequence lengths85model.decoder.resize_token_embeddings(len(processor.tokenizer))86model.config.encoder.image_size = image_size87model.config.decoder.max_length = len(max(train_dataset["labels"], key=len))88# Add task token for decoder to start89model.config.pad_token_id = processor.tokenizer.pad_token_id90model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s>'])[0]919293# Arguments for training94output_dir = "/tmp"95training_args = Seq2SeqTrainingArguments(96output_dir=output_dir,97num_train_epochs=args.epochs,98learning_rate=args.lr,99per_device_train_batch_size=args.per_device_train_batch_size,100bf16=True,101tf32=True,102gradient_checkpointing=args.gradient_checkpointing,103logging_steps=10,104save_total_limit=1,105evaluation_strategy="no",106save_strategy="epoch",107)108109# Create Trainer110trainer = Seq2SeqTrainer(111model=model,112args=training_args,113train_dataset=train_dataset,114)115116# Start training117trainer.train()118119# save model and processor120trainer.model.save_pretrained("/opt/ml/model/")121processor.save_pretrained("/opt/ml/model/")122123# copy inference script124os.makedirs("/opt/ml/model/code", exist_ok=True)125shutil.copyfile(126os.path.join(os.path.dirname(__file__), "inference.py"),127"/opt/ml/model/code/inference.py",128)129130131def main():132args, _ = parse_arge()133training_function(args)134135136if __name__ == "__main__":137main()138139140