基于MAX78000FTHR开发板设计的手势识别器
基于MAX78000FTHR开发板设计的手势识别器,使用ASL Alphabet数据集训练完成,支持石头、剪刀、布、空识别。
标签
AI
手势识别
MAX78000
冷月烟
更新2023-01-31
820

一、项目介绍

基于MAX78000FTHR开发板设计的手势识别器,支持石头、剪刀、布、空识别,识别结果会显示到外接的屏幕上,屏幕也会实时显示摄像头采集到的画面。

 

二、项目设计思路

项目分为三部分:识别模型、摄像头采集、屏幕显示。

使用板卡自带的摄像头进行图像采集,将数据传到主芯片,将处理结果显示到外接的屏幕上。

 

整体流程如下

1.学习官方代码,依照教程进行模型的训练、量化、部署的学习。

2.收集所需素材,通过上网查找相关数据集、自拍摄等方式收集到各种石头剪刀布的素材。

3.编写硬件驱动,使用SPI驱动LCD显示屏,驱动摄像头采集图像,并显示到LCD上。

4.编写数据加载脚本等工具,进行图像预处理与训练量化。

5.完成模型部署,根据实际情况调节代码。

 

整体框图

FoEbCM2BN6AbOY5mboYtnDh2qKV1

 

三、搜集素材的思路

1.数据集图像

对于识别训练来说,图像越多肯定效果越好,但是个人很难去采集大量的图片进行训练,因此我优先寻找可以直接使用的数据集,于是上网找了找,发现了一个美国手语识别的数据集ASL Alphabet(https://www.kaggle.com/datasets/grassknoted/asl-alphabet)。训练数据集包含 87,000 张 200x200 像素的图像。有 29 个类,其中 26 个用于字母 A-Z,3 个用于 SPACE、DELETE 和 NOTHING。

数据集里面的B、nothing、S、V是类似于石头剪刀布的图像,这里我们可以直接拿来使用。

Fu5X7iJbBgrsVdIdsRxAeGRPJDxo

2.自拍摄图像

为了补充训练集的不足,增加可用图像,这里我自己采集了一些图像用于训练。美信官方提供了一个图像采集的例程,可以将采集的图像传到电脑上,这里我使用截图工具将需要的图像截下来,作为数据集的补充。

Fl8aZV4qTDHwDcNrKjb1MI7Ie6gR

Fi5mL9k9PZkTZsrdG5MEmCMjviQhFll_jMqFPTtnw3t6SAgPUGWVQVy1

以上两种获得的图像就完全足矣这次的训练使用了。

四、预训练实现过程

编写数据加载脚本

将数据处理成64*64大小

###################################################################################################
#
# Copyright (C) 2018-2020 Maxim Integrated Products, Inc. All Rights Reserved.
#
# Maxim Integrated Products, Inc. Default Copyright Notice:
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
#
###################################################################################################
"""
RPS Datasets
"""
import os
import sys

import torchvision
from torchvision import transforms

import ai8x


def rps_get_datasets(data, load_train=True, load_test=True):
    """
    rps dataset
    """
    (data_dir, args) = data
    path = data_dir
    dataset_path = os.path.join(path, "rps_big")
    is_dir = os.path.isdir(dataset_path)
    if not is_dir:
        print("******************************************")
        print("No data!!!")
        print("******************************************")
        sys.exit("Dataset not found..")
    training_data_path = os.path.join(data_dir, "rps_big")
    training_data_path = os.path.join(training_data_path, "train")
    test_data_path = os.path.join(data_dir, "rps_big")
    test_data_path = os.path.join(test_data_path, "test")
    # Loading and normalizing train dataset
    if load_train:
        train_transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ColorJitter(
                brightness=(0.3, .8),
                contrast=(.7, 1),
                saturation=0.2,
                ),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(([
                transforms.ColorJitter(),
                ]), p=0.3),
            transforms.ToTensor(),
            ai8x.normalize(args=args)
        ])

        train_dataset = torchvision.datasets.ImageFolder(root=training_data_path,
                                                         transform=train_transform)
    else:
        train_dataset = None

    # Loading and normalizing test dataset
    if load_test:
        test_transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            ai8x.normalize(args=args)
        ])

        test_dataset = torchvision.datasets.ImageFolder(root=test_data_path,
                                                        transform=test_transform)

        if args.truncate_testset:
            test_dataset.data = test_dataset.data[:1]
    else:
        test_dataset = None

    return train_dataset, test_dataset


datasets = [
    {
        'name': 'rps_big',
        'input': (3, 64, 64),
        'output': ( 'b', 'nothing', 's', 'v'),
        'weight': (1, 1, 1, 1),
        'loader': rps_get_datasets,
    },
]

 

开始训练

2022-11-26 07:37:24,861 - Optimizer Type: <class 'torch.optim.adam.Adam'>
2022-11-26 07:37:24,862 - Optimizer Args: {'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.0001, 'amsgrad': False}
2022-11-26 07:37:25,041 - Dataset sizes:
	training=10800
	validation=1200
	test=4
2022-11-26 07:37:25,042 - Reading compression schedule from: schedule-asl.yaml
2022-11-26 07:37:25,046 - 

2022-11-26 07:37:25,047 - Training epoch: 10800 samples (256 per mini-batch)
2022-11-26 07:37:50,724 - Epoch: [0][   10/   43]    Overall Loss 1.368387    Objective Loss 1.368387                                        LR 0.001000    Time 2.567639    
2022-11-26 07:38:13,264 - Epoch: [0][   20/   43]    Overall Loss 1.313537    Objective Loss 1.313537                                        LR 0.001000    Time 2.410738    
2022-11-26 07:38:38,605 - Epoch: [0][   30/   43]    Overall Loss 1.245654    Objective Loss 1.245654                                        LR 0.001000    Time 2.451815    
2022-11-26 07:39:05,066 - Epoch: [0][   40/   43]    Overall Loss 1.188697    Objective Loss 1.188697                                        LR 0.001000    Time 2.500355    
2022-11-26 07:39:10,427 - Epoch: [0][   43/   43]    Overall Loss 1.173808    Objective Loss 1.173808    Top1 63.157895    LR 0.001000    Time 2.450569    
2022-11-26 07:39:10,739 - --- validate (epoch=0)-----------
2022-11-26 07:39:10,740 - 1200 samples (256 per mini-batch)
2022-11-26 07:39:18,305 - Epoch: [0][    5/    5]    Loss 0.968138    Top1 60.000000    
2022-11-26 07:39:18,584 - ==> Top1: 60.000    Loss: 0.968

2022-11-26 07:39:18,592 - ==> Confusion:
[[211  12  35  24]
 [ 57 160  14  65]
 [ 54 172  29  44]
 [  1   2   0 320]]

2022-11-26 07:39:18,604 - ==> Best [Top1: 60.000   Sparsity:0.00   Params: 60080 on epoch: 0]

 

最后一次训练与验证

2022-11-26 10:34:42,820 - Saving checkpoint to: logs/2022.11.26-073724/qat_checkpoint.pth.tar
2022-11-26 10:34:42,832 - 

2022-11-26 10:34:42,836 - Training epoch: 10800 samples (256 per mini-batch)
2022-11-26 10:35:09,183 - Epoch: [99][   10/   43]    Overall Loss 0.345190    Objective Loss 0.345190                                        LR 0.001000    Time 2.634476    
2022-11-26 10:35:33,291 - Epoch: [99][   20/   43]    Overall Loss 0.345397    Objective Loss 0.345397                                        LR 0.001000    Time 2.522606    
2022-11-26 10:35:57,030 - Epoch: [99][   30/   43]    Overall Loss 0.345680    Objective Loss 0.345680                                        LR 0.001000    Time 2.473005    
2022-11-26 10:36:19,786 - Epoch: [99][   40/   43]    Overall Loss 0.345482    Objective Loss 0.345482                                        LR 0.001000    Time 2.423636    
2022-11-26 10:36:24,455 - Epoch: [99][   43/   43]    Overall Loss 0.345504    Objective Loss 0.345504    Top1 100.000000    LR 0.001000    Time 2.363106    
2022-11-26 10:36:24,660 - --- validate (epoch=99)-----------
2022-11-26 10:36:24,663 - 1200 samples (256 per mini-batch)
2022-11-26 10:36:31,351 - Epoch: [99][    5/    5]    Loss 0.343993    Top1 100.000000    
2022-11-26 10:36:31,619 - ==> Top1: 100.000    Loss: 0.344

2022-11-26 10:36:31,625 - ==> Confusion:
[[282   0   0   0]
 [  0 296   0   0]
 [  0   0 299   0]
 [  0   0   0 323]]

2022-11-26 10:36:31,637 - ==> Best [Top1: 100.000   Sparsity:0.00   Params: 60080 on epoch: 99]
2022-11-26 10:36:31,638 - Saving checkpoint to: logs/2022.11.26-073724/qat_checkpoint.pth.tar
2022-11-26 10:36:31,654 - --- test ---------------------
2022-11-26 10:36:31,655 - 4 samples (256 per mini-batch)
2022-11-26 10:36:32,722 - Test: [    1/    1]    Loss 0.343015    Top1 100.000000    
2022-11-26 10:36:32,938 - ==> Top1: 100.000    Loss: 0.343

2022-11-26 10:36:32,940 - ==> Confusion:
[[1 0 0 0]
 [0 1 0 0]
 [0 0 1 0]
 [0 0 0 1]]

 

不得不说,100%的识别率还是很夸张的,可能是与识别的图像简单,分类少的原因。

五、实现结果展示

识别空

FrEjpLwnpd1lq2crup07B7g5gY2A

识别布

FrnL6nELO9UD8hhrUZfeOn4UKcRF

识别石头

FgQ27TQL5s6Dst_Ld-NkgN0IbAEA

识别剪刀

FgMTxIgdsHAvXFwGwGYOxUqPUsh9

 

六、主要代码

图像显示

void lcd_show_sampledata(uint32_t* data0, uint32_t* data1, uint32_t* data2, int xcord, int ycord,
                         int length)
{
    int i;
    int j;
    int x;
    int y;
    int r;
    int g;
    int b;
    int scale = 1.2;

    uint32_t color;
    uint8_t* ptr0;
    uint8_t* ptr1;
    uint8_t* ptr2;

    x = 0;
    y = 0;
    for (i = 0; i < length; i++)
    {
        ptr0 = (uint8_t*)&data0[i];
        ptr1 = (uint8_t*)&data1[i];
        ptr2 = (uint8_t*)&data2[i];
        for (j = 0; j < 4; j++)
        {
            r = ptr0[j];
            g = ptr1[j];
            b = ptr2[j];
            color = RGB(r, g, b); // convert to RGB565
            MXC_TFT_WritePixel(xcord * scale + 2 * x * scale, ycord * scale + 2 * y * scale, scale, scale, color);
            x += 1;
            if (x >= (IMAGE_SIZE_X))
            {
                x = 0;
                y += 1;
                if ((y + 6) >= (IMAGE_SIZE_Y))
                    return;
            }
        }
    }
}

 

图像采集与CNN识别

// Capture a single camera frame.
printf("\nCapture a camera frame %d\n", ++frame);
capture_camera_img();
// Copy the image data to the CNN input arrays.
printf("Copy camera frame to CNN input buffers.\n");
process_camera_img(input_0_camera, input_1_camera, input_2_camera);

convert_img_unsigned_to_signed(input_0_camera, input_1_camera, input_2_camera);

cnn_init();         // Bring state machine into consistent state
cnn_load_weights(); // Load kernels
cnn_load_bias();
cnn_configure(); // Configure state machine
cnn_start();     // Start CNN processing
load_input();    // Load data input via FIFO
MXC_TMR_SW_Start(MXC_TMR0);

while (cnn_time == 0)
	__WFI(); // Wait for CNN

softmax_layer();

printf("Time for CNN: %d us\n\n", cnn_time);

printf("Classification results:\n");
for (i = 0; i < CNN_NUM_OUTPUTS; i++)
{
	digs      = (1000 * ml_softmax[i] + 0x4000) >> 15;
	tens      = digs % 10;
	digs      = digs / 10;
	result[i] = digs;
	printf("[%7d] -> Class %d %8s: %d.%d%%\r\n", ml_data[i], i, classes[i], digs, tens);
}

printf("\n");

 

判断并显示识别结果

if (result[0] > 0) //适应实际修改
{
	//布
	TFT_Print(buff, 30, 55, font_2, sprintf(buff, "Paper   "));
	printf("User choose: %s \r\n", classes[0]);

}
else if (result[1] > 60)
{
	//石头
	TFT_Print(buff, 30, 55, font_2, sprintf(buff, "Rock    "));
	printf("User choose: %s \r\n", classes[1]);

}
else if (result[2] > 60)
{
	//剪刀
	TFT_Print(buff, 30, 55, font_2, sprintf(buff, "Scissors"));
	printf("User choose: %s \r\n", classes[2]);

}
else if (result[3] > 60)
{
	//空
	TFT_Print(buff, 30, 55, font_2, sprintf(buff, "Empty   "));
	printf("User choose: %s \r\n", classes[3]);

}
else
{
	TFT_Print(buff, 30, 55, font_2, sprintf(buff, "Unknown "));
}

TFT_Print(buff, 205, 55, font_2, sprintf(buff, "Paper:%d  ",result[0]));
TFT_Print(buff, 205, 75, font_2, sprintf(buff, "Rock:%d  ",result[1]));
TFT_Print(buff, 205, 95, font_2, sprintf(buff, "Scissors:%d  ",result[2]));
TFT_Print(buff, 205, 115, font_2, sprintf(buff, "Empty:%d  ",result[3]));

 

七、问题与下一步计划

一开始本打算做完整手语识别的,但是实际部署到板卡上发现识别情况与期望相差非常大,参考官方提供的例程,发现都是一些比较简单的识别,并没有太多种类识别部署到板卡上,于是就转变思路仅仅选择石头剪刀布与空这几个简单的手势去做,但也出现了布难以识别(识别阈值被我调节到大于0就判断为布),非这三种手势也会给出输出的问题(应该是空数据集数据种类太少的原因,应该把一些其他手势放到空里面)。

下一步计划继续进行AI相关的学习,试着将更多的手势去部署到板卡上测试,通过修改训练参数让代码的适应能力更好,将更多的图片换成自己采集的图片,这样的话估计能大幅度提高实际用摄像头采集识别的准确率。

附件下载
rps-demo.zip
代码
团队介绍
团队成员
冷月烟
评论
0 / 100
查看更多
目录
硬禾服务号
关注最新动态
0512-67862536
info@eetree.cn
江苏省苏州市苏州工业园区新平街388号腾飞创新园A2幢815室
苏州硬禾信息科技有限公司
Copyright © 2023 苏州硬禾信息科技有限公司 All Rights Reserved 苏ICP备19040198号