超分辨率第五章-SRDenseNet

超分辨率第五章-SRDenseNet

SRDenseNet发表于2017年

论文:Super-Resolution Using Dense Skip Connections

参考博客:超分之SRDenseNet

代码位置:F:\Github下载\SRDenseNet-pytorch-master

一.模型介绍

1.模型特点

  • 将稠密块作为SRDenseNet的基本结构。
  • Skip connection将各个level的特征直接与图像重建输入端相连,此外Dense skip connection可以缓解网络深度带来的梯度消失问题。

2.Dense块

  • Dense块各个层之间的skip connection是通过concat在一起的,而Resnet块是通过求和加在一起的。这样的好处在于可以缓解梯度消失的问题以及加强了信息在各个layer之间的流动

  • 整个块分为8个layers,第 i 个layer产生16 × i 张feature map,最后一层产生128张特征图作为块的输出,一共有八个Dense块。

    pAQ6AAS.png


  • pAQ6F78.png

3.模型架构

  • 前向传播过程
    • 首先,输入的低分辨率图像经过第一层CNN提取低层的特征信息。
    • 然后经过8个Dense块提取高层特征信息,通过skip connection的方式将各个level的特征信息相连,滤波器的大小统一设置成3×3。
    • 再经过一层Bottleneck layer,降低前面特征图连接导致图像张数(通道数)太多而带来的高计算复杂度问题,通过1 × 1 卷积层进行缩减。
    • 之后经过2个反卷积子网络,进行上采样。
    • 最后使用3×3卷积核以及输出通道为1的卷积层进行重建。
  • 图示
    • pAQrhHU.png

二.数据集

训练集使用coco2017,验证集使用set5

三.模型架构

1.定义特征提取层

1
2
3
4
5
6
7
8
class ConvLayer(nn.Module):  
def __init__(self, in_channels, out_channels, kernel_size):
super(ConvLayer, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
return self.relu(self.conv(x)) #卷积后应用ReLU激活

2.定义dense块结构

  • 定义一个dense块中单独的层,每个层均为33\16结构
1
2
3
4
5
6
7
8
class DenseLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(DenseLayer, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
return torch.cat([x, self.relu(self.conv(x))], 1) #将输入x与卷积后的结果沿通道维拼接
  • 定义dense块
1
2
3
4
5
6
7
8
9
class DenseBlock(nn.Module):  
def __init__(self, in_channels, growth_rate, num_layers):
self.block = [ConvLayer(in_channels, growth_rate, kernel_size=3)] #添加第一个DenseLayer层
for i in range(num_layers - 1): #每一层的输入通道数是之前所有层输出通道数的累加
self.block.append(DenseLayer(growth_rate * (i + 1), growth_rate, kernel_size=3)) # 开始加入之后的dense块的层
self.block = nn.Sequential(*self.block) # 将所有层封装成一个Sequential模块

def forward(self, x):
return torch.cat([x, self.block(x)], 1) # 将输入x与DenseBlock的输出沿通道维拼接

3.整体架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class SRDenseNet(nn.Module):
def __init__(self, num_channels=1, growth_rate=16, num_blocks=8, num_layers=8):
super(SRDenseNet, self).__init__()

# 初始低维特征提取层,将维度变为16*8
self.conv = ConvLayer(num_channels, growth_rate * num_layers, 3)

# 高维特征提取层:8个dense块结构
self.dense_blocks = []
for i in range(num_blocks):
self.dense_blocks.append(DenseBlock(growth_rate * num_layers * (i + 1), growth_rate, num_layers))
self.dense_blocks = nn.Sequential(*self.dense_blocks)

# 瓶颈层降低通道数
self.bottleneck = nn.Sequential(
#8*16:8个dense块结束后的连接,8*8*16:8个dense块中的8层卷积层的连接
nn.Conv2d(growth_rate * num_layers + growth_rate * num_layers * num_blocks, 256, kernel_size=1),
nn.ReLU(inplace=True)
)

# 反卷积层
self.deconv = nn.Sequential(
nn.ConvTranspose2d(256, 256, kernel_size=3, stride=2, padding=3 // 2, output_padding=1),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 256, kernel_size=3, stride=2, padding=3 // 2, output_padding=1),
nn.ReLU(inplace=True)
)

# 重建层
self.reconstruction = nn.Conv2d(256, num_channels, kernel_size=3, padding=3 // 2)

self._initialize_weights()

def _initialize_weights(self): #使用Kaiming初始化方法初始化权重
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias.data)

def forward(self, x):
x = self.conv(x)
x = self.dense_blocks(x)
x = self.bottleneck(x)
x = self.deconv(x)
x = self.reconstruction(x)
return x

四.训练模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
model = SRDenseNet(growth_rate=args.growth_rate, num_blocks=args.num_blocks, num_layers=args.num_layers).to(device)

if args.weights_file is not None: ## 如果提供了权重文件,则加载预训练权重
state_dict = model.state_dict()
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)

train_dataset = TrainDataset(args.train_file, patch_size=args.patch_size, scale=args.scale)
train_dataloader = DataLoader(dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True)
eval_dataset = EvalDataset(args.eval_file)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)

best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0

for epoch in range(args.num_epochs):
for param_group in optimizer.param_groups:
param_group['lr'] = args.lr * (0.1 ** (epoch // int(args.num_epochs * 0.8)))

model.train()
epoch_losses = AverageMeter()

with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:
t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))

for data in train_dataloader:
inputs, labels = data

inputs = inputs.to(device)
labels = labels.to(device)

preds = model(inputs)

loss = criterion(preds, labels)

epoch_losses.update(loss.item(), len(inputs))

optimizer.zero_grad()
loss.backward()
optimizer.step()

t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
t.update(len(inputs))

torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))

model.eval()
epoch_psnr = AverageMeter()

for data in eval_dataloader:
inputs, labels = data

inputs = inputs.to(device)
labels = labels.to(device)

with torch.no_grad():
preds = model(inputs).clamp(0.0, 1.0)

preds = preds[:, :, args.scale:-args.scale, args.scale:-args.scale]
labels = labels[:, :, args.scale:-args.scale, args.scale:-args.scale]

epoch_psnr.update(calc_psnr(preds, labels), len(inputs))

print('eval psnr: {:.2f}'.format(epoch_psnr.avg))

if epoch_psnr.avg > best_psnr:
best_epoch = epoch
best_psnr = epoch_psnr.avg
best_weights = copy.deepcopy(model.state_dict())

print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

五.测试模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = SRDenseNet().to(device)

state_dict = model.state_dict()
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)

model.eval()

image = pil_image.open(args.image_file).convert('RGB')

image_width = (image.width // args.scale) * args.scale
image_height = (image.height // args.scale) * args.scale

hr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)
bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))

lr, _ = preprocess(lr, device)#分出y空间
hr, _ = preprocess(hr, device)
_, ycbcr = preprocess(bicubic, device)

with torch.no_grad():
preds = model(lr).clamp(0.0, 1.0)

psnr = calc_psnr(hr, preds)
print('PSNR: {:.2f}'.format(psnr))

preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
output = pil_image.fromarray(output)
output.save(args.image_file.replace('.', '_srdensenet_x{}.'.format(args.scale)))
-------------本文结束-------------