Contact Us!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/train.py
Views: 804
1
print("started imports")
2
3
import sys
4
import argparse
5
import time
6
import cv2
7
import wandb
8
from PIL import Image
9
import os
10
11
from torch.utils.data import DataLoader
12
import torch.optim as optim
13
import torch.nn.functional as F
14
import torch
15
import torchvision.transforms as transforms
16
import torch.optim.lr_scheduler as scheduler
17
18
# custom imports
19
sys.path.append('./apex/')
20
21
from apex import amp
22
from network.AEI_Net import *
23
from network.MultiscaleDiscriminator import *
24
from utils.training.Dataset import FaceEmbedVGG2, FaceEmbed
25
from utils.training.image_processing import make_image_list, get_faceswap
26
from utils.training.losses import hinge_loss, compute_discriminator_loss, compute_generator_losses
27
from utils.training.detector import detect_landmarks, paint_eyes
28
from AdaptiveWingLoss.core import models
29
from arcface_model.iresnet import iresnet100
30
31
print("finished imports")
32
33
34
def train_one_epoch(G: 'generator model',
35
D: 'discriminator model',
36
opt_G: "generator opt",
37
opt_D: "discriminator opt",
38
scheduler_G: "scheduler G opt",
39
scheduler_D: "scheduler D opt",
40
netArc: 'ArcFace model',
41
model_ft: 'Landmark Detector',
42
args: 'Args Namespace',
43
dataloader: torch.utils.data.DataLoader,
44
device: 'torch device',
45
epoch:int,
46
loss_adv_accumulated:int):
47
48
for iteration, data in enumerate(dataloader):
49
start_time = time.time()
50
51
Xs_orig, Xs, Xt, same_person = data
52
53
Xs_orig = Xs_orig.to(device)
54
Xs = Xs.to(device)
55
Xt = Xt.to(device)
56
same_person = same_person.to(device)
57
58
# get the identity embeddings of Xs
59
with torch.no_grad():
60
embed = netArc(F.interpolate(Xs_orig, [112, 112], mode='bilinear', align_corners=False))
61
62
diff_person = torch.ones_like(same_person)
63
64
if args.diff_eq_same:
65
same_person = diff_person
66
67
# generator training
68
opt_G.zero_grad()
69
70
Y, Xt_attr = G(Xt, embed)
71
Di = D(Y)
72
ZY = netArc(F.interpolate(Y, [112, 112], mode='bilinear', align_corners=False))
73
74
if args.eye_detector_loss:
75
Xt_eyes, Xt_heatmap_left, Xt_heatmap_right = detect_landmarks(Xt, model_ft)
76
Y_eyes, Y_heatmap_left, Y_heatmap_right = detect_landmarks(Y, model_ft)
77
eye_heatmaps = [Xt_heatmap_left, Xt_heatmap_right, Y_heatmap_left, Y_heatmap_right]
78
else:
79
eye_heatmaps = None
80
81
lossG, loss_adv_accumulated, L_adv, L_attr, L_id, L_rec, L_l2_eyes = compute_generator_losses(G, Y, Xt, Xt_attr, Di,
82
embed, ZY, eye_heatmaps,loss_adv_accumulated,
83
diff_person, same_person, args)
84
85
with amp.scale_loss(lossG, opt_G) as scaled_loss:
86
scaled_loss.backward()
87
opt_G.step()
88
if args.scheduler:
89
scheduler_G.step()
90
91
# discriminator training
92
opt_D.zero_grad()
93
lossD = compute_discriminator_loss(D, Y, Xs, diff_person)
94
with amp.scale_loss(lossD, opt_D) as scaled_loss:
95
scaled_loss.backward()
96
97
if (not args.discr_force) or (loss_adv_accumulated < 4.):
98
opt_D.step()
99
if args.scheduler:
100
scheduler_D.step()
101
102
103
batch_time = time.time() - start_time
104
105
if iteration % args.show_step == 0:
106
images = [Xs, Xt, Y]
107
if args.eye_detector_loss:
108
Xt_eyes_img = paint_eyes(Xt, Xt_eyes)
109
Yt_eyes_img = paint_eyes(Y, Y_eyes)
110
images.extend([Xt_eyes_img, Yt_eyes_img])
111
image = make_image_list(images)
112
if args.use_wandb:
113
wandb.log({"gen_images":wandb.Image(image, caption=f"{epoch:03}" + '_' + f"{iteration:06}")})
114
else:
115
cv2.imwrite('./images/generated_image.jpg', image[:,:,::-1])
116
117
if iteration % 10 == 0:
118
print(f'epoch: {epoch} {iteration} / {len(dataloader)}')
119
print(f'lossD: {lossD.item()} lossG: {lossG.item()} batch_time: {batch_time}s')
120
print(f'L_adv: {L_adv.item()} L_id: {L_id.item()} L_attr: {L_attr.item()} L_rec: {L_rec.item()}')
121
if args.eye_detector_loss:
122
print(f'L_l2_eyes: {L_l2_eyes.item()}')
123
print(f'loss_adv_accumulated: {loss_adv_accumulated}')
124
if args.scheduler:
125
print(f'scheduler_G lr: {scheduler_G.get_last_lr()} scheduler_D lr: {scheduler_D.get_last_lr()}')
126
127
if args.use_wandb:
128
if args.eye_detector_loss:
129
wandb.log({"loss_eyes": L_l2_eyes.item()}, commit=False)
130
wandb.log({"loss_id": L_id.item(),
131
"lossD": lossD.item(),
132
"lossG": lossG.item(),
133
"loss_adv": L_adv.item(),
134
"loss_attr": L_attr.item(),
135
"loss_rec": L_rec.item()})
136
137
if iteration % 5000 == 0:
138
torch.save(G.state_dict(), f'./saved_models_{args.run_name}/G_latest.pth')
139
torch.save(D.state_dict(), f'./saved_models_{args.run_name}/D_latest.pth')
140
141
torch.save(G.state_dict(), f'./current_models_{args.run_name}/G_' + str(epoch)+ '_' + f"{iteration:06}" + '.pth')
142
torch.save(D.state_dict(), f'./current_models_{args.run_name}/D_' + str(epoch)+ '_' + f"{iteration:06}" + '.pth')
143
144
if (iteration % 250 == 0) and (args.use_wandb):
145
### Посмотрим как выглядит свап на трех конкретных фотках, чтобы проследить динамику
146
G.eval()
147
148
res1 = get_faceswap('examples/images/training//source1.png', 'examples/images/training//target1.png', G, netArc, device)
149
res2 = get_faceswap('examples/images/training//source2.png', 'examples/images/training//target2.png', G, netArc, device)
150
res3 = get_faceswap('examples/images/training//source3.png', 'examples/images/training//target3.png', G, netArc, device)
151
152
res4 = get_faceswap('examples/images/training//source4.png', 'examples/images/training//target4.png', G, netArc, device)
153
res5 = get_faceswap('examples/images/training//source5.png', 'examples/images/training//target5.png', G, netArc, device)
154
res6 = get_faceswap('examples/images/training//source6.png', 'examples/images/training//target6.png', G, netArc, device)
155
156
output1 = np.concatenate((res1, res2, res3), axis=0)
157
output2 = np.concatenate((res4, res5, res6), axis=0)
158
159
output = np.concatenate((output1, output2), axis=1)
160
161
wandb.log({"our_images":wandb.Image(output, caption=f"{epoch:03}" + '_' + f"{iteration:06}")})
162
163
G.train()
164
165
166
def train(args, device):
167
# training params
168
batch_size = args.batch_size
169
max_epoch = args.max_epoch
170
171
# initializing main models
172
G = AEI_Net(args.backbone, num_blocks=args.num_blocks, c_id=512).to(device)
173
D = MultiscaleDiscriminator(input_nc=3, n_layers=5, norm_layer=torch.nn.InstanceNorm2d).to(device)
174
G.train()
175
D.train()
176
177
# initializing model for identity extraction
178
netArc = iresnet100(fp16=False)
179
netArc.load_state_dict(torch.load('arcface_model/backbone.pth'))
180
netArc=netArc.cuda()
181
netArc.eval()
182
183
if args.eye_detector_loss:
184
model_ft = models.FAN(4, "False", "False", 98)
185
checkpoint = torch.load('./AdaptiveWingLoss/AWL_detector/WFLW_4HG.pth')
186
if 'state_dict' not in checkpoint:
187
model_ft.load_state_dict(checkpoint)
188
else:
189
pretrained_weights = checkpoint['state_dict']
190
model_weights = model_ft.state_dict()
191
pretrained_weights = {k: v for k, v in pretrained_weights.items() \
192
if k in model_weights}
193
model_weights.update(pretrained_weights)
194
model_ft.load_state_dict(model_weights)
195
model_ft = model_ft.to(device)
196
model_ft.eval()
197
else:
198
model_ft=None
199
200
opt_G = optim.Adam(G.parameters(), lr=args.lr_G, betas=(0, 0.999), weight_decay=1e-4)
201
opt_D = optim.Adam(D.parameters(), lr=args.lr_D, betas=(0, 0.999), weight_decay=1e-4)
202
203
G, opt_G = amp.initialize(G, opt_G, opt_level=args.optim_level)
204
D, opt_D = amp.initialize(D, opt_D, opt_level=args.optim_level)
205
206
if args.scheduler:
207
scheduler_G = scheduler.StepLR(opt_G, step_size=args.scheduler_step, gamma=args.scheduler_gamma)
208
scheduler_D = scheduler.StepLR(opt_D, step_size=args.scheduler_step, gamma=args.scheduler_gamma)
209
else:
210
scheduler_G = None
211
scheduler_D = None
212
213
if args.pretrained:
214
try:
215
G.load_state_dict(torch.load(args.G_path, map_location=torch.device('cpu')), strict=False)
216
D.load_state_dict(torch.load(args.D_path, map_location=torch.device('cpu')), strict=False)
217
print("Loaded pretrained weights for G and D")
218
except FileNotFoundError as e:
219
print("Not found pretrained weights. Continue without any pretrained weights.")
220
221
if args.vgg:
222
dataset = FaceEmbedVGG2(args.dataset_path, same_prob=args.same_person, same_identity=args.same_identity)
223
else:
224
dataset = FaceEmbed([args.dataset_path], same_prob=args.same_person)
225
226
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)
227
228
# Будем считать аккумулированный adv loss, чтобы обучать дискриминатор только когда он ниже порога, если discr_force=True
229
loss_adv_accumulated = 20.
230
231
for epoch in range(0, max_epoch):
232
train_one_epoch(G,
233
D,
234
opt_G,
235
opt_D,
236
scheduler_G,
237
scheduler_D,
238
netArc,
239
model_ft,
240
args,
241
dataloader,
242
device,
243
epoch,
244
loss_adv_accumulated)
245
246
def main(args):
247
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
248
if not torch.cuda.is_available():
249
print('cuda is not available. using cpu. check if it\'s ok')
250
251
print("Starting traing")
252
train(args, device=device)
253
254
255
if __name__ == "__main__":
256
parser = argparse.ArgumentParser()
257
258
# dataset params
259
parser.add_argument('--dataset_path', default='/VggFace2-crop/', help='Path to the dataset. If not VGG2 dataset is used, param --vgg should be set False')
260
parser.add_argument('--G_path', default='./saved_models/G.pth', help='Path to pretrained weights for G. Only used if pretrained=True')
261
parser.add_argument('--D_path', default='./saved_models/D.pth', help='Path to pretrained weights for D. Only used if pretrained=True')
262
parser.add_argument('--vgg', default=True, type=bool, help='When using VGG2 dataset (or any other dataset with several photos for one identity)')
263
# weights for loss
264
parser.add_argument('--weight_adv', default=1, type=float, help='Adversarial Loss weight')
265
parser.add_argument('--weight_attr', default=10, type=float, help='Attributes weight')
266
parser.add_argument('--weight_id', default=20, type=float, help='Identity Loss weight')
267
parser.add_argument('--weight_rec', default=10, type=float, help='Reconstruction Loss weight')
268
parser.add_argument('--weight_eyes', default=0., type=float, help='Eyes Loss weight')
269
# training params you may want to change
270
271
parser.add_argument('--backbone', default='unet', const='unet', nargs='?', choices=['unet', 'linknet', 'resnet'], help='Backbone for attribute encoder')
272
parser.add_argument('--num_blocks', default=2, type=int, help='Numbers of AddBlocks at AddResblock')
273
parser.add_argument('--same_person', default=0.2, type=float, help='Probability of using same person identity during training')
274
parser.add_argument('--same_identity', default=True, type=bool, help='Using simswap approach, when source_id = target_id. Only possible with vgg=True')
275
parser.add_argument('--diff_eq_same', default=False, type=bool, help='Don\'t use info about where is defferent identities')
276
parser.add_argument('--pretrained', default=True, type=bool, help='If using the pretrained weights for training or not')
277
parser.add_argument('--discr_force', default=False, type=bool, help='If True Discriminator would not train when adversarial loss is high')
278
parser.add_argument('--scheduler', default=False, type=bool, help='If True decreasing LR is used for learning of generator and discriminator')
279
parser.add_argument('--scheduler_step', default=5000, type=int)
280
parser.add_argument('--scheduler_gamma', default=0.2, type=float, help='It is value, which shows how many times to decrease LR')
281
parser.add_argument('--eye_detector_loss', default=False, type=bool, help='If True eye loss with using AdaptiveWingLoss detector is applied to generator')
282
# info about this run
283
parser.add_argument('--use_wandb', default=False, type=bool, help='Use wandb to track your experiments or not')
284
parser.add_argument('--run_name', required=True, type=str, help='Name of this run. Used to create folders where to save the weights.')
285
parser.add_argument('--wandb_project', default='your-project-name', type=str)
286
parser.add_argument('--wandb_entity', default='your-login', type=str)
287
# training params you probably don't want to change
288
parser.add_argument('--batch_size', default=16, type=int)
289
parser.add_argument('--lr_G', default=4e-4, type=float)
290
parser.add_argument('--lr_D', default=4e-4, type=float)
291
parser.add_argument('--max_epoch', default=2000, type=int)
292
parser.add_argument('--show_step', default=500, type=int)
293
parser.add_argument('--save_epoch', default=1, type=int)
294
parser.add_argument('--optim_level', default='O2', type=str)
295
296
args = parser.parse_args()
297
298
if args.vgg==False and args.same_identity==True:
299
raise ValueError("Sorry, you can't use some other dataset than VGG2 Faces with param same_identity=True")
300
301
if args.use_wandb==True:
302
wandb.init(project=args.wandb_project, entity=args.wandb_entity, settings=wandb.Settings(start_method='fork'))
303
304
config = wandb.config
305
config.dataset_path = args.dataset_path
306
config.weight_adv = args.weight_adv
307
config.weight_attr = args.weight_attr
308
config.weight_id = args.weight_id
309
config.weight_rec = args.weight_rec
310
config.weight_eyes = args.weight_eyes
311
config.same_person = args.same_person
312
config.Vgg2Face = args.vgg
313
config.same_identity = args.same_identity
314
config.diff_eq_same = args.diff_eq_same
315
config.discr_force = args.discr_force
316
config.scheduler = args.scheduler
317
config.scheduler_step = args.scheduler_step
318
config.scheduler_gamma = args.scheduler_gamma
319
config.eye_detector_loss = args.eye_detector_loss
320
config.pretrained = args.pretrained
321
config.run_name = args.run_name
322
config.G_path = args.G_path
323
config.D_path = args.D_path
324
config.batch_size = args.batch_size
325
config.lr_G = args.lr_G
326
config.lr_D = args.lr_D
327
elif not os.path.exists('./images'):
328
os.mkdir('./images')
329
330
# Создаем папки, чтобы было куда сохранять последние веса моделей, а также веса с каждой эпохи
331
if not os.path.exists(f'./saved_models_{args.run_name}'):
332
os.mkdir(f'./saved_models_{args.run_name}')
333
os.mkdir(f'./current_models_{args.run_name}')
334
335
main(args)
336
337