深度学习基础第三章-VGG网络
本模型存放于目录:
E:\python文件\deep-learning-for-image-processing-master\pytorch_classification\Test3_vggnet
一.模型介绍
特点:
- 通过堆叠多个3x3的卷积核来替代大尺度卷积核(减少所需参数)
- 论文中提到,可以通过堆叠两个3x3的卷积核替代5x5的卷积核,堆叠三个3x3的卷积核替代7x7的卷积核 (拥有相同的感受野)
二.数据集-花分类数据集
1.定义预处理函数
1 | data_transform = { #对训练集与测试集图片进行预处理 |
2.从磁盘中读取数据集
1 | data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path |
3.保存各类比的字典索引
1 | # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} |
4.加载训练集与测试集
1 | batch_size = 32 |
三.网络模型搭建
1.根据版本提供相应的网络结构
- 由于vgg网络有很多版本,因此通过字典保存相应不同的结构
1 | #字典文件保存各网络模型的配置文件 (特征提取部分) |
1 | def make_features(cfg: list): #根据字典的列表得到相应的网络模型结构 |
2.定义网络模型
1 | class VGG(nn.Module): |
3.实例化网络模型(使用vgg16)
1 | def vgg(model_name="vgg16", **kwargs): #实例化模型 |
四·训练模型
1 | model_name = "vgg16" #使用vgg16版本 |
五.测试模型效果
- 代码与AlnexNet部分一致
1 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |