千家信息网

如何分析Pytorch中UNet网络结构以及代码编写

发表于:2025-12-01 作者:千家信息网编辑
千家信息网最后更新 2025年12月01日,这篇文章给大家介绍如何分析Pytorch中UNet网络结构以及代码编写,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。一、前言Windows环境开发,环境情况如下:开发环境:Win
千家信息网最后更新 2025年12月01日如何分析Pytorch中UNet网络结构以及代码编写

这篇文章给大家介绍如何分析Pytorch中UNet网络结构以及代码编写,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。

一、前言

Windows环境开发,环境情况如下:

开发环境:Windows

开发语言:Python3.7.4

框架版本:Pytorch2.3.0

CUDA:10.2

cuDNN:7.6.0

主要讲解UNet网络结构,以及相应代码的代码编写

PS:文中出现的所有代码,均可在我的github上下载,欢迎Follow、Star:点击查看

二、UNet网络结构

在语义分割领域,基于深度学习的语义分割算法开山之作是FCN(Fully Convolutional Networks for Semantic Segmentation),而UNet是遵循FCN的原理,并进行了相应的改进,使其适应小样本的简单分割问题。

UNet论文地址:点击查看

研究一个深度学习算法,可以先看网络结构,看懂网络结构后,再Loss计算方法、训练方法等。本文主要针对UNet的网络结构进行讲解,其它内容会在后续章节进行说明。

1、网络结构原理

UNet最早发表在2015的MICCAI会议上,4年多的时间,论文引用量已经达到了9700多次。

UNet成为了大多做医疗影像语义分割任务的baseline,同时也启发了大量研究者对于U型网络结构的研究,发表了一批基于UNet网络结构的改进方法的论文。

UNet网络结构,最主要的两个特点是:U型网络结构和Skip Connection跳层连接。

这种"摞在一起"的操作,就是Concat。

同样道理,对于feature map,一个大小为256*256*64的feature map,即feature map的w(宽)为256,h(高)为256,c(通道数)为64。和一个大小为256*256*32的feature map进行Concat融合,就会得到一个大小为256*256*96的feature map。

在实际使用中,Concat融合的两个feature map的大小不一定相同,例如256*256*64的feature map和240*240*32的feature map进行Concat。

这种时候,就有两种办法:

第一种:将大256*256*64的feature map进行裁剪,裁剪为240*240*64的feature map,比如上下左右,各舍弃8 pixel,裁剪后再进行Concat,得到240*240*96的feature map。

第二种:将小240*240*32的feature map进行padding操作,padding为256*256*32的feature map,比如上下左右,各补8 pixel,padding后再进行Concat,得到256*256*96的feature map。

UNet采用的Concat方案就是第二种,将小的feature map进行padding,padding的方式是补0,一种常规的常量填充。

2、代码

有些朋友可能对Pytorch不太了解,推荐一个快速入门的官方教程。一个小时,你就可以掌握一些基本概念和Pytorch代码编写方法。

Pytorch官方基础:点击查看

我们将整个UNet网络拆分为多个模块进行讲解。

DoubleConv模块:

先看下连续两次的卷积操作。

从UNet网络中可以看出,不管是下采样过程还是上采样过程,每一层都会连续进行两次卷积操作,这种操作在UNet网络中重复很多次,可以单独写一个DoubleConv模块:

import torch.nn as nnclass DoubleConv(nn.Module):    """(convolution => [BN] => ReLU) * 2"""    def __init__(self, in_channels, out_channels):        super().__init__()        self.double_conv = nn.Sequential(            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0),            nn.BatchNorm2d(out_channels),            nn.ReLU(inplace=True),            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=0),            nn.BatchNorm2d(out_channels),            nn.ReLU(inplace=True)        )    def forward(self, x):        return self.double_conv(x)

解释下,上述的Pytorch代码:torch.nn.Sequential是一个时序容器,Modules 会以它们传入的顺序被添加到容器中。比如上述代码的操作顺序:卷积->BN->ReLU->卷积->BN->ReLU。

DoubleConv模块的in_channels和out_channels可以灵活设定,以便扩展使用。

如上图所示的网络,in_channels设为1,out_channels为64。

输入图片大小为572*572,经过步长为1,padding为0的3*3卷积,得到570*570的feature map,再经过一次卷积得到568*568的feature map。

计算公式:O=(H−F+2×P)/S+1

H为输入feature map的大小,O为输出feature map的大小,F为卷积核的大小,P为padding的大小,S为步长。

Down模块:

UNet网络一共有4次下采样过程,模块化代码如下:

class Down(nn.Module):    """Downscaling with maxpool then double conv"""    def __init__(self, in_channels, out_channels):        super().__init__()        self.maxpool_conv = nn.Sequential(            nn.MaxPool2d(2),            DoubleConv(in_channels, out_channels)        )    def forward(self, x):        return self.maxpool_conv(x)

这里的代码很简单,就是一个maxpool池化层,进行下采样,然后接一个DoubleConv模块。

至此,UNet网络的左半部分的下采样过程的代码都写好了,接下来是右半部分的上采样过程

Up模块:

上采样过程用到的最多的当然就是上采样了,除了常规的上采样操作,还有进行特征的融合。

这块的代码实现起来也稍复杂一些:

class Up(nn.Module):    """Upscaling then double conv"""    def __init__(self, in_channels, out_channels, bilinear=True):        super().__init__()        # if bilinear, use the normal convolutions to reduce the number of channels        if bilinear:            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        else:            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)        self.conv = DoubleConv(in_channels, out_channels)    def forward(self, x1, x2):        x1 = self.up(x1)        # input is CHW        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,                        diffY // 2, diffY - diffY // 2])        # if you have padding issues, see        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd        x = torch.cat([x2, x1], dim=1)        return self.conv(x)

代码复杂一些,我们可以分开来看,首先是__init__初始化函数里定义的上采样方法以及卷积采用DoubleConv。上采样,定义了两种方法:Upsample和ConvTranspose2d,也就是双线性插值反卷积

双线性插值很好理解,示意图:

熟悉双线性插值的朋友对于这幅图应该不陌生,简单地讲:已知Q11、Q12、Q21、Q22四个点坐标,通过Q11和Q21求R1,再通过Q12和Q22求R2,最后通过R1和R2求P,这个过程就是双线性插值。

对于一个feature map而言,其实就是在像素点中间补点,补的点的值是多少,是由相邻像素点的值决定的。

反卷积,顾名思义,就是反着卷积。卷积是让featuer map越来越小,反卷积就是让feature map越来越大,示意图:

下面蓝色为原始图片,周围白色的虚线方块为padding结果,通常为0,上面绿色为卷积后的图片。

这个示意图,就是一个从2*2的feature map->4*4的feature map过程。

在forward前向传播函数中,x1接收的是上采样的数据,x2接收的是特征融合的数据。特征融合方法就是,上文提到的,先对小的feature map进行padding,再进行concat。

OutConv模块:

用上述的DoubleConv模块、Down模块、Up模块就可以拼出UNet的主体网络结构了。UNet网络的输出需要根据分割数量,整合输出通道,结果如下图所示:

操作很简单,就是channel的变换,上图展示的是分类为2的情况(通道为2)。

虽然这个操作很简单,也就调用一次,为了美观整洁,也封装一下吧。

class OutConv(nn.Module):    def __init__(self, in_channels, out_channels):        super(OutConv, self).__init__()        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)    def forward(self, x):        return self.conv(x)

至此,UNet网络用到的模块都已经写好,我们可以将上述的模块代码都放到一个unet_parts.py文件里,然后再创建unet_model.py,根据UNet网络结构,设置每个模块的输入输出通道个数以及调用顺序,编写如下代码:

""" Full assembly of the parts to form the complete network """"""Refer https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py"""import torch.nn.functional as Ffrom unet_parts import *class UNet(nn.Module):    def __init__(self, n_channels, n_classes, bilinear=False):        super(UNet, self).__init__()        self.n_channels = n_channels        self.n_classes = n_classes        self.bilinear = bilinear        self.inc = DoubleConv(n_channels, 64)        self.down1 = Down(64, 128)        self.down2 = Down(128, 256)        self.down3 = Down(256, 512)        self.down4 = Down(512, 1024)        self.up1 = Up(1024, 512, bilinear)        self.up2 = Up(512, 256, bilinear)        self.up3 = Up(256, 128, bilinear)        self.up4 = Up(128, 64, bilinear)        self.outc = OutConv(64, n_classes)    def forward(self, x):        x1 = self.inc(x)        x2 = self.down1(x1)        x3 = self.down2(x2)        x4 = self.down3(x3)        x5 = self.down4(x4)        x = self.up1(x5, x4)        x = self.up2(x, x3)        x = self.up3(x, x2)        x = self.up4(x, x1)        logits = self.outc(x)        return logits    if __name__ == '__main__':    net = UNet(n_channels=3, n_classes=1)    print(net)

使用命令python unet_model.py,如果没有错误,你会得到如下结果:

UNet(  (inc): DoubleConv(    (double_conv): Sequential(      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (2): ReLU(inplace=True)      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)      (5): ReLU(inplace=True)    )  )  (down1): Down(    (maxpool_conv): Sequential(      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)      (1): DoubleConv(        (double_conv): Sequential(          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)          (2): ReLU(inplace=True)          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)          (5): ReLU(inplace=True)        )      )    )  )  (down2): Down(    (maxpool_conv): Sequential(      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)      (1): DoubleConv(        (double_conv): Sequential(          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)          (2): ReLU(inplace=True)          (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))          (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)          (5): ReLU(inplace=True)        )      )    )  )  (down3): Down(    (maxpool_conv): Sequential(      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)      (1): DoubleConv(        (double_conv): Sequential(          (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)          (2): ReLU(inplace=True)          (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))          (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)          (5): ReLU(inplace=True)        )      )    )  )  (down4): Down(    (maxpool_conv): Sequential(      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)      (1): DoubleConv(        (double_conv): Sequential(          (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))          (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)          (2): ReLU(inplace=True)          (3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))          (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)          (5): ReLU(inplace=True)        )      )    )  )  (up1): Up(    (up): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))    (conv): DoubleConv(      (double_conv): Sequential(        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1))        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)        (2): ReLU(inplace=True)        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)        (5): ReLU(inplace=True)      )    )  )  (up2): Up(    (up): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))    (conv): DoubleConv(      (double_conv): Sequential(        (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1))        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)        (2): ReLU(inplace=True)        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)        (5): ReLU(inplace=True)      )    )  )  (up3): Up(    (up): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))    (conv): DoubleConv(      (double_conv): Sequential(        (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)        (2): ReLU(inplace=True)        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)        (5): ReLU(inplace=True)      )    )  )  (up4): Up(    (up): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))    (conv): DoubleConv(      (double_conv): Sequential(        (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)        (2): ReLU(inplace=True)        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)        (5): ReLU(inplace=True)      )    )  )  (outc): OutConv(    (conv): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))  ))

关于如何分析Pytorch中UNet网络结构以及代码编写就分享到这里了,希望以上内容可以对大家有一定的帮助,可以学到更多知识。如果觉得文章不错,可以把它分享出去让更多的人看到。

网络 代码 模块 网络结构 结构 卷积 就是 大小 过程 方法 通道 插值 输出 上下 内容 图片 特征 环境 示意图 结果 数据库的安全要保护哪些东西 数据库安全各自的含义是什么 生产安全数据库录入 数据库的安全性及管理 数据库安全策略包含哪些 海淀数据库安全审计系统 建立农村房屋安全信息数据库 易用的数据库客户端支持安全管理 连接数据库失败ssl安全错误 数据库的锁怎样保障安全 没有数据库怎么学习sql 软件开发工程师的专业技能怎么写 闵行区项目数据库销售价格 服务器内存台式机用不了吗 腾讯云服务器彻底删除还能恢复吗 广东省服务器托管规定云主机 line服务器地址怎么配置 嘉定区项目数据库服务销售 金山区市场软件开发咨询热线 数据库储存策略 苹果验证id连接不上服务器 镇江大型软件开发 国防科技大学网络安全技术怎样 最强族长与中国网络安全 打印机服务器如何远程打印 国内数据库市场 河南以比特科技存储服务器 无代码软件开发输入框 安徽苹果软件开发哪里好 北京crm软件开发中心 软件开发中心产品 计算机网络技术张海霞的笔记 有关网络安全的新闻报道小学 苹果验证id连接不上服务器 网络上传发票服务器异常 csgo显示韩国服务器脱机状态 dell服务器 冗余电源 网络安全与数据保护概念股 互联网网络安全防范措施 我国哪些网络技术是世界前沿的
0