CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
ai-forever

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

GitHub Repository: ai-forever/sber-swap
Path: blob/main/coordinate_reg/image_infer.py
Views: 792
1
import cv2
2
import numpy as np
3
import os
4
import mxnet as mx
5
from skimage import transform as trans
6
import insightface
7
import sys
8
# sys.path.append('/home/jovyan/FaceShifter-2/FaceShifter3/')
9
from insightface_func.face_detect_crop_single import Face_detect_crop
10
import kornia
11
12
13
M = np.array([[ 0.57142857, 0., 32.],[ 0.,0.57142857, 32.]])
14
IM = np.array([[[1.75, -0., -56.],[ -0., 1.75, -56.]]])
15
16
17
def square_crop(im, S):
18
if im.shape[0] > im.shape[1]:
19
height = S
20
width = int(float(im.shape[1]) / im.shape[0] * S)
21
scale = float(S) / im.shape[0]
22
else:
23
width = S
24
height = int(float(im.shape[0]) / im.shape[1] * S)
25
scale = float(S) / im.shape[1]
26
resized_im = cv2.resize(im, (width, height))
27
det_im = np.zeros((S, S, 3), dtype=np.uint8)
28
det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im
29
return det_im, scale
30
31
32
def transform(data, center, output_size, scale, rotation):
33
scale_ratio = scale
34
rot = float(rotation) * np.pi / 180.0
35
#translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
36
t1 = trans.SimilarityTransform(scale=scale_ratio)
37
cx = center[0] * scale_ratio
38
cy = center[1] * scale_ratio
39
t2 = trans.SimilarityTransform(translation=(-1 * cx, -1 * cy))
40
t3 = trans.SimilarityTransform(rotation=rot)
41
t4 = trans.SimilarityTransform(translation=(output_size / 2,
42
output_size / 2))
43
t = t1 + t2 + t3 + t4
44
M = t.params[0:2]
45
cropped = cv2.warpAffine(data,
46
M, (output_size, output_size),
47
borderValue=0.0)
48
return cropped, M
49
50
51
def trans_points2d_batch(pts, M):
52
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
53
for j in range(pts.shape[0]):
54
for i in range(pts.shape[1]):
55
pt = pts[j][i]
56
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
57
new_pt = np.dot(M[j], new_pt)
58
new_pts[j][i] = new_pt[0:2]
59
return new_pts
60
61
62
def trans_points2d(pts, M):
63
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
64
for i in range(pts.shape[0]):
65
pt = pts[i]
66
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
67
new_pt = np.dot(M, new_pt)
68
#print('new_pt', new_pt.shape, new_pt)
69
new_pts[i] = new_pt[0:2]
70
71
return new_pts
72
73
74
def trans_points3d(pts, M):
75
scale = np.sqrt(M[0][0] * M[0][0] + M[0][1] * M[0][1])
76
#print(scale)
77
new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
78
for i in range(pts.shape[0]):
79
pt = pts[i]
80
new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
81
new_pt = np.dot(M, new_pt)
82
#print('new_pt', new_pt.shape, new_pt)
83
new_pts[i][0:2] = new_pt[0:2]
84
new_pts[i][2] = pts[i][2] * scale
85
86
return new_pts
87
88
89
def trans_points(pts, M):
90
if pts.shape[1] == 2:
91
return trans_points2d(pts, M)
92
else:
93
return trans_points3d(pts, M)
94
95
96
class Handler:
97
def __init__(self, prefix, epoch, im_size=192, det_size=224, ctx_id=0, root='./insightface_func/models'):
98
print('loading', prefix, epoch)
99
if ctx_id >= 0:
100
ctx = mx.gpu(ctx_id)
101
else:
102
ctx = mx.cpu()
103
image_size = (im_size, im_size)
104
# self.detector = insightface.model_zoo.get_model(
105
# 'retinaface_mnet025_v2') #can replace with your own face detector
106
self.detector = Face_detect_crop(name='antelope', root=root)
107
self.detector.prepare(ctx_id=ctx_id, det_thresh=0.6, det_size=(640,640))
108
#self.detector = insightface.model_zoo.get_model('retinaface_r50_v1')
109
#self.detector.prepare(ctx_id=ctx_id)
110
self.det_size = det_size
111
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
112
all_layers = sym.get_internals()
113
sym = all_layers['fc1_output']
114
self.image_size = image_size
115
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
116
model.bind(for_training=False,
117
data_shapes=[('data', (1, 3, image_size[0], image_size[1]))
118
])
119
model.set_params(arg_params, aux_params)
120
self.model = model
121
self.image_size = image_size
122
123
124
def get_without_detection_batch(self, img, M, IM):
125
rimg = kornia.warp_affine(img, M.repeat(img.shape[0],1,1), (192, 192), padding_mode='zeros')
126
rimg = kornia.bgr_to_rgb(rimg)
127
128
data = mx.nd.array(rimg)
129
db = mx.io.DataBatch(data=(data, ))
130
self.model.forward(db, is_train=False)
131
pred = self.model.get_outputs()[-1].asnumpy()
132
pred = pred.reshape((pred.shape[0], -1, 2))
133
pred[:, :, 0:2] += 1
134
pred[:, :, 0:2] *= (self.image_size[0] // 2)
135
136
pred = trans_points2d_batch(pred, IM.repeat(img.shape[0],1,1).numpy())
137
138
return pred
139
140
141
def get_without_detection_without_transform(self, img):
142
input_blob = np.zeros((1, 3) + self.image_size, dtype=np.float32)
143
rimg = cv2.warpAffine(img, M, self.image_size, borderValue=0.0)
144
rimg = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)
145
rimg = np.transpose(rimg, (2, 0, 1)) #3*112*112, RGB
146
147
input_blob[0] = rimg
148
data = mx.nd.array(input_blob)
149
db = mx.io.DataBatch(data=(data, ))
150
self.model.forward(db, is_train=False)
151
pred = self.model.get_outputs()[-1].asnumpy()[0]
152
pred = pred.reshape((-1, 2))
153
pred[:, 0:2] += 1
154
pred[:, 0:2] *= (self.image_size[0] // 2)
155
pred = trans_points2d(pred, IM)
156
157
return pred
158
159
160
def get_without_detection(self, img):
161
bbox = [0, 0, img.shape[0], img.shape[1]]
162
input_blob = np.zeros((1, 3) + self.image_size, dtype=np.float32)
163
164
w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
165
center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
166
rotate = 0
167
_scale = self.image_size[0] * 2 / 3.0 / max(w, h)
168
169
rimg, M = transform(img, center, self.image_size[0], _scale,
170
rotate)
171
rimg = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)
172
rimg = np.transpose(rimg, (2, 0, 1)) #3*112*112, RGB
173
174
input_blob[0] = rimg
175
data = mx.nd.array(input_blob)
176
db = mx.io.DataBatch(data=(data, ))
177
self.model.forward(db, is_train=False)
178
pred = self.model.get_outputs()[-1].asnumpy()[0]
179
if pred.shape[0] >= 3000:
180
pred = pred.reshape((-1, 3))
181
else:
182
pred = pred.reshape((-1, 2))
183
pred[:, 0:2] += 1
184
pred[:, 0:2] *= (self.image_size[0] // 2)
185
if pred.shape[1] == 3:
186
pred[:, 2] *= (self.image_size[0] // 2)
187
188
IM = cv2.invertAffineTransform(M)
189
pred = trans_points(pred, IM)
190
191
return pred
192
193
194
def get(self, img, get_all=False):
195
out = []
196
det_im, det_scale = square_crop(img, self.det_size)
197
bboxes, _ = self.detector.detect(det_im)
198
if bboxes.shape[0] == 0:
199
return out
200
bboxes /= det_scale
201
if not get_all:
202
areas = []
203
for i in range(bboxes.shape[0]):
204
x = bboxes[i]
205
area = (x[2] - x[0]) * (x[3] - x[1])
206
areas.append(area)
207
m = np.argsort(areas)[-1]
208
bboxes = bboxes[m:m + 1]
209
for i in range(bboxes.shape[0]):
210
bbox = bboxes[i]
211
input_blob = np.zeros((1, 3) + self.image_size, dtype=np.float32)
212
w, h = (bbox[2] - bbox[0]), (bbox[3] - bbox[1])
213
center = (bbox[2] + bbox[0]) / 2, (bbox[3] + bbox[1]) / 2
214
rotate = 0
215
_scale = self.image_size[0] * 2 / 3.0 / max(w, h)
216
rimg, M = transform(img, center, self.image_size[0], _scale,
217
rotate)
218
rimg = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)
219
rimg = np.transpose(rimg, (2, 0, 1)) #3*112*112, RGB
220
input_blob[0] = rimg
221
data = mx.nd.array(input_blob)
222
db = mx.io.DataBatch(data=(data, ))
223
self.model.forward(db, is_train=False)
224
pred = self.model.get_outputs()[-1].asnumpy()[0]
225
if pred.shape[0] >= 3000:
226
pred = pred.reshape((-1, 3))
227
else:
228
pred = pred.reshape((-1, 2))
229
pred[:, 0:2] += 1
230
pred[:, 0:2] *= (self.image_size[0] // 2)
231
if pred.shape[1] == 3:
232
pred[:, 2] *= (self.image_size[0] // 2)
233
234
IM = cv2.invertAffineTransform(M)
235
pred = trans_points(pred, IM)
236
out.append(pred)
237
return out
238
239