CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/apex/tests/L0/run_amp/test_rnn.py
Views: 794
1
import unittest
2
3
from apex import amp
4
import random
5
import torch
6
from torch import nn
7
8
from utils import common_init, HALF
9
10
class TestRnnCells(unittest.TestCase):
11
def setUp(self):
12
self.handle = amp.init(enabled=True)
13
common_init(self)
14
15
def tearDown(self):
16
self.handle._deactivate()
17
18
def run_cell_test(self, cell, state_tuple=False):
19
shape = (self.b, self.h)
20
for typ in [torch.float, torch.half]:
21
xs = [torch.randn(shape, dtype=typ).requires_grad_()
22
for _ in range(self.t)]
23
hidden_fn = lambda: torch.zeros(shape, dtype=typ)
24
if state_tuple:
25
hidden = (hidden_fn(), hidden_fn())
26
else:
27
hidden = hidden_fn()
28
outputs = []
29
for i in range(self.t):
30
hidden = cell(xs[i], hidden)
31
if state_tuple:
32
output = hidden[0]
33
else:
34
output = hidden
35
outputs.append(output)
36
for y in outputs:
37
self.assertEqual(y.type(), HALF)
38
outputs[-1].float().sum().backward()
39
for i, x in enumerate(xs):
40
self.assertEqual(x.grad.dtype, x.dtype)
41
42
def test_rnn_cell_is_half(self):
43
cell = nn.RNNCell(self.h, self.h)
44
self.run_cell_test(cell)
45
46
def test_gru_cell_is_half(self):
47
cell = nn.GRUCell(self.h, self.h)
48
self.run_cell_test(cell)
49
50
def test_lstm_cell_is_half(self):
51
cell = nn.LSTMCell(self.h, self.h)
52
self.run_cell_test(cell, state_tuple=True)
53
54
class TestRnns(unittest.TestCase):
55
def setUp(self):
56
self.handle = amp.init(enabled=True)
57
common_init(self)
58
59
def tearDown(self):
60
self.handle._deactivate()
61
62
def run_rnn_test(self, rnn, layers, bidir, state_tuple=False):
63
for typ in [torch.float, torch.half]:
64
x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
65
hidden_fn = lambda: torch.zeros((layers + (layers * bidir),
66
self.b, self.h), dtype=typ)
67
if state_tuple:
68
hidden = (hidden_fn(), hidden_fn())
69
else:
70
hidden = hidden_fn()
71
output, _ = rnn(x, hidden)
72
self.assertEqual(output.type(), HALF)
73
output[-1, :, :].float().sum().backward()
74
self.assertEqual(x.grad.dtype, x.dtype)
75
76
def test_rnn_is_half(self):
77
configs = [(1, False), (2, False), (2, True)]
78
for layers, bidir in configs:
79
rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=layers,
80
nonlinearity='relu', bidirectional=bidir)
81
self.run_rnn_test(rnn, layers, bidir)
82
83
def test_gru_is_half(self):
84
configs = [(1, False), (2, False), (2, True)]
85
for layers, bidir in configs:
86
rnn = nn.GRU(input_size=self.h, hidden_size=self.h, num_layers=layers,
87
bidirectional=bidir)
88
self.run_rnn_test(rnn, layers, bidir)
89
90
def test_lstm_is_half(self):
91
configs = [(1, False), (2, False), (2, True)]
92
for layers, bidir in configs:
93
rnn = nn.LSTM(input_size=self.h, hidden_size=self.h, num_layers=layers,
94
bidirectional=bidir)
95
self.run_rnn_test(rnn, layers, bidir, state_tuple=True)
96
97
def test_rnn_packed_sequence(self):
98
num_layers = 2
99
rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers)
100
for typ in [torch.float, torch.half]:
101
x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
102
lens = sorted([random.randint(self.t // 2, self.t) for _ in range(self.b)],
103
reverse=True)
104
# `pack_padded_sequence` breaks if default tensor type is non-CPU
105
torch.set_default_tensor_type(torch.FloatTensor)
106
lens = torch.tensor(lens, dtype=torch.int64, device=torch.device('cpu'))
107
packed_seq = nn.utils.rnn.pack_padded_sequence(x, lens)
108
torch.set_default_tensor_type(torch.cuda.FloatTensor)
109
hidden = torch.zeros((num_layers, self.b, self.h), dtype=typ)
110
output, _ = rnn(packed_seq, hidden)
111
self.assertEqual(output.data.type(), HALF)
112
output.data.float().sum().backward()
113
self.assertEqual(x.grad.dtype, x.dtype)
114
115
if __name__ == '__main__':
116
unittest.main()
117
118