CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/utils/inference/video_processing.py
Views: 792
1
import cv2
2
import numpy as np
3
import os
4
from PIL import Image
5
from typing import List, Tuple, Callable, Any
6
from tqdm import tqdm
7
import traceback
8
from insightface.utils import face_align
9
from scipy.spatial import distance
10
11
from .masks import face_mask_static
12
from .image_processing import normalize_and_torch, normalize_and_torch_batch, crop_face
13
14
import torch
15
import torch.nn.functional as F
16
from torch.utils.data import DataLoader, Dataset
17
import torchvision.transforms as transforms
18
import kornia
19
20
21
def add_audio_from_another_video(video_with_sound: str,
22
video_without_sound: str,
23
audio_name: str,
24
fast_cpu=True,
25
gpu=False) -> None:
26
27
if not os.path.exists('./examples/audio/'):
28
os.makedirs('./examples/audio/')
29
fast_cmd = "-c:v libx264 -preset ultrafast -crf 18" if fast_cpu else ""
30
gpu_cmd = "-c:v h264_nvenc" if gpu else ""
31
os.system(f"ffmpeg -v -8 -i {video_with_sound} -vn -vcodec h264_nvenc ./examples/audio/{audio_name}.m4a")
32
os.system(f"ffmpeg -v -8 -i {video_without_sound} -i ./examples/audio/{audio_name}.m4a {fast_cmd} {gpu_cmd}{video_without_sound[:-4]}_audio.mp4 -y")
33
os.system(f"rm -rf ./examples/audio/{audio_name}.m4a")
34
os.system(f"mv {video_without_sound[:-4]}_audio.mp4 {video_without_sound}")
35
36
37
def read_video(path_to_video: str) -> Tuple[List[np.ndarray], float]:
38
"""
39
Read video by frames using its path
40
"""
41
42
# load video
43
cap = cv2.VideoCapture(path_to_video)
44
45
width_original, height_original = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) #
46
fps, frames = cap.get(cv2.CAP_PROP_FPS), cap.get(cv2.CAP_PROP_FRAME_COUNT)
47
48
full_frames = []
49
i = 0 # current frame
50
51
while(cap.isOpened()):
52
if i == frames:
53
break
54
55
ret, frame = cap.read()
56
57
i += 1
58
if ret==True:
59
full_frames.append(frame)
60
p = i * 100 / frames
61
else:
62
break
63
64
cap.release()
65
66
return full_frames, fps
67
68
69
def get_target(full_frames: List[np.ndarray],
70
app: Callable,
71
crop_size: int):
72
i = 0
73
target = None
74
while target is None:
75
if i < len(full_frames):
76
try:
77
target = [crop_face(full_frames[i], app, crop_size)[0]]
78
except TypeError:
79
i += 1
80
else:
81
print("Video doesn't contain face!")
82
break
83
return target
84
85
86
def smooth_landmarks(kps_arr, n = 2):
87
kps_arr_smooth_final = []
88
for ka in kps_arr:
89
kps_arr_s = [[ka[0]]]
90
for i in range(1, len(ka)):
91
if (len(ka[i])==0) or (len(ka[i-1])==0):
92
kps_arr_s.append([ka[i]])
93
elif (distance.euclidean(ka[i][0], ka[i-1][0]) > 5) or (distance.euclidean(ka[i][2], ka[i-1][2]) > 5):
94
kps_arr_s.append([ka[i]])
95
else:
96
kps_arr_s[-1].append(ka[i])
97
98
kps_arr_smooth = []
99
100
for a in kps_arr_s:
101
a_smooth = []
102
for i in range(len(a)):
103
q = min(i-0, len(a)-i-1, n)
104
a_smooth.append(np.mean( np.array(a[i-q:i+1+q]), axis=0 ) )
105
106
kps_arr_smooth += a_smooth
107
kps_arr_smooth_final.append(kps_arr_smooth)
108
return kps_arr_smooth_final
109
110
111
def crop_frames_and_get_transforms(full_frames: List[np.ndarray],
112
target_embeds: List,
113
app: Callable,
114
netArc: Callable,
115
crop_size: int,
116
set_target: bool,
117
similarity_th: float) -> Tuple[List[Any], List[Any]]:
118
"""
119
Crop faces from frames and get respective tranforms
120
"""
121
122
crop_frames = [ [] for _ in range(target_embeds.shape[0]) ]
123
tfm_array = [ [] for _ in range(target_embeds.shape[0]) ]
124
kps_array = [ [] for _ in range(target_embeds.shape[0]) ]
125
126
target_embeds = F.normalize(target_embeds)
127
for frame in tqdm(full_frames):
128
try:
129
kps = app.get(frame, crop_size)
130
if len(kps) > 1 or set_target:
131
faces = []
132
for p in kps:
133
M, _ = face_align.estimate_norm(p, crop_size, mode ='None')
134
align_img = cv2.warpAffine(frame, M, (crop_size, crop_size), borderValue=0.0)
135
faces.append(align_img)
136
137
face_norm = normalize_and_torch_batch(np.array(faces))
138
face_norm = F.interpolate(face_norm, scale_factor=0.5, mode='bilinear', align_corners=True)
139
face_embeds = netArc(face_norm)
140
face_embeds = F.normalize(face_embeds)
141
142
similarity = face_embeds@target_embeds.T
143
best_idxs = similarity.argmax(0).detach().cpu().numpy()
144
for idx, best_idx in enumerate(best_idxs):
145
if similarity[best_idx][idx] > similarity_th:
146
kps_array[idx].append(kps[best_idx])
147
else:
148
kps_array[idx].append([])
149
150
else:
151
kps_array[0].append(kps[0])
152
153
except TypeError:
154
for q in range (len(target_embeds)):
155
kps_array[0].append([])
156
157
smooth_kps = smooth_landmarks(kps_array, n = 2)
158
159
for i, frame in tqdm(enumerate(full_frames)):
160
for q in range (len(target_embeds)):
161
try:
162
M, _ = face_align.estimate_norm(smooth_kps[q][i], crop_size, mode ='None')
163
align_img = cv2.warpAffine(frame, M, (crop_size, crop_size), borderValue=0.0)
164
crop_frames[q].append(align_img)
165
tfm_array[q].append(M)
166
except:
167
crop_frames[q].append([])
168
tfm_array[q].append([])
169
170
torch.cuda.empty_cache()
171
return crop_frames, tfm_array
172
173
174
def resize_frames(crop_frames: List[np.ndarray], new_size=(256, 256)) -> Tuple[List[np.ndarray], np.ndarray]:
175
"""
176
Resize frames to new size
177
"""
178
179
resized_frs = []
180
present = np.ones(len(crop_frames))
181
182
for i, crop_fr in tqdm(enumerate(crop_frames)):
183
try:
184
resized_frs.append(cv2.resize(crop_fr, new_size))
185
except:
186
present[i] = 0
187
188
return resized_frs, present
189
190
191
def get_final_video(final_frames: List[np.ndarray],
192
crop_frames: List[np.ndarray],
193
full_frames: List[np.ndarray],
194
tfm_array: List[np.ndarray],
195
OUT_VIDEO_NAME: str,
196
fps: float,
197
handler) -> None:
198
"""
199
Create final video from frames
200
"""
201
202
out = cv2.VideoWriter(f"{OUT_VIDEO_NAME}", cv2.VideoWriter_fourcc(*'mp4v'), fps, (full_frames[0].shape[1], full_frames[0].shape[0]))
203
size = (full_frames[0].shape[0], full_frames[0].shape[1])
204
params = [None for i in range(len(crop_frames))]
205
result_frames = full_frames.copy()
206
207
for i in tqdm(range(len(full_frames))):
208
if i == len(full_frames):
209
break
210
for j in range(len(crop_frames)):
211
try:
212
swap = cv2.resize(final_frames[j][i], (224, 224))
213
214
if len(crop_frames[j][i]) == 0:
215
params[j] = None
216
continue
217
218
landmarks = handler.get_without_detection_without_transform(swap)
219
if params[j] == None:
220
landmarks_tgt = handler.get_without_detection_without_transform(crop_frames[j][i])
221
mask, params[j] = face_mask_static(swap, landmarks, landmarks_tgt, params[j])
222
else:
223
mask = face_mask_static(swap, landmarks, landmarks_tgt, params[j])
224
225
swap = torch.from_numpy(swap).cuda().permute(2,0,1).unsqueeze(0).type(torch.float32)
226
mask = torch.from_numpy(mask).cuda().unsqueeze(0).unsqueeze(0).type(torch.float32)
227
full_frame = torch.from_numpy(result_frames[i]).cuda().permute(2,0,1).unsqueeze(0)
228
mat = torch.from_numpy(tfm_array[j][i]).cuda().unsqueeze(0).type(torch.float32)
229
230
mat_rev = kornia.invert_affine_transform(mat)
231
swap_t = kornia.warp_affine(swap, mat_rev, size)
232
mask_t = kornia.warp_affine(mask, mat_rev, size)
233
final = (mask_t*swap_t + (1-mask_t)*full_frame).type(torch.uint8).squeeze().permute(1,2,0).cpu().detach().numpy()
234
235
result_frames[i] = final
236
torch.cuda.empty_cache()
237
238
except Exception as e:
239
pass
240
241
out.write(result_frames[i])
242
243
out.release()
244
245
246
class Frames(Dataset):
247
def __init__(self, frames_list):
248
self.frames_list = frames_list
249
250
self.transforms = transforms.Compose([
251
transforms.ToTensor()
252
])
253
254
def __getitem__(self, idx):
255
frame = Image.fromarray(self.frames_list[idx][:,:,::-1])
256
257
return self.transforms(frame)
258
259
def __len__(self):
260
return len(self.frames_list)
261
262
263
def face_enhancement(final_frames: List[np.ndarray], model) -> List[np.ndarray]:
264
enhanced_frames_all = []
265
for i in range(len(final_frames)):
266
enhanced_frames = final_frames[i].copy()
267
face_idx = [i for i, x in enumerate(final_frames[i]) if not isinstance(x, list)]
268
face_frames = [x for i, x in enumerate(final_frames[i]) if not isinstance(x, list)]
269
ff_i = 0
270
271
dataset = Frames(face_frames)
272
dataloader = DataLoader(dataset, batch_size=20, shuffle=False, num_workers=1, drop_last=False)
273
274
for iteration, data in tqdm(enumerate(dataloader)):
275
frames = data
276
data = {'image': frames, 'label': frames}
277
generated = model(data, mode='inference2')
278
generated = torch.clamp(generated*255, 0, 255)
279
generated = (generated).type(torch.uint8).permute(0,2,3,1).cpu().detach().numpy()
280
for generated_frame in generated:
281
enhanced_frames[face_idx[ff_i]] = generated_frame[:,:,::-1]
282
ff_i+=1
283
enhanced_frames_all.append(enhanced_frames)
284
285
return enhanced_frames_all
286
287