Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/L0/run_amp/test_rnn.py
Views: 794
import unittest12from apex import amp3import random4import torch5from torch import nn67from utils import common_init, HALF89class TestRnnCells(unittest.TestCase):10def setUp(self):11self.handle = amp.init(enabled=True)12common_init(self)1314def tearDown(self):15self.handle._deactivate()1617def run_cell_test(self, cell, state_tuple=False):18shape = (self.b, self.h)19for typ in [torch.float, torch.half]:20xs = [torch.randn(shape, dtype=typ).requires_grad_()21for _ in range(self.t)]22hidden_fn = lambda: torch.zeros(shape, dtype=typ)23if state_tuple:24hidden = (hidden_fn(), hidden_fn())25else:26hidden = hidden_fn()27outputs = []28for i in range(self.t):29hidden = cell(xs[i], hidden)30if state_tuple:31output = hidden[0]32else:33output = hidden34outputs.append(output)35for y in outputs:36self.assertEqual(y.type(), HALF)37outputs[-1].float().sum().backward()38for i, x in enumerate(xs):39self.assertEqual(x.grad.dtype, x.dtype)4041def test_rnn_cell_is_half(self):42cell = nn.RNNCell(self.h, self.h)43self.run_cell_test(cell)4445def test_gru_cell_is_half(self):46cell = nn.GRUCell(self.h, self.h)47self.run_cell_test(cell)4849def test_lstm_cell_is_half(self):50cell = nn.LSTMCell(self.h, self.h)51self.run_cell_test(cell, state_tuple=True)5253class TestRnns(unittest.TestCase):54def setUp(self):55self.handle = amp.init(enabled=True)56common_init(self)5758def tearDown(self):59self.handle._deactivate()6061def run_rnn_test(self, rnn, layers, bidir, state_tuple=False):62for typ in [torch.float, torch.half]:63x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()64hidden_fn = lambda: torch.zeros((layers + (layers * bidir),65self.b, self.h), dtype=typ)66if state_tuple:67hidden = (hidden_fn(), hidden_fn())68else:69hidden = hidden_fn()70output, _ = rnn(x, hidden)71self.assertEqual(output.type(), HALF)72output[-1, :, :].float().sum().backward()73self.assertEqual(x.grad.dtype, x.dtype)7475def test_rnn_is_half(self):76configs = [(1, False), (2, False), (2, True)]77for layers, bidir in configs:78rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=layers,79nonlinearity='relu', bidirectional=bidir)80self.run_rnn_test(rnn, layers, bidir)8182def test_gru_is_half(self):83configs = [(1, False), (2, False), (2, True)]84for layers, bidir in configs:85rnn = nn.GRU(input_size=self.h, hidden_size=self.h, num_layers=layers,86bidirectional=bidir)87self.run_rnn_test(rnn, layers, bidir)8889def test_lstm_is_half(self):90configs = [(1, False), (2, False), (2, True)]91for layers, bidir in configs:92rnn = nn.LSTM(input_size=self.h, hidden_size=self.h, num_layers=layers,93bidirectional=bidir)94self.run_rnn_test(rnn, layers, bidir, state_tuple=True)9596def test_rnn_packed_sequence(self):97num_layers = 298rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers)99for typ in [torch.float, torch.half]:100x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()101lens = sorted([random.randint(self.t // 2, self.t) for _ in range(self.b)],102reverse=True)103# `pack_padded_sequence` breaks if default tensor type is non-CPU104torch.set_default_tensor_type(torch.FloatTensor)105lens = torch.tensor(lens, dtype=torch.int64, device=torch.device('cpu'))106packed_seq = nn.utils.rnn.pack_padded_sequence(x, lens)107torch.set_default_tensor_type(torch.cuda.FloatTensor)108hidden = torch.zeros((num_layers, self.b, self.h), dtype=typ)109output, _ = rnn(packed_seq, hidden)110self.assertEqual(output.data.type(), HALF)111output.data.float().sum().backward()112self.assertEqual(x.grad.dtype, x.dtype)113114if __name__ == '__main__':115unittest.main()116117118