Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/L0/run_amp/test_checkpointing.py
Views: 794
import unittest12import torch3import torch.nn as nn4import torch.nn.functional as F5import torch.optim as optim67from apex import amp8910from utils import common_init, FLOAT111213class MyModel(torch.nn.Module):14def __init__(self):15super(MyModel, self).__init__()16self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)17self.bn1 = nn.BatchNorm2d(6)18self.param = nn.Parameter(torch.randn(1))1920def forward(self, x):21x = x * self.param22x = F.relu(self.conv1(x))23x = self.bn1(x)24return x252627class TestCheckpointing(unittest.TestCase):28def setUp(self):29self.initial_lr = 1e-330self.test_opt_levels = ("O0", "O1", "O2", "O3")3132def seed(self):33torch.manual_seed(2809)34torch.backends.cudnn.benchmark = False35torch.backends.cudnn.deterministic = True3637def check_state_dict_fp32(self, state_dict):38for key in state_dict:39if 'num_batches_tracked' in key:40continue41param = state_dict[key]42self.assertEqual(param.type(), FLOAT,43'Parameter in state_dict not FLOAT')4445def train_step(self, model, optimizer, data, loss_ids):46optimizer.zero_grad()4748output = model(data)4950# Call backward for num_losses-151for idx in loss_ids:52loss = output.mean()53with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:54scaled_loss.backward(retain_graph=True)5556optimizer.step()57return output5859def compare_models(self, modelA, modelB, test_setup=''):60state_dictA = modelA.state_dict()61state_dictB = modelB.state_dict()62self.assertEqual(len(state_dictA), len(state_dictB),63'state_dicts have different lengths' + test_setup)64for key in state_dictA:65paramA = state_dictA[key]66paramB = state_dictB[key]67self.assertTrue((paramA==paramB).all(),68msg='Parameters in state_dices not equal.' +69'key: {}\nparam: {}\nrestored: {}\ndiff: {} for {}'.format(70key, paramA, paramB, paramA - paramB, test_setup))7172def test_restoring(self):73nb_epochs = 1074nb_epochs_restore = nb_epochs // 275for opt_level in self.test_opt_levels:76for res_opt_level in self.test_opt_levels:77for amp_before_load in [True, False]:78for num_losses in range(1, 3):79test_setup = ('#' * 75 + '\n' + \80f'opt_level {opt_level}\n' + \81f'restore_opt_level {res_opt_level}\n' + \82f'amp_before_load {amp_before_load}\n' + \83f'num_losses {num_losses}\n')8485self.seed()8687# Create reference model88model = MyModel().to('cuda')8990optimizer = optim.SGD(model.parameters(),91lr=self.initial_lr)9293# Initialize with num_losses*2 for the original model and the restored one94model, optimizer = amp.initialize(95model, optimizer, opt_level=opt_level,96num_losses=num_losses*2, verbosity=0)9798# Compare training behavior for same restore option99# We cannot really generalize it, since a saved model in O0100# would introduce a skipped step in O1, which will raise an error101if opt_level == res_opt_level:102# train for nb_epochs and restore after nb_epochs_restore103for epoch in range(nb_epochs):104105x = torch.randn(16, 3, 24, 24, device='cuda')106output = self.train_step(107model, optimizer, x, range(num_losses))108# Initialize model one step before comparing.109# Otherwise the batchnorm layers will be updated110# additionally in restore_model111if epoch == (nb_epochs_restore - 1):112# Load model and optimizer113checkpoint = {114'model': model.state_dict(),115'optimizer': optimizer.state_dict(),116'amp': amp.state_dict()117}118# Check state_dict for FP32 tensors119self.check_state_dict_fp32(checkpoint['model'])120121# Restore model122restore_model = MyModel().to('cuda')123restore_optimizer = optim.SGD(124restore_model.parameters(),125lr=self.initial_lr)126127if amp_before_load:128restore_model, restore_optimizer = amp.initialize(129restore_model,130restore_optimizer,131opt_level=res_opt_level,132num_losses=num_losses*2,133verbosity=0)134135restore_model.load_state_dict(checkpoint['model'])136restore_optimizer.load_state_dict(checkpoint['optimizer'])137# FIXME: We cannot test the amp.state_dict in the same script138# amp.load_state_dict(checkpoint['amp'])139140if not amp_before_load:141restore_model, restore_optimizer = amp.initialize(142restore_model,143restore_optimizer,144opt_level=res_opt_level,145num_losses=num_losses*2,146verbosity=0)147148elif epoch >= nb_epochs_restore:149restore_output = self.train_step(150restore_model,151restore_optimizer,152x,153range(num_losses, num_losses*2))154self.assertTrue(155torch.allclose(output.float(), restore_output.float()),156'Output of reference and restored models differ for ' + test_setup)157self.compare_models(model, restore_model, test_setup)158# if opt_level != res_opt_level159else:160# skip tests for different opt_levels161continue162163def test_loss_scale_decrease(self):164num_losses = 3165nb_decrease_loss_scales = [0, 1, 2]166for opt_level in self.test_opt_levels:167#print('#' * 75 + f'\n opt_level {opt_level}\n')168# Create new tmp copy for this run169nb_decrease_loss_scales_tmp = list(nb_decrease_loss_scales)170171model = MyModel().to('cuda')172173optimizer = optim.SGD(model.parameters(),174lr=self.initial_lr)175176model, optimizer = amp.initialize(177model, optimizer, opt_level=opt_level, num_losses=num_losses,178verbosity=0)179180if amp._amp_state.opt_properties.loss_scale != 'dynamic':181#print('Static loss scale set. Skipping opt_level.')182continue183184# force to skip some updates to decrease the loss_scale185initial_loss_scales = []186for idx in range(num_losses):187initial_loss_scales.append(188amp._amp_state.loss_scalers[idx].loss_scale())189190for _ in range(len(nb_decrease_loss_scales)):191x = torch.randn(16, 3, 24, 24, device='cuda')192for idx in range(num_losses):193while nb_decrease_loss_scales_tmp[idx] > 0:194optimizer.zero_grad()195output = model(x * 2**17)196loss = output.mean()197198with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:199scaled_loss.backward(retain_graph=True)200optimizer.step()201nb_decrease_loss_scales_tmp[idx] -= 1202203# Check loss scales afterwards204updated_loss_scales = []205for idx in range(num_losses):206updated_loss_scales.append(207amp._amp_state.loss_scalers[idx].loss_scale())208for factor, update_ls, init_ls in zip(nb_decrease_loss_scales,209updated_loss_scales,210initial_loss_scales):211self.assertEqual(update_ls, init_ls / 2**factor)212213# Check state dict214amp_state_dict = amp.state_dict()215for scaler_idx, factor, init_ls in zip(amp_state_dict,216nb_decrease_loss_scales,217initial_loss_scales):218scaler = amp_state_dict[scaler_idx]219self.assertEqual(scaler['loss_scale'], init_ls / 2**factor)220unskipped_target = 0221self.assertEqual(scaler['unskipped'], unskipped_target)222223def test_state_dict(self):224for opt_level in self.test_opt_levels:225# Skip O3226if opt_level == 'O3':227continue228229model = MyModel().to('cuda')230optimizer = optim.Adam(model.parameters(), lr=1e-3)231model, optimizer = amp.initialize(232model, optimizer, opt_level=opt_level, verbosity=0)233234# Export state_dict and check for Half235state_dict = model.state_dict()236for key in state_dict:237self.assertFalse('Half' in state_dict[key].type())238239# Check, if model is still trainable240# Create dummy data241data = torch.randn(10, 3, 4, 4, device='cuda')242target = torch.randn(10, 6, 4, 4, device='cuda')243244# Get initnial loss245optimizer.zero_grad()246output = model(data)247loss = F.mse_loss(output, target)248with amp.scale_loss(loss, optimizer) as scaled_loss:249scaled_loss.backward()250optimizer.step()251last_loss = loss.item()252253# train for some epochs254for epoch in range(10):255optimizer.zero_grad()256output = model(data)257loss = F.mse_loss(output, target)258with amp.scale_loss(loss, optimizer) as scaled_loss:259scaled_loss.backward()260optimizer.step()261self.assertTrue(loss.item() < last_loss)262last_loss = loss.item()263264if __name__=='__main__':265unittest.main()266267268269