在大佬的博客补充了一些小问题,按照如下修改,你的代码就能跑起来了
收费教程:YOLOv5更换骨干网络之 MobileViT-S / MobileViT-XS / MobileViT-XXS
知识储备
MobileViT模型简介
MobileViT、MobileViTv2、MobileViTv3学习笔记(自用)
MobileViTv1、MobileViTv2、MobileViTv3网络详解
准备工作:
我使用的是6.0 yolov5s
mobilevit
正式修改
- 将mobilevit.py放在yolov5/models

2. 修改models/yolo.py
加入所有的模块,或者只加入MV2Block, MobileViTBlock

加入MV2Block, MobileViTBlock

3.修改yaml文件
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license # Parameters nc: 1 # number of classes depth_multiple: 0.33 # model depth multiple width_multiple: 0.50 # layer channel multiple anchors: - [10,13, 16,30, 33,23] # P3/8 - [30,61, 62,45, 59,119] # P4/16 - [116,90, 156,198, 373,326] # P5/32 # YOLOv5 backbone backbone: # [from, number, module, args] 640 x 640 # [[-1, 1, Conv, [32, 6, 2, 2]], # 0-P1/2 320 x 320 [[-1, 1, Focus, [32, 3]], [-1, 1, MV2Block, [32, 1, 2]], # 1-P2/4 [-1, 1, MV2Block, [48, 2, 2]], # 160 x 160 [-1, 2, MV2Block, [48, 1, 2]], [-1, 1, MV2Block, [64, 2, 2]], # 80 x 80 [-1, 1, MobileViTBlock, [64,96, 2, 3, 2, 192]], # 5 out_dim,dim, depth, kernel_size, patch_size, mlp_dim [-1, 1, MV2Block, [80, 2, 2]], # 40 x 40 [-1, 1, MobileViTBlock, [80,120, 4, 3, 2, 480]], # 7 [-1, 1, MV2Block, [96, 2, 2]], # 20 x 20 [-1, 1, MobileViTBlock, [96,144, 3, 3, 2, 576]], # 11-P2/4 # 9 ] # YOLOv5 head head: [[-1, 1, Conv, [256, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 7], 1, Concat, [1]], # cat backbone P4 [-1, 3, C3, [256, False]], # 13 [-1, 1, Conv, [128, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 5], 1, Concat, [1]], # cat backbone P3 [-1, 3, C3, [128, False]], # 17 (P3/8-small) [-1, 1, Conv, [128, 3, 2]], [[-1, 14], 1, Concat, [1]], # cat head P4 [-1, 3, C3, [256, False]], # 20 (P4/16-medium) [-1, 1, Conv, [256, 3, 2]], [[-1, 10], 1, Concat, [1]], # cat head P5 [-1, 3, C3, [512, False]], # 23 (P5/32-large) [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ]
- 修改mobilevit.py
 
 
 可以愉快的跑起来了!!!
END
谢谢观看,有用的话点个赞吧!
ADD
einops.EinopsError: Error while processing rearrange-reduction pattern "b d (h ph) (w pw) -> b (ph pw) (h w) d".
Input tensor shape: torch.Size([1, 120, 42, 42]). Additional info: {'ph': 4, 'pw': 4}
- 是因为输入输出不匹配造成 
- 记得关掉rect哦!一个是在参数里,另一个在下图。如果要在test或者val中跑,同样要改 

特别感谢养乐多阿









