Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/L0/run_optimizers/test_lamb.py
Views: 793
import unittest1import os23import torch4from torch.optim import Optimizer5import apex6from apex.multi_tensor_apply import multi_tensor_applier7from itertools import product89class RefLAMB(Optimizer):10r"""Implements Lamb algorithm.1112It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.1314Arguments:15params (iterable): iterable of parameters to optimize or dicts defining16parameter groups17lr (float, optional): learning rate (default: 1e-3)18betas (Tuple[float, float], optional): coefficients used for computing19running averages of gradient and its square (default: (0.9, 0.999))20eps (float, optional): term added to the denominator to improve21numerical stability (default: 1e-6)22weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)2324.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:25https://arxiv.org/abs/1904.0096226"""2728def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01):29if not 0.0 <= lr:30raise ValueError("Invalid learning rate: {}".format(lr))31if not 0.0 <= eps:32raise ValueError("Invalid epsilon value: {}".format(eps))33if not 0.0 <= betas[0] < 1.0:34raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))35if not 0.0 <= betas[1] < 1.0:36raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))37defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)38super(RefLAMB, self).__init__(params, defaults)39if multi_tensor_applier.available:40import amp_C41self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm42# Skip buffer43self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)44self.multi_tensor_lamb = amp_C.multi_tensor_lamb45else:46raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')4748def step(self, closure=None):49"""Performs a single optimization step.50Arguments:51closure (callable, optional): A closure that reevaluates the model52and returns the loss.53"""54loss = None55if closure is not None:56loss = closure()5758# create separate grad lists for fp32 and fp16 params59g_all_32, g_all_16 = [], []60for group in self.param_groups:61for p in group['params']:62if p.grad is None:63continue64if p.dtype == torch.float32:65g_all_32.append(p.grad.data)66elif p.dtype == torch.float16:67g_all_16.append(p.grad.data)68else:69raise RuntimeError('FusedLAMB only support fp16 and fp32.')7071device = self.param_groups[0]["params"][0].device72g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)73# compute grad norm for two lists74if len(g_all_32) > 0:75g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,76self._dummy_overflow_buf,77[g_all_32], False)[0]78if len(g_all_16) > 0:79g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,80self._dummy_overflow_buf,81[g_all_16], False)[0]8283# blend two grad norms to get global grad norm84global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,85self._dummy_overflow_buf,86[[g_norm_32, g_norm_16]],87False)[0]8889max_grad_norm = 1.090clipped_ratio = max_grad_norm / max(global_grad_norm, max_grad_norm)9192for group in self.param_groups:93for p in group['params']:94if p.grad is None:95continue96p.grad.data *= clipped_ratio97grad = p.grad.data98if grad.is_sparse:99raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')100101state = self.state[p]102103# State initialization104if len(state) == 0:105state['step'] = 0106# Exponential moving average of gradient values107state['m'] = torch.zeros_like(p.data)108# Exponential moving average of squared gradient values109state['v'] = torch.zeros_like(p.data)110111m_t, v_t = state['m'], state['v']112beta1, beta2 = group['betas']113114state['step'] += 1115116# m_t = beta1 * m + (1 - beta1) * g_t117m_t.mul_(beta1).add_(grad, alpha=1-beta1)118# v_t = beta2 * v + (1 - beta2) * (g_t * g_t)119v_t.mul_(beta2).addcmul_(grad, grad, value=1-beta2)120121# Debiasing122m_t_hat = m_t / (1.0 - beta1 ** state['step'])123v_t_hat = v_t / (1.0 - beta2 ** state['step'])124125update = m_t_hat / v_t_hat.sqrt().add(group['eps'])126127if group['weight_decay'] != 0:128update.add_(p.data, alpha=group['weight_decay'])129130trust_ratio = 1.0131w_norm = p.data.pow(2).sum().sqrt()132g_norm = update.pow(2).sum().sqrt()133if w_norm > 0 and g_norm > 0:134trust_ratio = w_norm / g_norm135136state['w_norm'] = w_norm137state['g_norm'] = g_norm138state['trust_ratio'] = trust_ratio139140step_size = group['lr']141142p.data.add_(update, alpha=-step_size*trust_ratio)143144return loss145146147class TestFusedLAMB(unittest.TestCase):148def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):149self.max_abs_diff = max_abs_diff150self.max_rel_diff = max_rel_diff151self.iters = iters152torch.cuda.manual_seed(9876)153154def tearDown(self):155pass156157def gen_param_optim(self, tensors, lamb_option):158ref_param = []159tst_param = []160for tensor in tensors:161ref_param.append(torch.nn.Parameter(tensor.clone()))162tst_param.append(torch.nn.Parameter(tensor.clone()))163164ref_optim = RefLAMB(ref_param, **lamb_option)165tst_optim = apex.optimizers.FusedLAMB(tst_param, use_nvlamb=True, **lamb_option)166167return (ref_param, tst_param, ref_optim, tst_optim)168169def gen_grad(self, ref_param, tst_param):170for p_ref, p_tst in zip(ref_param, tst_param):171p_ref.grad = torch.rand_like(p_ref)172p_tst.grad = p_ref.grad173174def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):175half_grads = []176for p_ref, _ in zip(ref_param, tst_param):177half_grads.append(torch.rand_like(p_ref).half())178p_ref.grad = half_grads[-1].float() / scale179return half_grads180181def get_max_diff(self, ref_param, tst_param):182max_abs_diff = max_rel_diff = 0183for p_ref, p_tst in zip(ref_param, tst_param):184max_abs_diff_p = (p_ref - p_tst).abs().max().item()185max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()186187if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p188if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p189190return max_abs_diff, max_rel_diff191192def gen_single_type_test(self, param_type=torch.float, device="cuda"):193nelem = 278011194tensor = torch.rand(nelem, dtype=param_type, device=device)195weight_decay = [0, 0.01]196197for wd in weight_decay:198lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}199ref_param, tst_param, ref_optim, tst_optim = \200self.gen_param_optim([tensor], lamb_option)201202for i in range(self.iters):203self.gen_grad(ref_param, tst_param)204ref_optim.step()205torch.cuda.synchronize()206tst_optim.step()207torch.cuda.synchronize()208max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)209210self.assertLessEqual(max_abs_diff, self.max_abs_diff)211self.assertLessEqual(max_rel_diff, self.max_rel_diff)212213def test_float(self):214self.gen_single_type_test(param_type=torch.float)215216@unittest.skip("PyTorch optimizer is not numerically correct for fp16")217def test_half(self):218self.gen_single_type_test(param_type=torch.float16)219220@unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")221def test_multi_device(self):222devices = ("cuda:0", "cuda:1")223for current_dev, tensor_dev in product(devices, devices):224with torch.cuda.device(current_dev):225self.gen_single_type_test(param_type=torch.float, device=tensor_dev)226227def test_multi_params(self):228sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]229weight_decay = [0, 0.01]230231for wd in weight_decay:232lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}233tensors = []234for size in sizes:235tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))236ref_param, tst_param, ref_optim, tst_optim = \237self.gen_param_optim(tensors, lamb_option)238239for i in range(self.iters):240self.gen_grad(ref_param, tst_param)241ref_optim.step()242tst_optim.step()243max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)244self.assertLessEqual(max_abs_diff, self.max_abs_diff)245self.assertLessEqual(max_rel_diff, self.max_rel_diff)246247def test_lamb_option(self):248nelem = 1249tensor = torch.rand(nelem, dtype=torch.float, device='cuda')250weight_decay = [0, 0.01]251252for wd in weight_decay:253lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}254ref_param, tst_param, ref_optim, tst_optim = \255self.gen_param_optim([tensor], lamb_option)256257for i in range(self.iters):258self.gen_grad(ref_param, tst_param)259ref_optim.step()260tst_optim.step()261max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)262263self.assertLessEqual(max_abs_diff, self.max_abs_diff)264self.assertLessEqual(max_rel_diff, self.max_rel_diff)265266267if __name__ == '__main__':268script_path = os.path.dirname(os.path.realpath(__file__))269unittest.main()270271272