Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
vardanagarwal
GitHub Repository: vardanagarwal/Proctoring-AI
Path: blob/master/coco models/tflite mobnetv1 ssd/seg_tflite.py
455 views
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Tue Aug 18 16:23:16 2020
4
5
@author: hp
6
"""
7
8
import numpy as np
9
import tensorflow as tf
10
import cv2
11
import visualization_utils as vis_util
12
13
def create_category_index(label_path='coco_ssd_mobilenet/labelmap.txt'):
14
"""
15
To create dictionary of label map
16
17
Parameters
18
----------
19
label_path : string, optional
20
Path to labelmap.txt. The default is 'coco_ssd_mobilenet/labelmap.txt'.
21
22
Returns
23
-------
24
category_index : dict
25
nested dictionary of labels.
26
27
"""
28
f = open(label_path)
29
category_index = {}
30
for i, val in enumerate(f):
31
if i != 0:
32
val = val[:-1]
33
if val != '???':
34
category_index.update({(i-1): {'id': (i-1), 'name': val}})
35
36
f.close()
37
return category_index
38
def get_output_dict(image, interpreter, output_details, nms=True, iou_thresh=0.5, score_thresh=0.6):
39
"""
40
Function to make predictions and generate dictionary of output
41
42
Parameters
43
----------
44
image : Array of uint8
45
Preprocessed Image to perform prediction on
46
interpreter : tensorflow.lite.python.interpreter.Interpreter
47
tflite model interpreter
48
input_details : list
49
input details of interpreter
50
output_details : list
51
nms : bool, optional
52
To perform non-maximum suppression or not. The default is True.
53
iou_thresh : int, optional
54
Intersection Over Union Threshold. The default is 0.5.
55
score_thresh : int, optional
56
score above predicted class is accepted. The default is 0.6.
57
58
Returns
59
-------
60
output_dict : dict
61
Dictionary containing bounding boxes, classes and scores.
62
63
"""
64
output_dict = {
65
'detection_boxes' : interpreter.get_tensor(output_details[0]['index'])[0],
66
'detection_classes' : interpreter.get_tensor(output_details[1]['index'])[0],
67
'detection_scores' : interpreter.get_tensor(output_details[2]['index'])[0],
68
'num_detections' : interpreter.get_tensor(output_details[3]['index'])[0]
69
}
70
71
output_dict['detection_classes'] = output_dict['detection_classes'].astype(np.int64)
72
if nms:
73
output_dict = apply_nms(output_dict, iou_thresh, score_thresh)
74
return output_dict
75
76
def apply_nms(output_dict, iou_thresh=0.5, score_thresh=0.6):
77
"""
78
Function to apply non-maximum suppression on different classes
79
80
Parameters
81
----------
82
output_dict : dictionary
83
dictionary containing:
84
'detection_boxes' : Bounding boxes coordinates. Shape (N, 4)
85
'detection_classes' : Class indices detected. Shape (N)
86
'detection_scores' : Shape (N)
87
'num_detections' : Total number of detections i.e. N. Shape (1)
88
iou_thresh : int, optional
89
Intersection Over Union threshold value. The default is 0.5.
90
score_thresh : int, optional
91
Score threshold value below which to ignore. The default is 0.6.
92
93
Returns
94
-------
95
output_dict : dictionary
96
dictionary containing only scores and IOU greater than threshold.
97
'detection_boxes' : Bounding boxes coordinates. Shape (N2, 4)
98
'detection_classes' : Class indices detected. Shape (N2)
99
'detection_scores' : Shape (N2)
100
where N2 is the number of valid predictions after those conditions.
101
102
"""
103
q = 90 # no of classes
104
num = int(output_dict['num_detections'])
105
boxes = np.zeros([1, num, q, 4])
106
scores = np.zeros([1, num, q])
107
# val = [0]*q
108
for i in range(num):
109
# indices = np.where(classes == output_dict['detection_classes'][i])[0][0]
110
boxes[0, i, output_dict['detection_classes'][i], :] = output_dict['detection_boxes'][i]
111
scores[0, i, output_dict['detection_classes'][i]] = output_dict['detection_scores'][i]
112
nmsd = tf.image.combined_non_max_suppression(boxes=boxes,
113
scores=scores,
114
max_output_size_per_class=num,
115
max_total_size=num,
116
iou_threshold=iou_thresh,
117
score_threshold=score_thresh,
118
pad_per_class=False,
119
clip_boxes=False)
120
valid = nmsd.valid_detections[0].numpy()
121
output_dict = {
122
'detection_boxes' : nmsd.nmsed_boxes[0].numpy()[:valid],
123
'detection_classes' : nmsd.nmsed_classes[0].numpy().astype(np.int64)[:valid],
124
'detection_scores' : nmsd.nmsed_scores[0].numpy()[:valid],
125
}
126
return output_dict
127
128
def make_and_show_inference(img, interpreter, input_details, output_details, category_index, nms=True, score_thresh=0.6, iou_thresh=0.5):
129
"""
130
Generate and draw inference on image
131
132
Parameters
133
----------
134
img : Array of uint8
135
Original Image to find predictions on.
136
interpreter : tensorflow.lite.python.interpreter.Interpreter
137
tflite model interpreter
138
input_details : list
139
input details of interpreter
140
output_details : list
141
output details of interpreter
142
category_index : dict
143
dictionary of labels
144
nms : bool, optional
145
To perform non-maximum suppression or not. The default is True.
146
score_thresh : int, optional
147
score above predicted class is accepted. The default is 0.6.
148
iou_thresh : int, optional
149
Intersection Over Union Threshold. The default is 0.5.
150
151
Returns
152
-------
153
NONE
154
"""
155
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
156
img_rgb = cv2.resize(img_rgb, (300, 300), cv2.INTER_AREA)
157
img_rgb = img_rgb.reshape([1, 300, 300, 3])
158
159
interpreter.set_tensor(input_details[0]['index'], img_rgb)
160
interpreter.invoke()
161
162
output_dict = get_output_dict(img_rgb, interpreter, output_details, nms, iou_thresh, score_thresh)
163
# Visualization of the results of a detection.
164
vis_util.visualize_boxes_and_labels_on_image_array(
165
img,
166
output_dict['detection_boxes'],
167
output_dict['detection_classes'],
168
output_dict['detection_scores'],
169
category_index,
170
use_normalized_coordinates=True,
171
min_score_thresh=score_thresh,
172
line_thickness=3)
173
174
# Load TFLite model and allocate tensors.
175
interpreter = tf.lite.Interpreter(model_path="coco_ssd_mobilenet/detect.tflite")
176
interpreter.allocate_tensors()
177
178
# Get input and output tensors.
179
input_details = interpreter.get_input_details()
180
output_details = interpreter.get_output_details()
181
182
category_index = create_category_index()
183
input_shape = input_details[0]['shape']
184
cap = cv2.VideoCapture(0)
185
186
while(True):
187
ret, img = cap.read()
188
if ret:
189
make_and_show_inference(img, interpreter, input_details, output_details, category_index)
190
cv2.imshow("image", img)
191
if cv2.waitKey(1) & 0xFF == ord('q'):
192
break
193
else:
194
break
195
196
cap.release()
197
cv2.destroyAllWindows()
198
199