超分辨率第五章-SRDenseNet
SRDenseNet发表于2017年
论文:Super-Resolution Using Dense Skip Connections
参考博客:超分之SRDenseNet
代码位置:F:\Github下载\SRDenseNet-pytorch-master
一.模型介绍
1.模型特点
- 将稠密块作为SRDenseNet的基本结构。
- Skip connection将各个level的特征直接与图像重建输入端相连,此外Dense skip connection可以缓解网络深度带来的梯度消失问题。
2.Dense块
Dense块各个层之间的skip connection是通过concat在一起的,而Resnet块是通过求和加在一起的。这样的好处在于可以缓解梯度消失的问题以及加强了信息在各个layer之间的流动
整个块分为8个layers,第 i 个layer产生16 × i 张feature map,最后一层产生128张特征图作为块的输出,一共有八个Dense块。
3.模型架构
- 前向传播过程
- 首先,输入的低分辨率图像经过第一层CNN提取低层的特征信息。
- 然后经过8个Dense块提取高层特征信息,通过skip connection的方式将各个level的特征信息相连,滤波器的大小统一设置成3×3。
- 再经过一层Bottleneck layer,降低前面特征图连接导致图像张数(通道数)太多而带来的高计算复杂度问题,通过1 × 1 卷积层进行缩减。
- 之后经过2个反卷积子网络,进行上采样。
- 最后使用3×3卷积核以及输出通道为1的卷积层进行重建。
- 图示
二.数据集
训练集使用coco2017,验证集使用set5
三.模型架构
1.定义特征提取层
1 | class ConvLayer(nn.Module): |
2.定义dense块结构
- 定义一个dense块中单独的层,每个层均为33\16结构
1 | class DenseLayer(nn.Module): |
- 定义dense块
1 | class DenseBlock(nn.Module): |
3.整体架构
1 | class SRDenseNet(nn.Module): |
四.训练模型
1 | model = SRDenseNet(growth_rate=args.growth_rate, num_blocks=args.num_blocks, num_layers=args.num_layers).to(device) |
五.测试模型
1 | cudnn.benchmark = True |