基于MAX78000的传送带工业钻头方向检测
基于MAX78000平台,能够实时检测传送带上不同种类钻头的摆放方向。
标签
嵌入式系统
持心铭志
更新2022-12-02
415

内容介绍

项目目标:能够实现对传送带上多种工业钻头方向的检测,具有足够的响应速率,99%以上的召回率,即尽可能做到不漏检。

 

思路方法:将关注点聚焦到钻头特定的有明显特征的部位,从而可以缩小检测区域,减小网络规模、计算量,提高检测精度和速度。这里选择聚焦在钻头头部位置,如果是正方向,则是检测到钻头头部。否则检测到钻头尾部。

Fn7g-GlizZUh5ltwa6qvb6C42qfq

数据集获取:通过对不同类型的钻头拍照,截取头部尾部区域,然后进行数据扩增。

 

1.原始采集数据为4种不同型号的钻头照片,每种3张,为PNG格式。通过PS抠图抠出钻头头部明显特征部分,为类别1。抠出尾部部分,为类别2。

2.将钻头头部和尾部照片与钻头所在环境的背景图片进行合成,每次截取背景中一片256x256的区域,随机地将钻头特征图片放在背景中的区域。然后将钻头特征图片旋转一定的角度,重复合成过程,值得近两万张数据集。

Fi8kSSyW43qNLA-O0wFD9KSuGAzxFqTZGSP1_MmgKjd2di2QsPaXqPD2Fi1iEc1uKo1jrm87Rw7rvMIWoMvqFrFMcxQ4y_-ITJuuEBA1811MUWfl

 

数据生成代码如下图所示:

import os
import shutil
import cv2
import copy
import numpy as np
import random
 
def opencv_rotate(img, angle):
    """
    图片旋转,默认应该是逆时针转动
    :param img:
    :param angle:
    :return:
    """
    h, w = img.shape[:2]  # 图像的(行数,列数,色彩通道数)
    borderValue = (0, 0, 0, 0)

    # 颜色空间转换?
    if img.shape[-1] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
    elif img.shape[-1] == 1:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    center = (w / 2, h / 2)
    scale = 1.0
    # 2.1获取M矩阵
    """
    M矩阵
    [
    cosA -sinA (1-cosA)*centerX+sinA*centerY
    sinA cosA  -sinA*centerX+(1-cosA)*centerY
    ]
    """
    # cv2.getRotationMatrix2D(获得仿射变化矩阵)
    M = cv2.getRotationMatrix2D(center, angle, scale)
    # 2.2 新的宽高,radians(angle) 把角度转为弧度 sin(弧度)
    new_H = int(w * np.fabs(np.sin(np.radians(angle))) + h * np.fabs(np.cos(np.radians(angle))))
    new_W = int(h * np.fabs(np.sin(np.radians(angle))) + w * np.fabs(np.cos(np.radians(angle))))

    # 2.3 平移
    M[0, 2] += (new_W - w) / 2
    M[1, 2] += (new_H - h) / 2
 
    # cv2.warpAffine(进行仿射变化)
    return cv2.warpAffine(img, M, (new_W, new_H), borderValue=borderValue)

def merge_png_jpg(img_png, img_jpg, merge_factor, x_offset=0, y_offset=0):
    height, width= img_png.shape[:2]

    img_jpg = copy.deepcopy(img_jpg)
    for i in range(0, 3):
        img_jpg[y_offset:y_offset + height, x_offset:x_offset + width, i] = ((1 - merge_factor) *  img_jpg[y_offset:y_offset + height, x_offset:x_offset + width, i] 
                                                                          + merge_factor * img_png[:height, :width, i])

    # cv2.imshow('img_new', img_jpg)
    # cv2.waitKey(0)

    return img_jpg

global num_saves

def translation_transformation(img_png, img_jpg, angle, label_class, save_path, png_name):
    png_height, png_width = img_png.shape[:2]
    jpg_height, jpg_width = img_jpg.shape[:2]

    max_offset_y = jpg_height - 2*128
    max_offset_x = jpg_width  - 2*128

    merge_factor = cv2.split(img_png)[-1] / 255.0

    for y_offset in range(2*128, max_offset_y, 100):
        for x_offset in range(2*128, max_offset_x, 100):
            img_new = merge_png_jpg(img_png, img_jpg, merge_factor, x_offset, y_offset)
            x_rand  = random.randint(-(int)(128-png_width/2),   (int)(128-png_width/2))
            y_rand  = random.randint(-(int)(128-png_height/2), (int)(128-png_height/2))

            y_left  = (int)(y_offset+png_height/2 - 128 + x_rand)
            y_right = (int)(y_offset+png_height/2 + 128 + x_rand)
            x_left  = (int)(x_offset+png_width/2  - 128 + y_rand)
            x_right = (int)(x_offset+png_width/2  + 128 + y_rand)

            img_new = img_new[y_left:y_right, x_left:x_right, :]
            # cv2.imshow('img_new', img_new)
            # cv2.waitKey(0)

            global num_saves
            
            save_dir = "train"
            if  not (num_saves % 4):
                save_dir = "valid"
            elif 2 == (num_saves % 4):
                save_dir = "test"

            img_name = str(num_saves) + '_' + str(angle) + \
                           '_' + str(y_offset) + '_' + str(x_offset) + '_' + png_name
            cv2.imwrite(save_path + '/' + save_dir + "/images/" + img_name + ".jpg", img_new)

            gt_file = open(save_path + '/' + save_dir + "/labels/" + img_name + ".txt", 'w')
            gt_data = [label_class, (x_offset + png_width / 2.0) / jpg_width, (y_offset + png_height / 2.0) / jpg_height, png_width / jpg_width, png_height / jpg_height]
            gt_file.writelines(" ".join(str(i) for i in gt_data))
            gt_file.close()

            num_saves = num_saves + 1
            print(num_saves)


def data_augment_ps(png_path, jpg_path, save_path):
    img_png = cv2.imread(png_path, cv2.IMREAD_UNCHANGED)
    img_jpg = cv2.imread(jpg_path, cv2.IMREAD_UNCHANGED)
    # img_jpg = cv2.resize(img_jpg, (0,0), fx = 0.5, fy = 0.5)

    png_name = png_path.split('/')[-1].split('.')[0]

    for angle in range(-20, 10, 4):
        translation_transformation(opencv_rotate(img_png, angle), img_jpg, angle, 0, save_path, png_name)

    for angle in range(160, 190, 4):
        translation_transformation(opencv_rotate(img_png, angle), img_jpg, angle, 1, save_path, png_name)

def main():
    png_path  = "/home/lz/DataDisk/datasets/jz_drill_recognize/for_max78000/foreground_img_tail"
    jpg_path  = "/home/lz/DataDisk/datasets/jz_drill_recognize/for_max78000/background_img"
    save_path = "/home/lz/RamDisk/jz_drill_ps_crop_aug_data"

    if os.path.exists(save_path):
        shutil.rmtree(save_path)

    os.makedirs(save_path + "/train/images")
    os.makedirs(save_path + "/train/labels")
    os.makedirs(save_path + "/valid/images")
    os.makedirs(save_path + "/valid/labels")
    os.makedirs(save_path + "/test/images")
    os.makedirs(save_path + "/test/labels")

    global num_saves
    num_saves = 0

    for i, png_name in enumerate(os.listdir(png_path)):
        for j, jpg_name in enumerate(os.listdir(jpg_path)):
            data_augment_ps(os.path.join(png_path, png_name), os.path.join(jpg_path, jpg_name), save_path)

if __name__ == "__main__":
    main()

模型训练

由于是一个二分类问题,这里采用了官方的ai85net-cd.py中的模型,该模型本来是实现对猫和狗做分类,这里为了初步快速验证实现的效果,故而选择了这个模型。为了减少对代码的改动,这里将数据集的格式按照该demo的要求安排。几带有钻头头部区域的类别放在cats目录中,带有钻头尾部区域的类别放在dogs类别中。

FsTn_t0OLuFu7K9mEDTQad_1GFAc

训练:这里直接启用官方的训练脚本即可,100epoch后top1达到100%,手动中止训练

./scripts/train_catsdogs.sh

2022-11-26 17:58:22,263 - Training epoch: 17972 samples (256 per mini-batch)
2022-11-26 17:58:23,402 - Epoch: [100][   10/   71]    Overall Loss 0.000018    Objective Loss 0.000018                                        LR 0.000216    Time 0.113771    
2022-11-26 17:58:24,192 - Epoch: [100][   20/   71]    Overall Loss 0.000012    Objective Loss 0.000012                                        LR 0.000216    Time 0.096369    
2022-11-26 17:58:24,961 - Epoch: [100][   30/   71]    Overall Loss 0.000012    Objective Loss 0.000012                                        LR 0.000216    Time 0.089875    
2022-11-26 17:58:25,741 - Epoch: [100][   40/   71]    Overall Loss 0.000010    Objective Loss 0.000010                                        LR 0.000216    Time 0.086900    
2022-11-26 17:58:26,529 - Epoch: [100][   50/   71]    Overall Loss 0.000010    Objective Loss 0.000010                                        LR 0.000216    Time 0.085259    
2022-11-26 17:58:27,311 - Epoch: [100][   60/   71]    Overall Loss 0.000010    Objective Loss 0.000010                                        LR 0.000216    Time 0.084088    
2022-11-26 17:58:28,054 - Epoch: [100][   70/   71]    Overall Loss 0.000010    Objective Loss 0.000010    Top1 100.000000    LR 0.000216    Time 0.082687    
2022-11-26 17:58:28,108 - Epoch: [100][   71/   71]    Overall Loss 0.000010    Objective Loss 0.000010    Top1 100.000000    LR 0.000216    Time 0.082278    
2022-11-26 17:58:28,156 - --- validate (epoch=100)-----------
2022-11-26 17:58:28,157 - 1996 samples (256 per mini-batch)
2022-11-26 17:58:28,946 - Epoch: [100][    8/    8]    Loss 0.000024    Top1 100.000000    
2022-11-26 17:58:28,994 - ==> Top1: 100.000    Loss: 0.000

将模型量化为int8类型

python quantize.py ../ai8x-training/logs/2022.11.26-174709/qat_best.pth.tar ../jz_cat_dog_synthesis/q8.pth.tar --device MAX78000

评估模型的性能,这里测试集为实际的钻头图片,一共有624张

FgpLAtaE1hyxno5KUdawyooLWtQ9Fl-pRPvpp0QGb4tPuTJ3HcHFtvIIFpBhNlW3NvSmu4Lk4C_-D6xdi1nE

FtfCDv5ax5KCTgclSmEcekwf4G2fFnJPzBp_7reQt2wQLWHmWZhB0QnkFrcTPspndvhhC58vRlEg7jpIxOBf

python train.py --model ai85cdnet --dataset cats_vs_dogs --confusion --evaluate --exp-load-weights-from ../jz_cat_dog_synthesis/q8.pth.tar -8 --device MAX78000 "$@"

+----------------------+-------------+-----------+
| Key                  | Type        | Value     |
|----------------------+-------------+-----------|
| arch                 | str         | ai85cdnet |
| compression_sched    | dict        |           |
| epoch                | int         | 109       |
| extras               | dict        |           |
| optimizer_state_dict | dict        |           |
| optimizer_type       | type        | Adam      |
| state_dict           | OrderedDict |           |
+----------------------+-------------+-----------+

2022-11-26 22:37:39,101 - => Checkpoint['extras'] contents:
+-----------------+--------+---------------+
| Key             | Type   | Value         |
|-----------------+--------+---------------|
| best_epoch      | int    | 109           |
| best_mAP        | int    | 0             |
| best_top1       | float  | 100.0         |
| clipping_method | str    | MAX_BIT_SHIFT |
| current_mAP     | int    | 0             |
| current_top1    | float  | 100.0         |
+-----------------+--------+---------------+

2022-11-26 22:37:39,102 - Loaded compression schedule from checkpoint (epoch 109)
2022-11-26 22:37:39,104 - => loaded 'state_dict' from checkpoint '../jz_cat_dog_synthesis/q8.pth.tar'
2022-11-26 22:37:39,110 - Optimizer Type: <class 'torch.optim.sgd.SGD'>
2022-11-26 22:37:39,110 - Optimizer Args: {'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0.0001, 'nesterov': False}
2022-11-26 22:37:39,299 - Dataset sizes:
	training=17972
	validation=1996
	test=624
2022-11-26 22:37:39,299 - --- test ---------------------
2022-11-26 22:37:39,300 - 624 samples (256 per mini-batch)
2022-11-26 22:37:40,133 - Test: [    3/    3]    Loss 0.437118    Top1 87.500000    
2022-11-26 22:37:40,189 - ==> Top1: 87.500    Loss: 0.437

2022-11-26 22:37:40,190 - ==> Confusion:
[[203  65]
 [ 13 343]]

结果分析

初测结果还是比较理想的,87%的top1准确率,虽然离实际应用的精度要求还有距离,但是还有很多设计改进的空间。在yolo v5模型下可以达到99%以上的准确率。后续会根据MAX78000的硬件特点设计性能更好的网路模型,此外数据集的丰富度还有待进一步提升。

 

团队介绍

个人参赛

评论

0 / 100
查看更多
目录
硬禾服务号
关注最新动态
0512-67862536
info@eetree.cn
江苏省苏州市苏州工业园区新平街388号腾飞创新园A2幢815室
苏州硬禾信息科技有限公司
Copyright © 2023 苏州硬禾信息科技有限公司 All Rights Reserved 苏ICP备19040198号