Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/L0/run_amp/test_promotion.py
Views: 794
import unittest12import itertools as it34from apex import amp5import torch6from torch import nn7import torch.nn.functional as F89from utils import common_init, HALF, FLOAT, DTYPES1011class TestPromotion(unittest.TestCase):12def setUp(self):13self.handle = amp.init(enabled=True)14common_init(self)1516def tearDown(self):17self.handle._deactivate()1819def run_binary_promote_test(self, fns, input_shape, x_inplace=False):20type_pairs = it.product(DTYPES, DTYPES)21for fn, (xtype, ytype) in it.product(fns, type_pairs):22x = torch.randn(input_shape, dtype=xtype).requires_grad_()23x_leaf = x24if x_inplace:25# We need a non-leaf to call in place on26x = x.clone()27y = torch.randn(input_shape, dtype=ytype)28out = fn(x, y)29if x_inplace:30# In place: always match xtype31self.assertEqual(out.type(), x.type())32else:33# Out of place: match widest type34if xtype == torch.float or ytype == torch.float:35self.assertEqual(out.type(), FLOAT)36else:37self.assertEqual(out.type(), HALF)38out.float().sum().backward()39self.assertEqual(x_leaf.grad.dtype, xtype)4041def test_atan2_matches_widest(self):42fns = [lambda x, y : torch.atan2(x, y),43lambda x, y : x.atan2(y)]44self.run_binary_promote_test(fns, (self.b,))4546def test_mul_matches_widest(self):47fns = [lambda x, y : torch.mul(x, y),48lambda x, y: x.mul(y)]49self.run_binary_promote_test(fns, (self.b,))5051def test_cat_matches_widest(self):52shape = self.b53ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)]54x_float = torch.randn(shape)55out = torch.cat(ys + [x_float])56self.assertEqual(out.type(), FLOAT)57x_half = torch.randn(shape, dtype=torch.half)58out = torch.cat(ys + [x_half])59self.assertEqual(out.type(), HALF)6061def test_inplace_exp_is_error_for_half(self):62xs = torch.randn(self.b)63xs.exp_()64self.assertEqual(xs.type(), FLOAT)65xs = torch.randn(self.b, dtype=torch.half)66with self.assertRaises(NotImplementedError):67xs.exp_()6869def test_inplace_add_matches_self(self):70fn = lambda x, y: x.add_(y)71self.run_binary_promote_test([fn], (self.b,), x_inplace=True)7273if __name__ == '__main__':74unittest.main()757677