Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/utils/inference/image_processing.py
Views: 792
import base641from io import BytesIO2from typing import Callable, List34import numpy as np5import torch6import cv27from .masks import face_mask_static8from matplotlib import pyplot as plt9from insightface.utils import face_align101112def crop_face(image_full: np.ndarray, app: Callable, crop_size: int) -> np.ndarray:13"""14Crop face from image and resize15"""16kps = app.get(image_full, crop_size)17M, _ = face_align.estimate_norm(kps[0], crop_size, mode ='None')18align_img = cv2.warpAffine(image_full, M, (crop_size, crop_size), borderValue=0.0)19return [align_img]202122def normalize_and_torch(image: np.ndarray) -> torch.tensor:23"""24Normalize image and transform to torch25"""26image = torch.tensor(image.copy(), dtype=torch.float32).cuda()27if image.max() > 1.:28image = image/255.2930image = image.permute(2, 0, 1).unsqueeze(0)31image = (image - 0.5) / 0.53233return image343536def normalize_and_torch_batch(frames: np.ndarray) -> torch.tensor:37"""38Normalize batch images and transform to torch39"""40batch_frames = torch.from_numpy(frames.copy()).cuda()41if batch_frames.max() > 1.:42batch_frames = batch_frames/255.4344batch_frames = batch_frames.permute(0, 3, 1, 2)45batch_frames = (batch_frames - 0.5)/0.54647return batch_frames484950def get_final_image(final_frames: List[np.ndarray],51crop_frames: List[np.ndarray],52full_frame: np.ndarray,53tfm_arrays: List[np.ndarray],54handler) -> None:55"""56Create final video from frames57"""58final = full_frame.copy()59params = [None for i in range(len(final_frames))]6061for i in range(len(final_frames)):62frame = cv2.resize(final_frames[i][0], (224, 224))6364landmarks = handler.get_without_detection_without_transform(frame)65landmarks_tgt = handler.get_without_detection_without_transform(crop_frames[i][0])6667mask, _ = face_mask_static(crop_frames[i][0], landmarks, landmarks_tgt, params[i])68mat_rev = cv2.invertAffineTransform(tfm_arrays[i][0])6970swap_t = cv2.warpAffine(frame, mat_rev, (full_frame.shape[1], full_frame.shape[0]), borderMode=cv2.BORDER_REPLICATE)71mask_t = cv2.warpAffine(mask, mat_rev, (full_frame.shape[1], full_frame.shape[0]))72mask_t = np.expand_dims(mask_t, 2)7374final = mask_t*swap_t + (1-mask_t)*final75final = np.array(final, dtype='uint8')76return final777879def show_images(images: List[np.ndarray],80titles=None,81figsize=(20, 5),82fontsize=15):83if titles:84assert len(titles) == len(images), "Amount of images should be the same as the amount of titles"8586fig, axes = plt.subplots(1, len(images), figsize=figsize)87for idx, (ax, image) in enumerate(zip(axes, images)):88ax.imshow(image[:, :, ::-1])89if titles:90ax.set_title(titles[idx], fontsize=fontsize)91ax.axis("off")929394