PyTorch实战(一)-- 从零构建CNN模型,精准识别MNIST手写数字

张开发
2026/6/8 0:25:25 15 分钟阅读
PyTorch实战(一)-- 从零构建CNN模型,精准识别MNIST手写数字
1. 认识MNIST数据集与CNN模型MNIST数据集堪称深度学习界的Hello World它包含6万张训练图片和1万张测试图片每张都是28x28像素的灰度手写数字。我第一次接触这个数据集时发现它特别适合练手——图片尺寸小、数据干净、分类明确。你可以把每张图片想象成28行28列的Excel表格每个单元格填着0-255的数字表示灰度值。卷积神经网络(CNN)天生适合处理这类图像数据。它通过卷积核自动提取特征就像我们用手指在图片上滑动感受纹理。举个例子识别数字7时CNN会先捕捉斜线特征再组合成完整数字。我建议新手从LeNet-5这种经典结构入手它只有5层网络却能实现99%以上的识别准确率。2. 搭建开发环境工欲善其事必先利其器推荐使用Anaconda创建Python3.8环境。我习惯用这个命令创建专属环境conda create -n pytorch_cnn python3.8 conda activate pytorch_cnn安装PyTorch时要注意版本匹配最近CUDA 11.7比较稳定pip install torch1.13.1cu117 torchvision0.14.1cu117 --extra-index-url https://download.pytorch.org/whl/cu117验证安装是否成功时我常犯的错误是忘记检查GPU驱动。建议运行以下测试代码import torch print(torch.__version__) print(torch.cuda.is_available()) # 期待看到True3. 数据预处理实战原始数据需要经过精心处理才能喂给模型。PyTorch的transforms模块提供了强大的处理管道transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])这里有两个关键点容易踩坑ToTensor()会自动将像素值从0-255缩放到0-1Normalize的参数是MNIST特有的均值0.1307和标准差0.3081加载数据时建议设置num_workers参数加速train_loader DataLoader(datasettrain_dataset, batch_size64, shuffleTrue, num_workers4)4. 构建CNN模型详解让我们拆解这个包含两个卷积层的网络结构class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 nn.Conv2d(1, 10, kernel_size5) self.conv2 nn.Conv2d(10, 20, kernel_size5) self.pool nn.MaxPool2d(2) self.fc nn.Linear(320, 10)这里有个计算技巧经过两次卷积和池化后特征图尺寸会从28x28变为4x4。具体计算过程是第一层(28-5)/1 1 24 - 池化后24/212第二层(12-5)/1 1 8 - 池化后8/24前向传播时要注意维度变换def forward(self, x): x F.relu(self.pool(self.conv1(x))) x F.relu(self.pool(self.conv2(x))) x x.view(-1, 320) # 关键展平操作 return self.fc(x)5. 模型训练技巧选择优化器时SGD配合momentum效果不错optimizer optim.SGD(model.parameters(), lr0.01, momentum0.9)训练循环中我习惯添加这几项改进学习率衰减scheduler optim.lr_scheduler.StepLR(optimizer, step_size5, gamma0.1)早停机制当验证集准确率连续3轮不提升时停止模型检查点保存最佳权重监控训练过程有个小技巧if batch_idx % 100 99: print(fEpoch: {epoch1}, Loss: {running_loss/100:.3f}) running_loss 0.06. 模型评估与调优测试时一定要设置torch.no_grad()with torch.no_grad(): for data in test_loader: images, labels data outputs model(images) _, predicted torch.max(outputs.data, 1)如果准确率卡在98%左右可以尝试增加卷积层通道数如20-32添加Dropout层防止过拟合使用更复杂的结构如ResNet块我常用的可视化工具import matplotlib.pyplot as plt plt.plot(train_losses, labeltrain) plt.plot(val_losses, labelval) plt.legend()7. 常见问题排查遇到CUDA out of memory错误时可以减小batch_size从64降到32使用梯度累积每4个batch更新一次参数如果训练loss不下降检查数据是否正常显示梯度是否正常回传打印参数梯度学习率是否合适尝试1e-3到1e-5有个容易忽略的细节DataLoader的shuffle参数在训练集要设为True测试集设为False。8. 项目扩展建议掌握基础CNN后可以尝试这些进阶操作数据增强添加随机旋转、平移使用预训练模型在MNIST上微调ResNet部署模型用Flask搭建Web应用我最近尝试将模型导出为ONNX格式dummy_input torch.randn(1, 1, 28, 28) torch.onnx.export(model, dummy_input, mnist_cnn.onnx)记住调参就像烹饪需要耐心尝试。我最初训练MNIST时反复调整了二十多次才突破99%准确率。现在每次看到这个简单的28x28网格都能想起初学时的兴奋感——原来机器真的能看懂我们写的数字。

更多文章