sam自动生成mask代码解析

发布时间 2023-07-21 14:20:52作者: 海_纳百川

要自动生成mask,请向“SamAutomaticMaskGenerator”类注入SAM模型(需要先初始化SAM模型)

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
#自动生成采样点对图像进行分割
mask_generator = SamAutomaticMaskGenerator(sam)

masks = mask_generator.generate(image)

print(len(masks))
print(masks[0].keys())
print(masks[0])

plt.figure(figsize=(16,16))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

其中print masks这块输出为

42
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])
{'segmentation': array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [ True,  True,  True, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False],
       [False, False, False, ..., False, False, False]]), 'area': 18821, 'bbox': [0, 113, 207, 152], 'predicted_iou': 0.9937220215797424, 'point_coords': [[93.75, 146.015625]], 'stability_score': 0.9622295498847961, 'crop_box': [0, 0, 400, 267]}

例如生成的图片

masks = mask_generator.generate(image)

Mask generation返回该图像所有的masks信息,每一个mask都是一个字典对象,mask的keys如下:

  • segmentation : np的二维数组,为二值的mask图片
  • area : mask的像素面积
  • bbox : mask的外接矩形框,为XYWH格式
  • predicted_iou : 该mask的质量(模型预测出的与真实框的iou)
  • point_coords : 用于生成该mask的point输入
  • stability_score : mask质量的附加指标
  • crop_box : 用于以XYWH格式生成此遮罩的图像裁剪