Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/sagemaker/32_train_deploy_embedding_models/scripts/run_mnr.py
Views: 2555
from dataclasses import dataclass, field1import os2from sentence_transformers import (3SentenceTransformerModelCardData,4SentenceTransformer,5SentenceTransformerTrainer,6SentenceTransformerTrainingArguments,7)8from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss9from sentence_transformers.training_args import BatchSamplers10from transformers import set_seed, HfArgumentParser111213from sentence_transformers.evaluation import (14InformationRetrievalEvaluator,15SequentialEvaluator,16)17from sentence_transformers.util import cos_sim18from datasets import load_dataset, concatenate_datasets192021@dataclass22class ScriptArguments:23train_dataset_path: str = field(24default="/opt/ml/input/data/train/",25metadata={"help": "Path to the dataset, e.g. /opt/ml/input/data/train/"},26)27test_dataset_path: str = field(28default="/opt/ml/input/data/test/",29metadata={"help": "Path to the dataset, e.g. /opt/ml/input/data/test/"},30)31model_id: str = field(32default=None, metadata={"help": "Model ID to use for Embedding training"}33)34num_train_epochs: int = field(35default=1, metadata={"help": "Number of training epochs"}36)37per_device_train_batch_size: int = field(38default=32, metadata={"help": "Training batch size"}39)40per_device_eval_batch_size: int = field(41default=16, metadata={"help": "Evaluation batch size"}42)43gradient_accumulation_steps: int = field(44default=16, metadata={"help": "Gradient accumulation steps"}45)46learning_rate: float = field(47default=2e-5, metadata={"help": "Learning rate for the optimizer"}48)495051def create_evaluator(52train_dataset, test_dataset, matryoshka_dimensions=[768, 512, 256, 128, 64]53):54corpus_dataset = concatenate_datasets([train_dataset, test_dataset])5556# Convert the datasets to dictionaries57corpus = dict(58zip(corpus_dataset["id"], corpus_dataset["positive"])59) # Our corpus (cid => document)60queries = dict(61zip(test_dataset["id"], test_dataset["anchor"])62) # Our queries (qid => question)6364# Create a mapping of relevant document (1 in our case) for each query65relevant_docs = {} # Query ID to relevant documents (qid => set([relevant_cids])66for q_id in queries:67relevant_docs[q_id] = [q_id]6869matryoshka_evaluators = []70# Iterate over the different dimensions71for dim in matryoshka_dimensions:72ir_evaluator = InformationRetrievalEvaluator(73queries=queries,74corpus=corpus,75relevant_docs=relevant_docs,76name=f"dim_{dim}",77truncate_dim=dim, # Truncate the embeddings to a certain dimension78score_functions={"cosine": cos_sim},79)80matryoshka_evaluators.append(ir_evaluator)8182# Create a sequential evaluator83return SequentialEvaluator(matryoshka_evaluators)848586def training_function(script_args):87################88# Dataset89################9091train_dataset = load_dataset(92"json",93data_files=os.path.join(script_args.train_dataset_path, "dataset.json"),94split="train",95)96test_dataset = load_dataset(97"json",98data_files=os.path.join(script_args.test_dataset_path, "dataset.json"),99split="train",100)101102###################103# Model & Evaluator104###################105106matryoshka_dimensions = [768, 512, 256, 128, 64] # Important: large to small107108model = SentenceTransformer(109script_args.model_id,110device="cuda",111model_kwargs={"attn_implementation": "sdpa"}, # needs Ampere GPU or newer112model_card_data=SentenceTransformerModelCardData(113language="en",114license="apache-2.0",115model_name="BGE base Financial Matryoshka",116),117)118evaluator = create_evaluator(119train_dataset, test_dataset, matryoshka_dimensions=matryoshka_dimensions120)121122###################123# Loss Function124###################125126# create Matryoshka loss function with MultipleNegativesRankingLoss127inner_train_loss = MultipleNegativesRankingLoss(model)128train_loss = MatryoshkaLoss(129model, inner_train_loss, matryoshka_dims=matryoshka_dimensions130)131132################133# Training134################135training_args = SentenceTransformerTrainingArguments(136output_dir="/opt/ml/model", # output directory for sagemaker to upload to s3137num_train_epochs=script_args.num_train_epochs, # number of epochs138per_device_train_batch_size=script_args.per_device_train_batch_size, # training batch size139per_device_eval_batch_size=script_args.per_device_eval_batch_size, # evaluation batch size140gradient_accumulation_steps=script_args.gradient_accumulation_steps, # gradient accumulation steps141warmup_ratio=0.1, # warmup ratio142learning_rate=script_args.learning_rate, # learning rate143lr_scheduler_type="cosine", # use constant learning rate scheduler144optim="adamw_torch_fused", # use fused adamw optimizer145tf32=True, # use tf32 precision146bf16=True, # use bf16 precision147batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch148eval_strategy="epoch", # evaluate after each epoch149save_strategy="epoch", # save after each epoch150logging_steps=10, # log every 10 steps151save_total_limit=3, # save only the last 3 models152load_best_model_at_end=True, # load the best model when training ends153metric_for_best_model="eval_dim_128_cosine_ndcg@10", # Optimizing for the best ndcg@10 score for the 128 dimension154)155156trainer = SentenceTransformerTrainer(157model=model, # bg-base-en-v1158args=training_args, # training arguments159train_dataset=train_dataset.select_columns(160["positive", "anchor"]161), # training dataset162loss=train_loss,163evaluator=evaluator,164)165166##########################167# Train model168##########################169# start training, the model will be automatically saved to the hub and the output directory170trainer.train()171172# save the best model173trainer.save_model()174175176if __name__ == "__main__":177parser = HfArgumentParser((ScriptArguments))178script_args = parser.parse_args_into_dataclasses()[0]179180# set seed181set_seed(42)182183# launch training184training_function(script_args)185186187