Path: blob/main/models/networks/sync_batchnorm/comm.py
880 views
# -*- coding: utf-8 -*-1# File : comm.py2# Author : Jiayuan Mao3# Email : [email protected]4# Date : 27/01/20185#6# This file is part of Synchronized-BatchNorm-PyTorch.7# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch8# Distributed under MIT License.910import queue11import collections12import threading1314__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']151617class FutureResult(object):18"""A thread-safe future implementation. Used only as one-to-one pipe."""1920def __init__(self):21self._result = None22self._lock = threading.Lock()23self._cond = threading.Condition(self._lock)2425def put(self, result):26with self._lock:27assert self._result is None, 'Previous result has\'t been fetched.'28self._result = result29self._cond.notify()3031def get(self):32with self._lock:33if self._result is None:34self._cond.wait()3536res = self._result37self._result = None38return res394041_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])42_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])434445class SlavePipe(_SlavePipeBase):46"""Pipe for master-slave communication."""4748def run_slave(self, msg):49self.queue.put((self.identifier, msg))50ret = self.result.get()51self.queue.put(True)52return ret535455class SyncMaster(object):56"""An abstract `SyncMaster` object.5758- During the replication, as the data parallel will trigger an callback of each module, all slave devices should59call `register(id)` and obtain an `SlavePipe` to communicate with the master.60- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,61and passed to a registered callback.62- After receiving the messages, the master device should gather the information and determine to message passed63back to each slave devices.64"""6566def __init__(self, master_callback):67"""6869Args:70master_callback: a callback to be invoked after having collected messages from slave devices.71"""72self._master_callback = master_callback73self._queue = queue.Queue()74self._registry = collections.OrderedDict()75self._activated = False7677def __getstate__(self):78return {'master_callback': self._master_callback}7980def __setstate__(self, state):81self.__init__(state['master_callback'])8283def register_slave(self, identifier):84"""85Register an slave device.8687Args:88identifier: an identifier, usually is the device id.8990Returns: a `SlavePipe` object which can be used to communicate with the master device.9192"""93if self._activated:94assert self._queue.empty(), 'Queue is not clean before next initialization.'95self._activated = False96self._registry.clear()97future = FutureResult()98self._registry[identifier] = _MasterRegistry(future)99return SlavePipe(identifier, self._queue, future)100101def run_master(self, master_msg):102"""103Main entry for the master device in each forward pass.104The messages were first collected from each devices (including the master device), and then105an callback will be invoked to compute the message to be sent back to each devices106(including the master device).107108Args:109master_msg: the message that the master want to send to itself. This will be placed as the first110message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.111112Returns: the message to be sent back to the master device.113114"""115self._activated = True116117intermediates = [(0, master_msg)]118for i in range(self.nr_slaves):119intermediates.append(self._queue.get())120121results = self._master_callback(intermediates)122assert results[0][0] == 0, 'The first result should belongs to the master.'123124for i, res in results:125if i == 0:126continue127self._registry[i].result.put(res)128129for i in range(self.nr_slaves):130assert self._queue.get() is True131132return results[0][1]133134@property135def nr_slaves(self):136return len(self._registry)137138139