超分辨率第一章-SRCNN

超分辨率第一章-SRCNN

第一个超分辨率模型-SRCNN (SISR),2014年提出

参考网址:【超分辨率】【深度学习】SRCNN pytorch代码(附详细注释和数据集)_srcnn代码-CSDN博客

SRCNN超分辨率Pytorch实现,代码逐行讲解,附源码

模型位置:F:\Github下载\SRCNN_Pytorch_1.0-master

一.模型介绍

1
2
#运行程序
python train.py --train-file=data_set/train_set/91-image_x3.h5 --eval-file=data_set/eval_set/Set5_x3.h5 --outputs-dir=outputs

SRCNN(2014年Dong等人提出,前端上采样框架 )

  • 先将图片下采样预处理得到低分辨率图像
  • 再利用双三次插值法将图片放大到目标分辨率(基于插值的上采样方法)
  • 再用卷积核大小分别为 9×9、1×1、5×5的三个卷积层,分别进行特征提取,拟合 LR-HR 图像对之间的非线性映射以及将网络模型的输出结果进行重建,得到最后的高分辨率图像
  • 图示:pAZzAbQ.png

二.数据集

以img-91作为训练集,Set5作为测试集。

三.模型搭建

1
2
3
4
5
6
7
8
9
10
11
12
13
class SRCNN(nn.Module):
def __init__(self, num_channels=1):
super(SRCNN, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.conv3(x)
return x

四.模型训练

1.调用库

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import argparse #argparse用于编写用户友好的命令行接口。程序通过定义它期望从命令行接收的参数,然后 argparse 会自动从 sys.argv 解析出那些参数。这允许你的程序更加灵活和可配置
import os
import copy
import numpy as np
from torch import Tensor
import torch
from torch import nn
import torch.optim as optim
#cudnn 是 NVIDIA 提供的深度神经网络加速库(cuDNN)的 PyTorch 接口。它可以提高深度学习模型的计算速度和效率,特别是在使用 NVIDIA GPU 时。
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
#进度条
from tqdm import tqdm
from model import SRCNN
#从 datasets 模块中导入了 TrainDataset 和 EvalDataset 两个类。这两个类很可能分别用于加载训练数据集和评估数据集。
from datasets import TrainDataset, EvalDataset
# utils 模块中导入了 AverageMeter 和 calc_psnr 两个工具或函数。AverageMeter 可能是一个用于计算平均值的工具类,而 calc_psnr 函数则用于计算峰值信噪比(Peak Signal-to-Noise Ratio),这是一种常用的图像质量评估指标。
from utils import AverageMeter, calc_psnr

2.命令行参数设定

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#初始参数设定,argparse是Python标准库中的一个模块,用于编写用户友好的命令行接口。程序定义了它期望从命令行接收的参数,然后argparse会自动从sys.argv解析出那些参数。
parser = argparse.ArgumentParser() #parser = argparse.ArgumentParser() 创建了一个ArgumentParser对象。这个对象将包含将命令行解析成Python数据类型所需的全部信息。
#通过调用parser.add_argument()方法,可以向解析器添加命令行参数。每个add_argument()调用都指定了一个命令行选项(如--train-file),并可能包含一些额外的参数(如type=str,required=True等),这些参数定义了命令行选项应该如何被解析。
#--train-file, --eval-file, --outputs-dir:这些参数被标记为required=True,意味着它们在命令行中必须被提供。其他参数有默认值。
parser.add_argument('--train-file', type=str, required=True)
parser.add_argument('--eval-file', type=str, required=True)
parser.add_argument('--outputs-dir', type=str, required=True)
parser.add_argument('--scale', type=int, default=3) #图片上采样放大尺寸倍数
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--num-epochs', type=int, default=400)
parser.add_argument('--seed', type=int, default=123)
args = parser.parse_args() #解析命令行参数,并将结果存储在名为args的命名空间中。之后,你可以通过args.参数名的方式来访问这些参数的值。

#保存输出到相应目录下
args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))
if not os.path.exists(args.outputs_dir):
os.makedirs(args.outputs_dir)

3.加载数据集并进行预处理

1
2
3
4
5
6
7
8
9
10
11
12
#预处理训练集
train_dataset = TrainDataset(args.train_file)
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True)
#预处理验证集
eval_dataset = EvalDataset(args.eval_file)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

4.设置训练参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
cudnn.benchmark = True #开启cudnn的benchmark模式,用于加速计算。但请注意,这可能会导致每次运行程序时,前馈计算的结果有细微差异,因为cudnn会寻找最优的卷积算法。
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed) #随机数种子
model = SRCNN().to(device)
criterion = nn.MSELoss() #代价函数MSE

#定义优化器为Adam,并为模型的不同部分设置不同的学习率。这里,conv3层的学习率是conv1和conv2层学习率的十分之一。
optimizer = optim.Adam([
{'params': model.conv1.parameters()},
{'params': model.conv2.parameters()},
{'params': model.conv3.parameters(), 'lr': args.lr*0.1}
], lr=args.lr)

# 在训练开始前,复制当前模型的最佳权重(这里初始化为当前模型的权重)。这些权重将在验证过程中根据性能进行更新。
best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

5.模型训练与验证

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
for epoch in range(args.num_epochs):
#训练模式
model.train()
epoch_losses = AverageMeter() #初始化一个AverageMeter对象来跟踪当前epoch的损失平均值
#创建一个进度条,并设置进度条描述(当前轮/总批次)
with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
t.set_description('epoch:{}/{}'.format(epoch, args.num_epochs))
for data in train_dataloader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
preds = model(inputs)
loss = criterion(preds, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()

epoch_losses.update(loss.item(), len(inputs)) #更新当前的平均损失,并在进度条上显示
t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
t.update(len(inputs))

#验证模式
model.eval()
epoch_psnr = AverageMeter() #初始化一个AverageMeter对象来跟踪当前epoch的psnr平均值

for data in eval_dataloader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
preds = model(inputs).clamp(0.0, 1.0) #对预测值进行裁剪
epoch_psnr.update(calc_psnr(preds, labels), len(inputs)) #计算PSNR值并更新当前epoch的PSNR平均值。calc_psnr函数用于计算预测值和真实值之间的PSNR。
print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

#保存最优epoch的权重文件
if epoch_psnr.avg > best_psnr:
best_epoch = epoch
best_psnr = epoch_psnr.avg
best_weights = copy.deepcopy(model.state_dict())

print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

五.模型测试

1
2
#运行程序
python test.py --weights-file=outputs/x3/best.pth --image-file=data/car.bmp

1.命令行参数设定

1
2
3
4
5
parser = argparse.ArgumentParser()
parser.add_argument('--weights-file', type=str, required=True)
parser.add_argument('--image-file', type=str, required=True)
parser.add_argument('--scale', type=int, default=3)
args = parser.parse_args()

2.加载预训练权重

1
2
3
4
5
6
7
8
9
10
11
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SRCNN().to(device)

state_dict = model.state_dict() #加载模型的参数状态
#加载预训练权重,并映射到GPU
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():# 如果预训练权重中的键在模型的参数状态中存在
state_dict[n].copy_(p)# 则用预训练权重替换模型参数
else:
raise KeyError(n)

3.双三次插值(BICUBIC)调整图片尺寸

1
2
3
4
5
6
7
8
9
10
11
12
13
model.eval()
image = pil_image.open(args.image_file).convert('RGB')# 加载图像,并转换为RGB格式
# 根据args.scale调整图像尺寸,使其为scale的整数倍
image_width = (image.width // args.scale) * args.scale
image_height = (image.height // args.scale) * args.scale
# 调整至scale的倍数
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
# 将原始图片进行下采样,缩小到原始尺寸的1/scale,得到低分辨率图片
image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
# 将低分辨率图片进行上采样,放大到规定的尺寸
image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
# 保存处理后的图像
image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

4.调整色彩空间并进行超分辨率重建

  • 将调整后的图像从RGB色彩空间转换到YCbCr色彩空间,并仅对Y分量进行超分辨率重建(SRCNN等模型通常只处理亮度分量)
  • YCbCr是一种色彩空间,其中Y代表亮度分量(Luminance),Cb和Cr代表蓝色和红色的色度分量(Chrominance)。YCbCr色彩空间是YUV色彩空间的一种变种,广泛应用于视频压缩和图像处理中。
1
2
3
4
5
6
7
image = np.array(image).astype(np.float32)# 将图像转换为numpy数组,并转换为float32类型 
ycbcr = convert_rgb_to_ycbcr(image)# 将RGB图像转换为YCbCr色彩空间

y = ycbcr[..., 0] #提取Y分量
y /= 255. #归一化Y分量到[0, 1]
y = torch.from_numpy(y).to(device) #将numpy数组转换为torch张量,并移动到指定设备(如GPU)
y = y.unsqueeze(0).unsqueeze(0) #增加一个批次维度和一个通道维度

5.使用模型进行重建

1
2
3
4
5
6
7
8
9
10
11
12
13
14
with torch.no_grad():# 关闭梯度计算,进行前向传播
preds = model(y).clamp(0.0, 1.0) #使用模型进行预测,并限制输出值在[0, 1]之间

psnr = calc_psnr(y, preds)
print('PSNR: {:.2f}'.format(psnr))

preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)#将预测结果转换回uint8类型,并去除批次和通道维度

# 将预测的Y分量与原始的Cb、Cr分量合并,然后转换回RGB色彩空间
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)

output = pil_image.fromarray(output) #将numpy数组转换回PIL图像
output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale))) #保存最终的图像

6.结果

pAn4tSJ.png

六.以单幅低分辨率图像实现超分辨率

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#将单幅低分辨率图像以原尺寸规模进行超分辨率处理
import argparse
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image
from model import SRCNN
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights-file', type=str, required=True)
parser.add_argument('--image-file', type=str, required=True)
parser.add_argument('--scale', type=int, default=3)
args = parser.parse_args()

cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SRCNN().to(device)

state_dict = model.state_dict()
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)

model.eval()
image = pil_image.open(args.image_file).convert('RGB')
image_width = (image.width // args.scale) * args.scale
image_height = (image.height // args.scale) * args.scale
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
#不进行下采样得到低分辨率图像的操作,从而也不需要上采样恢复尺寸
# image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
# image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

image = np.array(image).astype(np.float32)
ycbcr = convert_rgb_to_ycbcr(image)

y = ycbcr[..., 0]
y /= 255.
y = torch.from_numpy(y).to(device)
y = y.unsqueeze(0).unsqueeze(0)

with torch.no_grad():
preds = model(y).clamp(0.0, 1.0)

# psnr = calc_psnr(y, preds)
# print('PSNR: {:.2f}'.format(psnr))

preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
output = pil_image.fromarray(output)
output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale)))

#python make.py --weights-file=outputs/x3/best.pth --image-file=data/test.bmp
  • 个人评价:该模型实现效果很差。若不提供原始高分辨率图像,几乎不能将低分辨率图像变为高分辨率图像。
-------------本文结束-------------