Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/utils/inference/core.py
Views: 792
from typing import List, Tuple, Callable, Any12import torch3import torch.nn.functional as F4import numpy as np5from tqdm import tqdm67from .faceshifter_run import faceshifter_batch8from .image_processing import crop_face, normalize_and_torch, normalize_and_torch_batch9from .video_processing import read_video, crop_frames_and_get_transforms, resize_frames101112def transform_target_to_torch(resized_frs: np.ndarray, half=True) -> torch.tensor:13"""14Transform target, so it could be used by model15"""16target_batch_rs = torch.from_numpy(resized_frs.copy()).cuda()17target_batch_rs = target_batch_rs[:, :, :, [2,1,0]]/255.1819if half:20target_batch_rs = target_batch_rs.half()2122target_batch_rs = (target_batch_rs - 0.5)/0.5 # normalize23target_batch_rs = target_batch_rs.permute(0, 3, 1, 2)2425return target_batch_rs262728def model_inference(full_frames: List[np.ndarray],29source: List,30target: List,31netArc: Callable,32G: Callable,33app: Callable,34set_target: bool,35similarity_th=0.15,36crop_size=224,37BS=60,38half=True) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:39"""40Using original frames get faceswaped frames and transofrmations41"""42# Get Arcface embeddings of target image43target_norm = normalize_and_torch_batch(np.array(target))44target_embeds = netArc(F.interpolate(target_norm, scale_factor=0.5, mode='bilinear', align_corners=True))4546# Get the cropped faces from original frames and transformations to get those crops47crop_frames_list, tfm_array_list = crop_frames_and_get_transforms(full_frames, target_embeds, app, netArc, crop_size, set_target, similarity_th=similarity_th)4849# Normalize source images and transform to torch and get Arcface embeddings50source_embeds = []51for source_curr in source:52source_curr = normalize_and_torch(source_curr)53source_embeds.append(netArc(F.interpolate(source_curr, scale_factor=0.5, mode='bilinear', align_corners=True)))5455final_frames_list = []56for idx, (crop_frames, tfm_array, source_embed) in enumerate(zip(crop_frames_list, tfm_array_list, source_embeds)):57# Resize croped frames and get vector which shows on which frames there were faces58resized_frs, present = resize_frames(crop_frames)59resized_frs = np.array(resized_frs)6061# transform embeds of Xs and target frames to use by model62target_batch_rs = transform_target_to_torch(resized_frs, half=half)6364if half:65source_embed = source_embed.half()6667# run model68size = target_batch_rs.shape[0]69model_output = []7071for i in tqdm(range(0, size, BS)):72Y_st = faceshifter_batch(source_embed, target_batch_rs[i:i+BS], G)73model_output.append(Y_st)74torch.cuda.empty_cache()75model_output = np.concatenate(model_output)7677# create list of final frames with transformed faces78final_frames = []79idx_fs = 08081for pres in tqdm(present):82if pres == 1:83final_frames.append(model_output[idx_fs])84idx_fs += 185else:86final_frames.append([])87final_frames_list.append(final_frames)8889return final_frames_list, crop_frames_list, full_frames, tfm_array_list9091