Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/tests/L0/run_mlp/test_mlp.py
Views: 794
"""Tests for c++ MLP"""1import unittest2from time import time3import numpy as np45import torch6from torch import nn78from apex.mlp import MLP910batch_size = 102411mlp_sizes = [480, 1024, 1024, 512, 256, 1]12num_iters = 101314class TestMLP(unittest.TestCase):1516def test_creation(self):17MLP(mlp_sizes)1819def test_numeric(self):20mlp = MLP(mlp_sizes).cuda()2122mlp_layers = []23for i in range(mlp.num_layers):24linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])25mlp.weights[i].data.copy_(linear.weight)26mlp.biases[i].data.copy_(linear.bias)27mlp_layers.append(linear)28mlp_layers.append(nn.ReLU(inplace=True))2930ref_mlp = nn.Sequential(*mlp_layers).cuda()3132test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.).requires_grad_()33ref_input = test_input.clone().detach().requires_grad_()34mlp_out = mlp(test_input)35ref_out = ref_mlp(ref_input)36np.testing.assert_allclose(37mlp_out.detach().cpu().numpy(),38ref_out.detach().cpu().numpy(),39atol=1e-7, rtol=1e-5)4041# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out42mlp_out.mean().mul(10.).backward()43ref_out.mean().mul(10.).backward()44np.testing.assert_allclose(45test_input.grad.detach().cpu().numpy(),46ref_input.grad.detach().cpu().numpy(),47atol=0, rtol=1e-5)48np.testing.assert_allclose(49mlp.biases[0].grad.detach().cpu().numpy(),50ref_mlp[0].bias.grad.detach().cpu().numpy(),51atol=1e-7, rtol=1e-5)5253def test_no_bias(self):54for use_activation in ['none', 'relu', 'sigmoid']:55mlp = MLP(mlp_sizes, bias=False, activation=use_activation).cuda()5657mlp_layers = []58for i in range(mlp.num_layers):59linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=False)60mlp.weights[i].data.copy_(linear.weight)61mlp_layers.append(linear)62if use_activation == 'relu':63mlp_layers.append(nn.ReLU(inplace=True))64if use_activation == 'sigmoid':65mlp_layers.append(nn.Sigmoid())6667ref_mlp = nn.Sequential(*mlp_layers).cuda()6869test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.).requires_grad_()70ref_input = test_input.clone().detach().requires_grad_()71mlp_out = mlp(test_input)72ref_out = ref_mlp(ref_input)73np.testing.assert_allclose(74mlp_out.detach().cpu().numpy(),75ref_out.detach().cpu().numpy(),76atol=1e-7, rtol=1e-5)7778# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out79mlp_out.mean().mul(10.).backward()80ref_out.mean().mul(10.).backward()81np.testing.assert_allclose(82test_input.grad.detach().cpu().numpy(),83ref_input.grad.detach().cpu().numpy(),84atol=0, rtol=100)85np.testing.assert_allclose(86mlp.weights[0].grad.detach().cpu().numpy(),87ref_mlp[0].weight.grad.detach().cpu().numpy(),88atol=1e-7, rtol=100)8990def test_with_bias(self):91for use_activation in ['none', 'relu', 'sigmoid']:92mlp = MLP(mlp_sizes, bias=True, activation=use_activation).cuda()9394mlp_layers = []95for i in range(mlp.num_layers):96linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=True)97mlp.weights[i].data.copy_(linear.weight)98mlp.biases[i].data.copy_(linear.bias)99mlp_layers.append(linear)100if use_activation == 'relu':101mlp_layers.append(nn.ReLU(inplace=True))102if use_activation == 'sigmoid':103mlp_layers.append(nn.Sigmoid())104105ref_mlp = nn.Sequential(*mlp_layers).cuda()106107test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.).requires_grad_()108ref_input = test_input.clone().detach().requires_grad_()109mlp_out = mlp(test_input)110ref_out = ref_mlp(ref_input)111np.testing.assert_allclose(112mlp_out.detach().cpu().numpy(),113ref_out.detach().cpu().numpy(),114atol=1e-7, rtol=1e-5)115116# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out117mlp_out.mean().mul(10.).backward()118ref_out.mean().mul(10.).backward()119np.testing.assert_allclose(120test_input.grad.detach().cpu().numpy(),121ref_input.grad.detach().cpu().numpy(),122atol=0, rtol=1)123np.testing.assert_allclose(124mlp.weights[0].grad.detach().cpu().numpy(),125ref_mlp[0].weight.grad.detach().cpu().numpy(),126atol=1e-7, rtol=1)127np.testing.assert_allclose(128mlp.biases[0].grad.detach().cpu().numpy(),129ref_mlp[0].bias.grad.detach().cpu().numpy(),130atol=1e-7, rtol=1e-5)131132def test_no_grad(self):133mlp = MLP(mlp_sizes).cuda()134135mlp_layers = []136for i in range(mlp.num_layers):137linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])138mlp.weights[i].data.copy_(linear.weight)139mlp.biases[i].data.copy_(linear.bias)140mlp_layers.append(linear)141mlp_layers.append(nn.ReLU(inplace=True))142143ref_mlp = nn.Sequential(*mlp_layers).cuda()144145test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.)146ref_input = test_input.clone().detach()147mlp_out = mlp(test_input)148ref_out = ref_mlp(ref_input)149np.testing.assert_allclose(150mlp_out.detach().cpu().numpy(),151ref_out.detach().cpu().numpy(),152atol=1e-7, rtol=1e-5)153154# Use mean value as scalar loss. Multiply 10 to make it big enough not zero out155mlp_out.mean().mul(10.).backward()156ref_out.mean().mul(10.).backward()157np.testing.assert_allclose(158mlp.weights[0].grad.detach().cpu().numpy(),159ref_mlp[0].weight.grad.detach().cpu().numpy(),160atol=1e-7, rtol=1e-5)161162163def test_performance_half(self):164mlp = MLP(mlp_sizes).cuda().half()165166mlp_layers = []167for i in range(mlp.num_layers):168linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])169mlp.weights[i].data.copy_(linear.weight)170mlp.biases[i].data.copy_(linear.bias)171mlp_layers.append(linear)172mlp_layers.append(nn.ReLU(inplace=True))173174ref_mlp = nn.Sequential(*mlp_layers).cuda().half()175176test_input = torch.empty(177batch_size, mlp_sizes[0], device="cuda", dtype=torch.half).fill_(10.).requires_grad_()178ref_input = torch.empty(179batch_size, mlp_sizes[0], device="cuda", dtype=torch.half).fill_(10.).requires_grad_()180181# Warm up GPU182for _ in range(100):183ref_out = ref_mlp(ref_input)184ref_loss = ref_out.mean()185ref_mlp.zero_grad()186ref_loss.backward()187mlp_out = mlp(test_input)188test_loss = mlp_out.mean()189mlp.zero_grad()190test_loss.backward()191192torch.cuda.profiler.start()193torch.cuda.synchronize()194start_time = time()195for _ in range(num_iters):196ref_out = ref_mlp(ref_input)197ref_loss = ref_out.mean()198ref_mlp.zero_grad()199ref_loss.backward()200torch.cuda.synchronize()201stop_time = time()202print(F"\nPytorch MLP time {(stop_time - start_time) * 1000. / num_iters:.4f} ms")203204torch.cuda.synchronize()205start_time = time()206for _ in range(num_iters):207mlp_out = mlp(test_input)208test_loss = mlp_out.mean()209mlp.zero_grad()210test_loss.backward()211torch.cuda.synchronize()212stop_time = time()213print(F"C++ MLP time {(stop_time - start_time) * 1000. / num_iters:.4f} ms")214torch.cuda.profiler.stop()215216if __name__ == '__main__':217unittest.main()218219220