Path: blob/main/sagemaker/17_custom_inference_script/code/inference.py
4666 views
1from transformers import AutoTokenizer, AutoModel2import torch3import torch.nn.functional as F45# Helper: Mean Pooling - Take attention mask into account for correct averaging6def mean_pooling(model_output, attention_mask):7token_embeddings = model_output[0] #First element of model_output contains all token embeddings8input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()9return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)101112def model_fn(model_dir):13# Load model from HuggingFace Hub14tokenizer = AutoTokenizer.from_pretrained(model_dir)15model = AutoModel.from_pretrained(model_dir)16return model, tokenizer1718def predict_fn(data, model_and_tokenizer):19# destruct model and tokenizer20model, tokenizer = model_and_tokenizer2122# Tokenize sentences23sentences = data.pop("inputs", data)24encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')2526# Compute token embeddings27with torch.no_grad():28model_output = model(**encoded_input)2930# Perform pooling31sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])3233# Normalize embeddings34sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)3536# return dictonary, which will be json serializable37return {"vectors": sentence_embeddings}383940