Path: blob/main/apex/tests/L0/run_amp/test_larc.py
1103 views
import unittest12import torch3from torch import nn4from torch.nn import Parameter56from apex import amp7from apex.parallel.LARC import LARC8from utils import common_init91011class MyModel(torch.nn.Module):12def __init__(self, unique):13super(MyModel, self).__init__()14self.weight0 = Parameter(15unique + torch.arange(2, device="cuda", dtype=torch.float32)16)1718def forward(self, input):19return (input * self.weight0).sum()202122class TestLARC(unittest.TestCase):23def setUp(self):24self.x = torch.ones((2), device="cuda", dtype=torch.float32)25common_init(self)2627def tearDown(self):28pass2930def test_larc_mixed_precision(self):31for opt_level in ["O0", "O1", "O2", "O3"]:32model = MyModel(1)3334optimizer = LARC(35torch.optim.SGD(36[{"params": model.parameters(), "lr": 0.25}], momentum=0.12537)38)3940model, optimizer = amp.initialize(41model, optimizer, opt_level=opt_level, verbosity=042)4344optimizer.zero_grad()45loss = model(self.x)46with amp.scale_loss(loss, optimizer) as scaled_loss:47scaled_loss.backward()48optimizer.step()495051if __name__ == "__main__":52unittest.main()535455