常用代码段-nms操作

发布时间 2023-08-25 09:19:00作者: 海_纳百川

非极大值抑制(Non-Maximum Suppression,NMS)是一种常用于目标检测和计算机视觉任务的技术,用于从重叠的检测框中选择最佳的候选框。以下是使用 PyTorch 实现标准的 NMS 算法的示例代码:

import torch

def nms(boxes, scores, iou_threshold):
    sorted_indices = scores.argsort(descending=True)
    selected_indices = []

    while sorted_indices.numel() > 0:
        current_index = sorted_indices[0]
        selected_indices.append(current_index.item())

        if sorted_indices.numel() == 1:
            break

        current_box = boxes[current_index]
        other_boxes = boxes[sorted_indices[1:]]
        ious = calculate_iou(current_box, other_boxes)

        valid_indices = (ious <= iou_threshold).nonzero().squeeze()
        if valid_indices.numel() == 0:
            break

        sorted_indices = sorted_indices[valid_indices + 1]

    return selected_indices

def calculate_iou(box, boxes):
    x1 = torch.max(box[0], boxes[:, 0])
    y1 = torch.max(box[1], boxes[:, 1])
    x2 = torch.min(box[2], boxes[:, 2])
    y2 = torch.min(box[3], boxes[:, 3])

    intersection_area = torch.clamp(x2 - x1 + 1, min=0) * torch.clamp(y2 - y1 + 1, min=0)
    box_area = (box[2] - box[0] + 1) * (box[3] - box[1] + 1)
    boxes_area = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)

    iou = intersection_area / (box_area + boxes_area - intersection_area)
    return iou

# 示例数据:框坐标和置信度得分
boxes = torch.tensor([[100, 100, 200, 200], [120, 120, 220, 220], [150, 150, 250, 250]])
scores = torch.tensor([0.9, 0.8, 0.7])

# NMS 参数
iou_threshold = 0.5

# 执行 NMS 算法
selected_indices = nms(boxes, scores, iou_threshold)
print("选择的索引:", selected_indices)

在此示例中,我们首先定义了 nms 函数来执行 NMS 算法。然后,我们实现了一个简单的 calculate_iou 函数来计算两个框的交并比(IoU)。最后,我们使用示例数据 boxes 和 scores 运行 NMS 算法,并打印出选定的索引。