Vitis AI 迁移学习并部署在DPU中

慈云数据 6个月前 (05-28) 技术支持 71 0

目录

1. 本文目的

2. ResNet18介绍

3. 迁移学习

4. 量化配置文件

5. 模型编译:

6. 总结


1. 本文目的

使用迁移学习的方法,将预训练的resnet18模型从原来的1000类分类任务,改造为适应自定义的30类分类任务。

2. ResNet18介绍

ResNet18是一种基于深度残差网络(ResNet)的卷积神经网络模型,由何凯明等人于2015年提出。ResNet的核心思想是通过引入残差块(Residual Block),解决了深度网络中的梯度消失和退化问题,使得网络可以更深更有效地学习特征。

ResNet18是ResNet系列中最简单的一个模型,共有18层,其中包括:

  • 一个7×7的卷积层,输出通道数为64,步幅为2,后接批量归一化(Batch Normalization)和ReLU激活函数。
  • 一个3×3的最大池化层(Max Pooling),步幅为2。
  • 四个由残差块组成的模块,每个模块包含两个或三个残差块,每个残差块由两个3×3的卷积层、批量归一化和ReLU激活函数组成。每个模块的第一个残差块可能会改变输入输出的通道数和步幅,以适应下一个模块。这四个模块的输出通道数分别为64、128、256、512,步幅分别为1、2、2、2。
  • 一个全局平均池化层(Global Average Pooling),将最后一个模块的输出转换为一个一维向量。
  • 一个全连接层(Fully Connected),将一维向量映射到最终的类别数1000类上。

    3. 迁移学习

    根据不同的任务和数据集,迁移学习有以下几种常见的方法:

    • 微调网络:这种方法是在预训练模型的基础上,修改最后一层或几层,并且对整个网络进行微调训练。这种方法适用于新数据集和原数据集相似度较高,且新数据集规模较大的情况。
    • 特征提取:这种方法是将预训练模型看作一个特征提取器,冻结除了最后一层以外的所有层,只修改和训练最后一层。这种方法适用于新数据集和原数据集相似度较高,但新数据集规模较小的情况。
    • 模型蒸馏:这种方法是将预训练模型看作一个教师模型,用它来指导一个更小的学生模型,使学生模型能够学习到教师模型的知识。这种方法适用于新数据集和原数据集相似度较低,或者需要减少模型大小和计算量的情况。

      首先导入所需的模块:

      #!pip install -i  torchsummary
      # torchsummary是一个用于查看网络结构,非必须
      from torchsummary import summary
      import torch, torchvision, random
      from pytorch_nndct.apis import Inspector, torch_quantizer
      import torchvision.transforms as transforms
      from torchvision import models
      from tqdm import tqdm
      

      然后导入预训练模型,并查看网络结构:

      model = models.resnet18(pretrained=True) # 载入预训练模型
      summary(model, (3, 224, 224))
      ---以下为执行结果
      ----------------------------------------------------------------
              Layer (type)               Output Shape         Param #
      ================================================================
                  Conv2d-1         [-1, 64, 112, 112]           9,408
             BatchNorm2d-2         [-1, 64, 112, 112]             128
                    ReLU-3         [-1, 64, 112, 112]               0
               MaxPool2d-4           [-1, 64, 56, 56]               0
                  Conv2d-5           [-1, 64, 56, 56]          36,864
             BatchNorm2d-6           [-1, 64, 56, 56]             128
                    ReLU-7           [-1, 64, 56, 56]               0
                  Conv2d-8           [-1, 64, 56, 56]          36,864
             BatchNorm2d-9           [-1, 64, 56, 56]             128
                   ReLU-10           [-1, 64, 56, 56]               0
             BasicBlock-11           [-1, 64, 56, 56]               0
                 Conv2d-12           [-1, 64, 56, 56]          36,864
            BatchNorm2d-13           [-1, 64, 56, 56]             128
                   ReLU-14           [-1, 64, 56, 56]               0
                 Conv2d-15           [-1, 64, 56, 56]          36,864
            BatchNorm2d-16           [-1, 64, 56, 56]             128
                   ReLU-17           [-1, 64, 56, 56]               0
             BasicBlock-18           [-1, 64, 56, 56]               0
                 Conv2d-19          [-1, 128, 28, 28]          73,728
            BatchNorm2d-20          [-1, 128, 28, 28]             256
                   ReLU-21          [-1, 128, 28, 28]               0
                 Conv2d-22          [-1, 128, 28, 28]         147,456
            BatchNorm2d-23          [-1, 128, 28, 28]             256
                 Conv2d-24          [-1, 128, 28, 28]           8,192
            BatchNorm2d-25          [-1, 128, 28, 28]             256
                   ReLU-26          [-1, 128, 28, 28]               0
             BasicBlock-27          [-1, 128, 28, 28]               0
                 Conv2d-28          [-1, 128, 28, 28]         147,456
            BatchNorm2d-29          [-1, 128, 28, 28]             256
                   ReLU-30          [-1, 128, 28, 28]               0
                 Conv2d-31          [-1, 128, 28, 28]         147,456
            BatchNorm2d-32          [-1, 128, 28, 28]             256
                   ReLU-33          [-1, 128, 28, 28]               0
             BasicBlock-34          [-1, 128, 28, 28]               0
                 Conv2d-35          [-1, 256, 14, 14]         294,912
            BatchNorm2d-36          [-1, 256, 14, 14]             512
                   ReLU-37          [-1, 256, 14, 14]               0
                 Conv2d-38          [-1, 256, 14, 14]         589,824
            BatchNorm2d-39          [-1, 256, 14, 14]             512
                 Conv2d-40          [-1, 256, 14, 14]          32,768
            BatchNorm2d-41          [-1, 256, 14, 14]             512
                   ReLU-42          [-1, 256, 14, 14]               0
             BasicBlock-43          [-1, 256, 14, 14]               0
                 Conv2d-44          [-1, 256, 14, 14]         589,824
            BatchNorm2d-45          [-1, 256, 14, 14]             512
                   ReLU-46          [-1, 256, 14, 14]               0
                 Conv2d-47          [-1, 256, 14, 14]         589,824
            BatchNorm2d-48          [-1, 256, 14, 14]             512
                   ReLU-49          [-1, 256, 14, 14]               0
             BasicBlock-50          [-1, 256, 14, 14]               0
                 Conv2d-51            [-1, 512, 7, 7]       1,179,648
            BatchNorm2d-52            [-1, 512, 7, 7]           1,024
                   ReLU-53            [-1, 512, 7, 7]               0
                 Conv2d-54            [-1, 512, 7, 7]       2,359,296
            BatchNorm2d-55            [-1, 512, 7, 7]           1,024
                 Conv2d-56            [-1, 512, 7, 7]         131,072
            BatchNorm2d-57            [-1, 512, 7, 7]           1,024
                   ReLU-58            [-1, 512, 7, 7]               0
             BasicBlock-59            [-1, 512, 7, 7]               0
                 Conv2d-60            [-1, 512, 7, 7]       2,359,296
            BatchNorm2d-61            [-1, 512, 7, 7]           1,024
                   ReLU-62            [-1, 512, 7, 7]               0
                 Conv2d-63            [-1, 512, 7, 7]       2,359,296
            BatchNorm2d-64            [-1, 512, 7, 7]           1,024
                   ReLU-65            [-1, 512, 7, 7]               0
             BasicBlock-66            [-1, 512, 7, 7]               0
      AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
                 Linear-68                 [-1, 1000]         513,000
      ================================================================
      Total params: 11,689,512
      Trainable params: 11,689,512
      Non-trainable params: 0
      ----------------------------------------------------------------
      Input size (MB): 0.57
      Forward/backward pass size (MB): 62.79
      Params size (MB): 44.59
      Estimated Total Size (MB): 107.96
      ----------------------------------------------------------------
      

      可以原始看到最后一层有1000个特征输出,对应1000分类。我们要做的,就是使用特征提取方法,修改最后一层(FC),实现一个30分类的特征输出。

      修改全链接层,然后查看修改结果:

      print('Original output layer:')
      print(model.fc)
      #输入特征数(in_features)保持不变,输出特征数(out_features)设置为10
      model.fc = torch.nn.Linear(model.fc.in_features, 30)
      print('New output layer:')
      print(model.fc)
      ---以下为执行结果
      Original output layer:
      Linear(in_features=512, out_features=1000, bias=True)
      New output layer:
      Linear(in_features=512, out_features=30, bias=True)
      

      配置优化器,只微调输出层(FC),然后执行训练:

      # 只微调训练最后一层全连接层的参数,其它层冻结
      optimizer = optim.Adam(model.fc.parameters())
      # 遍历每个 EPOCH
      for epoch in tqdm(range(EPOCHS)):
          model.train()
          for images, labels in train_loader:  # 获取训练集的一个 batch,包含数据和标注
              images = images.to(device)
              labels = labels.to(device)
              outputs = model(images)           # 前向预测,获得当前 batch 的预测结果
              loss = criterion(outputs, labels) # 比较预测结果和标注,计算当前 batch 的交叉熵损失函数
              
              optimizer.zero_grad()
              loss.backward()                   # 损失函数对神经网络权重反向传播求梯度
              optimizer.step()                  # 优化更新神经网络权重
      ---以下为执行结果
      100%|██████████| 20/20 [03:04
微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon