Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever
GitHub Repository: ai-forever/sber-swap
Path: blob/main/models/networks/sync_batchnorm/batchnorm.py
880 views
1
# -*- coding: utf-8 -*-
2
# File : batchnorm.py
3
# Author : Jiayuan Mao
4
# Email : [email protected]
5
# Date : 27/01/2018
6
#
7
# This file is part of Synchronized-BatchNorm-PyTorch.
8
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
# Distributed under MIT License.
10
11
import collections
12
import contextlib
13
14
import torch
15
import torch.nn.functional as F
16
17
from torch.nn.modules.batchnorm import _BatchNorm
18
19
try:
20
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
21
except ImportError:
22
ReduceAddCoalesced = Broadcast = None
23
24
try:
25
from jactorch.parallel.comm import SyncMaster
26
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
27
except ImportError:
28
from .comm import SyncMaster
29
from .replicate import DataParallelWithCallback
30
31
__all__ = [
32
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
33
'patch_sync_batchnorm', 'convert_model'
34
]
35
36
37
def _sum_ft(tensor):
38
"""sum over the first and last dimention"""
39
return tensor.sum(dim=0).sum(dim=-1)
40
41
42
def _unsqueeze_ft(tensor):
43
"""add new dimensions at the front and the tail"""
44
return tensor.unsqueeze(0).unsqueeze(-1)
45
46
47
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
48
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
49
50
51
class _SynchronizedBatchNorm(_BatchNorm):
52
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
53
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
54
55
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
56
57
self._sync_master = SyncMaster(self._data_parallel_master)
58
59
self._is_parallel = False
60
self._parallel_id = None
61
self._slave_pipe = None
62
63
def forward(self, input):
64
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
65
if not (self._is_parallel and self.training):
66
return F.batch_norm(
67
input, self.running_mean, self.running_var, self.weight, self.bias,
68
self.training, self.momentum, self.eps)
69
70
# Resize the input to (B, C, -1).
71
input_shape = input.size()
72
input = input.view(input.size(0), self.num_features, -1)
73
74
# Compute the sum and square-sum.
75
sum_size = input.size(0) * input.size(2)
76
input_sum = _sum_ft(input)
77
input_ssum = _sum_ft(input ** 2)
78
79
# Reduce-and-broadcast the statistics.
80
if self._parallel_id == 0:
81
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
82
else:
83
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
84
85
# Compute the output.
86
if self.affine:
87
# MJY:: Fuse the multiplication for speed.
88
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
89
else:
90
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
91
92
# Reshape it.
93
return output.view(input_shape)
94
95
def __data_parallel_replicate__(self, ctx, copy_id):
96
self._is_parallel = True
97
self._parallel_id = copy_id
98
99
# parallel_id == 0 means master device.
100
if self._parallel_id == 0:
101
ctx.sync_master = self._sync_master
102
else:
103
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
104
105
def _data_parallel_master(self, intermediates):
106
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
107
108
# Always using same "device order" makes the ReduceAdd operation faster.
109
# Thanks to:: Tete Xiao (http://tetexiao.com/)
110
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
111
112
to_reduce = [i[1][:2] for i in intermediates]
113
to_reduce = [j for i in to_reduce for j in i] # flatten
114
target_gpus = [i[1].sum.get_device() for i in intermediates]
115
116
sum_size = sum([i[1].sum_size for i in intermediates])
117
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
118
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
119
120
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
121
122
outputs = []
123
for i, rec in enumerate(intermediates):
124
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
125
126
return outputs
127
128
def _compute_mean_std(self, sum_, ssum, size):
129
"""Compute the mean and standard-deviation with sum and square-sum. This method
130
also maintains the moving average on the master device."""
131
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
132
mean = sum_ / size
133
sumvar = ssum - sum_ * mean
134
unbias_var = sumvar / (size - 1)
135
bias_var = sumvar / size
136
137
if hasattr(torch, 'no_grad'):
138
with torch.no_grad():
139
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
140
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
141
else:
142
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
143
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
144
145
return mean, bias_var.clamp(self.eps) ** -0.5
146
147
148
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
149
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
150
mini-batch.
151
152
.. math::
153
154
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
155
156
This module differs from the built-in PyTorch BatchNorm1d as the mean and
157
standard-deviation are reduced across all devices during training.
158
159
For example, when one uses `nn.DataParallel` to wrap the network during
160
training, PyTorch's implementation normalize the tensor on each device using
161
the statistics only on that device, which accelerated the computation and
162
is also easy to implement, but the statistics might be inaccurate.
163
Instead, in this synchronized version, the statistics will be computed
164
over all training samples distributed on multiple devices.
165
166
Note that, for one-GPU or CPU-only case, this module behaves exactly same
167
as the built-in PyTorch implementation.
168
169
The mean and standard-deviation are calculated per-dimension over
170
the mini-batches and gamma and beta are learnable parameter vectors
171
of size C (where C is the input size).
172
173
During training, this layer keeps a running estimate of its computed mean
174
and variance. The running sum is kept with a default momentum of 0.1.
175
176
During evaluation, this running mean/variance is used for normalization.
177
178
Because the BatchNorm is done over the `C` dimension, computing statistics
179
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
180
181
Args:
182
num_features: num_features from an expected input of size
183
`batch_size x num_features [x width]`
184
eps: a value added to the denominator for numerical stability.
185
Default: 1e-5
186
momentum: the value used for the running_mean and running_var
187
computation. Default: 0.1
188
affine: a boolean value that when set to ``True``, gives the layer learnable
189
affine parameters. Default: ``True``
190
191
Shape::
192
- Input: :math:`(N, C)` or :math:`(N, C, L)`
193
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
194
195
Examples:
196
>>> # With Learnable Parameters
197
>>> m = SynchronizedBatchNorm1d(100)
198
>>> # Without Learnable Parameters
199
>>> m = SynchronizedBatchNorm1d(100, affine=False)
200
>>> input = torch.autograd.Variable(torch.randn(20, 100))
201
>>> output = m(input)
202
"""
203
204
def _check_input_dim(self, input):
205
if input.dim() != 2 and input.dim() != 3:
206
raise ValueError('expected 2D or 3D input (got {}D input)'
207
.format(input.dim()))
208
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
209
210
211
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
212
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
213
of 3d inputs
214
215
.. math::
216
217
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
218
219
This module differs from the built-in PyTorch BatchNorm2d as the mean and
220
standard-deviation are reduced across all devices during training.
221
222
For example, when one uses `nn.DataParallel` to wrap the network during
223
training, PyTorch's implementation normalize the tensor on each device using
224
the statistics only on that device, which accelerated the computation and
225
is also easy to implement, but the statistics might be inaccurate.
226
Instead, in this synchronized version, the statistics will be computed
227
over all training samples distributed on multiple devices.
228
229
Note that, for one-GPU or CPU-only case, this module behaves exactly same
230
as the built-in PyTorch implementation.
231
232
The mean and standard-deviation are calculated per-dimension over
233
the mini-batches and gamma and beta are learnable parameter vectors
234
of size C (where C is the input size).
235
236
During training, this layer keeps a running estimate of its computed mean
237
and variance. The running sum is kept with a default momentum of 0.1.
238
239
During evaluation, this running mean/variance is used for normalization.
240
241
Because the BatchNorm is done over the `C` dimension, computing statistics
242
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
243
244
Args:
245
num_features: num_features from an expected input of
246
size batch_size x num_features x height x width
247
eps: a value added to the denominator for numerical stability.
248
Default: 1e-5
249
momentum: the value used for the running_mean and running_var
250
computation. Default: 0.1
251
affine: a boolean value that when set to ``True``, gives the layer learnable
252
affine parameters. Default: ``True``
253
254
Shape::
255
- Input: :math:`(N, C, H, W)`
256
- Output: :math:`(N, C, H, W)` (same shape as input)
257
258
Examples:
259
>>> # With Learnable Parameters
260
>>> m = SynchronizedBatchNorm2d(100)
261
>>> # Without Learnable Parameters
262
>>> m = SynchronizedBatchNorm2d(100, affine=False)
263
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
264
>>> output = m(input)
265
"""
266
267
def _check_input_dim(self, input):
268
if input.dim() != 4:
269
raise ValueError('expected 4D input (got {}D input)'
270
.format(input.dim()))
271
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
272
273
274
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
275
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
276
of 4d inputs
277
278
.. math::
279
280
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
281
282
This module differs from the built-in PyTorch BatchNorm3d as the mean and
283
standard-deviation are reduced across all devices during training.
284
285
For example, when one uses `nn.DataParallel` to wrap the network during
286
training, PyTorch's implementation normalize the tensor on each device using
287
the statistics only on that device, which accelerated the computation and
288
is also easy to implement, but the statistics might be inaccurate.
289
Instead, in this synchronized version, the statistics will be computed
290
over all training samples distributed on multiple devices.
291
292
Note that, for one-GPU or CPU-only case, this module behaves exactly same
293
as the built-in PyTorch implementation.
294
295
The mean and standard-deviation are calculated per-dimension over
296
the mini-batches and gamma and beta are learnable parameter vectors
297
of size C (where C is the input size).
298
299
During training, this layer keeps a running estimate of its computed mean
300
and variance. The running sum is kept with a default momentum of 0.1.
301
302
During evaluation, this running mean/variance is used for normalization.
303
304
Because the BatchNorm is done over the `C` dimension, computing statistics
305
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
306
or Spatio-temporal BatchNorm
307
308
Args:
309
num_features: num_features from an expected input of
310
size batch_size x num_features x depth x height x width
311
eps: a value added to the denominator for numerical stability.
312
Default: 1e-5
313
momentum: the value used for the running_mean and running_var
314
computation. Default: 0.1
315
affine: a boolean value that when set to ``True``, gives the layer learnable
316
affine parameters. Default: ``True``
317
318
Shape::
319
- Input: :math:`(N, C, D, H, W)`
320
- Output: :math:`(N, C, D, H, W)` (same shape as input)
321
322
Examples:
323
>>> # With Learnable Parameters
324
>>> m = SynchronizedBatchNorm3d(100)
325
>>> # Without Learnable Parameters
326
>>> m = SynchronizedBatchNorm3d(100, affine=False)
327
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
328
>>> output = m(input)
329
"""
330
331
def _check_input_dim(self, input):
332
if input.dim() != 5:
333
raise ValueError('expected 5D input (got {}D input)'
334
.format(input.dim()))
335
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
336
337
338
@contextlib.contextmanager
339
def patch_sync_batchnorm():
340
import torch.nn as nn
341
342
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
343
344
nn.BatchNorm1d = SynchronizedBatchNorm1d
345
nn.BatchNorm2d = SynchronizedBatchNorm2d
346
nn.BatchNorm3d = SynchronizedBatchNorm3d
347
348
yield
349
350
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
351
352
353
def convert_model(module):
354
"""Traverse the input module and its child recursively
355
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
356
to SynchronizedBatchNorm*N*d
357
358
Args:
359
module: the input module needs to be convert to SyncBN model
360
361
Examples:
362
>>> import torch.nn as nn
363
>>> import torchvision
364
>>> # m is a standard pytorch model
365
>>> m = torchvision.models.resnet18(True)
366
>>> m = nn.DataParallel(m)
367
>>> # after convert, m is using SyncBN
368
>>> m = convert_model(m)
369
"""
370
if isinstance(module, torch.nn.DataParallel):
371
mod = module.module
372
mod = convert_model(mod)
373
mod = DataParallelWithCallback(mod)
374
return mod
375
376
mod = module
377
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
378
torch.nn.modules.batchnorm.BatchNorm2d,
379
torch.nn.modules.batchnorm.BatchNorm3d],
380
[SynchronizedBatchNorm1d,
381
SynchronizedBatchNorm2d,
382
SynchronizedBatchNorm3d]):
383
if isinstance(module, pth_module):
384
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
385
mod.running_mean = module.running_mean
386
mod.running_var = module.running_var
387
if module.affine:
388
mod.weight.data = module.weight.data.clone().detach()
389
mod.bias.data = module.bias.data.clone().detach()
390
391
for name, child in module.named_children():
392
mod.add_module(name, convert_model(child))
393
394
return mod
395
396