超分辨率第六章-RCAN

超分辨率第六章-RCAN

RCAN发表于2018年,引入了注意力机制:Channel Attention (CA)

论文地址:Image Super-Resolution Using Very Deep Residual Channel Attention Networks

参考博客

代码位置:F:\Github下载\RCAN-pytorch-master

一.模型介绍

1.模型特点

(1) 使用了RIR结构

  • 仅仅通过叠加残差块来构建更深的网络很难获得更好的提升效果

  • 残差嵌套(residual in residual,RIR)结构构造非常深的可训练网络,RIR中的长跳连接和短跳连接有助于绕过大量的低频信息,使主网络学习到更有效的信息。

  • LSC和SSC可以直接将低层得到的低频的信息直接跨层和高层的特征信息融合,并迫使网络集中学习残差信息获取更丰富的高层特征信息

2.引入了通道注意力机制(CA)

  • 低分辨率图像(LR)的输入和特征包含大量的低频信息,这些信息在通道间被平等对待,从而阻碍了模型的表征能力。
  • 因此RCAN引入CA (通道注意力) 机制来解决此问题
    • 不同的通道可能对不同的特征有不同的贡献,有些通道可能包含更多的关键信息,而其他通道则可能包含噪声或冗余信息。
    • 一般来说,通道注意力机制通过对每层特征图全局信息的学习来为每个通道赋予不同的权重,达到加强有用的特征,抑制无用特征的效果
    • 通过对不同通道施加不同的权重来提供不同重要程度的特征信息,因为不同通道的特征信息是有好有坏的,有的对图像的超分具有提升作用,而有的会损坏重建的质量,因此我们通过这个权重来让网络更加注重那些有用的特征信息,从而更好的提升表现力。

2.模型架构(G = 10 , B = 20 )

  • 分为四个部分:浅层特征提取、残差嵌套(RIR)深度特征提取、上采样模块、重建部分

  • 整个网络组成:Conv(浅层特征提取层) + RIR(深度特征提取层)+ 亚像素卷积层(上采样模块) + Conv(重建层)

    • RIR(residual in residual,RIR):G个RG(带长跳连接)
  • 每个RG(residual groups):B个RCAB组成(带短跳连接)
  • 每个RCAB(Residual Channel Attention Block):Conv + ReLU + Conv + CA (带短跳连接)
  • CA(Channel Attention):Global pooling(全局池化) + (下采样,特征转化) + (上采样,权重计算)(带元素相乘的门控)
  • 整体架构

    • pAlfFd1.png
  • RCAB架构

    • pAlfQeA.png
  • CA架构

    • pAlfsYV.png
    • pAlhStP.md.png

3.CA(注意力机制)

  • 产生不同通道注意力需要一个反应不同通道重要程度的常数权重$w$,并将权重$w$和通道特征信息相结合的机制。
    • 全局池化
      • 通过全局平均池化来产生产生该常数,即将不同通道的feature map映射成一个常数,如果这个特征图高频部分多,那个全局平均池化得来的值
        $y_i$值就越大,不同大小的$y_i$与x相乘,就赋予了x不同的权重,即实现了对C个通道不同的注意($y_i$值越大,注意力越高)
      • 设输入feature map为:,其中是第c个输入通道数,C为输入通道总数。池化的结果为,即一个一维张量,长度为C,数学表达式为:
    • 特征转换与权重计算
      • 为了限制模型复杂度和辅助泛化,通过在非线性周围形成两个卷积层的瓶颈来参数化门机制。然后经过sigmoid为每个通道学习特定采样的激活,控制每个通道的激励。
      • $W_D$是将全局平均池化的结果进行通道降维,通过卷积的方式使得输出通道变为原来的$ \frac { 1 } { r }$,在通过ReLU激活函数保留非线性关系。
      • $W_U$是进行通道升维,通过卷积的方式将输出通道变为输入通道数的r倍,这里r不是SR缩放因子,实验中取r=16.
      • 经过sigmoid为每个通道学习特定采样的激活,控制每个通道的激励
        • δ()、f()分别表示ReLU激活函数和sigmoid门。
    • 相乘
      • 最后将激活后的非线性函数(c*1*1)按不同通道将权值和feature map进行相乘结合得到(c*h*w),输入特征与注意力权重相乘,得到重加权后的特征表示。这样,重要的通道会被放大,而不重要的通道则会减弱,从而更好地聚焦于重要的特征信息
  • 结构
    • pAlf4T1.png

二.数据集

  • 使用DIV2K作为训练集,set5作为测试集

三.模型搭建

  • 除了CA使用1×1卷积外,其余均使用3×3卷积核
  • 除了CA使用、r(r=16)的通道以外,其余均使用C=64。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
class ChannelAttention(nn.Module):
def __init__(self, num_features, reduction):
super(ChannelAttention, self).__init__()
self.module = nn.Sequential(
nn.AdaptiveAvgPool2d(1), #将输入特征图进行自适应平均池化到 C*1*1 大小,用于提取全局信息
nn.Conv2d(num_features, num_features // reduction, kernel_size=1), #减少通道数
nn.ReLU(inplace=True),
nn.Conv2d(num_features // reduction, num_features, kernel_size=1), #复原通道数
nn.Sigmoid() #门控激活,得到注意力权重
)

def forward(self, x):
return x * self.module(x) #输入特征与注意力权重相乘,得到重加权后的特征表示。
#这样,重要的通道会被放大,而不重要的通道则会减弱,从而更好地聚焦于重要的特征信息

2.RCAB (残差通道注意块)

1
2
3
4
5
6
7
8
9
10
11
12
class RCAB(nn.Module):
def __init__(self, num_features, reduction):
super(RCAB, self).__init__()
self.module = nn.Sequential(
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(num_features, num_features, kernel_size=3, padding=1),
ChannelAttention(num_features, reduction)
)

def forward(self, x):
return x + self.module(x) #通过短跳连接实现残差学习

3.RG (残差组)

1
2
3
4
5
6
7
8
9
class RG(nn.Module):
def __init__(self, num_features, num_rcab, reduction):
super(RG, self).__init__()
self.module = [RCAB(num_features, reduction) for _ in range(num_rcab)] #一个残差组中共有20个残差块
self.module.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)) #在最后追加一个卷积层
self.module = nn.Sequential(*self.module)

def forward(self, x):
return x + self.module(x) #通过短跳连接实现残差学习

4.RCAN主结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class RCAN(nn.Module):
def __init__(self, args):
super(RCAN, self).__init__()
scale = args.scale #放大倍数
num_features = args.num_features #卷积核大小为3
num_rg = args.num_rg #残差组数量为10
num_rcab = args.num_rcab #每个残差组所含残差块数量为20
reduction = args.reduction #缩放维度倍数为16

self.sf = nn.Conv2d(3, num_features, kernel_size=3, padding=1) #低层特征提取层
self.rgs = nn.Sequential(*[RG(num_features, num_rcab, reduction) for _ in range(num_rg)]) #遍历10个残差组
self.conv1 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1) #追加一个卷积层
self.upscale = nn.Sequential( #通过上采样和亚像素卷积层实现上采样
nn.Conv2d(num_features, num_features * (scale ** 2), kernel_size=3, padding=1),
nn.PixelShuffle(scale)
)
self.conv2 = nn.Conv2d(num_features, 3, kernel_size=3, padding=1) #重建层

def forward(self, x):
x = self.sf(x)
residual = x
x = self.rgs(x)
x = self.conv1(x)
x += residual #通过长跳连接实现残差学习
x = self.upscale(x)
x = self.conv2(x)
return x

四.模型训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# 将RCAN模型实例化并移动到指定的设备(CPU或GPU)  
model = RCAN(opt).to(device)

# 定义损失函数为L1损失,常用于图像重建任务
criterion = nn.L1Loss()

# 使用Adam优化器来优化模型参数,学习率通过opt.lr获取
optimizer = optim.Adam(model.parameters(), lr=opt.lr)

# 加载数据集,参数从opt中获取
dataset = Dataset(opt.images_dir, opt.patch_size, opt.scale, opt.use_fast_loader)

# 使用DataLoader来封装数据集,以便于批量处理和并行加载
dataloader = DataLoader(dataset=dataset,
batch_size=opt.batch_size,
shuffle=True, # 在每个epoch开始时打乱数据
num_workers=opt.threads, # 使用多个进程来加载数据
pin_memory=True, # 如果在GPU上,则锁定内存页,减少CPU到GPU的传输时间
drop_last=True) # 如果最后一个batch小于batch_size,则丢弃

# 遍历指定的训练轮次
for epoch in range(opt.num_epochs):
epoch_losses = AverageMeter() # 使用AverageMeter来记录平均损失

# 使用tqdm来显示进度条
with tqdm(total=(len(dataset) - len(dataset) % opt.batch_size)) as _tqdm:
_tqdm.set_description('epoch: {}/{}'.format(epoch + 1, opt.num_epochs))

# 遍历数据加载器中的每个batch
for data in dataloader:
inputs, labels = data # 解包每个batch的数据和标签
inputs = inputs.to(device) # 将数据和标签移动到指定设备
labels = labels.to(device)

# 前向传播
preds = model(inputs)

# 计算损失
loss = criterion(preds, labels)

# 更新损失的平均值
epoch_losses.update(loss.item(), len(inputs))

# 反向传播和优化
optimizer.zero_grad() # 清除过往梯度
loss.backward() # 反向传播,计算当前梯度
optimizer.step() # 根据梯度更新网络参数

# 更新进度条上的信息
_tqdm.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
_tqdm.update(len(inputs))

# 每个epoch结束后,保存模型参数
torch.save(model.state_dict(), os.path.join(opt.outputs_dir, '{}_epoch_{}.pth'.format(opt.arch, epoch)))

五.模型测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# 1. 加载预训练的RCAN模型  
model = RCAN(opt) # 使用配置选项opt初始化RCAN模型

# 2. 加载模型的权重
state_dict = model.state_dict() # 获取模型当前的权重字典
# 加载预训练权重,并仅更新模型中存在的权重
for n, p in torch.load(opt.weights_path, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n) # 如果预训练权重中的某个键在模型权重字典中不存在,则抛出错误

# 3. 将模型移动到指定的设备上(如CPU或GPU)
model = model.to(device)
# 设置模型为评估模式
model.eval()

# 4. 处理输入图像
filename = os.path.basename(opt.image_path).split('.')[0] # 从文件路径中提取文件名(不含扩展名)

# 使用PIL库打开并转换图像为RGB格式
input = pil_image.open(opt.image_path).convert('RGB')

# 使用双三次插值降采样图像,作为模型的低分辨率输入
lr = input.resize((input.width // opt.scale, input.height // opt.scale), pil_image.BICUBIC)

# 使用双三次插值将降采样后的图像上采样回原始尺寸,作为双三次插值的结果
bicubic = lr.resize((input.width, input.height), pil_image.BICUBIC)
bicubic.save(os.path.join(opt.outputs_dir, '{}_x{}_bicubic.png'.format(filename, opt.scale))) # 保存双三次插值结果

# 5. 准备模型输入
# 将低分辨率图像转换为Tensor,并添加批次维度,然后移动到指定设备上
input = transforms.ToTensor()(lr).unsqueeze(0).to(device)

# 6. 使用模型进行超分辨率重建
with torch.no_grad(): # 不计算梯度,以节省内存和计算资源
pred = model(input) # 进行前向传播,得到超分辨率图像

# 7. 处理模型输出
# 将输出Tensor转换为numpy数组,并进行必要的转换以匹配PIL图像的格式
output = pred.mul_(255.0).clamp_(0.0, 255.0).squeeze(0).permute(1, 2, 0).byte().cpu().numpy()
# 将numpy数组转换为PIL图像
output = pil_image.fromarray(output, mode='RGB')
# 保存超分辨率图像
output.save(os.path.join(opt.outputs_dir, '{}_x{}_{}.png'.format(filename, opt.scale, opt.arch)))
-------------本文结束-------------