Path: blob/main/models/networks/sync_batchnorm/replicate.py
880 views
# -*- coding: utf-8 -*-1# File : replicate.py2# Author : Jiayuan Mao3# Email : [email protected]4# Date : 27/01/20185#6# This file is part of Synchronized-BatchNorm-PyTorch.7# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch8# Distributed under MIT License.910import functools1112from torch.nn.parallel.data_parallel import DataParallel1314__all__ = [15'CallbackContext',16'execute_replication_callbacks',17'DataParallelWithCallback',18'patch_replication_callback'19]202122class CallbackContext(object):23pass242526def execute_replication_callbacks(modules):27"""28Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.2930The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`3132Note that, as all modules are isomorphism, we assign each sub-module with a context33(shared among multiple copies of this module on different devices).34Through this context, different copies can share some information.3536We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback37of any slave copies.38"""39master_copy = modules[0]40nr_modules = len(list(master_copy.modules()))41ctxs = [CallbackContext() for _ in range(nr_modules)]4243for i, module in enumerate(modules):44for j, m in enumerate(module.modules()):45if hasattr(m, '__data_parallel_replicate__'):46m.__data_parallel_replicate__(ctxs[j], i)474849class DataParallelWithCallback(DataParallel):50"""51Data Parallel with a replication callback.5253An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by54original `replicate` function.55The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`5657Examples:58> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)59> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])60# sync_bn.__data_parallel_replicate__ will be invoked.61"""6263def replicate(self, module, device_ids):64modules = super(DataParallelWithCallback, self).replicate(module, device_ids)65execute_replication_callbacks(modules)66return modules676869def patch_replication_callback(data_parallel):70"""71Monkey-patch an existing `DataParallel` object. Add the replication callback.72Useful when you have customized `DataParallel` implementation.7374Examples:75> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)76> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])77> patch_replication_callback(sync_bn)78# this is equivalent to79> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)80> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])81"""8283assert isinstance(data_parallel, DataParallel)8485old_replicate = data_parallel.replicate8687@functools.wraps(old_replicate)88def new_replicate(module, device_ids):89modules = old_replicate(module, device_ids)90execute_replication_callbacks(modules)91return modules9293data_parallel.replicate = new_replicate949596