Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/L0/run_optimizers/test_fused_optimizer.py
Views: 794
import unittest1import os2import random34import math5import torch6import apex7from itertools import product8from torch.optim import Optimizer910class TestFusedOptimizer(unittest.TestCase):11def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):12self.max_abs_diff = max_abs_diff13self.max_rel_diff = max_rel_diff14self.iters = iters15torch.cuda.manual_seed(9876)1617def tearDown(self):18pass1920def gen_param_optim(self, tensors, options, tst_options=None):2122# Adding this to make backward compatible with existing tests. Just in23# case "tst_options" are not provided, it gets a copy of options24# which contains the parameters for the reference optimizer25if tst_options == None:26tst_options = options2728ref_param = []29tst_param = []30for tensor in tensors:31ref_param.append(torch.nn.Parameter(tensor.clone()))32tst_param.append(torch.nn.Parameter(tensor.clone()))3334ref_optim = self.ref_optim(ref_param, **options)35tst_optim = self.fused_optim(tst_param, **tst_options)3637return (ref_param, tst_param, ref_optim, tst_optim)3839def gen_grad(self, ref_param, tst_param):40for p_ref, p_tst in zip(ref_param, tst_param):41p_ref.grad = torch.rand_like(p_ref)42p_tst.grad = p_ref.grad4344def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):45half_grads = []46for p_ref, p_tst in zip(ref_param, tst_param):47half_grads.append(torch.rand_like(p_ref).half())48p_ref.grad = half_grads[-1].float() / scale49return half_grads5051def get_max_diff(self, ref_param, tst_param):52max_abs_diff = max_rel_diff = 053for p_ref, p_tst in zip(ref_param, tst_param):54max_abs_diff_p = (p_ref - p_tst).abs().max().item()55max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()5657if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p58if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p5960return max_abs_diff, max_rel_diff6162def gen_single_type_test(self, param_type=torch.float, device='cuda'):63nelem = 2780116465# Some ref and test optimizers may require different set of options.66# This is a quick workaround to add that functionality while making67# minimum changes in existing code.68# If there is no "tst_options" field provided, safe to initialize69# the test optimizer with the parameters of reference optimizer.70if not hasattr(self, 'tst_options'):71self.tst_options = self.options7273tensor = torch.rand(nelem, dtype=param_type, device=device)7475ref_param, tst_param, ref_optim, tst_optim = \76self.gen_param_optim([tensor], self.options, self.tst_options)7778for i in range(self.iters):79self.gen_grad(ref_param, tst_param)80ref_optim.step()81tst_optim.step()82max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)83self.assertLessEqual(max_abs_diff, self.max_abs_diff)84self.assertLessEqual(max_rel_diff, self.max_rel_diff)858687class TestFusedAdam(TestFusedOptimizer):8889def __init__(self, *args, **kwargs):90super(TestFusedAdam, self).__init__(*args, **kwargs)91self.options = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,92'weight_decay': 0, 'amsgrad': False}93self.ref_optim = torch.optim.Adam94self.fused_optim = apex.optimizers.FusedAdam9596def test_float(self):97self.gen_single_type_test(param_type=torch.float)9899def test_half(self):100self.gen_single_type_test(param_type=torch.float16)101102@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")103def test_multi_device(self):104devices = ("cuda:0", "cuda:1")105for current_dev, tensor_dev in product(devices, devices):106with torch.cuda.device(current_dev):107self.gen_single_type_test(param_type=torch.float, device=tensor_dev)108109@unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked')110def test_multi_params(self):111sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]112113tensors = []114for size in sizes:115tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))116ref_param, tst_param, ref_optim, tst_optim = \117self.gen_param_optim(tensors, self.options)118119for i in range(self.iters):120self.gen_grad(ref_param, tst_param)121ref_optim.step()122tst_optim.step()123max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)124self.assertLessEqual(max_abs_diff, self.max_abs_diff)125self.assertLessEqual(max_rel_diff, self.max_rel_diff)126127@unittest.skip('No longer support fuse scaling')128def test_scale(self):129nelem = 278011130tensor = torch.rand(nelem, dtype=torch.float, device='cuda')131ref_param, tst_param, ref_optim, tst_optim = \132self.gen_param_optim([tensor], self.options)133134for i in range(self.iters):135scale = random.random() * 1000136half_grads = self.gen_mixed_grad(ref_param, tst_param, scale)137ref_optim.step()138tst_optim.step(grads=half_grads, scale=scale)139max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)140141self.assertLessEqual(max_abs_diff, self.max_abs_diff)142self.assertLessEqual(max_rel_diff, self.max_rel_diff)143144@unittest.skip('No longer support output fp16 param')145def test_fp16_output(self):146nelem = 278011147148tensor = torch.rand(nelem, dtype=torch.float, device='cuda')149ref_param, tst_param, ref_optim, tst_optim = \150self.gen_param_optim([tensor], self.options)151152fp16_param = torch.nn.Parameter(tensor.clone().half())153154for i in range(self.iters):155half_grads = self.gen_mixed_grad(ref_param, tst_param)156ref_optim.step()157tst_optim.step(grads=half_grads, output_params=[fp16_param])158159max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)160self.assertLessEqual(max_abs_diff, self.max_abs_diff)161self.assertLessEqual(max_rel_diff, self.max_rel_diff)162163max_abs_diff, max_rel_diff = self.get_max_diff(tst_param, \164[fp16_param.float()])165self.assertLessEqual(max_abs_diff, self.max_abs_diff)166self.assertLessEqual(max_rel_diff, self.max_rel_diff)167168def test_adam_option(self):169nelem = 1170adam_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06,171'weight_decay':0, 'amsgrad':False}172173tensor = torch.rand(nelem, dtype=torch.float, device='cuda')174ref_param, tst_param, ref_optim, tst_optim = \175self.gen_param_optim([tensor], adam_option)176177for i in range(self.iters):178self.gen_grad(ref_param, tst_param)179ref_optim.step()180tst_optim.step()181max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)182183self.assertLessEqual(max_abs_diff, self.max_abs_diff)184self.assertLessEqual(max_rel_diff, self.max_rel_diff)185186187class TestFusedAdagrad(TestFusedOptimizer):188def __init__(self, *args, **kwargs):189super(TestFusedAdagrad, self).__init__(*args, **kwargs)190self.options = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 1.0e-5}191self.ref_optim = torch.optim.Adagrad192self.fused_optim = apex.optimizers.FusedAdagrad193194def test_float(self):195self.gen_single_type_test(param_type=torch.float)196197@unittest.skip("PyTorch optimizer is not numerically correct for fp16")198def test_half(self):199self.gen_single_type_test(param_type=torch.float16)200201@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")202def test_multi_device(self):203devices = ("cuda:0", "cuda:1")204for current_dev, tensor_dev in product(devices, devices):205with torch.cuda.device(current_dev):206self.gen_single_type_test(param_type=torch.float, device=tensor_dev)207208209def test_multi_params(self):210sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]211adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0}212213tensors = []214for size in sizes:215tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))216ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(217tensors, adagrad_option218)219220for _ in range(self.iters):221self.gen_grad(ref_param, tst_param)222ref_optim.step()223tst_optim.step()224max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)225self.assertLessEqual(max_abs_diff, self.max_abs_diff)226self.assertLessEqual(max_rel_diff, self.max_rel_diff)227228@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")229def test_multi_params_different_devices_throws(self):230sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]231adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0}232233tensors = []234for i, size in enumerate(sizes):235tensors.append(torch.rand(size, dtype=torch.float, device="cuda:"+str(i % 2)))236ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(237tensors, adagrad_option238)239self.gen_grad(ref_param, tst_param)240with self.assertRaisesRegex(RuntimeError, "not on the same device"):241tst_optim.step()242243def test_adagrad_option(self):244nelem = 1245adagrad_option = {"lr": 0.01, "eps": 3e-06, "weight_decay": 0}246247tensor = torch.rand(nelem, dtype=torch.float, device="cuda")248ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(249[tensor], adagrad_option250)251252for _ in range(self.iters):253self.gen_grad(ref_param, tst_param)254ref_optim.step()255tst_optim.step()256max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)257258self.assertLessEqual(max_abs_diff, self.max_abs_diff)259self.assertLessEqual(max_rel_diff, self.max_rel_diff)260261262class TestFusedSGD(TestFusedOptimizer):263def __init__(self, *args, **kwargs):264super(TestFusedSGD, self).__init__(*args, **kwargs)265self.options = {"lr": .25, "momentum": .125}266self.ref_optim = torch.optim.SGD267self.fused_optim = apex.optimizers.FusedSGD268269def test_float(self):270self.gen_single_type_test(param_type=torch.float)271272def test_half(self):273self.gen_single_type_test(param_type=torch.float16)274275@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")276def test_multi_device(self):277devices = ("cuda:0", "cuda:1")278for current_dev, tensor_dev in product(devices, devices):279with torch.cuda.device(current_dev):280self.gen_single_type_test(param_type=torch.float, device=tensor_dev)281282if __name__ == '__main__':283unittest.main()284285286