Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever
GitHub Repository: ai-forever/sber-swap
Path: blob/main/models/networks/sync_batchnorm/comm.py
880 views
1
# -*- coding: utf-8 -*-
2
# File : comm.py
3
# Author : Jiayuan Mao
4
# Email : [email protected]
5
# Date : 27/01/2018
6
#
7
# This file is part of Synchronized-BatchNorm-PyTorch.
8
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
# Distributed under MIT License.
10
11
import queue
12
import collections
13
import threading
14
15
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
17
18
class FutureResult(object):
19
"""A thread-safe future implementation. Used only as one-to-one pipe."""
20
21
def __init__(self):
22
self._result = None
23
self._lock = threading.Lock()
24
self._cond = threading.Condition(self._lock)
25
26
def put(self, result):
27
with self._lock:
28
assert self._result is None, 'Previous result has\'t been fetched.'
29
self._result = result
30
self._cond.notify()
31
32
def get(self):
33
with self._lock:
34
if self._result is None:
35
self._cond.wait()
36
37
res = self._result
38
self._result = None
39
return res
40
41
42
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
45
46
class SlavePipe(_SlavePipeBase):
47
"""Pipe for master-slave communication."""
48
49
def run_slave(self, msg):
50
self.queue.put((self.identifier, msg))
51
ret = self.result.get()
52
self.queue.put(True)
53
return ret
54
55
56
class SyncMaster(object):
57
"""An abstract `SyncMaster` object.
58
59
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
and passed to a registered callback.
63
- After receiving the messages, the master device should gather the information and determine to message passed
64
back to each slave devices.
65
"""
66
67
def __init__(self, master_callback):
68
"""
69
70
Args:
71
master_callback: a callback to be invoked after having collected messages from slave devices.
72
"""
73
self._master_callback = master_callback
74
self._queue = queue.Queue()
75
self._registry = collections.OrderedDict()
76
self._activated = False
77
78
def __getstate__(self):
79
return {'master_callback': self._master_callback}
80
81
def __setstate__(self, state):
82
self.__init__(state['master_callback'])
83
84
def register_slave(self, identifier):
85
"""
86
Register an slave device.
87
88
Args:
89
identifier: an identifier, usually is the device id.
90
91
Returns: a `SlavePipe` object which can be used to communicate with the master device.
92
93
"""
94
if self._activated:
95
assert self._queue.empty(), 'Queue is not clean before next initialization.'
96
self._activated = False
97
self._registry.clear()
98
future = FutureResult()
99
self._registry[identifier] = _MasterRegistry(future)
100
return SlavePipe(identifier, self._queue, future)
101
102
def run_master(self, master_msg):
103
"""
104
Main entry for the master device in each forward pass.
105
The messages were first collected from each devices (including the master device), and then
106
an callback will be invoked to compute the message to be sent back to each devices
107
(including the master device).
108
109
Args:
110
master_msg: the message that the master want to send to itself. This will be placed as the first
111
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112
113
Returns: the message to be sent back to the master device.
114
115
"""
116
self._activated = True
117
118
intermediates = [(0, master_msg)]
119
for i in range(self.nr_slaves):
120
intermediates.append(self._queue.get())
121
122
results = self._master_callback(intermediates)
123
assert results[0][0] == 0, 'The first result should belongs to the master.'
124
125
for i, res in results:
126
if i == 0:
127
continue
128
self._registry[i].result.put(res)
129
130
for i in range(self.nr_slaves):
131
assert self._queue.get() is True
132
133
return results[0][1]
134
135
@property
136
def nr_slaves(self):
137
return len(self._registry)
138
139