Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever
GitHub Repository: ai-forever/sber-swap
Path: blob/main/models/networks/sync_batchnorm/unittest.py
880 views
1
# -*- coding: utf-8 -*-
2
# File : unittest.py
3
# Author : Jiayuan Mao
4
# Email : [email protected]
5
# Date : 27/01/2018
6
#
7
# This file is part of Synchronized-BatchNorm-PyTorch.
8
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
# Distributed under MIT License.
10
11
import unittest
12
import torch
13
14
15
class TorchTestCase(unittest.TestCase):
16
def assertTensorClose(self, x, y):
17
adiff = float((x - y).abs().max())
18
if (y == 0).all():
19
rdiff = 'NaN'
20
else:
21
rdiff = float((adiff / y).abs().max())
22
23
message = (
24
'Tensor close check failed\n'
25
'adiff={}\n'
26
'rdiff={}\n'
27
).format(adiff, rdiff)
28
self.assertTrue(torch.allclose(x, y), message)
29
30
31