超分辨率第三章-SRGAN

超分辨率第三章-SRGAN

SRGNN是2017年提出的模型,首次使用GAN在超分辨领域。

参考文献:Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

参考博客:基于pytorch的SRGAN实现(全网最细!!!)

一.模型介绍

  • 先前超分辨率模型的局限性:虽然具有较高的峰值信噪比,但它们通常缺乏高频细节,并且在感知上不令人满意。

  • SRGAN中的生成网络就是SRResNet网络,其以ResNet块为基本结构,是一个具有深度的SR网络。生成网络使用感知损失进行训练,而不是传统的MSE方法,它使用预训练之后的VGG-16网络产生的feature map级进行计算,再加上本身生成网络带有的对抗损失。此外判别器也需要去训练,两个网络结合起来就是我们的SRGAN网络。

  • SRGNN提出了感知损失函数(Perceptual loss function),包括对抗损失与内容损失,在感知质量方面有了极大改进。MOS(平均意见得分)很高。

1.感知损失函数(Perceptual loss function)

    • 由基于VGG-16的内容损失函数和GAN的对抗损失函数组成
  • 内容损失函数
    • 采用预训练好的VGG-16网络的特征向量,使得生成网络的结果通过VGG某一层之后产生的feature map和原始高分辨率图像通过VGG-16网络产生的feature map做loss,作者指出这种loss更能反应图片之间的感知相似度。
  • 对抗损失函数

2.论文贡献

  • 深度RESNet(SRRESNet)针对MSE进行了优化,通过PSNR和结构相似度(SSIM)来测量图像SR的高放大因子
  • SRGAN,是一种基于GAN的网络,针对一种新的感知损失进行了优化。用在VGG网络的特征映射上计算的损失来代替基于MSE的内容损失,该特征映射对像素空间的变化更加不变,这样相较于原来像素损失超分的图像更具有纹理等高频细节.
  • 对来自三个公共基准数据集的图像进行广泛的平均意见得分(MOS)测试,证实SRGAN在很大程度上是高放大因子(4×)的照片真实感SR图像估计的最新技术, 即超分后的图像更加接近自然图像.

3.模型结构

pAKsjR1.png

二.数据集

  • 训练集使用:VOC2012(训练数据集包含16700张图片,验证数据集包含425张图片)
  • 测试集使用:Set5 Set14 BSD100 Urban100 SunHays80

三.模型搭建

1.生成器结构

  • 输入一张低分辨率图片,生成高分辨率图片
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class Generator(nn.Module):
def __init__(self, scale_factor):
upsample_block_num = int(math.log(scale_factor, 2)) #计算上采样块的数量,输入放大因子为4,则有两个上采样块

super(Generator, self).__init__()
self.block1 = nn.Sequential( #首先放大维度,特征提取
nn.Conv2d(3, 64, kernel_size=9, padding=4),
nn.PReLU()
)
self.block2 = ResidualBlock(64) #5个残差网络块,特征提取
self.block3 = ResidualBlock(64)
self.block4 = ResidualBlock(64)
self.block5 = ResidualBlock(64)
self.block6 = ResidualBlock(64)
self.block7 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64)
)
block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)] #定义了一个列表,进行上采样两次提高分辨率,每次提高2倍,共提升4倍
block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))#向该列表添加一个卷积层
self.block8 = nn.Sequential(*block8)#将这些层组合成一个顺序模型

def forward(self, x):
block1 = self.block1(x)
block2 = self.block2(block1)
block3 = self.block3(block2)
block4 = self.block4(block3)
block5 = self.block5(block4)
block6 = self.block6(block5)
block7 = self.block7(block6)
block8 = self.block8(block1 + block7) #特征融合相加

return (torch.tanh(block8) + 1) / 2 #使用torch.tanh函数将输出值映射到[-1, 1]区间,并通过(torch.tanh(block8) + 1) / 2将其缩放到[0, 1]区间,这是图像数据常见的归一化范围。

2.判别器结构

  • 输入图片,判断真假
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),

nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),

nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),

nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),

nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),

nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),

nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),

nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),

nn.AdaptiveAvgPool2d(1), #一个自适应平均池化层,它该层都会将其空间维度(高度和宽度)压缩到1x1,而保持通道数不变
nn.Conv2d(512, 1024, kernel_size=1),
nn.LeakyReLU(0.2),
nn.Conv2d(1024, 1, kernel_size=1)
)

def forward(self, x):
batch_size = x.size(0)
return torch.sigmoid(self.net(x).view(batch_size)) #显示输入图片为真实的概率,将最终的输出(原本是一个形状为(batch_size, 1, 1, 1)的四维张量)展平成一个一维张量,其长度为批次大小,其元素对应于批次中每个样本的判别结果

3.resnet结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)

def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)

return x + residual

4.上采样结构

1
2
3
4
5
6
7
8
9
10
11
12
class UpsampleBLock(nn.Module):
def __init__(self, in_channels, up_scale):
super(UpsampleBLock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)#up_scale ** 2是因为之后的像素重排(pixel shuffle)操作会将通道数重排成空间维度,以达到上采样的效果,此时增加维度可以保持整体维度不变。
self.pixel_shuffle = nn.PixelShuffle(up_scale) #这个层将输入特征图的通道数重新排列成空间维度,以实现上采样(新的通道数将是原始通道数除以up_scale^2,而高度和宽度将会乘以up_scale)
self.prelu = nn.PReLU()

def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.prelu(x)
return x

四.损失函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch import nn
from torchvision.models.vgg import vgg16

class GeneratorLoss(nn.Module):
def __init__(self):
super(GeneratorLoss, self).__init__()
#使用预训练的 VGG16 模型来构建特征提取网络
vgg = vgg16(pretrained=True)
#选择 VGG16 模型的前 31 层作为损失网络,并将其设置为评估模式(不进行梯度更新)
loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
#冻结其参数,不进行梯度更新
for param in loss_network.parameters():
param.requires_grad = False
#定义VGG16网络
self.loss_network = loss_network
#定义均方误差损失函数,计算生成器生成图像与目标图像之间的均方误差损失
self.mse_loss = nn.MSELoss()
#定义总变差损失函数,计算生成器生成图像的总变差损失,用于平滑生成的图像
self.tv_loss = TVLoss()

def forward(self, out_labels, out_images, target_images): #分别传入判别器判定概率,伪高分辨率图像,真图像
# Adversarial Loss(对抗损失):使生成的图像更接近真实图像,目标是最小化生成器对图像的判别结果的平均值与 1(真实值)的差距
adversarial_loss = torch.mean(1 - out_labels)
# Perception Loss(感知损失):计算生成图像和目标图像在vgg-16网络中提取的特征之间的均方误差损失
perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
# Image Loss(图像损失):计算生成图像和目标图像之间的均方误差损失
image_loss = self.mse_loss(out_images, target_images)
# TV Loss(总变差损失):计算生成图像的总变差损失,用于平滑生成的图像
tv_loss = self.tv_loss(out_images)
# 返回生成器的总损失,四个损失项加权求和
return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss


# 总变差损失衡量的是图像中相邻像素之间的差异程度
# 在模型训练过程中,将总变差损失作为损失函数的一部分,可以引导模型在优化过程中考虑图像的空间连续性,从而生成更加符合人类直觉的图像。
class TVLoss(nn.Module):
def __init__(self, tv_loss_weight=1):
super(TVLoss, self).__init__()
self.tv_loss_weight = tv_loss_weight

def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self.tensor_size(x[:, :, 1:, :])
count_w = self.tensor_size(x[:, :, :, 1:])
# 计算水平方向上的总变差损失
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
# 计算垂直方向上的总变差损失
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
# 返回总变差损失
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

@staticmethod
def tensor_size(t):
# 返回张量的尺寸大小,即通道数乘以高度乘以宽度
return t.size()[1] * t.size()[2] * t.size()[3]


if __name__ == "__main__":
g_loss = GeneratorLoss()
print(g_loss)

五.模型训练

1.载入数据集与初始化网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
opt = parser.parse_args()#用户可以在命令行中指定一些参数(如裁剪大小、放大因子、训练轮数等),这些参数将被存储在opt对象中

CROP_SIZE = opt.crop_size
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs

#创建训练和验证数据集
train_set = TrainDatasetFromFolder('data/VOC2012/train', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
val_set = ValDatasetFromFolder('data/VOC2012/val', upscale_factor=UPSCALE_FACTOR)
#创建数据加载器
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

#初始化生成器(netG)和判别器(netD)网络,并打印生成器和判别器的参数数量
netG = Generator(UPSCALE_FACTOR)
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

generator_criterion = GeneratorLoss()#定义生成器的内容损失函数,此处会引入VGG16网络进行计算

if torch.cuda.is_available():
netG.cuda()
netD.cuda()
generator_criterion.cuda()

#初始化优化器
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

#初始化一个字典results,用于存储训练过程中的各种指标(如判别器和生成器的损失、评分、PSNR、SSIM等)。这些指标将用于评估训练过程中的模型性能。
results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

2.训练阶段

  • 生成器(Generator)和判别器(Discriminator)交替训练
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
for epoch in range(1, NUM_EPOCHS + 1):
train_bar = tqdm(train_loader)
running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

netG.train()
netD.train()
for data, target in train_bar:
g_update_first = True
batch_size = data.size(0)
running_results['batch_sizes'] += batch_size

#先训练判别器,输入就是真图片、假图片和它们对应的标签。
# (1) Update D network: maximize D(x)-1-D(G(z))
real_img = target #真实图片
if torch.cuda.is_available():
real_img = real_img.float().cuda()
z = data #低分辨率图片
if torch.cuda.is_available():
z = z.float().cuda()

fake_img = netG(z) #通过生成器生成高分辨率伪图片
optimizerD.zero_grad() #清除判别器的梯度
real_out = netD(real_img).mean() #通过判别器对真实图像进行前向传播,并计算其输出的平均值
fake_out = netD(fake_img).mean() #通过判别器对伪图像进行前向传播,并计算其输出的平均值
d_loss = 1 - real_out + fake_out #计算判别器的损失
d_loss.backward(retain_graph=True) #反向传播,计算判别器的梯度,并保留计算图以进行后续优化步骤
optimizerD.step() #对判别器网络梯度进行更新

#再训练生成器,在训练生成器的时候我们希望生成器可以生成极为真实的假图片。因此我们在训练生成器需要知道判别器认为什么图片是真图片
# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
optimizerG.zero_grad() #清除生成器的梯度
fake_img = netG(z) #通过生成器生成高分辨率伪图片
fake_out = netD(fake_img).mean() #通过判别器对伪图像进行前向传播,并计算其输出的平均值
g_loss = generator_criterion(fake_out, fake_img, real_img)# 计算生成器的损失,包括对抗损失、感知损失、图像损失和TV损失
g_loss.backward() #反向传播,计算生成器的梯度
optimizerG.step() #对生成器网络梯度进行更新

# loss for current batch before optimization
#累加当前批次生成器的损失值乘以批次大小,用于计算平均损失
running_results['g_loss'] += g_loss.item() * batch_size
#累加当前批次判别器的损失值乘以批次大小,用于计算平均损失
running_results['d_loss'] += d_loss.item() * batch_size
#累加当前批次真实图像在判别器的输出得分乘以批次大小,用于计算平均得分
running_results['d_score'] += real_out.item() * batch_size
#累加当前批次伪图像在判别器的输出得分乘以批次大小,用于计算平均得分
running_results['g_score'] += fake_out.item() * batch_size

#更新训练进度条的描述信息
train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
running_results['g_loss'] / running_results['batch_sizes'],
running_results['d_score'] / running_results['batch_sizes'],
running_results['g_score'] / running_results['batch_sizes']))

3.验证阶段

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
netG.eval() #生成器验证模式
out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/' #创建用于保存训练结果的目录
if not os.path.exists(out_path):
os.makedirs(out_path)

with torch.no_grad():
val_bar = tqdm(val_loader) #验证集进度条
valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
val_images = []
for val_lr, val_hr_restore, val_hr in val_bar: #遍历验证数据集(低分辨率图 恢复的高分辨率图 高分辨率图)
batch_size = val_lr.size(0)
valing_results['batch_sizes'] += batch_size
lr = val_lr
hr = val_hr
if torch.cuda.is_available():
lr = lr.float().cuda()
hr = hr.float().cuda()
sr = netG(lr) #生成超分辨率图像

batch_mse = ((sr - hr) ** 2).data.mean() #计算批量图像的均方误差,这里应该使用.mean()而不是.data.mean(),后者在PyTorch新版本中已不推荐
valing_results['mse'] += batch_mse * batch_size #累加均方误差
batch_ssim = pytorch_ssim.ssim(sr, hr).item() #计算批量图像的结构相似度指数
valing_results['ssims'] += batch_ssim * batch_size #累加结构相似度指数
#计算平均峰值信噪比
valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
#计算平均结构相似度指数
valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
#更新训练进度条的描述信息
val_bar.set_description(
desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
valing_results['psnr'], valing_results['ssim']))

#将验证图像添加到列表中,用于后续保存
val_images.extend(
[display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
display_transform()(sr.data.cpu().squeeze(0))])
val_images = torch.stack(val_images) #将验证图像列表堆叠为张量
val_images = torch.chunk(val_images, val_images.size(0) // 15) #将堆叠后的张量分割为多个小块,每个小块包含15张图像
val_save_bar = tqdm(val_images, desc='[saving training results]') #创建保存图像进度条,并设置描述为“[saving training results]”
index = 1
#遍历图像批次并保存
for image in val_save_bar:
image = utils.make_grid(image, nrow=3, padding=5) #将小块中的图像创建为一个网格,每行显示3张图像,图像之间有5个像素的间隔
utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
index += 1

4.保存文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#将判别器和生成器的参数保存到指定文件
torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
# save loss\scores\psnr\ssim
results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
results['psnr'].append(valing_results['psnr'])
results['ssim'].append(valing_results['ssim'])

if epoch % 10 == 0 and epoch != 0:
out_path = 'statistics/'
# 创建一个DataFrame对象,用于存储训练结果数据
data_frame = pd.DataFrame(
data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
index=range(1, epoch + 1))
# 将DataFrame对象保存为CSV文件
data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')

五.模型测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import os
from math import log10

import numpy as np
import pandas as pd
import torch
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm

import pytorch_ssim
from data_utils import TestDatasetFromFolder, display_transform
from model import Generator

parser = argparse.ArgumentParser(description='Test Benchmark Datasets')
parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
parser.add_argument('--model_name', default='netG_epoch_4_150.pth', type=str, help='generator model epoch name')
opt = parser.parse_args()

UPSCALE_FACTOR = opt.upscale_factor
MODEL_NAME = opt.model_name

# 保存每个测试数据集的结果
results = {'Set5': {'psnr': [], 'ssim': []}, 'Set14': {'psnr': [], 'ssim': []}, 'BSD100': {'psnr': [], 'ssim': []},
'Urban100': {'psnr': [], 'ssim': []}, 'SunHays80': {'psnr': [], 'ssim': []}}

# 创建一个 Generator 对象
model = Generator(UPSCALE_FACTOR).eval()
if torch.cuda.is_available():
model = model.cuda()
# 加载训练好的模型参数
model.load_state_dict(torch.load('epochs/' + MODEL_NAME))

# 加载测试数据集
test_set = TestDatasetFromFolder('data/test', upscale_factor=UPSCALE_FACTOR)
test_loader = DataLoader(dataset=test_set, num_workers=4, batch_size=1, shuffle=False)
# 创建一个用于 test_loader 的 tqdm 进度条
test_bar = tqdm(test_loader, desc='[testing benchmark datasets]')

# 测试结果输出路径
out_path = 'benchmark_results/SRF_' + str(UPSCALE_FACTOR) + '/'
if not os.path.exists(out_path):
os.makedirs(out_path)

for image_name, lr_image, hr_restore_img, hr_image in test_bar:
# 由于 image_name 是一个包含单个元素的列表,所以将其取出
image_name = image_name[0]
# 将 lr_image 转换为 Variable 对象,并设置 volatile=True
# volatile=True 表示不会计算梯度,这在推理阶段通常是需要的
lr_image = Variable(lr_image, volatile=True)
hr_image = Variable(hr_image, volatile=True)
if torch.cuda.is_available():
lr_image = lr_image.cuda()
hr_image = hr_image.cuda()

# 生成超分变率图像
sr_image = model(lr_image)

mse = ((hr_image - sr_image) ** 2).data.mean()
# 计算峰值信噪比(Peak Signal-to-Noise Ratio)
psnr = 10 * log10(255 ** 2 / mse)
# 计算结构相似性指数(Structural Similarity Index)
# 使用 pytorch_ssim 库中的 ssim 函数计算 SSIM
ssim = pytorch_ssim.ssim(sr_image, hr_image).data[0]

# 创建一个包含三张图像的张量,分别是原始恢复的高分辨率图像、原始高分辨率图像和生成的超分辨率图像
# 将每张图像应用 display_transform() 转换,并通过 squeeze(0) 去除批次维度
test_images = torch.stack(
[display_transform()(hr_restore_img.squeeze(0)), display_transform()(hr_image.data.cpu().squeeze(0)),
display_transform()(sr_image.data.cpu().squeeze(0))])

# 使用 make_grid 函数将三张图像拼接成一张大图像
# nrow=3 表示每行显示 3 张图像,padding=5 表示图像之间的间距为 5
image = utils.make_grid(test_images, nrow=3, padding=5)

# 使用 save_image 函数将合成的图像保存到指定路径
utils.save_image(image, out_path + image_name.split('.')[0] + '_psnr_%.4f_ssim_%.4f.' % (psnr, ssim) +
image_name.split('.')[-1], padding=5)

# 将对应数据集的PSNR和SSIM保存到对应的字典当中
results[image_name.split('_')[0]]['psnr'].append(psnr)
results[image_name.split('_')[0]]['ssim'].append(ssim)

# 最终结果保存路径
out_path = 'statistics/'
saved_results = {'psnr': [], 'ssim': []}

# 遍历 results 字典中的每个值
for item in results.values():
# 获取 PSNR 和 SSIM 的列表
psnr = np.array(item['psnr'])
ssim = np.array(item['ssim'])

# 如果列表为空,将 PSNR 和 SSIM 设置为 'No data'
if (len(psnr) == 0) or (len(ssim) == 0):
psnr = 'No data'
ssim = 'No data'
else:
# 如果列表不为空,计算 PSNR 和 SSIM 的均值
psnr = psnr.mean()
ssim = ssim.mean()

# 将计算得到的 PSNR 和 SSIM 添加到 saved_results 字典的相应列表中
saved_results['psnr'].append(psnr)
saved_results['ssim'].append(ssim)

# 创建一个 DataFrame 对象,使用 saved_results 字典作为数据,以 results.keys() 作为列标签
data_frame = pd.DataFrame(saved_results, results.keys())
# 将 DataFrame 对象保存为 CSV 文件
# 文件路径由 out_path、'srf_'、UPSCALE_FACTOR 值和 '_test_results.csv' 组成
# index_label='DataSet' 表示使用 'DataSet' 作为索引标签
data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_test_results.csv', index_label='DataSet')


-------------本文结束-------------