sam之自动生成mask代码流程

发布时间 2023-07-31 14:30:49作者: 海_纳百川

本文不涉及sam的训练流程,只设计推理过程

最近接触这个sam,由于网络中关于sam的自动mask功能的介绍较少,所以本周对源码进行了解读

说到sam自动提取mask,包含三个部分,第一个部分是如何对原始图像进行分割成一个个小块,第二部分是送到sam中进行处理得出结果,第三个是如何对一个个小块的结果进行后处理,过滤掉一些,保留更重要的结果。

sam推理这一块是最基础的一个功能,一旦训练完以后它也变不出什么花样。

1.sam网络结构

首先说一下他的网络结构,包含三个部分,图像编码image-encoding,分割点编码point-encoding,分割解码mask-decode

具体的网络结构在这里不进行描述,想了解的可以自己去看一看,这方面的博客也是不少。

图像编码这一块,网络的输入尺寸是1024*1024,原始图像需要进行预处理,先等比例resize,最大边resize到1024,然后再减均值除以方差,最后再进行pad,pad的填充值为0

2.预处理部分

自动提取mask预处理部分涉及的东西还是挺多的,预处理部分主要涉及如何对一张图像进行均匀分割和采样点设置

首先就是对图像进行均匀的分割成几个小块。分割的方法,原始尺寸不分割+均匀分割成4块+均匀分割成16块。也就是说总共21个图像

这21个图像中,会设置不同数量的分割点。具体是这样设置的:1.原始尺寸上:32*32个采样点,2.均匀分割4块:每块16*16个采样点,3均匀分割16快:每块8*8个采样点

采样点和图片都设置好以后,就直接按照普通图片推理方式,进行推理即可

3.后处理部分

在把图片+采样点送去sam以后,就会得出结果,但是结果也是很多个,需要对这些结果进行过滤

1.根据iou_pred过滤掉第一批mask,此时的iou_pred就是模型输出的maks置信度

2.根据stability_score过滤掉第二批mask,stability_score的计算方式是:统计maks的值中所有大于(mask_threshold+offset)的数量作为交集,统计maks的值中所有大于(mask_threshold)的数量作为并集,然后用交集处理并集就是稳定性

3.保留mask边界框(靠近裁剪边缘,但是不靠近原始图像边缘)的maks,其他的过滤掉,这种方式我还是不怎么理解。

4.每个图像的输出结果的boxes,进行nms过滤,然后将box和point全部映射回原图

5.移除mask中的不连续区域和空洞