CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/distributed/amp_master_params/compare.py
Views: 794
1
import torch
2
3
model_params_rank0 = torch.load("rank0model.pth",
4
map_location = lambda storage, loc: storage.cuda(0))
5
model_params_rank1 = torch.load("rank1model.pth",
6
map_location = lambda storage, loc: storage.cuda(0))
7
master_params_rank0 = torch.load("rank0master.pth",
8
map_location = lambda storage, loc: storage.cuda(0))
9
master_params_rank1 = torch.load("rank1master.pth",
10
map_location = lambda storage, loc: storage.cuda(0))
11
12
for model_rank0, model_rank1, master_rank0, master_rank1 in zip(
13
model_params_rank0,
14
model_params_rank1,
15
master_params_rank0,
16
master_params_rank1):
17
assert torch.allclose(model_rank0, model_rank1), "Model param mismatch"
18
assert torch.allclose(master_rank0, master_rank1), "Master param mismatch"
19
# Some debugging/investigation assistance code:
20
# maxval, maxind = torch.max(((torch.abs(model_rank0).float())/torch.abs(master_rank0)).view(-1), 0)
21
# offending_val_half = model_rank0.view(-1)[maxind.item()]
22
# offending_val_float = master_rank0.view(-1)[maxind.item()]
23
# print(maxval.item(), maxind.item(), offending_val_half.item(), offending_val_float.item(),
24
# offending_val_float.half().item())
25
# rtol needs to be > 2^-11 because of denormals...
26
assert torch.allclose(model_rank0, master_rank0.half(), rtol=.005), "Model-master mismatch"
27
28
print("OK: Model and master params match across ranks.")
29
30