Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/utils/inference/util.py
Views: 792
"""1Copyright (C) 2019 NVIDIA Corporation. All rights reserved.2Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).3"""45import re6import importlib7import torch8from argparse import Namespace9import numpy as np10from PIL import Image11import os12import argparse13import dill as pickle14#import util.coco151617def save_obj(obj, name):18with open(name, 'wb') as f:19pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)202122def load_obj(name):23with open(name, 'rb') as f:24return pickle.load(f)2526# returns a configuration for creating a generator27# |default_opt| should be the opt of the current experiment28# |**kwargs|: if any configuration should be overriden, it can be specified here293031def copyconf(default_opt, **kwargs):32conf = argparse.Namespace(**vars(default_opt))33for key in kwargs:34print(key, kwargs[key])35setattr(conf, key, kwargs[key])36return conf373839def tile_images(imgs, picturesPerRow=4):40""" Code borrowed from41https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format/2652199742"""4344# Padding45if imgs.shape[0] % picturesPerRow == 0:46rowPadding = 047else:48rowPadding = picturesPerRow - imgs.shape[0] % picturesPerRow49if rowPadding > 0:50imgs = np.concatenate([imgs, np.zeros((rowPadding, *imgs.shape[1:]), dtype=imgs.dtype)], axis=0)5152# Tiling Loop (The conditionals are not necessary anymore)53tiled = []54for i in range(0, imgs.shape[0], picturesPerRow):55tiled.append(np.concatenate([imgs[j] for j in range(i, i + picturesPerRow)], axis=1))5657tiled = np.concatenate(tiled, axis=0)58return tiled596061# Converts a Tensor into a Numpy array62# |imtype|: the desired type of the converted numpy array63def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):64if isinstance(image_tensor, list):65image_numpy = []66for i in range(len(image_tensor)):67image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))68return image_numpy6970if image_tensor.dim() == 4:71# transform each image in the batch72images_np = []73for b in range(image_tensor.size(0)):74one_image = image_tensor[b]75one_image_np = tensor2im(one_image)76images_np.append(one_image_np.reshape(1, *one_image_np.shape))77images_np = np.concatenate(images_np, axis=0)78if tile:79images_tiled = tile_images(images_np)80return images_tiled81else:82return images_np8384if image_tensor.dim() == 2:85image_tensor = image_tensor.unsqueeze(0)86image_numpy = image_tensor.detach().cpu().float().numpy()87if normalize:88#image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.089image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.090else:91image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.092image_numpy = np.clip(image_numpy, 0, 255)93if image_numpy.shape[2] == 1:94image_numpy = image_numpy[:, :, 0]95return image_numpy.astype(imtype)969798# Converts a one-hot tensor into a colorful label map99def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):100if label_tensor.dim() == 4:101# transform each image in the batch102images_np = []103for b in range(label_tensor.size(0)):104one_image = label_tensor[b]105one_image_np = tensor2label(one_image, n_label, imtype)106images_np.append(one_image_np.reshape(1, *one_image_np.shape))107images_np = np.concatenate(images_np, axis=0)108if tile:109images_tiled = tile_images(images_np)110return images_tiled111else:112images_np = images_np[0]113return images_np114115if label_tensor.dim() == 1:116return np.zeros((64, 64, 3), dtype=np.uint8)117if n_label == 0:118return tensor2im(label_tensor, imtype)119label_tensor = label_tensor.cpu().float()120if label_tensor.size()[0] > 1:121label_tensor = label_tensor.max(0, keepdim=True)[1]122label_tensor = Colorize(n_label)(label_tensor)123label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))124result = label_numpy.astype(imtype)125return result126127128def save_image(image_numpy, image_path, create_dir=False):129if create_dir:130os.makedirs(os.path.dirname(image_path), exist_ok=True)131if len(image_numpy.shape) == 2:132image_numpy = np.expand_dims(image_numpy, axis=2)133if image_numpy.shape[2] == 1:134image_numpy = np.repeat(image_numpy, 3, 2)135image_pil = Image.fromarray(image_numpy)136137# save to png138image_pil.save(image_path.replace('.jpg', '.png'))139140141def mkdirs(paths):142if isinstance(paths, list) and not isinstance(paths, str):143for path in paths:144mkdir(path)145else:146mkdir(paths)147148149def mkdir(path):150if not os.path.exists(path):151os.makedirs(path)152153154def atoi(text):155return int(text) if text.isdigit() else text156157158def natural_keys(text):159'''160alist.sort(key=natural_keys) sorts in human order161http://nedbatchelder.com/blog/200712/human_sorting.html162(See Toothy's implementation in the comments)163'''164return [atoi(c) for c in re.split('(\d+)', text)]165166167def natural_sort(items):168items.sort(key=natural_keys)169170171def str2bool(v):172if v.lower() in ('yes', 'true', 't', 'y', '1'):173return True174elif v.lower() in ('no', 'false', 'f', 'n', '0'):175return False176else:177raise argparse.ArgumentTypeError('Boolean value expected.')178179180def find_class_in_module(target_cls_name, module):181target_cls_name = target_cls_name.replace('_', '').lower()182clslib = importlib.import_module(module)183cls = None184for name, clsobj in clslib.__dict__.items():185if name.lower() == target_cls_name:186cls = clsobj187188if cls is None:189print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))190exit(0)191192return cls193194195def save_network(net, label, epoch, opt):196save_filename = '%s_net_%s.pth' % (epoch, label)197save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)198torch.save(net.cpu().state_dict(), save_path)199if len(opt.gpu_ids) and torch.cuda.is_available():200net.cuda()201202203def load_network(net, label, epoch, opt):204save_filename = '%s_net_%s.pth' % (epoch, label)205save_dir = os.path.join(opt.checkpoints_dir, opt.name)206save_path = os.path.join(save_dir, save_filename)207weights = torch.load(save_path)208net.load_state_dict(weights)209print('Load checkpoint from path: ', save_path)210return net211212213###############################################################################214# Code from215# https://github.com/ycszen/pytorch-seg/blob/master/transform.py216# Modified so it complies with the Citscape label map colors217###############################################################################218def uint82bin(n, count=8):219"""returns the binary of integer n, count refers to amount of bits"""220return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])221222223def labelcolormap(N):224if N == 35: # cityscape225cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81),226(128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153),227(180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),228(107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70),229(0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)],230dtype=np.uint8)231else:232cmap = np.zeros((N, 3), dtype=np.uint8)233for i in range(N):234r, g, b = 0, 0, 0235id = i + 1 # let's give 0 a color236for j in range(7):237str_id = uint82bin(id)238r = r ^ (np.uint8(str_id[-1]) << (7 - j))239g = g ^ (np.uint8(str_id[-2]) << (7 - j))240b = b ^ (np.uint8(str_id[-3]) << (7 - j))241id = id >> 3242cmap[i, 0] = r243cmap[i, 1] = g244cmap[i, 2] = b245246if N == 182: # COCO247important_colors = {248'sea': (54, 62, 167),249'sky-other': (95, 219, 255),250'tree': (140, 104, 47),251'clouds': (170, 170, 170),252'grass': (29, 195, 49)253}254for i in range(N):255name = util.coco.id2label(i)256if name in important_colors:257color = important_colors[name]258cmap[i] = np.array(list(color))259260return cmap261262263class Colorize(object):264def __init__(self, n=35):265self.cmap = labelcolormap(n)266self.cmap = torch.from_numpy(self.cmap[:n])267268def __call__(self, gray_image):269size = gray_image.size()270color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)271272for label in range(0, len(self.cmap)):273mask = (label == gray_image[0]).cpu()274color_image[0][mask] = self.cmap[label][0]275color_image[1][mask] = self.cmap[label][1]276color_image[2][mask] = self.cmap[label][2]277278return color_image279280281