Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/L1/common/compare.py
Views: 794
import argparse1import torch23parser = argparse.ArgumentParser(description='Compare')4parser.add_argument('--opt-level', type=str)5parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)6parser.add_argument('--loss-scale', type=str, default=None)7parser.add_argument('--fused-adam', action='store_true')8parser.add_argument('--use_baseline', action='store_true')9args = parser.parse_args()1011base_file = str(args.opt_level) + "_" +\12str(args.loss_scale) + "_" +\13str(args.keep_batchnorm_fp32) + "_" +\14str(args.fused_adam)1516file_e = "True_" + base_file17file_p = "False_" + base_file18if args.use_baseline:19file_b = "baselines/True_" + base_file2021dict_e = torch.load(file_e)22dict_p = torch.load(file_p)23if args.use_baseline:24dict_b = torch.load(file_b)2526torch.set_printoptions(precision=10)2728print(file_e)29print(file_p)30if args.use_baseline:31print(file_b)3233# ugly duplication here...34if not args.use_baseline:35for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):36assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p)3738loss_e = dict_e["Loss"][n]39loss_p = dict_p["Loss"][n]40assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p)41print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format(42i_e,43loss_e,44loss_p,45dict_e["Speed"][n],46dict_p["Speed"][n]))47else:48for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])):49assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p)5051loss_e = dict_e["Loss"][n]52loss_p = dict_p["Loss"][n]53loss_b = dict_b["Loss"][n]54assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p)55assert loss_e == loss_b, "Iteration {}, loss_e = {}, loss_b = {}".format(i_e, loss_e, loss_b)56print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format(57i_e,58loss_b,59loss_e,60loss_p,61dict_b["Speed"][n],62dict_e["Speed"][n],63dict_p["Speed"][n]))646566