Path: blob/main/utils/inference/faceshifter_run.py
1108 views
import torch1import numpy as np234def faceshifter_batch(source_emb: torch.tensor,5target: torch.tensor,6G: torch.nn.Module) -> np.ndarray:7"""8Apply faceshifter model for batch of target images9"""1011bs = target.shape[0]12assert target.ndim == 4, "target should have 4 dimentions -- B x C x H x W"1314if bs > 1:15source_emb = torch.cat([source_emb]*bs)1617with torch.no_grad():18Y_st, _ = G(target, source_emb)19Y_st = (Y_st.permute(0, 2, 3, 1)*0.5 + 0.5)*25520Y_st = Y_st[:, :, :, [2,1,0]].type(torch.uint8)21Y_st = Y_st.cpu().detach().numpy()22return Y_st2324