【语义分割】ST

慈云数据 8个月前 (03-13) 技术支持 48 0

【语义分割】ST_Unet论文 逐步代码解读

文章目录

  • 【语义分割】ST_Unet论文 逐步代码解读
    • 一、代码整体解读
    • 二、辅助Decode代码框架
      • 2.1 混合transformer和cnn的模型
      • 2.2 Swin transformer 部分
      • 2.3 FCM 部分
      • 三、主Decode代码框架
        • 3.1 基本卷积模块
        • 3.2 RAM
        • 3.3 输出参数
        • 四、Encode代码
          • 4.1 block函数解析
          • 4.2 上采样还原

            一、代码整体解读

            [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UAEkMEUl-1678889964762)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310143528765.png)]

            主要工程文件为这5个

            分别作用为:

            • 构造相应的deform 卷积
            • DCNN的残差网络
            • 编写相应的配置文件,可以改变相应参数
            • 模型的主函数和主框架
            • 模型的连接部分



              二、辅助Decode代码框架

              [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xB9iqcKa-1678889964763)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310145258457.png)]

              代码框架由3部分组成,encode,decode和decode中将图像还原成语义分割预测图

              • Transformer(config, img_size) 组成编码部分,包含主编码的DCNN和辅助的transformer
              • DecoderCup(config)组成解码部分,图像还原为[N,64,128,128]
              • SegmentationHead将图像变成6分类的[N,6,256,256]的图像

                2.1 混合transformer和cnn的模型

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-C1JDIQD6-1678889964763)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310150618037.png)]

                整体思路是这样的,decode一共分为4个阶段

                主要用空的数组来保存每一个阶段的输出值,与DCNN在每一个阶段通过RAM进行连接

                在class TransResNetV2(nn.Module)函数中进行相应的具体编写

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fFrLblTw-1678889964763)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310152428267.png)]

                相应的RAM操作示意图

                An和Sn分别表示第n阶段主编码器和辅助编码器的输出

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ua68cdLE-1678889964764)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310152957811.png)]

                一共组成分为4部分,将每一层都进行相应的整合,最后放在数组里面

                2.2 Swin transformer 部分

                将读入的数据进行打平操作

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OYQ7SY3i-1678889964764)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310154540561.png)]

                embeddings(trans_x)

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Y41JX67e-1678889964764)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310161441181.png)]

                  这部分操作的一般情况下的Swin transformer一样,同样满足(2,2,6,2)的层数结构,只不过是,加入了相应的残差结构,经过了扁平化操作后的数据类型为[12, 4096, 96]

                具体的transformer操作在这部分进行

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3uO7glYk-1678889964764)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310161944441.png)]

                在这个函数中主要是transformer块和SIM的残差组合

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VWJFuFdd-1678889964765)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310163215108.png)]

                  在这步之后,x就可以就是组成的tranform块输出的格式,其中是由shift_size来进行判断是W-MSA,还是SW-MSA,来进行的窗口移动,还是就单纯的结构的划分

                shift_size=0 if (i % 2 == 0) else window_size // 2,  # 判断是不是SW_MSA
                

                起了决定性的判断作用

                这段代码进行执行4次,将每次执行的结果进行保存

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-X8GyKSQg-1678889964765)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310170340544.png)]

                2.3 FCM 部分

                if (i_layer  
                

                在每一个Swin transformer阶段都进行了下采样,除了最后一个阶段

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4UiGLHwU-1678889964765)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310172543159.png)]

                结构示意图

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hp318ekr-1678889964766)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310172916320.png)]

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5qY2UKUs-1678889964766)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310173458842.png)]

                整体的代码逻辑都是按照这个思路来的,来进行的整合和结合



                三、主Decode代码框架

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-p157EjKs-1678889964766)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310174011667.png)]

                首先进行的是root函数,主打的是一个对图片进行预先处理

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-JhTJnNgQ-1678889964766)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310190150355.png)]

                  对图片进行相应的变形,主要还是三步走的对策,卷积,归一,relu。进行DCNN卷积网络时基本都是这样进行的

                3.1 基本卷积模块

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bNtANlad-1678889964766)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310190743797.png)]

                  将这两个归为一个操作,body里面是几个卷积的模块config.resnet.num_layers = (3, 4, 6, 3)组成的,重复的次数由设定好的值来进行重复

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DgRpAOxk-1678889964767)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310191345796.png)]

                   PreActBottleneck(nn.Module) 里面的值就是很单纯的DCNN的卷积网络的堆积

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Zc9qXqRd-1678889964767)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310191640918.png)]

                  在这里DCN加了一个, DeformConv2d,这个函数是自己编写的,一个可变形的卷积操作,其实他本质上和普通的卷积操作一样

                后面也是相同的操作,通过RAM模块将相应的结果组合在一起

                3.2 RAM

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pvpyR3hG-1678889964767)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310211839414.png)]

                  输入分为了主编码器和辅助编码器,总共的结合组成为3种,将不同变化的进行拼接

                [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OPLAKkjh-1678889964767)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310212655800.png)]

                相应参数:

                • x 原始参数
                • short 经过了注意力通道
                • s1 tranformer辅助通道过来的数据

                  3.3 输出参数

                  输出参数值主要分为两类:

                  • 结合所有参数的X [N, 32, 256, 256]
                  • 每个阶段提取出来的特征数据features

                    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9rD0FkJk-1678889964768)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310213311694.png)]

                      将这两个数据进行带入Encode,进行解码,可以逐步还原成原始图像



                    四、Encode代码

                    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nVjVVQLc-1678889964768)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310213517652.png)]

                    代码主要分为两步来实现:

                    • x 的卷积上采样
                    • x与skip的融合后,进行相应的卷积操作

                        skip是每个特征层的进过RAM后的保存数据,所有的融合卷积操作在block中完成

                      4.1 block函数解析

                      [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ksA8GBvs-1678889964768)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310233943010.png)]

                        在连接阶段主要是conv1和conv2,这两个函数, 进行上采样来保存维度一致,使他可以cat在一起 conv3和conv4在连接完成后,进行相应的上采样环节来使图像还原成原来的[n, 6, 256, 256]

                      4.2 上采样还原

                      [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vdcEgMW9-1678889964768)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310235005427.png)]

                      这部分代码将这里独立出去了

                      在这里x的输入参数应该是(N,16,256,256)

                      在进行了一次卷积和上采用后,就恢复成了原始图像

微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon