如何在yolo中增加注意力机制

发布时间 2023-10-03 15:41:26作者: Frommoon

本文在yolo的基础上增加了注意力机制

1.导入类

在ultralytics\nn\models\extra_modules\attention.py中导入想添加的注意力的类,如下图

2.attention.py中放入函数名


__all__ = ['EMA', 'SimAM', 'SpatialGroupEnhance', 'BiLevelRoutingAttention', 'BiLevelRoutingAttention_nchw', 'TripletAttention', 
           'CoordAtt', 'BAMBlock', 'EfficientAttention', 'LSKBlock', 'SEAttention', 'CPCA', 'MPCA']

3.需不需要通道数(True\False)

在ultralytics\nn\tasks.py中的parse_model函数中放入名字(768行)

            #需要通道数的注意力机制加在这里
        elif m in {EMA, SpatialAttention, BiLevelRoutingAttention, BiLevelRoutingAttention_nchw,
                   TripletAttention, CoordAtt, CBAM, BAMBlock, LSKBlock, ScConv, LAWDS, EMSConv, EMSConvP,
                   SEAttention, CPCA, Partial_conv3, FocalModulation, EfficientAttention, MPCA}:
            c2 = ch[f]
            args = [c2, *args]
            # print(args)
            # 不需要通道数的注意力机制加在这里
        elif m in {SimAM, SpatialGroupEnhance}:
            c2 = ch[f]
        elif m is ContextGuidedBlock_Down:
            c2 = ch[f] * 2
            args = [ch[f], c2, *args]

4.更改配置文件

位置:ultralytics\models\v8\yolov8-attention.yaml
加入注意力层,注意层数的对应变化


# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
#  - [-1, 1, BiLevelRoutingAttention_nchw, [8, 7]] # 10
#  - [-1, 1, BiLevelRoutingAttention, [8, 7]] # 10
  - [-1, 1, SimAM, [1e-4]] # 10   可不可以带参数返回去看类,不同的注意力机制只需要换名字
#  - [-1, 1, TripletAttention, []] # 10
# - [-1, 1, CPCA, []] # 10

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

5.测试

在Terminal输入python train.py --yaml ultralytics/models/v8/yolov8-attention.yaml --info

加入成功