CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
huggingface

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: huggingface/notebooks
Path: blob/main/sagemaker/17_custom_inference_script/code/inference.py
Views: 2555
1
2
from transformers import AutoTokenizer, AutoModel
3
import torch
4
import torch.nn.functional as F
5
6
# Helper: Mean Pooling - Take attention mask into account for correct averaging
7
def mean_pooling(model_output, attention_mask):
8
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
9
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
10
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
11
12
13
def model_fn(model_dir):
14
# Load model from HuggingFace Hub
15
tokenizer = AutoTokenizer.from_pretrained(model_dir)
16
model = AutoModel.from_pretrained(model_dir)
17
return model, tokenizer
18
19
def predict_fn(data, model_and_tokenizer):
20
# destruct model and tokenizer
21
model, tokenizer = model_and_tokenizer
22
23
# Tokenize sentences
24
sentences = data.pop("inputs", data)
25
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
26
27
# Compute token embeddings
28
with torch.no_grad():
29
model_output = model(**encoded_input)
30
31
# Perform pooling
32
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
33
34
# Normalize embeddings
35
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
36
37
# return dictonary, which will be json serializable
38
return {"vectors": sentence_embeddings}
39
40