Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/L0/run_amp/test_cache.py
Views: 794
import unittest12import functools as ft3import itertools as it45from apex import amp6from apex.amp import _amp_state7import torch8from torch import nn9import torch.nn.functional as F1011from utils import common_init, HALF, FLOAT,\12ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT1314def get_reference_grad(i, w, ops):15# Creating new tensors ensures, among other things, that the new tensors are not in the cache.16# In fact, they are guaranteed not to use the cache because they are not torch.nn.Parameters.17fp32_i = i.detach().clone().float()18fp32_w = w.detach().clone().float().requires_grad_()19loss = ops(fp32_i, fp32_w)20loss.backward()21return fp32_w.grad2223class WhitelistModule(torch.nn.Module):24def __init__(self, dtype):25super(WhitelistModule, self).__init__()26self.weight = torch.nn.Parameter(torch.arange(8*8, device='cuda', dtype=dtype).view(8,8))2728@staticmethod29def ops(input, weight):30return (input.mm(weight)).mm(weight).sum()3132def forward(self, input):33return self.ops(input, self.weight)343536class BlacklistModule(torch.nn.Module):37def __init__(self, dtype):38super(BlacklistModule, self).__init__()39self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))4041@staticmethod42def ops(input, weight):43return (input + torch.pow(weight, 2) + torch.pow(weight, 2)).sum()4445def forward(self, input):46return self.ops(input, self.weight)474849class PromoteModule(torch.nn.Module):50def __init__(self, dtype):51super(PromoteModule, self).__init__()52self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))5354@staticmethod55def ops(input, weight):56return ((input*weight)*weight).sum()5758def forward(self, input):59return self.ops(input, self.weight)6061class TestCache(unittest.TestCase):62def setUp(self):63self.x = torch.ones((2, 8), device='cuda', dtype=torch.float32)64common_init(self)6566def tearDown(self):67pass6869def train_eval_train_test(self, module, t):70model = module(t).cuda()71optimizer = torch.optim.SGD(model.parameters(), lr=1.0)7273_amp_state.allow_incoming_model_not_fp32 = True74model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0)75_amp_state.allow_incoming_model_not_fp32 = False7677def training_step():78for param in model.parameters():79param.grad = None8081loss = model(self.x).sum()82_amp_state.loss_scalers[0]._loss_scale = 4.083with amp.scale_loss(loss, optimizer) as scaled_loss:84scaled_loss.backward()8586self.assertEqual(len([p.grad for p in model.parameters() if p.grad is not None]), 1)87self.assertEqual(model.weight.grad.type(), model.weight.type())8889reference_grad = get_reference_grad(self.x, model.weight, model.ops)9091# Currently there's no difference in the allclose calls, so no need for branching,92# but I'm keeping this in case we want different tolerances for fp16 and fp32 checks.93if model.weight.grad.type() == "torch.cuda.HalfTensor":94self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))95elif model.weight.grad.type() == "torch.cuda.FloatTensor":96self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad))97else:98raise RuntimeError("model.weight.grad.type = {}".format(model.weight.grad.type()))99100model.weight.data -= 1.101102# Simulates first epoch103training_step()104105# Simulates eval106with torch.no_grad():107loss = model(self.x).sum()108109# Simulates resuming training after eval110training_step()111112_amp_state.handle._deactivate()113114# I could easily have these as a set of for loops in a single test,115# instead of going for granularity.116def test_whitelist_module_fp16_weight(self):117self.train_eval_train_test(WhitelistModule, torch.float16)118119def test_whitelist_module_fp32_weight(self):120self.train_eval_train_test(WhitelistModule, torch.float32)121122def test_blacklist_module_fp16_weight(self):123self.train_eval_train_test(BlacklistModule, torch.float16)124125def test_blacklist_module_fp32_weight(self):126self.train_eval_train_test(BlacklistModule, torch.float32)127128def test_promote_module_fp16_weight(self):129self.train_eval_train_test(PromoteModule, torch.float16)130131def test_promote_module_fp32_weight(self):132self.train_eval_train_test(PromoteModule, torch.float32)133134135if __name__ == '__main__':136unittest.main()137138139