import torch
import numpy as np
def faceshifter_batch(source_emb: torch.tensor,
target: torch.tensor,
G: torch.nn.Module) -> np.ndarray:
"""
Apply faceshifter model for batch of target images
"""
bs = target.shape[0]
assert target.ndim == 4, "target should have 4 dimentions -- B x C x H x W"
if bs > 1:
source_emb = torch.cat([source_emb]*bs)
with torch.no_grad():
Y_st, _ = G(target, source_emb)
Y_st = (Y_st.permute(0, 2, 3, 1)*0.5 + 0.5)*255
Y_st = Y_st[:, :, :, [2,1,0]].type(torch.uint8)
Y_st = Y_st.cpu().detach().numpy()
return Y_st