Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/L0/run_amp/test_basic_casts.py
Views: 794
import unittest12import functools as ft3import itertools as it45from apex import amp6import torch7from torch import nn8import torch.nn.functional as F910from utils import common_init, HALF, FLOAT,\11ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT1213def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):14for fn, typ in it.product(fns, expected.keys()):15x = torch.randn(input_shape, dtype=typ).requires_grad_()16y = fn(x)17test_case.assertEqual(y.type(), expected[typ])18if test_backward:19y.float().sum().backward()20test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])2122class TestBasicCasts(unittest.TestCase):23def setUp(self):24self.handle = amp.init(enabled=True)25common_init(self)2627def tearDown(self):28self.handle._deactivate()2930def test_linear_is_half(self):31m = nn.Linear(self.h, self.h)32f = ft.partial(F.linear, weight=m.weight, bias=m.bias)33run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h))3435def test_conv2d_is_half(self):36m = nn.Conv2d(self.c, self.c, self.k)37f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)38run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h))3940def test_softmax_is_float(self):41m = nn.Softmax(dim=1)42f = ft.partial(F.softmax, dim=1)43run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h))4445def test_group_norm_is_float(self):46m = nn.GroupNorm(num_groups=4, num_channels=self.c)47run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h))4849def test_mse_loss_is_float(self):50shape = (self.b, self.h)51target = torch.randn(shape)52mod = nn.MSELoss()53m = lambda x: mod(x, target)54f = ft.partial(F.mse_loss, target=target)55run_layer_test(self, [m], ALWAYS_FLOAT, shape)5657def test_relu_is_match(self):58run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h))5960def test_batch_norm_is_match(self):61m = nn.BatchNorm2d(num_features=self.c)62f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,63weight=m.weight, bias=m.bias, training=True)64run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h))6566# Test forward-only for BN inference67m.eval()68f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,69weight=m.weight, bias=m.bias, training=False)70run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h),71test_backward=False)7273class TestBannedMethods(unittest.TestCase):74def setUp(self):75self.handle = amp.init(enabled=True)76common_init(self)7778def tearDown(self):79self.handle._deactivate()8081def bce_common(self, assertion):82shape = (self.b, self.h)83target = torch.rand(shape)84mod = nn.BCELoss()85m = lambda x: mod(x, target)86f = ft.partial(F.binary_cross_entropy, target=target)87for fn in [m, f]:88x = torch.rand(shape, dtype=torch.half)89assertion(fn, x)9091def test_bce_raises_by_default(self):92assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x)93self.bce_common(assertion)9495def test_bce_is_float_with_allow_banned(self):96self.handle._deactivate()97self.handle = amp.init(enabled=True, allow_banned=True)98assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)99self.bce_common(assertion)100101class TestTensorCasts(unittest.TestCase):102def setUp(self):103self.handle = amp.init(enabled=True)104common_init(self)105106def tearDown(self):107self.handle._deactivate()108109def test_matmul_method_is_half(self):110other = torch.randn(self.h, self.h)111lhs = lambda x: x.matmul(other)112rhs = lambda x: other.matmul(x)113run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))114115def test_matmul_op_is_half(self):116other = torch.randn(self.h, self.h)117lhs = lambda x: x @ other118rhs = lambda x: other @ x119run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))120121def test_pow_method_is_float(self):122fn = lambda x: x.pow(2.)123run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))124125def test_pow_op_is_float(self):126fn = lambda x: x ** 2.127run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))128129def test_cpu_is_float(self):130fn = lambda x: x.cpu()131always_cpu_float = {torch.float: 'torch.FloatTensor',132torch.half: 'torch.FloatTensor'}133run_layer_test(self, [fn], always_cpu_float, (self.b, self.h))134135def test_sum_is_float(self):136fn = lambda x: x.sum()137run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))138139# TODO: maybe more tests on disabled casting?140141if __name__ == '__main__':142unittest.main()143144145