CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/L0/run_amp/test_basic_casts.py
Views: 794
1
import unittest
2
3
import functools as ft
4
import itertools as it
5
6
from apex import amp
7
import torch
8
from torch import nn
9
import torch.nn.functional as F
10
11
from utils import common_init, HALF, FLOAT,\
12
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
13
14
def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
15
for fn, typ in it.product(fns, expected.keys()):
16
x = torch.randn(input_shape, dtype=typ).requires_grad_()
17
y = fn(x)
18
test_case.assertEqual(y.type(), expected[typ])
19
if test_backward:
20
y.float().sum().backward()
21
test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])
22
23
class TestBasicCasts(unittest.TestCase):
24
def setUp(self):
25
self.handle = amp.init(enabled=True)
26
common_init(self)
27
28
def tearDown(self):
29
self.handle._deactivate()
30
31
def test_linear_is_half(self):
32
m = nn.Linear(self.h, self.h)
33
f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
34
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h))
35
36
def test_conv2d_is_half(self):
37
m = nn.Conv2d(self.c, self.c, self.k)
38
f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)
39
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h))
40
41
def test_softmax_is_float(self):
42
m = nn.Softmax(dim=1)
43
f = ft.partial(F.softmax, dim=1)
44
run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h))
45
46
def test_group_norm_is_float(self):
47
m = nn.GroupNorm(num_groups=4, num_channels=self.c)
48
run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h))
49
50
def test_mse_loss_is_float(self):
51
shape = (self.b, self.h)
52
target = torch.randn(shape)
53
mod = nn.MSELoss()
54
m = lambda x: mod(x, target)
55
f = ft.partial(F.mse_loss, target=target)
56
run_layer_test(self, [m], ALWAYS_FLOAT, shape)
57
58
def test_relu_is_match(self):
59
run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h))
60
61
def test_batch_norm_is_match(self):
62
m = nn.BatchNorm2d(num_features=self.c)
63
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
64
weight=m.weight, bias=m.bias, training=True)
65
run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h))
66
67
# Test forward-only for BN inference
68
m.eval()
69
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
70
weight=m.weight, bias=m.bias, training=False)
71
run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h),
72
test_backward=False)
73
74
class TestBannedMethods(unittest.TestCase):
75
def setUp(self):
76
self.handle = amp.init(enabled=True)
77
common_init(self)
78
79
def tearDown(self):
80
self.handle._deactivate()
81
82
def bce_common(self, assertion):
83
shape = (self.b, self.h)
84
target = torch.rand(shape)
85
mod = nn.BCELoss()
86
m = lambda x: mod(x, target)
87
f = ft.partial(F.binary_cross_entropy, target=target)
88
for fn in [m, f]:
89
x = torch.rand(shape, dtype=torch.half)
90
assertion(fn, x)
91
92
def test_bce_raises_by_default(self):
93
assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x)
94
self.bce_common(assertion)
95
96
def test_bce_is_float_with_allow_banned(self):
97
self.handle._deactivate()
98
self.handle = amp.init(enabled=True, allow_banned=True)
99
assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
100
self.bce_common(assertion)
101
102
class TestTensorCasts(unittest.TestCase):
103
def setUp(self):
104
self.handle = amp.init(enabled=True)
105
common_init(self)
106
107
def tearDown(self):
108
self.handle._deactivate()
109
110
def test_matmul_method_is_half(self):
111
other = torch.randn(self.h, self.h)
112
lhs = lambda x: x.matmul(other)
113
rhs = lambda x: other.matmul(x)
114
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
115
116
def test_matmul_op_is_half(self):
117
other = torch.randn(self.h, self.h)
118
lhs = lambda x: x @ other
119
rhs = lambda x: other @ x
120
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
121
122
def test_pow_method_is_float(self):
123
fn = lambda x: x.pow(2.)
124
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
125
126
def test_pow_op_is_float(self):
127
fn = lambda x: x ** 2.
128
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
129
130
def test_cpu_is_float(self):
131
fn = lambda x: x.cpu()
132
always_cpu_float = {torch.float: 'torch.FloatTensor',
133
torch.half: 'torch.FloatTensor'}
134
run_layer_test(self, [fn], always_cpu_float, (self.b, self.h))
135
136
def test_sum_is_float(self):
137
fn = lambda x: x.sum()
138
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
139
140
# TODO: maybe more tests on disabled casting?
141
142
if __name__ == '__main__':
143
unittest.main()
144
145