Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/utils/inference/video_processing.py
Views: 792
import cv21import numpy as np2import os3from PIL import Image4from typing import List, Tuple, Callable, Any5from tqdm import tqdm6import traceback7from insightface.utils import face_align8from scipy.spatial import distance910from .masks import face_mask_static11from .image_processing import normalize_and_torch, normalize_and_torch_batch, crop_face1213import torch14import torch.nn.functional as F15from torch.utils.data import DataLoader, Dataset16import torchvision.transforms as transforms17import kornia181920def add_audio_from_another_video(video_with_sound: str,21video_without_sound: str,22audio_name: str,23fast_cpu=True,24gpu=False) -> None:2526if not os.path.exists('./examples/audio/'):27os.makedirs('./examples/audio/')28fast_cmd = "-c:v libx264 -preset ultrafast -crf 18" if fast_cpu else ""29gpu_cmd = "-c:v h264_nvenc" if gpu else ""30os.system(f"ffmpeg -v -8 -i {video_with_sound} -vn -vcodec h264_nvenc ./examples/audio/{audio_name}.m4a")31os.system(f"ffmpeg -v -8 -i {video_without_sound} -i ./examples/audio/{audio_name}.m4a {fast_cmd} {gpu_cmd}{video_without_sound[:-4]}_audio.mp4 -y")32os.system(f"rm -rf ./examples/audio/{audio_name}.m4a")33os.system(f"mv {video_without_sound[:-4]}_audio.mp4 {video_without_sound}")343536def read_video(path_to_video: str) -> Tuple[List[np.ndarray], float]:37"""38Read video by frames using its path39"""4041# load video42cap = cv2.VideoCapture(path_to_video)4344width_original, height_original = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) #45fps, frames = cap.get(cv2.CAP_PROP_FPS), cap.get(cv2.CAP_PROP_FRAME_COUNT)4647full_frames = []48i = 0 # current frame4950while(cap.isOpened()):51if i == frames:52break5354ret, frame = cap.read()5556i += 157if ret==True:58full_frames.append(frame)59p = i * 100 / frames60else:61break6263cap.release()6465return full_frames, fps666768def get_target(full_frames: List[np.ndarray],69app: Callable,70crop_size: int):71i = 072target = None73while target is None:74if i < len(full_frames):75try:76target = [crop_face(full_frames[i], app, crop_size)[0]]77except TypeError:78i += 179else:80print("Video doesn't contain face!")81break82return target838485def smooth_landmarks(kps_arr, n = 2):86kps_arr_smooth_final = []87for ka in kps_arr:88kps_arr_s = [[ka[0]]]89for i in range(1, len(ka)):90if (len(ka[i])==0) or (len(ka[i-1])==0):91kps_arr_s.append([ka[i]])92elif (distance.euclidean(ka[i][0], ka[i-1][0]) > 5) or (distance.euclidean(ka[i][2], ka[i-1][2]) > 5):93kps_arr_s.append([ka[i]])94else:95kps_arr_s[-1].append(ka[i])9697kps_arr_smooth = []9899for a in kps_arr_s:100a_smooth = []101for i in range(len(a)):102q = min(i-0, len(a)-i-1, n)103a_smooth.append(np.mean( np.array(a[i-q:i+1+q]), axis=0 ) )104105kps_arr_smooth += a_smooth106kps_arr_smooth_final.append(kps_arr_smooth)107return kps_arr_smooth_final108109110def crop_frames_and_get_transforms(full_frames: List[np.ndarray],111target_embeds: List,112app: Callable,113netArc: Callable,114crop_size: int,115set_target: bool,116similarity_th: float) -> Tuple[List[Any], List[Any]]:117"""118Crop faces from frames and get respective tranforms119"""120121crop_frames = [ [] for _ in range(target_embeds.shape[0]) ]122tfm_array = [ [] for _ in range(target_embeds.shape[0]) ]123kps_array = [ [] for _ in range(target_embeds.shape[0]) ]124125target_embeds = F.normalize(target_embeds)126for frame in tqdm(full_frames):127try:128kps = app.get(frame, crop_size)129if len(kps) > 1 or set_target:130faces = []131for p in kps:132M, _ = face_align.estimate_norm(p, crop_size, mode ='None')133align_img = cv2.warpAffine(frame, M, (crop_size, crop_size), borderValue=0.0)134faces.append(align_img)135136face_norm = normalize_and_torch_batch(np.array(faces))137face_norm = F.interpolate(face_norm, scale_factor=0.5, mode='bilinear', align_corners=True)138face_embeds = netArc(face_norm)139face_embeds = F.normalize(face_embeds)140141similarity = face_embeds@target_embeds.T142best_idxs = similarity.argmax(0).detach().cpu().numpy()143for idx, best_idx in enumerate(best_idxs):144if similarity[best_idx][idx] > similarity_th:145kps_array[idx].append(kps[best_idx])146else:147kps_array[idx].append([])148149else:150kps_array[0].append(kps[0])151152except TypeError:153for q in range (len(target_embeds)):154kps_array[0].append([])155156smooth_kps = smooth_landmarks(kps_array, n = 2)157158for i, frame in tqdm(enumerate(full_frames)):159for q in range (len(target_embeds)):160try:161M, _ = face_align.estimate_norm(smooth_kps[q][i], crop_size, mode ='None')162align_img = cv2.warpAffine(frame, M, (crop_size, crop_size), borderValue=0.0)163crop_frames[q].append(align_img)164tfm_array[q].append(M)165except:166crop_frames[q].append([])167tfm_array[q].append([])168169torch.cuda.empty_cache()170return crop_frames, tfm_array171172173def resize_frames(crop_frames: List[np.ndarray], new_size=(256, 256)) -> Tuple[List[np.ndarray], np.ndarray]:174"""175Resize frames to new size176"""177178resized_frs = []179present = np.ones(len(crop_frames))180181for i, crop_fr in tqdm(enumerate(crop_frames)):182try:183resized_frs.append(cv2.resize(crop_fr, new_size))184except:185present[i] = 0186187return resized_frs, present188189190def get_final_video(final_frames: List[np.ndarray],191crop_frames: List[np.ndarray],192full_frames: List[np.ndarray],193tfm_array: List[np.ndarray],194OUT_VIDEO_NAME: str,195fps: float,196handler) -> None:197"""198Create final video from frames199"""200201out = cv2.VideoWriter(f"{OUT_VIDEO_NAME}", cv2.VideoWriter_fourcc(*'mp4v'), fps, (full_frames[0].shape[1], full_frames[0].shape[0]))202size = (full_frames[0].shape[0], full_frames[0].shape[1])203params = [None for i in range(len(crop_frames))]204result_frames = full_frames.copy()205206for i in tqdm(range(len(full_frames))):207if i == len(full_frames):208break209for j in range(len(crop_frames)):210try:211swap = cv2.resize(final_frames[j][i], (224, 224))212213if len(crop_frames[j][i]) == 0:214params[j] = None215continue216217landmarks = handler.get_without_detection_without_transform(swap)218if params[j] == None:219landmarks_tgt = handler.get_without_detection_without_transform(crop_frames[j][i])220mask, params[j] = face_mask_static(swap, landmarks, landmarks_tgt, params[j])221else:222mask = face_mask_static(swap, landmarks, landmarks_tgt, params[j])223224swap = torch.from_numpy(swap).cuda().permute(2,0,1).unsqueeze(0).type(torch.float32)225mask = torch.from_numpy(mask).cuda().unsqueeze(0).unsqueeze(0).type(torch.float32)226full_frame = torch.from_numpy(result_frames[i]).cuda().permute(2,0,1).unsqueeze(0)227mat = torch.from_numpy(tfm_array[j][i]).cuda().unsqueeze(0).type(torch.float32)228229mat_rev = kornia.invert_affine_transform(mat)230swap_t = kornia.warp_affine(swap, mat_rev, size)231mask_t = kornia.warp_affine(mask, mat_rev, size)232final = (mask_t*swap_t + (1-mask_t)*full_frame).type(torch.uint8).squeeze().permute(1,2,0).cpu().detach().numpy()233234result_frames[i] = final235torch.cuda.empty_cache()236237except Exception as e:238pass239240out.write(result_frames[i])241242out.release()243244245class Frames(Dataset):246def __init__(self, frames_list):247self.frames_list = frames_list248249self.transforms = transforms.Compose([250transforms.ToTensor()251])252253def __getitem__(self, idx):254frame = Image.fromarray(self.frames_list[idx][:,:,::-1])255256return self.transforms(frame)257258def __len__(self):259return len(self.frames_list)260261262def face_enhancement(final_frames: List[np.ndarray], model) -> List[np.ndarray]:263enhanced_frames_all = []264for i in range(len(final_frames)):265enhanced_frames = final_frames[i].copy()266face_idx = [i for i, x in enumerate(final_frames[i]) if not isinstance(x, list)]267face_frames = [x for i, x in enumerate(final_frames[i]) if not isinstance(x, list)]268ff_i = 0269270dataset = Frames(face_frames)271dataloader = DataLoader(dataset, batch_size=20, shuffle=False, num_workers=1, drop_last=False)272273for iteration, data in tqdm(enumerate(dataloader)):274frames = data275data = {'image': frames, 'label': frames}276generated = model(data, mode='inference2')277generated = torch.clamp(generated*255, 0, 255)278generated = (generated).type(torch.uint8).permute(0,2,3,1).cpu().detach().numpy()279for generated_frame in generated:280enhanced_frames[face_idx[ff_i]] = generated_frame[:,:,::-1]281ff_i+=1282enhanced_frames_all.append(enhanced_frames)283284return enhanced_frames_all285286287