Path: blob/master/coco models/tflite mobnetv1 ssd/seg_tflite.py
455 views
# -*- coding: utf-8 -*-1"""2Created on Tue Aug 18 16:23:16 202034@author: hp5"""67import numpy as np8import tensorflow as tf9import cv210import visualization_utils as vis_util1112def create_category_index(label_path='coco_ssd_mobilenet/labelmap.txt'):13"""14To create dictionary of label map1516Parameters17----------18label_path : string, optional19Path to labelmap.txt. The default is 'coco_ssd_mobilenet/labelmap.txt'.2021Returns22-------23category_index : dict24nested dictionary of labels.2526"""27f = open(label_path)28category_index = {}29for i, val in enumerate(f):30if i != 0:31val = val[:-1]32if val != '???':33category_index.update({(i-1): {'id': (i-1), 'name': val}})3435f.close()36return category_index37def get_output_dict(image, interpreter, output_details, nms=True, iou_thresh=0.5, score_thresh=0.6):38"""39Function to make predictions and generate dictionary of output4041Parameters42----------43image : Array of uint844Preprocessed Image to perform prediction on45interpreter : tensorflow.lite.python.interpreter.Interpreter46tflite model interpreter47input_details : list48input details of interpreter49output_details : list50nms : bool, optional51To perform non-maximum suppression or not. The default is True.52iou_thresh : int, optional53Intersection Over Union Threshold. The default is 0.5.54score_thresh : int, optional55score above predicted class is accepted. The default is 0.6.5657Returns58-------59output_dict : dict60Dictionary containing bounding boxes, classes and scores.6162"""63output_dict = {64'detection_boxes' : interpreter.get_tensor(output_details[0]['index'])[0],65'detection_classes' : interpreter.get_tensor(output_details[1]['index'])[0],66'detection_scores' : interpreter.get_tensor(output_details[2]['index'])[0],67'num_detections' : interpreter.get_tensor(output_details[3]['index'])[0]68}6970output_dict['detection_classes'] = output_dict['detection_classes'].astype(np.int64)71if nms:72output_dict = apply_nms(output_dict, iou_thresh, score_thresh)73return output_dict7475def apply_nms(output_dict, iou_thresh=0.5, score_thresh=0.6):76"""77Function to apply non-maximum suppression on different classes7879Parameters80----------81output_dict : dictionary82dictionary containing:83'detection_boxes' : Bounding boxes coordinates. Shape (N, 4)84'detection_classes' : Class indices detected. Shape (N)85'detection_scores' : Shape (N)86'num_detections' : Total number of detections i.e. N. Shape (1)87iou_thresh : int, optional88Intersection Over Union threshold value. The default is 0.5.89score_thresh : int, optional90Score threshold value below which to ignore. The default is 0.6.9192Returns93-------94output_dict : dictionary95dictionary containing only scores and IOU greater than threshold.96'detection_boxes' : Bounding boxes coordinates. Shape (N2, 4)97'detection_classes' : Class indices detected. Shape (N2)98'detection_scores' : Shape (N2)99where N2 is the number of valid predictions after those conditions.100101"""102q = 90 # no of classes103num = int(output_dict['num_detections'])104boxes = np.zeros([1, num, q, 4])105scores = np.zeros([1, num, q])106# val = [0]*q107for i in range(num):108# indices = np.where(classes == output_dict['detection_classes'][i])[0][0]109boxes[0, i, output_dict['detection_classes'][i], :] = output_dict['detection_boxes'][i]110scores[0, i, output_dict['detection_classes'][i]] = output_dict['detection_scores'][i]111nmsd = tf.image.combined_non_max_suppression(boxes=boxes,112scores=scores,113max_output_size_per_class=num,114max_total_size=num,115iou_threshold=iou_thresh,116score_threshold=score_thresh,117pad_per_class=False,118clip_boxes=False)119valid = nmsd.valid_detections[0].numpy()120output_dict = {121'detection_boxes' : nmsd.nmsed_boxes[0].numpy()[:valid],122'detection_classes' : nmsd.nmsed_classes[0].numpy().astype(np.int64)[:valid],123'detection_scores' : nmsd.nmsed_scores[0].numpy()[:valid],124}125return output_dict126127def make_and_show_inference(img, interpreter, input_details, output_details, category_index, nms=True, score_thresh=0.6, iou_thresh=0.5):128"""129Generate and draw inference on image130131Parameters132----------133img : Array of uint8134Original Image to find predictions on.135interpreter : tensorflow.lite.python.interpreter.Interpreter136tflite model interpreter137input_details : list138input details of interpreter139output_details : list140output details of interpreter141category_index : dict142dictionary of labels143nms : bool, optional144To perform non-maximum suppression or not. The default is True.145score_thresh : int, optional146score above predicted class is accepted. The default is 0.6.147iou_thresh : int, optional148Intersection Over Union Threshold. The default is 0.5.149150Returns151-------152NONE153"""154img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)155img_rgb = cv2.resize(img_rgb, (300, 300), cv2.INTER_AREA)156img_rgb = img_rgb.reshape([1, 300, 300, 3])157158interpreter.set_tensor(input_details[0]['index'], img_rgb)159interpreter.invoke()160161output_dict = get_output_dict(img_rgb, interpreter, output_details, nms, iou_thresh, score_thresh)162# Visualization of the results of a detection.163vis_util.visualize_boxes_and_labels_on_image_array(164img,165output_dict['detection_boxes'],166output_dict['detection_classes'],167output_dict['detection_scores'],168category_index,169use_normalized_coordinates=True,170min_score_thresh=score_thresh,171line_thickness=3)172173# Load TFLite model and allocate tensors.174interpreter = tf.lite.Interpreter(model_path="coco_ssd_mobilenet/detect.tflite")175interpreter.allocate_tensors()176177# Get input and output tensors.178input_details = interpreter.get_input_details()179output_details = interpreter.get_output_details()180181category_index = create_category_index()182input_shape = input_details[0]['shape']183cap = cv2.VideoCapture(0)184185while(True):186ret, img = cap.read()187if ret:188make_and_show_inference(img, interpreter, input_details, output_details, category_index)189cv2.imshow("image", img)190if cv2.waitKey(1) & 0xFF == ord('q'):191break192else:193break194195cap.release()196cv2.destroyAllWindows()197198199