CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/utils/inference/core.py
Views: 792
1
from typing import List, Tuple, Callable, Any
2
3
import torch
4
import torch.nn.functional as F
5
import numpy as np
6
from tqdm import tqdm
7
8
from .faceshifter_run import faceshifter_batch
9
from .image_processing import crop_face, normalize_and_torch, normalize_and_torch_batch
10
from .video_processing import read_video, crop_frames_and_get_transforms, resize_frames
11
12
13
def transform_target_to_torch(resized_frs: np.ndarray, half=True) -> torch.tensor:
14
"""
15
Transform target, so it could be used by model
16
"""
17
target_batch_rs = torch.from_numpy(resized_frs.copy()).cuda()
18
target_batch_rs = target_batch_rs[:, :, :, [2,1,0]]/255.
19
20
if half:
21
target_batch_rs = target_batch_rs.half()
22
23
target_batch_rs = (target_batch_rs - 0.5)/0.5 # normalize
24
target_batch_rs = target_batch_rs.permute(0, 3, 1, 2)
25
26
return target_batch_rs
27
28
29
def model_inference(full_frames: List[np.ndarray],
30
source: List,
31
target: List,
32
netArc: Callable,
33
G: Callable,
34
app: Callable,
35
set_target: bool,
36
similarity_th=0.15,
37
crop_size=224,
38
BS=60,
39
half=True) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
40
"""
41
Using original frames get faceswaped frames and transofrmations
42
"""
43
# Get Arcface embeddings of target image
44
target_norm = normalize_and_torch_batch(np.array(target))
45
target_embeds = netArc(F.interpolate(target_norm, scale_factor=0.5, mode='bilinear', align_corners=True))
46
47
# Get the cropped faces from original frames and transformations to get those crops
48
crop_frames_list, tfm_array_list = crop_frames_and_get_transforms(full_frames, target_embeds, app, netArc, crop_size, set_target, similarity_th=similarity_th)
49
50
# Normalize source images and transform to torch and get Arcface embeddings
51
source_embeds = []
52
for source_curr in source:
53
source_curr = normalize_and_torch(source_curr)
54
source_embeds.append(netArc(F.interpolate(source_curr, scale_factor=0.5, mode='bilinear', align_corners=True)))
55
56
final_frames_list = []
57
for idx, (crop_frames, tfm_array, source_embed) in enumerate(zip(crop_frames_list, tfm_array_list, source_embeds)):
58
# Resize croped frames and get vector which shows on which frames there were faces
59
resized_frs, present = resize_frames(crop_frames)
60
resized_frs = np.array(resized_frs)
61
62
# transform embeds of Xs and target frames to use by model
63
target_batch_rs = transform_target_to_torch(resized_frs, half=half)
64
65
if half:
66
source_embed = source_embed.half()
67
68
# run model
69
size = target_batch_rs.shape[0]
70
model_output = []
71
72
for i in tqdm(range(0, size, BS)):
73
Y_st = faceshifter_batch(source_embed, target_batch_rs[i:i+BS], G)
74
model_output.append(Y_st)
75
torch.cuda.empty_cache()
76
model_output = np.concatenate(model_output)
77
78
# create list of final frames with transformed faces
79
final_frames = []
80
idx_fs = 0
81
82
for pres in tqdm(present):
83
if pres == 1:
84
final_frames.append(model_output[idx_fs])
85
idx_fs += 1
86
else:
87
final_frames.append([])
88
final_frames_list.append(final_frames)
89
90
return final_frames_list, crop_frames_list, full_frames, tfm_array_list
91