超分辨率第七章-SRFBN

超分辨率第七章-SRFBN

SRFBN发表于2019年,引入了反馈网络机制,不会增加额外的参数,并且多次回传相当于加深了网络。

论文地址:Feedback Network for Image Super-Resoluition

MRI论文:A trusted medical image super-resolution method based on feedback adaptive weighted dense network

参考博客:【CVPR2019】超分辨率文章,SRFBN: Feedback Network for Image Super-Resoluition

代码位置:F:\Github下载\SRFBN_CVPR19-master

一.模型介绍

之前基于RNN的模型,如DRCN、DRRN使用前馈方式,先前的层无法从之后的层得到有效信息。

由于跳跃连接是从浅层到深层的路径,因此底层特征只能接受先前层的信息,感受野较小,影响了重建的质量

1.前置模型:DRCN

  • DRCN是一种基于递归结构,使用递归监督和跳跃连接的SISR深度网络模型

  • 递归监督(每次递归后都与重建层直接连接),用每次递归的输出直接与重建层相连,最终得到D个重建SR图像,再对这D幅SR图像进行加权求和。

    • 对于每一次递归,都输出到重建层,作为总和Loss的一部分,也就是说每一次递归都通过监督学习来学习参数,故称之为递归监督。对于D次的递归,最终的Loss由D个小loss通过加权平均得到,当反向传播的时候,每一次递归都会获取属于自己的那一部分梯度,这样就算来自深层递归的梯度消失了,自己的那份也可以用来训练更新参数。
  • 跳层连接(共享低频信息):将原始图像直接连接至重建层,保留低分辨率图像原始特征。

    • pAa20Xt.png

2.关于RNN、LSTM

  • RNN:参考博客

    • pAasUEQ.png

    • s是隐藏层的值,W是每个时间点之间的权重矩阵,我们注意到,RNN之所以可以解决序列问题,是因为它可以记住每一时刻的信息(即隐藏层的值s)。

    • 每一时刻的隐藏层不仅由该时刻的输入层决定,还由上一时刻的隐藏层决定,公式如下,其中 $O_t$ 代表$t$时刻的输出, $S_t$ 代表$t$时刻的隐藏层的值。

    • $ S _ { t } = f ( U . X _ { t } + W . S _ { t - 1 } ) $

    • $ O _ { t } = g ( V . S _ { t } ) $

    • 在整个训练过程中,每一时刻所用的都是同样的W,每一时刻的输出结果都与上一时刻的输入有着非常大的关系,如果我们将输入序列换个顺序,那么我们得到的结果也将是截然不同,这就是RNN的特性,可以处理序列数据,同时对序列也很敏感。

    • 代码

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      # 超参数  
      input_size = 10 # 输入特征维度
      hidden_size = 20 # 隐藏层神经元数量
      num_layers = 2 # RNN隐藏层数
      num_classes = 3 # 输出类别数
      seq_len = 5 # 序列长度
      batch_size = 32 # 批次大小
      num_epochs = 10 # 训练轮数
      learning_rate = 0.001 # 学习率

      x_train = torch.randn(100, seq_len, input_size) # 训练数据
      y_train = torch.randint(0, num_classes, (100,)) # 类别标签

      # 创建数据加载器
      train_dataset = TensorDataset(x_train, y_train)
      train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

      # 定义RNN模型
      class RNNClassifier(nn.Module):
      def __init__(self, input_size, hidden_size, num_layers, num_classes):
      super(RNNClassifier, self).__init__()
      self.hidden_size = hidden_size
      self.num_layers = num_layers
      self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
      self.fc = nn.Linear(hidden_size, num_classes)

      def forward(self, x):
      # 初始化隐藏层状态
      h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

      # 前向传播RNN
      out, _ = self.rnn(x, h0)

      # 取最后一个时间步的输出
      out = out[:, -1, :] #从out中提取每个样本的最后一个时间步的所有特征,从而得到一个形状为 (batch_size, features) 的二维张量。

      # 通过全连接层得到最终的输出
      out = self.fc(out)
      return out

      # 实例化模型、损失函数和优化器
      model = RNNClassifier(input_size, hidden_size, num_layers, num_classes).to('cuda' if torch.cuda.is_available() else 'cpu')
      criterion = nn.CrossEntropyLoss()
      optimizer = optim.Adam(model.parameters(), lr=learning_rate)

      # 训练循环
      for epoch in range(num_epochs):
      for batch_x, batch_y in train_loader:
      batch_x, batch_y = batch_x.to('cuda' if torch.cuda.is_available() else 'cpu'), batch_y.to('cuda' if torch.cuda.is_available() else 'cpu')

      # 前向传播
      outputs = model(batch_x)
      loss = criterion(outputs, batch_y)

      # 反向传播和优化
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

      print("训练完成!")
  • LSTM(Long short-term memory,长短期记忆):参考博客

    • LSTM主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题

    • 普通RNN存储所有序列信息,LSTM通过门控装置选择性的存储信息(门是用来控制每一时刻信息的记忆与遗忘)

    • 相比于原始的RNN的隐层(hidden state), LSTM增加了一个状态Ct;sigmoid激活函数会将值变为【0,1】,tanh函数将值变为【-1,1】之间。

    • 整体结构

      • pAagZRA.png
    • 单独结构

      • pAagkIe.md.png
    • 门控系统

      • 遗忘门:
        • $ f _ { t } = \sigma ( W _ { f } \cdot \left[ h _ { t - 1 } , x _ { t } \right] + b _ { f } )$
        • 遗忘门作用于LSTM的状态向量C上,用于控制上一个时间戳的记忆对当前时间戳的影响。它通过对之前的记忆状态进行加权选择来控制遗忘程度,这个加权选择的过程是通过一个线性变换后经过sigmoid激活函数来实现的,因此遗忘门的输出是一个介于0和1之间的数值,表示遗忘的程度。
        • 遗忘门决定了记忆单元中哪些信息应该被遗忘。当遗忘门的输出为0时,表示完全遗忘之前的信息;当输出为1时,表示完全保留之前的信息;当输出为0到1之间的数值时,表示部分遗忘之前的信息。这有助于LSTM在处理序列数据时,选择性地遗忘那些不重要或冗余的信息。
      • 输入门:
        • $ i _ { t } = \sigma \left( W _ { i } \cdot \left[ h _ { t - 1 } , x _ { t } \right] + b _ { i } \right)$
        • $ \tilde { C } _ { t } = \tan h ( W _ { C } \cdot \left[ h _ { t - 1 } , x _ { t } \right] + b _ { C } )$
        • $ C _ { t } = f _ { t } \ast C _ { t - 1 } + i _ { t } \ast \tilde { C } _ { t }$
        • 输入门用于控制新的输入信息对记忆单元的影响。它同样是通过一个线性变换后经过激活函数(sigmoid和tanh)来实现的。sigmoid函数的输出决定了哪些新信息应该被加入到记忆中,而tanh函数的输出则是对新信息的候选值进行缩放。
        • 输入门的作用是决定多少新信息应该进入记忆单元。当输入门的输出为0时,表示不接受新的输入信息;当输出为1时,表示完全接受新的输入信息。输入门则更侧重于控制当前时刻的输入和前一时刻的隐藏状态如何结合来更新记忆单元的状态
      • 输出门:
        • $ o _ { t } = \sigma ( W _ { o } \left[ h _ { t - 1 } , x _ { t } \right] + b _ { o } )$
        • $ h _ { t } = o _ { t } \ast \tan h ( C _ { t } )$
        • 输出门控制着从记忆单元中输出信息的数量。它的输入包括当前时刻的输入、上一个时刻的隐藏状态以及当前时刻的记忆单元状态,输出则是一个介于0和1之间的数值,表示输出的信息量。这个输出过程同样是通过一个线性变换后经过sigmoid激活函数来实现的。
        • 输出门的作用是控制哪些记忆状态应该被输出到网络的其他部分。当输出门的输出为0时,表示不输出任何信息;当输出为1时,表示完全输出当前时刻的记忆单元状态。这有助于LSTM在不同情境下灵活地应用记忆,从而产生适当的输出。
    • 代码

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      class LSTM(nn.Module):  
      def __init__(self, input_size=1, hidden_size=50, num_layers=1):
      super(LSTM, self).__init__()
      self.hidden_size = hidden_size
      self.num_layers = num_layers
      self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
      self.fc = nn.Linear(hidden_size, 1)

      def forward(self, x):
      h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
      c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

      out, _ = self.lstm(x, (h0, c0))
      out = self.fc(out[:, -1, :])
      return out

3.模型特点

  • 采取了具有反馈块(FB)的RNN结构,反馈块由多组上采样层和下采样层构成,具有密集的跳跃连接,通过上采样和下采样,网络能够在不同的尺度上捕捉图像的特征。上采样有助于恢复图像的细节信息,而下采样则有助于提取更高层次的语义信息。这种多尺度处理可以帮助网络更好地理解图像内容

  • 反馈块(FB)接收输入$F_{IN}$和上一次迭代$F^{t−1}_{out}$的隐藏层的信息,然后将其隐藏状态$F^t_{out}$传递到下一次迭代和输出。

  • 该模型可以使高层的信息自上而下流经反馈连接,以使用更多上下文信息来纠正低级特征。

  • 在SRFBN中,有三个不可缺少的部分:

    • 每次的迭代都会计算loss,迫使网络每次迭代都重建图像,将高层特征信息传入。
    • 使用recurrent结构,从而达到迭代的目的
    • 在每次迭代中都提供LR图像的输入(和上一轮的输出做一个concat)

    • pAaoiZR.md.png

4.模型架构

  • 每一次迭代由浅层特征提取层、反馈层、重建层以及一个将原始图像进行上采样的连接组成。经过t次迭代之后得到$ ( I _ { S R } ^ { 1 } , I _ { S R } ^ { 2 } , \ldots , I _ { S R } ^ { T } )$的高分辨率图片集合
  • pAaI0PK.png

  • 浅层特征提取层

    • 经过Conv(3, 4m) and Conv(1, m)
    • $ F _ { i n } ^ { t } = f _ { L R F B } ( I _ { L R } )$
  • 反馈层:$ F _ { o u t } ^ { t } = f _ { F B } ( F _ { o u t } ^ { t - 1 } , F _ { i n } ^ { t } )$
    • t次迭代时,通过反馈连接接受$t-1$次的隐藏层信息,将其与本次迭代的输入信息在特征维度上进行拼接。
    • $F^t_{in}$ and $F^{t−1}_{out}$通过Conv(1, m)连接并压缩,减少输入特征的维度
      • $ L _ { 0 } ^ { t } = C _ { 0 } ( \left[ F _ { o u t } ^ { t - 1 } , F _ { i n } ^ { t } \right] ) ,$
    • 之后交替式地进行上采样和下采样,并通过稠密连接进行连接。(高尺度连接到高尺度、低尺度连接到低尺度)
      • 这样做的目的提取跨尺度的特征,提取图片内部特征的先验信息。
      • 使用反卷积Deconv(k, m)进行上采样,$ H _ { g } ^ { t } = C _ { g } ^ { t } ( \left[ L _ { 0 } ^ { t } , L _ { 1 } ^ { t } , . . . , L _ { g - 1 } ^ { t } \right] ) ,$
      • 使用Conv(k, m)进行下采样,$ L _ { g } ^ { t } = C _ { g } ^ { \downarrow } ( \left[ H _ { 1 } ^ { t } , H _ { 2 } ^ { t } , \ldots , H _ { g } ^ { t } \right] ) ,$
      • 除了第一个组,在上采样层和下采样层之前使用Conv(1, m),提高计算效率
    • pAa7k8K.png
  • 重建层
    • 经过反卷积Deconv(k, m)放大,和Deconv(k, m)重建,并与上采样后的原始图像结合
    • $ I _ { R e s } ^ { t } = f _ { R B } ( F _ { o u t } ^ { t } ) $

二.数据集

三.模型搭建

1.基本结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class SRFBN(nn.Module):  #num_steps:迭代次数、num_groups:反馈块中的上下采样层组的个数
# 初始化方法,用于设置SRFBN模型的基本参数和构建模型结构,参数设置部分省略
def __init__(self, in_channels, out_channels, num_features, num_steps, num_groups, upscale_factor, act_type = 'prelu', norm_type = None):
super(SRFBN, self).__init__() # 调用父类的初始化方法
#浅层特征提取块
self.conv_in = ConvBlock(in_channels, 4*num_features, kernel_size=3, act_type=act_type, norm_type=norm_type)
self.feat_in = ConvBlock(4*num_features, num_features, kernel_size=1, act_type=act_type, norm_type=norm_type)
# 反馈块
self.block = FeedbackBlock(num_features, num_groups, upscale_factor, act_type, norm_type)
# 重建块
self.out = DeconvBlock(num_features, num_features, kernel_size=kernel_size, stride=stride, padding=padding, act_type='prelu', norm_type=norm_type)
self.conv_out = ConvBlock(num_features, out_channels, kernel_size=3, act_type=None, norm_type=norm_type)

# 前向传播
def forward(self, x):
self._reset_state() #重置状态
# 对输入数据进行预处理,减去均值
x = self.sub_mean(x)
# 对输入数据进行上采样
inter_res = nn.functional.interpolate(x, scale_factor=self.upscale_factor, mode='bilinear', align_corners=False)
# 通过低分辨率特征提取块
x = self.conv_in(x)
x = self.feat_in(x)
# 存储每一步的输出
outs = []
for _ in range(self.num_steps): #开始迭代过程
# 通过反馈块,得到当前隐藏层的状态
h = self.block(x)
# 通过重建块,并将当前结果与上采样得出的结果进行结合
h = torch.add(inter_res, self.conv_out(self.out(h)))
# 对结果添加均值
h = self.add_mean(h)
# 将结果添加到输出列表中
outs.append(h)

# 返回每一步的输出结果
return outs

def _reset_state(self): # 重置状态的方法,通常用于FeedbackBlock中重置隐藏状态或计数器
self.block.reset_state()

2.反馈层结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class FeedbackBlock(nn.Module):  
# 初始化函数,设置反馈块的各种参数
def __init__(self, num_features, num_groups, upscale_factor, act_type, norm_type):
super(FeedbackBlock, self).__init__() # 调用父类的初始化函数
# 存储分组数量
self.num_groups = num_groups

# 定义压缩卷积块,用于减少输入特征的维度
self.compress_in = ConvBlock(2*num_features, num_features, kernel_size=1, act_type=act_type, norm_type=norm_type)

# 初始化存储上采样、下采样、上采样转换、下采样转换块的列表
self.upBlocks = nn.ModuleList()
self.downBlocks = nn.ModuleList()
self.uptranBlocks = nn.ModuleList()
self.downtranBlocks = nn.ModuleList()

# 循环创建指定数量的上采样、下采样、上采样转换、下采样转换块
for idx in range(self.num_groups):
self.upBlocks.append(DeconvBlock(num_features, num_features, kernel_size=kernel_size, stride=stride, padding=padding, act_type=act_type, norm_type=norm_type))
self.downBlocks.append(ConvBlock(num_features, num_features, kernel_size=kernel_size, stride=stride, padding=padding, act_type=act_type, norm_type=norm_type, valid_padding=False))
if idx > 0: # 如果不是第一个分组,则需要转换块
self.uptranBlocks.append(ConvBlock(num_features*(idx+1), num_features, kernel_size=1, stride=1, act_type=act_type, norm_type=norm_type))
self.downtranBlocks.append(ConvBlock(num_features*(idx+1), num_features, kernel_size=1, stride=1, act_type=act_type, norm_type=norm_type))

# 输出压缩卷积块,用于将所有分组的输出合并并减少维度
self.compress_out = ConvBlock(num_groups*num_features, num_features, kernel_size=1, act_type=act_type, norm_type=norm_type)

# 状态重置标志和上一次隐藏状态
self.should_reset = True
self.last_hidden = None

# 前向传播函数
def forward(self, x):
# 如果是第一次运行,则重置隐藏状态并初始化为输入
if self.should_reset:
self.last_hidden = torch.zeros(x.size()).cuda()
self.last_hidden.copy_(x)
self.should_reset = False

# 将当前输入和上一次隐藏状态在特征维度上拼接!
x = torch.cat((x, self.last_hidden), dim=1)
# 通过输入压缩卷积块
x = self.compress_in(x)

# 初始化低分辨率和高分辨率特征列表
lr_features = []
hr_features = []
# 将输入特征添加到低分辨率特征列表
lr_features.append(x)

# 循环处理每个分组
for idx in range(self.num_groups):
# 将当前低分辨率特征列表中的所有特征在特征维度上拼接
LD_L = torch.cat(tuple(lr_features), 1)
# 如果不是第一个分组,则通过上采样转换块
if idx > 0:
LD_L = self.uptranBlocks[idx-1](LD_L)
# 通过上采样块
LD_H = self.upBlocks[idx](LD_L)
# 将上采样后的特征添加到高分辨率特征列表
hr_features.append(LD_H)

# 将当前高分辨率特征列表中的所有特征在特征维度上拼接
LD_H = torch.cat(tuple(hr_features), 1)
# 如果不是第一个分组,则通过下采样转换块
if idx > 0:
LD_H = self.downtranBlocks[idx-1](LD_H)
# 通过下采样块
LD_L = self.downBlocks[idx](LD_H)
# 将下采样后的特征添加到低分辨率特征列表
lr_features.append(LD_L)

# 删除高分辨率特征列表以节省内存
del hr_features
# 将除输入外的所有低分辨率特征在特征维度上拼接,并通过输出压缩卷积块
output = torch.cat(tuple(lr_features[1:]), 1)
output = self.compress_out(output)

# 更新隐藏状态
self.last_hidden = output

# 返回输出
return output

# 重置状态函数
def reset_state(self):
self.should_reset = True
-------------本文结束-------------