超分辨率第六章-RCAN
RCAN发表于2018年,引入了注意力机制:Channel Attention (CA)
论文地址:Image Super-Resolution Using Very Deep Residual Channel Attention Networks
参考博客
代码位置:F:\Github下载\RCAN-pytorch-master
一.模型介绍
1.模型特点
(1) 使用了RIR结构
仅仅通过叠加残差块来构建更深的网络很难获得更好的提升效果
残差嵌套(residual in residual,RIR)结构构造非常深的可训练网络,RIR中的长跳连接和短跳连接有助于绕过大量的低频信息,使主网络学习到更有效的信息。
- LSC和SSC可以直接将低层得到的低频的信息直接跨层和高层的特征信息融合,并迫使网络集中学习残差信息获取更丰富的高层特征信息
2.引入了通道注意力机制(CA)
- 低分辨率图像(LR)的输入和特征包含大量的低频信息,这些信息在通道间被平等对待,从而阻碍了模型的表征能力。
- 因此RCAN引入CA (通道注意力) 机制来解决此问题
- 不同的通道可能对不同的特征有不同的贡献,有些通道可能包含更多的关键信息,而其他通道则可能包含噪声或冗余信息。
- 一般来说,通道注意力机制通过对每层特征图全局信息的学习来为每个通道赋予不同的权重,达到加强有用的特征,抑制无用特征的效果。
- 通过对不同通道施加不同的权重来提供不同重要程度的特征信息,因为不同通道的特征信息是有好有坏的,有的对图像的超分具有提升作用,而有的会损坏重建的质量,因此我们通过这个权重来让网络更加注重那些有用的特征信息,从而更好的提升表现力。
2.模型架构(G = 10 , B = 20 )
分为四个部分:浅层特征提取、残差嵌套(RIR)深度特征提取、上采样模块、重建部分
整个网络组成:Conv(浅层特征提取层) + RIR(深度特征提取层)+ 亚像素卷积层(上采样模块) + Conv(重建层)
- RIR(residual in residual,RIR):G个RG(带长跳连接)
- 每个RG(residual groups):B个RCAB组成(带短跳连接)
- 每个RCAB(Residual Channel Attention Block):Conv + ReLU + Conv + CA (带短跳连接)
- CA(Channel Attention):Global pooling(全局池化) +
(下采样,特征转化) + (上采样,权重计算)(带元素相乘的门控)
整体架构
RCAB架构
CA架构
3.CA(注意力机制)
- 产生不同通道注意力需要一个反应不同通道重要程度的常数权重$w$,并将权重$w$和通道特征信息相结合的机制。
- 全局池化
- 通过全局平均池化来产生产生该常数,即将不同通道的feature map映射成一个常数,如果这个特征图高频部分多,那个全局平均池化得来的值
$y_i$值就越大,不同大小的$y_i$与x相乘,就赋予了x不同的权重,即实现了对C个通道不同的注意($y_i$值越大,注意力越高) - 设输入feature map为:,其中是第c个输入通道数,C为输入通道总数。池化的结果为,即一个一维张量,长度为C,数学表达式为:
- 通过全局平均池化来产生产生该常数,即将不同通道的feature map映射成一个常数,如果这个特征图高频部分多,那个全局平均池化得来的值
- 特征转换与权重计算
- 为了限制模型复杂度和辅助泛化,通过在非线性周围形成两个卷积层的瓶颈来参数化门机制。然后经过sigmoid为每个通道学习特定采样的激活,控制每个通道的激励。
- $W_D$是将全局平均池化的结果进行通道降维,通过卷积的方式使得输出通道变为原来的$ \frac { 1 } { r }$,在通过ReLU激活函数保留非线性关系。
- $W_U$是进行通道升维,通过卷积的方式将输出通道变为输入通道数的r倍,这里r不是SR缩放因子,实验中取r=16.
- 经过sigmoid为每个通道学习特定采样的激活,控制每个通道的激励。
- δ()、f()分别表示ReLU激活函数和sigmoid门。
- 相乘
- 最后将激活后的非线性函数(c*1*1)按不同通道将权值和feature map进行相乘结合得到(c*h*w),输入特征与注意力权重相乘,得到重加权后的特征表示。这样,重要的通道会被放大,而不重要的通道则会减弱,从而更好地聚焦于重要的特征信息
- 全局池化
- 结构
二.数据集
- 使用DIV2K作为训练集,set5作为测试集
三.模型搭建
- 除了CA使用1×1卷积外,其余均使用3×3卷积核
- 除了CA使用、r(r=16)的通道以外,其余均使用C=64。
1 | class ChannelAttention(nn.Module): |
2.RCAB (残差通道注意块)
1 | class RCAB(nn.Module): |
3.RG (残差组)
1 | class RG(nn.Module): |
4.RCAN主结构
1 | class RCAN(nn.Module): |
四.模型训练
1 | # 将RCAN模型实例化并移动到指定的设备(CPU或GPU) |
五.模型测试
1 | # 1. 加载预训练的RCAN模型 |