CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/utils/inference/image_processing.py
Views: 792
1
import base64
2
from io import BytesIO
3
from typing import Callable, List
4
5
import numpy as np
6
import torch
7
import cv2
8
from .masks import face_mask_static
9
from matplotlib import pyplot as plt
10
from insightface.utils import face_align
11
12
13
def crop_face(image_full: np.ndarray, app: Callable, crop_size: int) -> np.ndarray:
14
"""
15
Crop face from image and resize
16
"""
17
kps = app.get(image_full, crop_size)
18
M, _ = face_align.estimate_norm(kps[0], crop_size, mode ='None')
19
align_img = cv2.warpAffine(image_full, M, (crop_size, crop_size), borderValue=0.0)
20
return [align_img]
21
22
23
def normalize_and_torch(image: np.ndarray) -> torch.tensor:
24
"""
25
Normalize image and transform to torch
26
"""
27
image = torch.tensor(image.copy(), dtype=torch.float32).cuda()
28
if image.max() > 1.:
29
image = image/255.
30
31
image = image.permute(2, 0, 1).unsqueeze(0)
32
image = (image - 0.5) / 0.5
33
34
return image
35
36
37
def normalize_and_torch_batch(frames: np.ndarray) -> torch.tensor:
38
"""
39
Normalize batch images and transform to torch
40
"""
41
batch_frames = torch.from_numpy(frames.copy()).cuda()
42
if batch_frames.max() > 1.:
43
batch_frames = batch_frames/255.
44
45
batch_frames = batch_frames.permute(0, 3, 1, 2)
46
batch_frames = (batch_frames - 0.5)/0.5
47
48
return batch_frames
49
50
51
def get_final_image(final_frames: List[np.ndarray],
52
crop_frames: List[np.ndarray],
53
full_frame: np.ndarray,
54
tfm_arrays: List[np.ndarray],
55
handler) -> None:
56
"""
57
Create final video from frames
58
"""
59
final = full_frame.copy()
60
params = [None for i in range(len(final_frames))]
61
62
for i in range(len(final_frames)):
63
frame = cv2.resize(final_frames[i][0], (224, 224))
64
65
landmarks = handler.get_without_detection_without_transform(frame)
66
landmarks_tgt = handler.get_without_detection_without_transform(crop_frames[i][0])
67
68
mask, _ = face_mask_static(crop_frames[i][0], landmarks, landmarks_tgt, params[i])
69
mat_rev = cv2.invertAffineTransform(tfm_arrays[i][0])
70
71
swap_t = cv2.warpAffine(frame, mat_rev, (full_frame.shape[1], full_frame.shape[0]), borderMode=cv2.BORDER_REPLICATE)
72
mask_t = cv2.warpAffine(mask, mat_rev, (full_frame.shape[1], full_frame.shape[0]))
73
mask_t = np.expand_dims(mask_t, 2)
74
75
final = mask_t*swap_t + (1-mask_t)*final
76
final = np.array(final, dtype='uint8')
77
return final
78
79
80
def show_images(images: List[np.ndarray],
81
titles=None,
82
figsize=(20, 5),
83
fontsize=15):
84
if titles:
85
assert len(titles) == len(images), "Amount of images should be the same as the amount of titles"
86
87
fig, axes = plt.subplots(1, len(images), figsize=figsize)
88
for idx, (ax, image) in enumerate(zip(axes, images)):
89
ax.imshow(image[:, :, ::-1])
90
if titles:
91
ax.set_title(titles[idx], fontsize=fontsize)
92
ax.axis("off")
93
94