CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/L1/common/compare.py
Views: 794
1
import argparse
2
import torch
3
4
parser = argparse.ArgumentParser(description='Compare')
5
parser.add_argument('--opt-level', type=str)
6
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
7
parser.add_argument('--loss-scale', type=str, default=None)
8
parser.add_argument('--fused-adam', action='store_true')
9
parser.add_argument('--use_baseline', action='store_true')
10
args = parser.parse_args()
11
12
base_file = str(args.opt_level) + "_" +\
13
str(args.loss_scale) + "_" +\
14
str(args.keep_batchnorm_fp32) + "_" +\
15
str(args.fused_adam)
16
17
file_e = "True_" + base_file
18
file_p = "False_" + base_file
19
if args.use_baseline:
20
file_b = "baselines/True_" + base_file
21
22
dict_e = torch.load(file_e)
23
dict_p = torch.load(file_p)
24
if args.use_baseline:
25
dict_b = torch.load(file_b)
26
27
torch.set_printoptions(precision=10)
28
29
print(file_e)
30
print(file_p)
31
if args.use_baseline:
32
print(file_b)
33
34
# ugly duplication here...
35
if not args.use_baseline:
36
for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):
37
assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p)
38
39
loss_e = dict_e["Loss"][n]
40
loss_p = dict_p["Loss"][n]
41
assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p)
42
print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format(
43
i_e,
44
loss_e,
45
loss_p,
46
dict_e["Speed"][n],
47
dict_p["Speed"][n]))
48
else:
49
for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):
50
assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p)
51
52
loss_e = dict_e["Loss"][n]
53
loss_p = dict_p["Loss"][n]
54
loss_b = dict_b["Loss"][n]
55
assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p)
56
assert loss_e == loss_b, "Iteration {}, loss_e = {}, loss_b = {}".format(i_e, loss_e, loss_b)
57
print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format(
58
i_e,
59
loss_b,
60
loss_e,
61
loss_p,
62
dict_b["Speed"][n],
63
dict_e["Speed"][n],
64
dict_p["Speed"][n]))
65
66