物体检测之-垃圾分类

物体检测-垃圾分类

云中有鹿 2020/4/30

import os
import torch
from tqdm import tqdm
import math
import torchvision
from PIL import Image,ImageDraw,ImageFont
from torch import autograd
import torchvision.transforms as T

# 单独加载模型
CKP_PATH = './fasterrcnn_resnet50_fpn_coco-258fb6c6.pth'
Weight_PATH = "./TrainedNet1.pt" 
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
# 定义FasterRCNN的网络结,主要是修改预测的类别数量
def get_model(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
        pretrained=False, pretrained_backbone=False
    )
    model.load_state_dict(torch.load(CKP_PATH))
    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

# 加载模型权重
def Load_model():
    model = get_model(num_classes=205)
    if os.path.exists(Weight_PATH):
        model.load_state_dict(torch.load(Weight_PATH,map_location='cpu'))
    model.eval()
    
    return model
# 加载模型文件
model = Load_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 直接加载模型
model = torch.load('my_model.pth',map_location='cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 图片处理
def image_process(Image_url):
    img = Image.open(Image_url)
    transform = T.Compose([T.ToTensor()])  # Defing PyTorch Transform
    img_tensor = transform(img)  # Apply the transform to the image
    return img_tensor
# 生成验证集的结果,并检查验证集上的预测效果
def boxes_to_lines(preds,json_dict):
    r = []
    for bbox, label, score in zip(
        preds[0]["boxes"].cpu().detach().numpy(),
        preds[0]["labels"].cpu().detach().numpy(),
        preds[0]["scores"].cpu().detach().numpy(),
    ):
        # torchvision生成的bounding box格式为xyxy,需要转成xywh
        xyxy = list(bbox)
        xywh = [xyxy[0], xyxy[1], xyxy[2] - xyxy[0], xyxy[3]- xyxy[1]]
        if score >0.4 :
            r.append(
                {
                    "bbox": xywh,
                    "category_id": json_dict[str(label)],
                    "score": score,
                }
                )
    return r

#导入加载分类文件
import json
with open('./gabage-class.json', "r", encoding='utf-8') as fp:
    json_dict = json.load(fp)
#预测试结果显示
test_show_img_url ='../train/2020-gaebage/50345826a11f.JPG'
img_tensor = image_process(test_show_img_url)
preds = model.forward([img_tensor])
result= boxes_to_lines(preds,json_dict)
font = ImageFont.truetype('simsun.ttc',29)
base = Image.open(test_show_img_url).convert('RGBA')
d = ImageDraw.Draw(base)
for item in result:
    d.text((int(item['bbox'][0]),int(item['bbox'][1])),item['category_id'],(255,255,0),font=font) #分类标签
    d.rectangle([int(item['bbox'][0]),int(item['bbox'][1]),int(item['bbox'][0])+int(item['bbox'][2]), int(item['bbox'][1])+int(item['bbox'][3])],outline='RED')  # ,加入fill="red"的话,就可以填充颜色
base

png

result
[{'bbox': [650.15063, 204.76646, 197.27887, 238.59322],
  'category_id': '西瓜皮_湿垃圾',
  'score': 0.9993536},
 {'bbox': [984.9458, 506.62936, 99.37793, 102.61508],
  'category_id': '橡皮泥_干垃圾',
  'score': 0.99911386},
 {'bbox': [1169.229, 486.35214, 153.3247, 142.97177],
  'category_id': '粉笔_干垃圾',
  'score': 0.9983192},
 {'bbox': [1361.24, 487.92105, 33.58069, 30.368927],
  'category_id': '药片_有害垃圾',
  'score': 0.9971042},
 {'bbox': [813.12195, 431.14136, 57.010986, 183.35211],
  'category_id': '鸡骨头_湿垃圾',
  'score': 0.99564517},
 {'bbox': [969.70544, 180.80382, 125.880005, 297.3694],
  'category_id': '玉米棒_湿垃圾',
  'score': 0.99445397},
 {'bbox': [923.8991, 548.4976, 89.91388, 92.52167],
  'category_id': '动物内脏_湿垃圾',
  'score': 0.9941591},
 {'bbox': [1093.0338, 208.35583, 198.45679, 227.7938],
  'category_id': '粽子_湿垃圾',
  'score': 0.99379784},
 {'bbox': [815.61865, 312.36667, 157.42194, 92.74896],
  'category_id': '金属工具_可回收垃圾',
  'score': 0.9934223},
 {'bbox': [1307.6252, 366.43896, 101.814575, 120.547455],
  'category_id': '金属工具_可回收垃圾',
  'score': 0.9883195},
 {'bbox': [452.42352, 393.40768, 335.7898, 580.0679],
  'category_id': '毛发_干垃圾',
  'score': 0.967124},
 {'bbox': [1219.7949, 581.2517, 227.08667, 204.16669],
  'category_id': '农药瓶_有害垃圾',
  'score': 0.944504},
 {'bbox': [476.8133, 471.62048, 361.78912, 73.64337],
  'category_id': '口红_干垃圾',
  'score': 0.81554884},
 {'bbox': [935.1674, 183.82547, 79.83429, 97.20517],
  'category_id': '榴莲壳_干垃圾',
  'score': 0.69249374},
 {'bbox': [479.96027, 454.2975, 369.42932, 88.35443],
  'category_id': '笔_干垃圾',
  'score': 0.50012475}]