Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/utils/training/losses.py
Views: 813
import torch12l1_loss = torch.nn.L1Loss()3l2_loss = torch.nn.MSELoss()456def hinge_loss(X, positive=True):7if positive:8return torch.relu(1-X)9else:10return torch.relu(X+1)111213def compute_generator_losses(G, Y, Xt, Xt_attr, Di, embed, ZY, eye_heatmaps, loss_adv_accumulated,14diff_person, same_person, args):15# adversarial loss16L_adv = 0.17for di in Di:18L_adv += hinge_loss(di[0], True).mean(dim=[1, 2, 3])19L_adv = torch.sum(L_adv * diff_person) / (diff_person.sum() + 1e-4)2021# id loss22L_id =(1 - torch.cosine_similarity(embed, ZY, dim=1)).mean()2324# attr loss25if args.optim_level == "O2" or args.optim_level == "O3":26Y_attr = G.get_attr(Y.type(torch.half))27else:28Y_attr = G.get_attr(Y)2930L_attr = 031for i in range(len(Xt_attr)):32L_attr += torch.mean(torch.pow(Xt_attr[i] - Y_attr[i], 2).reshape(args.batch_size, -1), dim=1).mean()33L_attr /= 2.03435# reconstruction loss36L_rec = torch.sum(0.5 * torch.mean(torch.pow(Y - Xt, 2).reshape(args.batch_size, -1), dim=1) * same_person) / (same_person.sum() + 1e-6)3738# l2 eyes loss39if args.eye_detector_loss:40Xt_heatmap_left, Xt_heatmap_right, Y_heatmap_left, Y_heatmap_right = eye_heatmaps41L_l2_eyes = l2_loss(Xt_heatmap_left, Y_heatmap_left) + l2_loss(Xt_heatmap_right, Y_heatmap_right)42else:43L_l2_eyes = 04445# final loss of generator46lossG = args.weight_adv*L_adv + args.weight_attr*L_attr + args.weight_id*L_id + args.weight_rec*L_rec + args.weight_eyes*L_l2_eyes47loss_adv_accumulated = loss_adv_accumulated*0.98 + L_adv.item()*0.024849return lossG, loss_adv_accumulated, L_adv, L_attr, L_id, L_rec, L_l2_eyes505152def compute_discriminator_loss(D, Y, Xs, diff_person):53# fake part54fake_D = D(Y.detach())55loss_fake = 056for di in fake_D:57loss_fake += torch.sum(hinge_loss(di[0], False).mean(dim=[1, 2, 3]) * diff_person) / (diff_person.sum() + 1e-4)5859# ground truth part60true_D = D(Xs)61loss_true = 062for di in true_D:63loss_true += torch.sum(hinge_loss(di[0], True).mean(dim=[1, 2, 3]) * diff_person) / (diff_person.sum() + 1e-4)6465lossD = 0.5*(loss_true.mean() + loss_fake.mean())6667return lossD6869