Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/inference.py
Views: 792
import sys1import argparse2import cv23import torch4import time5import os67from utils.inference.image_processing import crop_face, get_final_image8from utils.inference.video_processing import read_video, get_target, get_final_video, add_audio_from_another_video, face_enhancement9from utils.inference.core import model_inference1011from network.AEI_Net import AEI_Net12from coordinate_reg.image_infer import Handler13from insightface_func.face_detect_crop_multi import Face_detect_crop14from arcface_model.iresnet import iresnet10015from models.pix2pix_model import Pix2PixModel16from models.config_sr import TestOptions171819def init_models(args):20# model for face cropping21app = Face_detect_crop(name='antelope', root='./insightface_func/models')22app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))2324# main model for generation25G = AEI_Net(args.backbone, num_blocks=args.num_blocks, c_id=512)26G.eval()27G.load_state_dict(torch.load(args.G_path, map_location=torch.device('cpu')))28G = G.cuda()29G = G.half()3031# arcface model to get face embedding32netArc = iresnet100(fp16=False)33netArc.load_state_dict(torch.load('arcface_model/backbone.pth'))34netArc=netArc.cuda()35netArc.eval()3637# model to get face landmarks38handler = Handler('./coordinate_reg/model/2d106det', 0, ctx_id=0, det_size=640)3940# model to make superres of face, set use_sr=True if you want to use super resolution or use_sr=False if you don't41if args.use_sr:42os.environ['CUDA_VISIBLE_DEVICES'] = '0'43torch.backends.cudnn.benchmark = True44opt = TestOptions()45#opt.which_epoch ='10_7'46model = Pix2PixModel(opt)47model.netG.train()48else:49model = None5051return app, G, netArc, handler, model525354def main(args):55app, G, netArc, handler, model = init_models(args)5657# get crops from source images58print('List of source paths: ',args.source_paths)59source = []60try:61for source_path in args.source_paths:62img = cv2.imread(source_path)63img = crop_face(img, app, args.crop_size)[0]64source.append(img[:, :, ::-1])65except TypeError:66print("Bad source images!")67exit()6869# get full frames from video70if not args.image_to_image:71full_frames, fps = read_video(args.target_video)72else:73target_full = cv2.imread(args.target_image)74full_frames = [target_full]7576# get target faces that are used for swap77set_target = True78print('List of target paths: ', args.target_faces_paths)79if not args.target_faces_paths:80target = get_target(full_frames, app, args.crop_size)81set_target = False82else:83target = []84try:85for target_faces_path in args.target_faces_paths:86img = cv2.imread(target_faces_path)87img = crop_face(img, app, args.crop_size)[0]88target.append(img)89except TypeError:90print("Bad target images!")91exit()9293start = time.time()94final_frames_list, crop_frames_list, full_frames, tfm_array_list = model_inference(full_frames,95source,96target,97netArc,98G,99app,100set_target,101similarity_th=args.similarity_th,102crop_size=args.crop_size,103BS=args.batch_size)104if args.use_sr:105final_frames_list = face_enhancement(final_frames_list, model)106107if not args.image_to_image:108get_final_video(final_frames_list,109crop_frames_list,110full_frames,111tfm_array_list,112args.out_video_name,113fps,114handler)115116add_audio_from_another_video(args.target_video, args.out_video_name, "audio")117print(f"Video saved with path {args.out_video_name}")118else:119result = get_final_image(final_frames_list, crop_frames_list, full_frames[0], tfm_array_list, handler)120cv2.imwrite(args.out_image_name, result)121print(f'Swapped Image saved with path {args.out_image_name}')122123print('Total time: ', time.time()-start)124125126if __name__ == "__main__":127parser = argparse.ArgumentParser()128129# Generator params130parser.add_argument('--G_path', default='weights/G_unet_2blocks.pth', type=str, help='Path to weights for G')131parser.add_argument('--backbone', default='unet', const='unet', nargs='?', choices=['unet', 'linknet', 'resnet'], help='Backbone for attribute encoder')132parser.add_argument('--num_blocks', default=2, type=int, help='Numbers of AddBlocks at AddResblock')133134parser.add_argument('--batch_size', default=40, type=int)135parser.add_argument('--crop_size', default=224, type=int, help="Don't change this")136parser.add_argument('--use_sr', default=False, type=bool, help='True for super resolution on swap images')137parser.add_argument('--similarity_th', default=0.15, type=float, help='Threshold for selecting a face similar to the target')138139parser.add_argument('--source_paths', default=['examples/images/mark.jpg', 'examples/images/elon_musk.jpg'], nargs='+')140parser.add_argument('--target_faces_paths', default=[], nargs='+', help="It's necessary to set the face/faces in the video to which the source face/faces is swapped. You can skip this parametr, and then any face is selected in the target video for swap.")141142# parameters for image to video143parser.add_argument('--target_video', default='examples/videos/nggyup.mp4', type=str, help="It's necessary for image to video swap")144parser.add_argument('--out_video_name', default='examples/results/result.mp4', type=str, help="It's necessary for image to video swap")145146# parameters for image to image147parser.add_argument('--image_to_image', default=False, type=bool, help='True for image to image swap, False for swap on video')148parser.add_argument('--target_image', default='examples/images/beckham.jpg', type=str, help="It's necessary for image to image swap")149parser.add_argument('--out_image_name', default='examples/results/result.png', type=str,help="It's necessary for image to image swap")150151args = parser.parse_args()152main(args)153154