CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/L0/run_amp/test_promotion.py
Views: 794
1
import unittest
2
3
import itertools as it
4
5
from apex import amp
6
import torch
7
from torch import nn
8
import torch.nn.functional as F
9
10
from utils import common_init, HALF, FLOAT, DTYPES
11
12
class TestPromotion(unittest.TestCase):
13
def setUp(self):
14
self.handle = amp.init(enabled=True)
15
common_init(self)
16
17
def tearDown(self):
18
self.handle._deactivate()
19
20
def run_binary_promote_test(self, fns, input_shape, x_inplace=False):
21
type_pairs = it.product(DTYPES, DTYPES)
22
for fn, (xtype, ytype) in it.product(fns, type_pairs):
23
x = torch.randn(input_shape, dtype=xtype).requires_grad_()
24
x_leaf = x
25
if x_inplace:
26
# We need a non-leaf to call in place on
27
x = x.clone()
28
y = torch.randn(input_shape, dtype=ytype)
29
out = fn(x, y)
30
if x_inplace:
31
# In place: always match xtype
32
self.assertEqual(out.type(), x.type())
33
else:
34
# Out of place: match widest type
35
if xtype == torch.float or ytype == torch.float:
36
self.assertEqual(out.type(), FLOAT)
37
else:
38
self.assertEqual(out.type(), HALF)
39
out.float().sum().backward()
40
self.assertEqual(x_leaf.grad.dtype, xtype)
41
42
def test_atan2_matches_widest(self):
43
fns = [lambda x, y : torch.atan2(x, y),
44
lambda x, y : x.atan2(y)]
45
self.run_binary_promote_test(fns, (self.b,))
46
47
def test_mul_matches_widest(self):
48
fns = [lambda x, y : torch.mul(x, y),
49
lambda x, y: x.mul(y)]
50
self.run_binary_promote_test(fns, (self.b,))
51
52
def test_cat_matches_widest(self):
53
shape = self.b
54
ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)]
55
x_float = torch.randn(shape)
56
out = torch.cat(ys + [x_float])
57
self.assertEqual(out.type(), FLOAT)
58
x_half = torch.randn(shape, dtype=torch.half)
59
out = torch.cat(ys + [x_half])
60
self.assertEqual(out.type(), HALF)
61
62
def test_inplace_exp_is_error_for_half(self):
63
xs = torch.randn(self.b)
64
xs.exp_()
65
self.assertEqual(xs.type(), FLOAT)
66
xs = torch.randn(self.b, dtype=torch.half)
67
with self.assertRaises(NotImplementedError):
68
xs.exp_()
69
70
def test_inplace_add_matches_self(self):
71
fn = lambda x, y: x.add_(y)
72
self.run_binary_promote_test([fn], (self.b,), x_inplace=True)
73
74
if __name__ == '__main__':
75
unittest.main()
76
77