还记得这篇文章吗?迁移学习|代码实现
在这篇文章中,我们知道了在构建模型时,可以借助一些非常有名的模型,这些模型在ImageNet数据集上早已经得到了检验。
同时torchvision模块也提供了预训练好的模型。我们只需稍作修改,便可运用到自己的实际任务中!
我们仍然按照这个步骤开始我们的模型的训练
准备一个可迭代的数据集
定义一个神经网络
将数据集输入到神经网络进行处理
计算损失
通过梯度下降算法更新参数
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import models
数据集准备
cifar10_train = torchvision.datasets.CIFAR10(
root = 'cifar10/',
train = True,
download = True
)
cifar10_test=torchvision.datasets.CIFAR10(
root = 'cifar10/',
train = False,
download = True
)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224,224))
])
cifar2_train=[(transform(img),[3,5].index(label)) for img,label in cifar10_train if label in [3,5]]
cifar2_test=[(transform(img),[3,5].index(label)) for img,label in cifar10_test if label in [3,5]]
train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64,shuffle=True)
test_loader = torch.utils.data.DataLoader(cifar2_test, batch_size=64,shuffle=True)
数据集使用CIFAR-10数据集中的猫和狗。
CIFAR-10数据集类别
种类 标签
plane 0
car 1
bird 2
cat 3
deer 4
dog 5
frog 6
horse 7
ship 8
truck 9
可以看到其中cat和dog的标签分别为3和5
借助:
[3,5].index(label)
我们可以将cat标签变为0,dog标签变为1,从而回到二分类问题。