CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/inference.py
Views: 792
1
import sys
2
import argparse
3
import cv2
4
import torch
5
import time
6
import os
7
8
from utils.inference.image_processing import crop_face, get_final_image
9
from utils.inference.video_processing import read_video, get_target, get_final_video, add_audio_from_another_video, face_enhancement
10
from utils.inference.core import model_inference
11
12
from network.AEI_Net import AEI_Net
13
from coordinate_reg.image_infer import Handler
14
from insightface_func.face_detect_crop_multi import Face_detect_crop
15
from arcface_model.iresnet import iresnet100
16
from models.pix2pix_model import Pix2PixModel
17
from models.config_sr import TestOptions
18
19
20
def init_models(args):
21
# model for face cropping
22
app = Face_detect_crop(name='antelope', root='./insightface_func/models')
23
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))
24
25
# main model for generation
26
G = AEI_Net(args.backbone, num_blocks=args.num_blocks, c_id=512)
27
G.eval()
28
G.load_state_dict(torch.load(args.G_path, map_location=torch.device('cpu')))
29
G = G.cuda()
30
G = G.half()
31
32
# arcface model to get face embedding
33
netArc = iresnet100(fp16=False)
34
netArc.load_state_dict(torch.load('arcface_model/backbone.pth'))
35
netArc=netArc.cuda()
36
netArc.eval()
37
38
# model to get face landmarks
39
handler = Handler('./coordinate_reg/model/2d106det', 0, ctx_id=0, det_size=640)
40
41
# model to make superres of face, set use_sr=True if you want to use super resolution or use_sr=False if you don't
42
if args.use_sr:
43
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
44
torch.backends.cudnn.benchmark = True
45
opt = TestOptions()
46
#opt.which_epoch ='10_7'
47
model = Pix2PixModel(opt)
48
model.netG.train()
49
else:
50
model = None
51
52
return app, G, netArc, handler, model
53
54
55
def main(args):
56
app, G, netArc, handler, model = init_models(args)
57
58
# get crops from source images
59
print('List of source paths: ',args.source_paths)
60
source = []
61
try:
62
for source_path in args.source_paths:
63
img = cv2.imread(source_path)
64
img = crop_face(img, app, args.crop_size)[0]
65
source.append(img[:, :, ::-1])
66
except TypeError:
67
print("Bad source images!")
68
exit()
69
70
# get full frames from video
71
if not args.image_to_image:
72
full_frames, fps = read_video(args.target_video)
73
else:
74
target_full = cv2.imread(args.target_image)
75
full_frames = [target_full]
76
77
# get target faces that are used for swap
78
set_target = True
79
print('List of target paths: ', args.target_faces_paths)
80
if not args.target_faces_paths:
81
target = get_target(full_frames, app, args.crop_size)
82
set_target = False
83
else:
84
target = []
85
try:
86
for target_faces_path in args.target_faces_paths:
87
img = cv2.imread(target_faces_path)
88
img = crop_face(img, app, args.crop_size)[0]
89
target.append(img)
90
except TypeError:
91
print("Bad target images!")
92
exit()
93
94
start = time.time()
95
final_frames_list, crop_frames_list, full_frames, tfm_array_list = model_inference(full_frames,
96
source,
97
target,
98
netArc,
99
G,
100
app,
101
set_target,
102
similarity_th=args.similarity_th,
103
crop_size=args.crop_size,
104
BS=args.batch_size)
105
if args.use_sr:
106
final_frames_list = face_enhancement(final_frames_list, model)
107
108
if not args.image_to_image:
109
get_final_video(final_frames_list,
110
crop_frames_list,
111
full_frames,
112
tfm_array_list,
113
args.out_video_name,
114
fps,
115
handler)
116
117
add_audio_from_another_video(args.target_video, args.out_video_name, "audio")
118
print(f"Video saved with path {args.out_video_name}")
119
else:
120
result = get_final_image(final_frames_list, crop_frames_list, full_frames[0], tfm_array_list, handler)
121
cv2.imwrite(args.out_image_name, result)
122
print(f'Swapped Image saved with path {args.out_image_name}')
123
124
print('Total time: ', time.time()-start)
125
126
127
if __name__ == "__main__":
128
parser = argparse.ArgumentParser()
129
130
# Generator params
131
parser.add_argument('--G_path', default='weights/G_unet_2blocks.pth', type=str, help='Path to weights for G')
132
parser.add_argument('--backbone', default='unet', const='unet', nargs='?', choices=['unet', 'linknet', 'resnet'], help='Backbone for attribute encoder')
133
parser.add_argument('--num_blocks', default=2, type=int, help='Numbers of AddBlocks at AddResblock')
134
135
parser.add_argument('--batch_size', default=40, type=int)
136
parser.add_argument('--crop_size', default=224, type=int, help="Don't change this")
137
parser.add_argument('--use_sr', default=False, type=bool, help='True for super resolution on swap images')
138
parser.add_argument('--similarity_th', default=0.15, type=float, help='Threshold for selecting a face similar to the target')
139
140
parser.add_argument('--source_paths', default=['examples/images/mark.jpg', 'examples/images/elon_musk.jpg'], nargs='+')
141
parser.add_argument('--target_faces_paths', default=[], nargs='+', help="It's necessary to set the face/faces in the video to which the source face/faces is swapped. You can skip this parametr, and then any face is selected in the target video for swap.")
142
143
# parameters for image to video
144
parser.add_argument('--target_video', default='examples/videos/nggyup.mp4', type=str, help="It's necessary for image to video swap")
145
parser.add_argument('--out_video_name', default='examples/results/result.mp4', type=str, help="It's necessary for image to video swap")
146
147
# parameters for image to image
148
parser.add_argument('--image_to_image', default=False, type=bool, help='True for image to image swap, False for swap on video')
149
parser.add_argument('--target_image', default='examples/images/beckham.jpg', type=str, help="It's necessary for image to image swap")
150
parser.add_argument('--out_image_name', default='examples/results/result.png', type=str,help="It's necessary for image to image swap")
151
152
args = parser.parse_args()
153
main(args)
154