Path: blob/main/models/networks/sync_batchnorm/batchnorm.py
880 views
# -*- coding: utf-8 -*-1# File : batchnorm.py2# Author : Jiayuan Mao3# Email : [email protected]4# Date : 27/01/20185#6# This file is part of Synchronized-BatchNorm-PyTorch.7# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch8# Distributed under MIT License.910import collections11import contextlib1213import torch14import torch.nn.functional as F1516from torch.nn.modules.batchnorm import _BatchNorm1718try:19from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast20except ImportError:21ReduceAddCoalesced = Broadcast = None2223try:24from jactorch.parallel.comm import SyncMaster25from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback26except ImportError:27from .comm import SyncMaster28from .replicate import DataParallelWithCallback2930__all__ = [31'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',32'patch_sync_batchnorm', 'convert_model'33]343536def _sum_ft(tensor):37"""sum over the first and last dimention"""38return tensor.sum(dim=0).sum(dim=-1)394041def _unsqueeze_ft(tensor):42"""add new dimensions at the front and the tail"""43return tensor.unsqueeze(0).unsqueeze(-1)444546_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])47_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])484950class _SynchronizedBatchNorm(_BatchNorm):51def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):52assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'5354super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)5556self._sync_master = SyncMaster(self._data_parallel_master)5758self._is_parallel = False59self._parallel_id = None60self._slave_pipe = None6162def forward(self, input):63# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.64if not (self._is_parallel and self.training):65return F.batch_norm(66input, self.running_mean, self.running_var, self.weight, self.bias,67self.training, self.momentum, self.eps)6869# Resize the input to (B, C, -1).70input_shape = input.size()71input = input.view(input.size(0), self.num_features, -1)7273# Compute the sum and square-sum.74sum_size = input.size(0) * input.size(2)75input_sum = _sum_ft(input)76input_ssum = _sum_ft(input ** 2)7778# Reduce-and-broadcast the statistics.79if self._parallel_id == 0:80mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))81else:82mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))8384# Compute the output.85if self.affine:86# MJY:: Fuse the multiplication for speed.87output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)88else:89output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)9091# Reshape it.92return output.view(input_shape)9394def __data_parallel_replicate__(self, ctx, copy_id):95self._is_parallel = True96self._parallel_id = copy_id9798# parallel_id == 0 means master device.99if self._parallel_id == 0:100ctx.sync_master = self._sync_master101else:102self._slave_pipe = ctx.sync_master.register_slave(copy_id)103104def _data_parallel_master(self, intermediates):105"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""106107# Always using same "device order" makes the ReduceAdd operation faster.108# Thanks to:: Tete Xiao (http://tetexiao.com/)109intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())110111to_reduce = [i[1][:2] for i in intermediates]112to_reduce = [j for i in to_reduce for j in i] # flatten113target_gpus = [i[1].sum.get_device() for i in intermediates]114115sum_size = sum([i[1].sum_size for i in intermediates])116sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)117mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)118119broadcasted = Broadcast.apply(target_gpus, mean, inv_std)120121outputs = []122for i, rec in enumerate(intermediates):123outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))124125return outputs126127def _compute_mean_std(self, sum_, ssum, size):128"""Compute the mean and standard-deviation with sum and square-sum. This method129also maintains the moving average on the master device."""130assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'131mean = sum_ / size132sumvar = ssum - sum_ * mean133unbias_var = sumvar / (size - 1)134bias_var = sumvar / size135136if hasattr(torch, 'no_grad'):137with torch.no_grad():138self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data139self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data140else:141self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data142self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data143144return mean, bias_var.clamp(self.eps) ** -0.5145146147class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):148r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a149mini-batch.150151.. math::152153y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta154155This module differs from the built-in PyTorch BatchNorm1d as the mean and156standard-deviation are reduced across all devices during training.157158For example, when one uses `nn.DataParallel` to wrap the network during159training, PyTorch's implementation normalize the tensor on each device using160the statistics only on that device, which accelerated the computation and161is also easy to implement, but the statistics might be inaccurate.162Instead, in this synchronized version, the statistics will be computed163over all training samples distributed on multiple devices.164165Note that, for one-GPU or CPU-only case, this module behaves exactly same166as the built-in PyTorch implementation.167168The mean and standard-deviation are calculated per-dimension over169the mini-batches and gamma and beta are learnable parameter vectors170of size C (where C is the input size).171172During training, this layer keeps a running estimate of its computed mean173and variance. The running sum is kept with a default momentum of 0.1.174175During evaluation, this running mean/variance is used for normalization.176177Because the BatchNorm is done over the `C` dimension, computing statistics178on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm179180Args:181num_features: num_features from an expected input of size182`batch_size x num_features [x width]`183eps: a value added to the denominator for numerical stability.184Default: 1e-5185momentum: the value used for the running_mean and running_var186computation. Default: 0.1187affine: a boolean value that when set to ``True``, gives the layer learnable188affine parameters. Default: ``True``189190Shape::191- Input: :math:`(N, C)` or :math:`(N, C, L)`192- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)193194Examples:195>>> # With Learnable Parameters196>>> m = SynchronizedBatchNorm1d(100)197>>> # Without Learnable Parameters198>>> m = SynchronizedBatchNorm1d(100, affine=False)199>>> input = torch.autograd.Variable(torch.randn(20, 100))200>>> output = m(input)201"""202203def _check_input_dim(self, input):204if input.dim() != 2 and input.dim() != 3:205raise ValueError('expected 2D or 3D input (got {}D input)'206.format(input.dim()))207super(SynchronizedBatchNorm1d, self)._check_input_dim(input)208209210class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):211r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch212of 3d inputs213214.. math::215216y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta217218This module differs from the built-in PyTorch BatchNorm2d as the mean and219standard-deviation are reduced across all devices during training.220221For example, when one uses `nn.DataParallel` to wrap the network during222training, PyTorch's implementation normalize the tensor on each device using223the statistics only on that device, which accelerated the computation and224is also easy to implement, but the statistics might be inaccurate.225Instead, in this synchronized version, the statistics will be computed226over all training samples distributed on multiple devices.227228Note that, for one-GPU or CPU-only case, this module behaves exactly same229as the built-in PyTorch implementation.230231The mean and standard-deviation are calculated per-dimension over232the mini-batches and gamma and beta are learnable parameter vectors233of size C (where C is the input size).234235During training, this layer keeps a running estimate of its computed mean236and variance. The running sum is kept with a default momentum of 0.1.237238During evaluation, this running mean/variance is used for normalization.239240Because the BatchNorm is done over the `C` dimension, computing statistics241on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm242243Args:244num_features: num_features from an expected input of245size batch_size x num_features x height x width246eps: a value added to the denominator for numerical stability.247Default: 1e-5248momentum: the value used for the running_mean and running_var249computation. Default: 0.1250affine: a boolean value that when set to ``True``, gives the layer learnable251affine parameters. Default: ``True``252253Shape::254- Input: :math:`(N, C, H, W)`255- Output: :math:`(N, C, H, W)` (same shape as input)256257Examples:258>>> # With Learnable Parameters259>>> m = SynchronizedBatchNorm2d(100)260>>> # Without Learnable Parameters261>>> m = SynchronizedBatchNorm2d(100, affine=False)262>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))263>>> output = m(input)264"""265266def _check_input_dim(self, input):267if input.dim() != 4:268raise ValueError('expected 4D input (got {}D input)'269.format(input.dim()))270super(SynchronizedBatchNorm2d, self)._check_input_dim(input)271272273class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):274r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch275of 4d inputs276277.. math::278279y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta280281This module differs from the built-in PyTorch BatchNorm3d as the mean and282standard-deviation are reduced across all devices during training.283284For example, when one uses `nn.DataParallel` to wrap the network during285training, PyTorch's implementation normalize the tensor on each device using286the statistics only on that device, which accelerated the computation and287is also easy to implement, but the statistics might be inaccurate.288Instead, in this synchronized version, the statistics will be computed289over all training samples distributed on multiple devices.290291Note that, for one-GPU or CPU-only case, this module behaves exactly same292as the built-in PyTorch implementation.293294The mean and standard-deviation are calculated per-dimension over295the mini-batches and gamma and beta are learnable parameter vectors296of size C (where C is the input size).297298During training, this layer keeps a running estimate of its computed mean299and variance. The running sum is kept with a default momentum of 0.1.300301During evaluation, this running mean/variance is used for normalization.302303Because the BatchNorm is done over the `C` dimension, computing statistics304on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm305or Spatio-temporal BatchNorm306307Args:308num_features: num_features from an expected input of309size batch_size x num_features x depth x height x width310eps: a value added to the denominator for numerical stability.311Default: 1e-5312momentum: the value used for the running_mean and running_var313computation. Default: 0.1314affine: a boolean value that when set to ``True``, gives the layer learnable315affine parameters. Default: ``True``316317Shape::318- Input: :math:`(N, C, D, H, W)`319- Output: :math:`(N, C, D, H, W)` (same shape as input)320321Examples:322>>> # With Learnable Parameters323>>> m = SynchronizedBatchNorm3d(100)324>>> # Without Learnable Parameters325>>> m = SynchronizedBatchNorm3d(100, affine=False)326>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))327>>> output = m(input)328"""329330def _check_input_dim(self, input):331if input.dim() != 5:332raise ValueError('expected 5D input (got {}D input)'333.format(input.dim()))334super(SynchronizedBatchNorm3d, self)._check_input_dim(input)335336337@contextlib.contextmanager338def patch_sync_batchnorm():339import torch.nn as nn340341backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d342343nn.BatchNorm1d = SynchronizedBatchNorm1d344nn.BatchNorm2d = SynchronizedBatchNorm2d345nn.BatchNorm3d = SynchronizedBatchNorm3d346347yield348349nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup350351352def convert_model(module):353"""Traverse the input module and its child recursively354and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d355to SynchronizedBatchNorm*N*d356357Args:358module: the input module needs to be convert to SyncBN model359360Examples:361>>> import torch.nn as nn362>>> import torchvision363>>> # m is a standard pytorch model364>>> m = torchvision.models.resnet18(True)365>>> m = nn.DataParallel(m)366>>> # after convert, m is using SyncBN367>>> m = convert_model(m)368"""369if isinstance(module, torch.nn.DataParallel):370mod = module.module371mod = convert_model(mod)372mod = DataParallelWithCallback(mod)373return mod374375mod = module376for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,377torch.nn.modules.batchnorm.BatchNorm2d,378torch.nn.modules.batchnorm.BatchNorm3d],379[SynchronizedBatchNorm1d,380SynchronizedBatchNorm2d,381SynchronizedBatchNorm3d]):382if isinstance(module, pth_module):383mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)384mod.running_mean = module.running_mean385mod.running_var = module.running_var386if module.affine:387mod.weight.data = module.weight.data.clone().detach()388mod.bias.data = module.bias.data.clone().detach()389390for name, child in module.named_children():391mod.add_module(name, convert_model(child))392393return mod394395396