Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/coordinate_reg/image_infer.py
Views: 792
import cv21import numpy as np2import os3import mxnet as mx4from skimage import transform as trans5import insightface6import sys7# sys.path.append('/home/jovyan/FaceShifter-2/FaceShifter3/')8from insightface_func.face_detect_crop_single import Face_detect_crop9import kornia101112M = np.array([[ 0.57142857, 0., 32.],[ 0.,0.57142857, 32.]])13IM = np.array([[[1.75, -0., -56.],[ -0., 1.75, -56.]]])141516def square_crop(im, S):17if im.shape[0] > im.shape[1]:18height = S19width = int(float(im.shape[1]) / im.shape[0] * S)20scale = float(S) / im.shape[0]21else:22width = S23height = int(float(im.shape[0]) / im.shape[1] * S)24scale = float(S) / im.shape[1]25resized_im = cv2.resize(im, (width, height))26det_im = np.zeros((S, S, 3), dtype=np.uint8)27det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im28return det_im, scale293031def transform(data, center, output_size, scale, rotation):32scale_ratio = scale33rot = float(rotation) * np.pi / 180.034#translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)35t1 = trans.SimilarityTransform(scale=scale_ratio)36cx = center[0] * scale_ratio37cy = center[1] * scale_ratio38t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))39t3 = trans.SimilarityTransform(rotation=rot)40t4 = trans.SimilarityTransform(translation=(output_size / 2,41output_size / 2))42t = t1 + t2 + t3 + t443M = t.params[0:2]44cropped = cv2.warpAffine(data,45M, (output_size, output_size),46borderValue=0.0)47return cropped, M484950def trans_points2d_batch(pts, M):51new_pts = np.zeros(shape=pts.shape, dtype=np.float32)52for j in range(pts.shape[0]):53for i in range(pts.shape[1]):54pt = pts[j][i]55new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)56new_pt = np.dot(M[j], new_pt)57new_pts[j][i] = new_pt[0:2]58return new_pts596061def trans_points2d(pts, M):62new_pts = np.zeros(shape=pts.shape, dtype=np.float32)63for i in range(pts.shape[0]):64pt = pts[i]65new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)66new_pt = np.dot(M, new_pt)67#print('new_pt', new_pt.shape, new_pt)68new_pts[i] = new_pt[0:2]6970return new_pts717273def trans_points3d(pts, M):74scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])75#print(scale)76new_pts = np.zeros(shape=pts.shape, dtype=np.float32)77for i in range(pts.shape[0]):78pt = pts[i]79new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)80new_pt = np.dot(M, new_pt)81#print('new_pt', new_pt.shape, new_pt)82new_pts[i][0:2] = new_pt[0:2]83new_pts[i][2] = pts[i][2] * scale8485return new_pts868788def trans_points(pts, M):89if pts.shape[1] == 2:90return trans_points2d(pts, M)91else:92return trans_points3d(pts, M)939495class Handler:96def __init__(self, prefix, epoch, im_size=192, det_size=224, ctx_id=0, root='./insightface_func/models'):97print('loading', prefix, epoch)98if ctx_id >= 0:99ctx = mx.gpu(ctx_id)100else:101ctx = mx.cpu()102image_size = (im_size, im_size)103# self.detector = insightface.model_zoo.get_model(104# 'retinaface_mnet025_v2') #can replace with your own face detector105self.detector = Face_detect_crop(name='antelope', root=root)106self.detector.prepare(ctx_id=ctx_id, det_thresh=0.6, det_size=(640,640))107#self.detector = insightface.model_zoo.get_model('retinaface_r50_v1')108#self.detector.prepare(ctx_id=ctx_id)109self.det_size = det_size110sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)111all_layers = sym.get_internals()112sym = all_layers['fc1_output']113self.image_size = image_size114model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)115model.bind(for_training=False,116data_shapes=[('data', (1, 3, image_size[0], image_size[1]))117])118model.set_params(arg_params, aux_params)119self.model = model120self.image_size = image_size121122123def get_without_detection_batch(self, img, M, IM):124rimg = kornia.warp_affine(img, M.repeat(img.shape[0],1,1), (192, 192), padding_mode='zeros')125rimg = kornia.bgr_to_rgb(rimg)126127data = mx.nd.array(rimg)128db = mx.io.DataBatch(data=(data, ))129self.model.forward(db, is_train=False)130pred = self.model.get_outputs()[-1].asnumpy()131pred = pred.reshape((pred.shape[0], -1, 2))132pred[:, :, 0:2] += 1133pred[:, :, 0:2] *= (self.image_size[0] // 2)134135pred = trans_points2d_batch(pred, IM.repeat(img.shape[0],1,1).numpy())136137return pred138139140def get_without_detection_without_transform(self, img):141input_blob = np.zeros((1, 3) + self.image_size, dtype=np.float32)142rimg = cv2.warpAffine(img, M, self.image_size, borderValue=0.0)143rimg = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)144rimg = np.transpose(rimg, (2, 0, 1)) #3*112*112, RGB145146input_blob[0] = rimg147data = mx.nd.array(input_blob)148db = mx.io.DataBatch(data=(data, ))149self.model.forward(db, is_train=False)150pred = self.model.get_outputs()[-1].asnumpy()[0]151pred = pred.reshape((-1, 2))152pred[:, 0:2] += 1153pred[:, 0:2] *= (self.image_size[0] // 2)154pred = trans_points2d(pred, IM)155156return pred157158159def get_without_detection(self, img):160bbox = [0, 0, img.shape[0], img.shape[1]]161input_blob = np.zeros((1, 3) + self.image_size, dtype=np.float32)162163w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])164center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2165rotate = 0166_scale = self.image_size[0] * 2 / 3.0 / max(w, h)167168rimg, M = transform(img, center, self.image_size[0], _scale,169rotate)170rimg = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)171rimg = np.transpose(rimg, (2, 0, 1)) #3*112*112, RGB172173input_blob[0] = rimg174data = mx.nd.array(input_blob)175db = mx.io.DataBatch(data=(data, ))176self.model.forward(db, is_train=False)177pred = self.model.get_outputs()[-1].asnumpy()[0]178if pred.shape[0] >= 3000:179pred = pred.reshape((-1, 3))180else:181pred = pred.reshape((-1, 2))182pred[:, 0:2] += 1183pred[:, 0:2] *= (self.image_size[0] // 2)184if pred.shape[1] == 3:185pred[:, 2] *= (self.image_size[0] // 2)186187IM = cv2.invertAffineTransform(M)188pred = trans_points(pred, IM)189190return pred191192193def get(self, img, get_all=False):194out = []195det_im, det_scale = square_crop(img, self.det_size)196bboxes, _ = self.detector.detect(det_im)197if bboxes.shape[0] == 0:198return out199bboxes /= det_scale200if not get_all:201areas = []202for i in range(bboxes.shape[0]):203x = bboxes[i]204area = (x[2] - x[0]) * (x[3] - x[1])205areas.append(area)206m = np.argsort(areas)[-1]207bboxes = bboxes[m:m + 1]208for i in range(bboxes.shape[0]):209bbox = bboxes[i]210input_blob = np.zeros((1, 3) + self.image_size, dtype=np.float32)211w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])212center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2213rotate = 0214_scale = self.image_size[0] * 2 / 3.0 / max(w, h)215rimg, M = transform(img, center, self.image_size[0], _scale,216rotate)217rimg = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)218rimg = np.transpose(rimg, (2, 0, 1)) #3*112*112, RGB219input_blob[0] = rimg220data = mx.nd.array(input_blob)221db = mx.io.DataBatch(data=(data, ))222self.model.forward(db, is_train=False)223pred = self.model.get_outputs()[-1].asnumpy()[0]224if pred.shape[0] >= 3000:225pred = pred.reshape((-1, 3))226else:227pred = pred.reshape((-1, 2))228pred[:, 0:2] += 1229pred[:, 0:2] *= (self.image_size[0] // 2)230if pred.shape[1] == 3:231pred[:, 2] *= (self.image_size[0] // 2)232233IM = cv2.invertAffineTransform(M)234pred = trans_points(pred, IM)235out.append(pred)236return out237238239