物体分类tflite之垃圾分类
2020-05-01
物体分类tflite之垃圾分类
云中有鹿 2020/4/30
import time
import json
import numpy as np
import tensorflow as tf
from PIL import Image
# 加载模型并分配张量
interpreter = tf.lite.Interpreter(model_path="./converted_model.tflite")
interpreter.allocate_tensors()
# 获取输入输出张量
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)
print(output_details)
[{'name': 'input_1', 'index': 1, 'shape': array([ 1, 224, 224, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
[{'name': 'Identity', 'index': 0, 'shape': array([ 1, 40], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0)}]
#加载分类
with open("./garbage_classify_rule.json", 'r') as load_f:
load_dict = json.load(load_f)
image = Image.open('./test4.png').convert('RGB').resize(
(224, 224), Image.ANTIALIAS)
image = np.array(image,dtype=np.float32).reshape(input_details[0]['shape'])
start = time.clock() #计算时间
interpreter.set_tensor(input_details[0]['index'],image)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
pred_label = np.argmax(output_data[0])
elapsed = (time.clock() - start)
print("Time used:",elapsed,"ms")
print(pred_label)
print(load_dict[str(pred_label)])
Time used: 0.8423990000000003 ms
36
可回收物/饮料瓶
Image.open('./test4.png').resize(
(224, 224))