深度学习---单目标关键点检测网络Stacked Hourglass

发布时间 2023-11-17 15:32:28作者: 半夜打老虎

Stacked Hourglass Networks是2016年提出的一种用于单人人体姿态估计的网络,并取得了很好的效果。这里我们从网络结构以及一些实现细节简单分析下这个网络。
paper: https://arxiv.org/pdf/1603.06937.pdf
code: https://github.com/princeton-vl/pytorch_stacked_hourglass https://github.com/zhoujinhai/Stack_HourGlass

一、网络结构

1.1 网络整体结构

网络名字Stacked HourGlass(堆叠的沙漏)其实已经反应了网络的大致结构,由像沙漏一样的结构堆叠而成,每一个沙漏结构在文章称为Hourglass模块,hourglass模块之间还有一个中间监督层(Intermediate Supervision),用于衔接各个hourglass模块,通过这样的架构,不断重复进行自上而下,自下而上的推断机制,通过这种机制从而能够重新评估整张图像的初始估计和特征。

1.2 Hourglass模块

前面提到的hourglass模块如下图所示,其在论文中由4层组成,方格的大小表示特征的维度,每一个方格都表示一个残差块,大体逻辑是先降维然后通过残差块提取特征,再升维后进行特征融合。

由于其层级结构,所以在实现上采用了递归的方式。

残差块如下图所示。

1.3 中间监督层


上一个Hourglass的输出经过由残差块处理得到特征A,该特征经过两个分支,

  • 一个分支经过1*1的卷积得到相同的维度特征作为下一个hourglass模块的输入。
  • 另一个经过11的卷积输出得到中间层生成的heatmaps,可以和真实的标签计算loss,这些特征层再经过11的卷积输出和A相同的维度特征,然后也作为下一个hourglass模块的输入。

所以下一个hourglass模块的输入有三个,两个在上面提到,还有一个是输入到hourglass的特征。

为什么中间监督那么关键?
这是因为当通过每个Hourglass模块时,网络都将有机会在局部和全局上下文中处理特征,然后生成预测。 随后的Hourglass模块允许这些高级特征再次被处理,以进一步评估和重新评估更高阶空间关系。

二、实现细节

2.1 标签生成

假如有3个关键点,那真实的训练标签就有3张对应的HeatMap图。关键点由x,y坐标表示,那如何将其转换成训练用的HeatMap呢,采用高斯热力图,也就是越靠近关键点位置,其值越接近于1,越远越接近于0。

其中sigma用于控制高斯热力图的范围。越大形成的热力图范围越大。

2.2 结果解析

网络训练完成后,推理阶段输出的也是HeatMap图,那要得到关键点位置信息,就需要进行解析,其过程刚好和上面标签生成相反,即找出heatMap中最大值所在位置作为关键点位置。

三、模型部署

3.1 模型转换

将训练好的模型转换为ONNX格式

3.2 模型推理

推理可以采用前面提到过的OpenCV的DNN模块,或者采用NCNN进行推理。
大体思路是获取最后一个hourglass模块的输出,然后解析出每一个HeatMap最大点的位置,这些位置再转换到原图即检测到的关键点位置。