YOLOV7改进--添加CBAM注意力机制

慈云数据 8个月前 (03-13) 技术支持 134 0

YOLOV7改进--添加CBAM注意力机制

  • CBAM注意力机制
  • 代码
    • 在commen.py中添加CBAM模块
    • 在yolo.py中添加CBAM模块名
    • 在cfg文件中添加CBAM信息

      因为项目需要,尝试在yolov7上加入CBAM注意力机制,看看能不能提升点性能。之前有在yolov5上添加CBAM的经验,所以直接把yolov5中的CBAM搬过来,废话不多说,直接看代码吧!

      CBAM注意力机制

      首先,介绍一下CBAM注意力机制:

      论文来源:https://arxiv.org/pdf/1807.06521.pdf

      在这里插入图片描述

      Convolutional Block Attention Module (CBAM)由两个模块构成,分别为通道注意力(CAM)和空间注意力模块(SAM),CAM可以使网络关注图像的前景,使网络更加关注有意义的gt区域,而SAM可以让网络关注到整张图片中富含上下文信息的位置。这两个模块即插即用,建议串行加入到网络中(论文里面是串行比并行好,在博主的数据集下,并行和串行效果不明显,博主认为特征融合没有苛刻的要求,视使用的数据集而定,怎么连效果好就怎么连),下面的展示的代码是串行方法。

      代码

      在commen.py中添加CBAM模块

      这部分代码同yolov5的一样,直接拿来用!

      class ChannelAttention(nn.Module):
          def __init__(self, in_planes, ratio=16):
              super(ChannelAttention, self).__init__()
              self.avg_pool = nn.AdaptiveAvgPool2d(1)
              self.max_pool = nn.AdaptiveMaxPool2d(1)
              self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
              self.relu = nn.ReLU()
              self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
              self.sigmoid = nn.Sigmoid()
          def forward(self, x):
              avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
              max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
              out = self.sigmoid(avg_out + max_out)
              return out
      class SpatialAttention(nn.Module):
          def __init__(self, kernel_size=7):
              super(SpatialAttention, self).__init__()
              assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
              padding = 3 if kernel_size == 7 else 1
              self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
              self.sigmoid = nn.Sigmoid()
          def forward(self, x):
              avg_out = torch.mean(x, dim=1, keepdim=True)
              max_out, _ = torch.max(x, dim=1, keepdim=True)
              x = torch.cat([avg_out, max_out], dim=1)
              x = self.conv(x)
              return self.sigmoid(x)
              
      class CBAM(nn.Module):
          # Standard convolution
          def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
              super(CBAM, self).__init__()
              self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
              self.bn = nn.BatchNorm2d(c2)
              self.act = nn.Hardswish() if act else nn.IdEntity()
              self.ca = ChannelAttention(c2)
              self.sa = SpatialAttention()
          def forward(self, x):
              x = self.act(self.bn(self.conv(x)))
              x = self.ca(x) * x
              x = self.sa(x) * x
              return x
          def fuseforward(self, x):
              return self.act(self.conv(x))
      

      在yolo.py中添加CBAM模块名

      找到yolo.py第459行,加入CBAM模块名。

      if m in [nn.Conv2d, Conv, RobustConv, RobustConv2, DWConv, GhostConv, RepConv, RepConv_OREPA, DownC, 
                       SPP, SPPF, SPPCSPC, GhostSPPCSPC, MixConv2d, Focus, Stem, GhostStem, CrossConv, 
                       Bottleneck, BottleneckCSPA, BottleneckCSPB, BottleneckCSPC, 
                       RepBottleneck, RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC,  
                       Res, ResCSPA, ResCSPB, ResCSPC, 
                       RepRes, RepResCSPA, RepResCSPB, RepResCSPC, 
                       ResX, ResXCSPA, ResXCSPB, ResXCSPC, 
                       RepResX, RepResXCSPA, RepResXCSPB, RepResXCSPC, 
                       Ghost, GhostCSPA, GhostCSPB, GhostCSPC,
                       SwinTransformerBlock, STCSPA, STCSPB, STCSPC,
                       SwinTransformer2Block, ST2CSPA, ST2CSPB, ST2CSPC, CBAM]:
          c1, c2 = ch[f], args[0]
          if c2 != no:  # if not output
              c2 = make_divisible(c2 * gw, 8)
          args = [c1, c2, *args[1:]]
          if m in [DownC, SPPCSPC, GhostSPPCSPC, 
                           BottleneckCSPA, BottleneckCSPB, BottleneckCSPC, 
                           RepBottleneckCSPA, RepBottleneckCSPB, RepBottleneckCSPC, 
                           ResCSPA, ResCSPB, ResCSPC, 
                           RepResCSPA, RepResCSPB, RepResCSPC, 
                           ResXCSPA, ResXCSPB, ResXCSPC, 
                           RepResXCSPA, RepResXCSPB, RepResXCSPC,
                           GhostCSPA, GhostCSPB, GhostCSPC,
                           STCSPA, STCSPB, STCSPC,
                           ST2CSPA, ST2CSPB, ST2CSPC]:
               args.insert(2, n)  # number of repeats
               n = 1
      

      在cfg文件中添加CBAM信息

      这里以添加到backbone为例,将Conv替换成CBAM即可,同样也可在FPN里替换。

      # parameters
      nc: 80  # number of classes
      depth_multiple: 1.0  # model depth multiple
      width_multiple: 1.0  # layer channel multiple
      # anchors
      anchors:
        - [10,13, 16,30, 33,23]  # P3/8
        - [30,61, 62,45, 59,119]  # P4/16
        - [116,90, 156,198, 373,326]  # P5/32
      backbone:
        # [from, number, module, args] c2, k=1, s=1, p=None, g=1, act=True
        # [[-1, 1, Conv, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]],  # 0-P1/2 
        [[-1, 1, CBAM, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]],  # 0-P1/2  
        
        #  [-1, 1, Conv, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]],  # 1-P2/4  
         [-1, 1, CBAM, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]],  # 1-P2/4    
         
         [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [[-1, -2, -3, -4], 1, Concat, [1]],
         [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 7
         
         [-1, 1, MP, []],  # 8-P3/8
         [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [[-1, -2, -3, -4], 1, Concat, [1]],
         [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 14
         
         [-1, 1, MP, []],  # 15-P4/16
         [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [[-1, -2, -3, -4], 1, Concat, [1]],
         [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 21
         
         [-1, 1, MP, []],  # 22-P5/32
         [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [[-1, -2, -3, -4], 1, Concat, [1]],
         [-1, 1, Conv, [512, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 28
        ]
      # yolov7-tiny head
      head:
        [[-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, SP, [5]],
         [-2, 1, SP, [9]],
         [-3, 1, SP, [13]],
         [[-1, -2, -3, -4], 1, Concat, [1]],
         [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [[-1, -7], 1, Concat, [1]],
         [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 37
        
         [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, nn.Upsample, [None, 2, 'nearest']],
         [21, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P4
         [[-1, -2], 1, Concat, [1]],
         
         [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [[-1, -2, -3, -4], 1, Concat, [1]],
         [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 47
        
         [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, nn.Upsample, [None, 2, 'nearest']],
         [14, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P3
         [[-1, -2], 1, Concat, [1]],
         
         [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [[-1, -2, -3, -4], 1, Concat, [1]],
         [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 57
         
         [-1, 1, Conv, [128, 3, 2, None, 1, nn.LeakyReLU(0.1)]],
         [[-1, 47], 1, Concat, [1]],
         
         [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [[-1, -2, -3, -4], 1, Concat, [1]],
         [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 65
         
         [-1, 1, Conv, [256, 3, 2, None, 1, nn.LeakyReLU(0.1)]],
         [[-1, 37], 1, Concat, [1]],
         
         [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [[-1, -2, -3, -4], 1, Concat, [1]],
         [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],  # 73
            
         [57, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [65, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [73, 1, Conv, [512, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
         [[74,75,76], 1, Detect, [nc, anchors]],   # Detect(P3, P4, P5)
        ]
      
微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon