LeNet5实现-pytorch
完整代码:PyNet/pytorch/lenet5_test.py
加载数据
pytorch提供模块torchvision,用于数据的加载、预处理和批量化
torchvision.datasets内置类MNIST用于mnist数据集下载和加载torchvision.transforms对数据进行预处理torchvision.DataLoader对数据进行批量化
1 | def load_mnist_data(batch_size=128, shuffle=False): |
网络定义
LeNet-5模型定义参考卷积神经网络推导-单张图片矩阵计算
torch.nn模块实现了网络层类,包括卷积层(Conv2d)、最大池化层(MaxPool2d)、全连接层(Linear)和其他激活层
torch.nn模块提供functional类用于网络层类的实现
1 | class LeNet5(nn.Module): |
训练
训练参数如下
- 学习率
lr = 1e-3 - 批量大小
batch_size = 128 - 迭代次数
epochs = 500
训练结果
训练时间
| CPU | GPU | 单次迭代时间 |
|---|---|---|
| 8核 Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz | GeForce 940MX | 约13秒 |
迭代500次训练结果
| 训练集精度 | 测试集精度 |
|---|---|
| 99.40% | 98.63% |

