超分辨率第三章-SRGAN
SRGNN是2017年提出的模型,首次使用GAN在超分辨领域。
参考文献:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
一.模型介绍
先前超分辨率模型的局限性:虽然具有较高的峰值信噪比,但它们通常缺乏高频细节,并且在感知上不令人满意。
SRGAN中的生成网络就是SRResNet网络,其以ResNet块为基本结构,是一个具有深度的SR网络。生成网络使用感知损失进行训练,而不是传统的MSE方法,它使用预训练之后的VGG-16网络产生的feature map级进行计算,再加上本身生成网络带有的对抗损失。此外判别器也需要去训练,两个网络结合起来就是我们的SRGAN网络。
SRGNN提出了感知损失函数(Perceptual loss function),包括对抗损失与内容损失,在感知质量方面有了极大改进。MOS(平均意见得分)很高。
1.感知损失函数(Perceptual loss function)
- 由基于VGG-16的内容损失函数和GAN的对抗损失函数组成
- 内容损失函数
- 采用预训练好的VGG-16网络的特征向量,使得生成网络的结果通过VGG某一层之后产生的feature map和原始高分辨率图像通过VGG-16网络产生的feature map做loss,作者指出这种loss更能反应图片之间的感知相似度。
- 对抗损失函数
2.论文贡献
- 深度RESNet(SRRESNet)针对MSE进行了优化,通过PSNR和结构相似度(SSIM)来测量图像SR的高放大因子
- SRGAN,是一种基于GAN的网络,针对一种新的感知损失进行了优化。用在VGG网络的特征映射上计算的损失来代替基于MSE的内容损失,该特征映射对像素空间的变化更加不变,这样相较于原来像素损失超分的图像更具有纹理等高频细节.
- 对来自三个公共基准数据集的图像进行广泛的平均意见得分(MOS)测试,证实SRGAN在很大程度上是高放大因子(4×)的照片真实感SR图像估计的最新技术, 即超分后的图像更加接近自然图像.
3.模型结构
二.数据集
- 训练集使用:VOC2012(训练数据集包含16700张图片,验证数据集包含425张图片)
- 测试集使用:Set5 Set14 BSD100 Urban100 SunHays80
三.模型搭建
1.生成器结构
- 输入一张低分辨率图片,生成高分辨率图片
1 | class Generator(nn.Module): |
2.判别器结构
- 输入图片,判断真假
1 | class Discriminator(nn.Module): |
3.resnet结构
1 | class ResidualBlock(nn.Module): |
4.上采样结构
1 | class UpsampleBLock(nn.Module): |
四.损失函数
1 | import torch |
五.模型训练
1.载入数据集与初始化网络
1 | opt = parser.parse_args()#用户可以在命令行中指定一些参数(如裁剪大小、放大因子、训练轮数等),这些参数将被存储在opt对象中 |
2.训练阶段
- 生成器(Generator)和判别器(Discriminator)交替训练
1 | for epoch in range(1, NUM_EPOCHS + 1): |
3.验证阶段
1 | netG.eval() #生成器验证模式 |
4.保存文件
1 | #将判别器和生成器的参数保存到指定文件 |
五.模型测试
1 | import argparse |