nnUnetV2:使用自定义网络

慈云数据 2024-04-05 技术支持 92 0
前言

2023年3月17日,nnUnet迎来重大更新。紧接着不久,Facebook推出大一统多模态分割模型Segment Anything。喜忧参半,喜的是一直关注的医学图像分割仓库更新了,忧的是以后分割的赛道变了,小打小闹的堆模块水文章估计不行了,各种微雕大模型的工作会逐渐应用到医学图像分割领域。

nnUnetV2:使用自定义网络
(图片来源网络,侵删)

闲话少说,回到本文的主题:怎么在新版nnUnetV2使用自定义网络。

基本知识

nnUnetV2默认使用深监督,意味着自定义网络输出应为一个列表形式。然而,在网络推理时,我们只需要最高分辨率的输出,不需要多层次输出。在nnUnetV2官方实现中,使用deep_supervision参数控制是否多层次输出。综上所述,自定义网络需要满足两个条件:

nnUnetV2:使用自定义网络
(图片来源网络,侵删)
  • 支持多层次输出

  • 使用变量deep_supervision控制输出类型

    实战

    这里提供一种对已有网络包装的方法,仅供参考

    Python
    import torch.nn as nn
    class custom_net(nn.Module):
        def __init__(self,):
            super(custom_net, self).__init__()
            self.deep_supervision = True
            # 使用你自己的网络
            self.model = None
        def forward(self, x):
            output = self.model(x)
            if self.deep_supervision:
                return [output, ]
            else:
                return output
    

    将自定义网络嵌套进主框架。打开文件 nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py

    替换函数 build_network_architecture

    Python
        def build_network_architecture(self, plans_manager: PlansManager,
                                       dataset_json,
                                       configuration_manager: ConfigurationManager,
                                       num_input_channels,
                                       enable_deep_supervision: bool = True) -> nn.Module:
            from dynamic_network_architectures.initialization.weight_init import InitWeights_He
            model = custom_net()
            model.apply(InitWeights_He(1e-2))
            return model
    
微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon