Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/utils/training/losses.py
Views: 813
1
import torch
2
3
l1_loss = torch.nn.L1Loss()
4
l2_loss = torch.nn.MSELoss()
5
6
7
def hinge_loss(X, positive=True):
8
if positive:
9
return torch.relu(1-X)
10
else:
11
return torch.relu(X+1)
12
13
14
def compute_generator_losses(G, Y, Xt, Xt_attr, Di, embed, ZY, eye_heatmaps, loss_adv_accumulated,
15
diff_person, same_person, args):
16
# adversarial loss
17
L_adv = 0.
18
for di in Di:
19
L_adv += hinge_loss(di[0], True).mean(dim=[1, 2, 3])
20
L_adv = torch.sum(L_adv * diff_person) / (diff_person.sum() + 1e-4)
21
22
# id loss
23
L_id =(1 - torch.cosine_similarity(embed, ZY, dim=1)).mean()
24
25
# attr loss
26
if args.optim_level == "O2" or args.optim_level == "O3":
27
Y_attr = G.get_attr(Y.type(torch.half))
28
else:
29
Y_attr = G.get_attr(Y)
30
31
L_attr = 0
32
for i in range(len(Xt_attr)):
33
L_attr += torch.mean(torch.pow(Xt_attr[i] - Y_attr[i], 2).reshape(args.batch_size, -1), dim=1).mean()
34
L_attr /= 2.0
35
36
# reconstruction loss
37
L_rec = torch.sum(0.5 * torch.mean(torch.pow(Y - Xt, 2).reshape(args.batch_size, -1), dim=1) * same_person) / (same_person.sum() + 1e-6)
38
39
# l2 eyes loss
40
if args.eye_detector_loss:
41
Xt_heatmap_left, Xt_heatmap_right, Y_heatmap_left, Y_heatmap_right = eye_heatmaps
42
L_l2_eyes = l2_loss(Xt_heatmap_left, Y_heatmap_left) + l2_loss(Xt_heatmap_right, Y_heatmap_right)
43
else:
44
L_l2_eyes = 0
45
46
# final loss of generator
47
lossG = args.weight_adv*L_adv + args.weight_attr*L_attr + args.weight_id*L_id + args.weight_rec*L_rec + args.weight_eyes*L_l2_eyes
48
loss_adv_accumulated = loss_adv_accumulated*0.98 + L_adv.item()*0.02
49
50
return lossG, loss_adv_accumulated, L_adv, L_attr, L_id, L_rec, L_l2_eyes
51
52
53
def compute_discriminator_loss(D, Y, Xs, diff_person):
54
# fake part
55
fake_D = D(Y.detach())
56
loss_fake = 0
57
for di in fake_D:
58
loss_fake += torch.sum(hinge_loss(di[0], False).mean(dim=[1, 2, 3]) * diff_person) / (diff_person.sum() + 1e-4)
59
60
# ground truth part
61
true_D = D(Xs)
62
loss_true = 0
63
for di in true_D:
64
loss_true += torch.sum(hinge_loss(di[0], True).mean(dim=[1, 2, 3]) * diff_person) / (diff_person.sum() + 1e-4)
65
66
lossD = 0.5*(loss_true.mean() + loss_fake.mean())
67
68
return lossD
69