CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/utils/inference/util.py
Views: 792
1
"""
2
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
"""
5
6
import re
7
import importlib
8
import torch
9
from argparse import Namespace
10
import numpy as np
11
from PIL import Image
12
import os
13
import argparse
14
import dill as pickle
15
#import util.coco
16
17
18
def save_obj(obj, name):
19
with open(name, 'wb') as f:
20
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
21
22
23
def load_obj(name):
24
with open(name, 'rb') as f:
25
return pickle.load(f)
26
27
# returns a configuration for creating a generator
28
# |default_opt| should be the opt of the current experiment
29
# |**kwargs|: if any configuration should be overriden, it can be specified here
30
31
32
def copyconf(default_opt, **kwargs):
33
conf = argparse.Namespace(**vars(default_opt))
34
for key in kwargs:
35
print(key, kwargs[key])
36
setattr(conf, key, kwargs[key])
37
return conf
38
39
40
def tile_images(imgs, picturesPerRow=4):
41
""" Code borrowed from
42
https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format/26521997
43
"""
44
45
# Padding
46
if imgs.shape[0] % picturesPerRow == 0:
47
rowPadding = 0
48
else:
49
rowPadding = picturesPerRow - imgs.shape[0] % picturesPerRow
50
if rowPadding > 0:
51
imgs = np.concatenate([imgs, np.zeros((rowPadding, *imgs.shape[1:]), dtype=imgs.dtype)], axis=0)
52
53
# Tiling Loop (The conditionals are not necessary anymore)
54
tiled = []
55
for i in range(0, imgs.shape[0], picturesPerRow):
56
tiled.append(np.concatenate([imgs[j] for j in range(i, i + picturesPerRow)], axis=1))
57
58
tiled = np.concatenate(tiled, axis=0)
59
return tiled
60
61
62
# Converts a Tensor into a Numpy array
63
# |imtype|: the desired type of the converted numpy array
64
def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
65
if isinstance(image_tensor, list):
66
image_numpy = []
67
for i in range(len(image_tensor)):
68
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
69
return image_numpy
70
71
if image_tensor.dim() == 4:
72
# transform each image in the batch
73
images_np = []
74
for b in range(image_tensor.size(0)):
75
one_image = image_tensor[b]
76
one_image_np = tensor2im(one_image)
77
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
78
images_np = np.concatenate(images_np, axis=0)
79
if tile:
80
images_tiled = tile_images(images_np)
81
return images_tiled
82
else:
83
return images_np
84
85
if image_tensor.dim() == 2:
86
image_tensor = image_tensor.unsqueeze(0)
87
image_numpy = image_tensor.detach().cpu().float().numpy()
88
if normalize:
89
#image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
90
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
91
else:
92
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
93
image_numpy = np.clip(image_numpy, 0, 255)
94
if image_numpy.shape[2] == 1:
95
image_numpy = image_numpy[:, :, 0]
96
return image_numpy.astype(imtype)
97
98
99
# Converts a one-hot tensor into a colorful label map
100
def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
101
if label_tensor.dim() == 4:
102
# transform each image in the batch
103
images_np = []
104
for b in range(label_tensor.size(0)):
105
one_image = label_tensor[b]
106
one_image_np = tensor2label(one_image, n_label, imtype)
107
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
108
images_np = np.concatenate(images_np, axis=0)
109
if tile:
110
images_tiled = tile_images(images_np)
111
return images_tiled
112
else:
113
images_np = images_np[0]
114
return images_np
115
116
if label_tensor.dim() == 1:
117
return np.zeros((64, 64, 3), dtype=np.uint8)
118
if n_label == 0:
119
return tensor2im(label_tensor, imtype)
120
label_tensor = label_tensor.cpu().float()
121
if label_tensor.size()[0] > 1:
122
label_tensor = label_tensor.max(0, keepdim=True)[1]
123
label_tensor = Colorize(n_label)(label_tensor)
124
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
125
result = label_numpy.astype(imtype)
126
return result
127
128
129
def save_image(image_numpy, image_path, create_dir=False):
130
if create_dir:
131
os.makedirs(os.path.dirname(image_path), exist_ok=True)
132
if len(image_numpy.shape) == 2:
133
image_numpy = np.expand_dims(image_numpy, axis=2)
134
if image_numpy.shape[2] == 1:
135
image_numpy = np.repeat(image_numpy, 3, 2)
136
image_pil = Image.fromarray(image_numpy)
137
138
# save to png
139
image_pil.save(image_path.replace('.jpg', '.png'))
140
141
142
def mkdirs(paths):
143
if isinstance(paths, list) and not isinstance(paths, str):
144
for path in paths:
145
mkdir(path)
146
else:
147
mkdir(paths)
148
149
150
def mkdir(path):
151
if not os.path.exists(path):
152
os.makedirs(path)
153
154
155
def atoi(text):
156
return int(text) if text.isdigit() else text
157
158
159
def natural_keys(text):
160
'''
161
alist.sort(key=natural_keys) sorts in human order
162
http://nedbatchelder.com/blog/200712/human_sorting.html
163
(See Toothy's implementation in the comments)
164
'''
165
return [atoi(c) for c in re.split('(\d+)', text)]
166
167
168
def natural_sort(items):
169
items.sort(key=natural_keys)
170
171
172
def str2bool(v):
173
if v.lower() in ('yes', 'true', 't', 'y', '1'):
174
return True
175
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
176
return False
177
else:
178
raise argparse.ArgumentTypeError('Boolean value expected.')
179
180
181
def find_class_in_module(target_cls_name, module):
182
target_cls_name = target_cls_name.replace('_', '').lower()
183
clslib = importlib.import_module(module)
184
cls = None
185
for name, clsobj in clslib.__dict__.items():
186
if name.lower() == target_cls_name:
187
cls = clsobj
188
189
if cls is None:
190
print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
191
exit(0)
192
193
return cls
194
195
196
def save_network(net, label, epoch, opt):
197
save_filename = '%s_net_%s.pth' % (epoch, label)
198
save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
199
torch.save(net.cpu().state_dict(), save_path)
200
if len(opt.gpu_ids) and torch.cuda.is_available():
201
net.cuda()
202
203
204
def load_network(net, label, epoch, opt):
205
save_filename = '%s_net_%s.pth' % (epoch, label)
206
save_dir = os.path.join(opt.checkpoints_dir, opt.name)
207
save_path = os.path.join(save_dir, save_filename)
208
weights = torch.load(save_path)
209
net.load_state_dict(weights)
210
print('Load checkpoint from path: ', save_path)
211
return net
212
213
214
###############################################################################
215
# Code from
216
# https://github.com/ycszen/pytorch-seg/blob/master/transform.py
217
# Modified so it complies with the Citscape label map colors
218
###############################################################################
219
def uint82bin(n, count=8):
220
"""returns the binary of integer n, count refers to amount of bits"""
221
return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
222
223
224
def labelcolormap(N):
225
if N == 35: # cityscape
226
cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81),
227
(128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153),
228
(180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
229
(107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),
230
(0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)],
231
dtype=np.uint8)
232
else:
233
cmap = np.zeros((N, 3), dtype=np.uint8)
234
for i in range(N):
235
r, g, b = 0, 0, 0
236
id = i + 1 # let's give 0 a color
237
for j in range(7):
238
str_id = uint82bin(id)
239
r = r ^ (np.uint8(str_id[-1]) << (7 - j))
240
g = g ^ (np.uint8(str_id[-2]) << (7 - j))
241
b = b ^ (np.uint8(str_id[-3]) << (7 - j))
242
id = id >> 3
243
cmap[i, 0] = r
244
cmap[i, 1] = g
245
cmap[i, 2] = b
246
247
if N == 182: # COCO
248
important_colors = {
249
'sea': (54, 62, 167),
250
'sky-other': (95, 219, 255),
251
'tree': (140, 104, 47),
252
'clouds': (170, 170, 170),
253
'grass': (29, 195, 49)
254
}
255
for i in range(N):
256
name = util.coco.id2label(i)
257
if name in important_colors:
258
color = important_colors[name]
259
cmap[i] = np.array(list(color))
260
261
return cmap
262
263
264
class Colorize(object):
265
def __init__(self, n=35):
266
self.cmap = labelcolormap(n)
267
self.cmap = torch.from_numpy(self.cmap[:n])
268
269
def __call__(self, gray_image):
270
size = gray_image.size()
271
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
272
273
for label in range(0, len(self.cmap)):
274
mask = (label == gray_image[0]).cpu()
275
color_image[0][mask] = self.cmap[label][0]
276
color_image[1][mask] = self.cmap[label][1]
277
color_image[2][mask] = self.cmap[label][2]
278
279
return color_image
280
281