卷积神经网络|迁移学习-猫狗分类完整代码实现

文摘   科技   2023-04-07 17:21   河南  

还记得这篇文章吗?迁移学习|代码实现

在这篇文章中,我们知道了在构建模型时,可以借助一些非常有名的模型,这些模型在ImageNet数据集上早已经得到了检验。

同时torchvision模块也提供了预训练好的模型。我们只需稍作修改,便可运用到自己的实际任务中!

我们仍然按照这个步骤开始我们的模型的训练

  • 准备一个可迭代的数据集

  • 定义一个神经网络

  • 将数据集输入到神经网络进行处理

  • 计算损失

  • 通过梯度下降算法更新参数


import torch import torchvisionimport torchvision.transforms as transformsimport torch.nn as nnimport torch.optim as optimimport matplotlib.pyplot as pltfrom torchvision import models

数据集准备


cifar10_train = torchvision.datasets.CIFAR10( root = 'cifar10/', train = True, download = True)cifar10_test=torchvision.datasets.CIFAR10( root = 'cifar10/', train = False, download = True)
transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224,224)) ])

cifar2_train=[(transform(img),[3,5].index(label)) for img,label in cifar10_train if label in [3,5]]
cifar2_test=[(transform(img),[3,5].index(label)) for img,label in cifar10_test if label in [3,5]]
train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64,shuffle=True)test_loader = torch.utils.data.DataLoader(cifar2_test, batch_size=64,shuffle=True)

数据集使用CIFAR-10数据集中的猫和狗

CIFAR-10数据集类别

种类       标签

  • plane       0

  • car           1

  • bird         2

  • cat           3

  • deer         4

  • dog          5

  • frog         6

  • horse       7

  • ship         8

  • truck        9


可以看到其中cat和dog的标签分别为3和5

借助:

[3,5].index(label)

我们可以将cat标签变为0dog标签变为1,从而回到二分类问题。

math and code
计算机专业研究生在读,拥有深厚的计算机科学和数学背景,对编程、算法、数据结构、深度学习等领域都有着深入的了解和实践经验。对编程语言的掌握熟练而全面,无论是主流的Python、Java,还是强大的C++、Go,都能轻松驾驭。