使用PyTorch进行小样本学习的图像分类

文摘   2024-10-28 08:01   浙江  
近期文章回顾(更多热门文章请关注公众号与知乎Rocky Ding哦)

写在前面

WeThinkIn最新福利放送:大家只需关注WeThinkIn公众号,后台回复“简历资源”,即可获取包含Rocky独家简历模版在内的60套精选的简历模板资源,希望能给大家在AIGC时代带来帮助。

AIGC时代的《三年面试五年模拟》算法工程师求职面试秘籍独家资源:https://github.com/WeThinkIn/Interview-for-Algorithm-Engineer/tree/main

Rocky最新发布Stable Diffusion 3和FLUX.1系列模型的深入浅出全维度解析文章,点击链接直达干货知识:https://zhuanlan.zhihu.com/p/684068402


以下章来源于:DeepHub IMBA

本文仅用于学术分享,如有侵权,请联系台作删文处理

WeThinkIn导读

 

本文简要总结了四种小样本学习图像分类算法的方法,并使用pytorch实现了一个简单的分类模型,附有操作代码。

近年来,基于深度学习的模型在目标检测和图像识别等任务中表现出色。像ImageNet这样具有挑战性的图像分类数据集,包含1000种不同的对象分类,现在一些模型已经超过了人类水平上。但是这些模型依赖于监督训练流程,标记训练数据的可用性对它们有重大影响,并且模型能够检测到的类别也仅限于它们接受训练的类。

由于在训练过程中没有足够的标记图像用于所有类,这些模型在现实环境中可能不太有用。并且我们希望的模型能够识别它在训练期间没有见到过的类,因为几乎不可能在所有潜在对象的图像上进行训练。我们将从几个样本中学习的问题被称为“少样本学习 Few-Shot learning”。

什么是小样本学习?

少样本学习是机器学习的一个子领域。它涉及到在只有少数训练样本和监督数据的情况下对新数据进行分类。只需少量的训练样本,我们创建的模型就可以相当好地执行。

考虑以下场景:在医疗领域,对于一些不常见的疾病,可能没有足够的x光图像用于训练。对于这样的场景,构建一个小样本学习分类器是完美的解决方案。

小样本的变化

一般来说,研究人员确定了四种类型:

  1. N-Shot Learning (NSL)

  2. Few-Shot Learning ( FSL )

  3. One-Shot Learning (OSL)

  4. Zero-Shot Learning (ZSL)

当我们谈论 FSL 时,我们通常指的是 N-way-K-Shot 分类。N 代表类别数,K 代表每个类中要训练的样本数。所以N-Shot Learning 被视为比所有其他概念更广泛的概念。可以说 Few-Shot、One-Shot 和 Zero-Shot是 NSL 的子领域。而零样本学习旨在在没有任何训练示例的情况下对看不见的类进行分类。

在 One-Shot Learning 中,每个类只有一个样本。Few-Shot 每个类有 2 到 5 个样本,也就是说 Few-Shot 是更灵活的 One-Shot Learning 版本。

小样本学习方法

通常,在解决 Few Shot Learning 问题时应考虑两种方法:

数据级方法 (DLA)

这个策略非常简单,如果没有足够的数据来创建实体模型并防止欠拟合和过拟合,那么就应该添加更多数据。正因为如此,许多 FSL 问题都可以通过利用来更大大的基础数据集的更多数据来解决。基本数据集的显着特征是它缺少构成我们对 Few-Shot 挑战的支持集的类。例如,如果我们想要对某种鸟类进行分类,则基础数据集可能包含许多其他鸟类的图片。

参数级方法 (PLA)

从参数级别的角度来看,Few-Shot Learning 样本相对容易过拟合,因为它们通常具有大的高维空间。限制参数空间、使用正则化和使用适当的损失函数将有助于解决这个问题。少量的训练样本将被模型泛化。

通过将模型引导到广阔的参数空间可以提高性能。由于缺乏训练数据,正常的优化方法可能无法产生准确的结果。

因为上面的原因,训练我们的模型以发现通过参数空间的最佳路径,产生最佳的预测结果。这种方法被称为元学习。

小样本学习图像分类算法

有4种比较常见的小样本学习的方法:

与模型无关的元学习 Model-Agnostic Meta-Learning

基于梯度的元学习 (GBML) 原则是 MAML 的基础。在 GBML 中,元学习者通过基础模型训练和学习所有任务表示的共享特征来获得先前的经验。每次有新任务要学习时,元学习器都会利用其现有经验和新任务提供的最少量的新训练数据进行微调训练。

一般情况下,如果我们随机初始化参数经过几次更新算法将不会收敛到良好的性能。MAML 试图解决这个问题。MAML 只需几个梯度步骤并且保证没有过度拟合的前提下,为元参数学习器提供了可靠的初始化,这样可以对新任务进行最佳快速学习。

步骤如下:

元学习者在每个分集(episode)开始时创建自己的副本C,

C 在这一分集上进行训练(在 base-model 的帮助下),

C 对查询集进行预测,

从这些预测中计算出的损失用于更新 C,

这种情况一直持续到完成所有分集的训练。

  1. 元学习者在每个分集(episode)开始时创建自己的副本C,

  2. C 在这一分集上进行训练(在 base-model 的帮助下),

  3. C 对查询集进行预测,

  4. 从这些预测中计算出的损失用于更新 C,

  5. 这种情况一直持续到完成所有分集的训练。

这种技术的最大优势在于,它被认为与元学习算法的选择无关。因此MAML 方法被广泛用于许多需要快速适应的机器学习算法,尤其是深度神经网。

匹配网络 Matching Networks

为解决 FSL 问题而创建的第一个度量学习方法是匹配网络 (MN)。

当使用匹配网络方法解决 Few-Shot Learning 问题时需要一个大的基础数据集。。

将该数据集分为几个分集之后,对于每一分集,匹配网络进行以下操作:

  • 来自支持集和查询集的每个图像都被馈送到一个 CNN,该 CNN 为它们输出特征的嵌入

  • 查询图像使用支持集训练的模型得到嵌入特征的余弦距离,通过 softmax 进行分类

  • 分类结果的交叉熵损失通过 CNN 反向传播更新特征嵌入模型

匹配网络可以通过这种方式学习构建图像嵌入。MN 能够使用这种方法对照片进行分类,并且无需任何特殊的类别先验知识。他只要简单地比较类的几个实例就可以了。

由于类别因分集而异,因此匹配网络会计算对类别区分很重要的图片属性(特征)。而当使用标准分类时,算法会选择每个类别独有的特征。

原型网络 Prototypical Networks

与匹配网络类似的是原型网络(PN)。它通过一些细微的变化来提高算法的性能。PN 比 MN 取得了更好的结果,但它们训练过程本质上是相同的,只是比较了来自支持集的一些查询图片嵌入,但是 原型网络提供了不同的策略。

我们需要在 PN 中创建类的原型:通过对类中图像的嵌入进行平均而创建的类的嵌入。然后仅使用这些类原型来比较查询图像嵌入。当用于单样本学习问题时,它可与匹配网络相媲美。

关系网络 Relation Network

关系网络可以说继承了所有上面提到方法的研究的结果。RN是基于PN思想的但包含了显著的算法改进。

该方法使用的距离函数是可学习的,而不是像以前研究的事先定义它。关系模块位于嵌入模块之上,嵌入模块是从输入图像计算嵌入和类原型的部分。

可训练的关系模块(距离函数)输入是查询图像的嵌入与每个类的原型,输出为每个分类匹配的关系分数。关系分数通过 Softmax 得到一个预测。

使用 Open-AI Clip 进行零样本学习

CLIP(Contrastive Language-Image Pre-Training)是一个在各种(图像、文本)对上训练的神经网络。它无需直接针对任务进行优化,就可以为给定的图像来预测最相关的文本片段(类似于 GPT-2 和 3 的零样本的功能)。

CLIP 在 ImageNet“零样本”上可以达到原始 ResNet50 的性能,而且需要不使用任何标记示例,它克服了计算机视觉中的几个主要挑战,下面我们使用Pytorch来实现一个简单的分类模型。

引入包

! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.gitimport numpy as np
import torch
from pkg_resources import packaging
 
print("Torch version:", torch.__version__)

加载模型

 import clipclip.available\_models\(\) # it will list the names of available CLIP modelsmodel, preprocess = clip.load\("ViT-B/32"\)  
 model.cuda\(\).eval\(\)  
 input\_resolution = model.visual.input\_resolution  
 context\_length = model.context\_length  
 vocab\_size = model.vocab\_size  

 print\("Model parameters:", f"\{np.sum\(\[int\(np.prod\(p.shape\)\) for p in model.parameters\(\)\]\):,\}"\)  
 print\("Input resolution:", input\_resolution\)  
 print\("Context length:", context\_length\)  
 print\("Vocab size:", vocab\_size\)

图像预处理

我们将向模型输入8个示例图像及其文本描述,并比较对应特征之间的相似性。

分词器不区分大小写,我们可以自由地给出任何合适的文本描述。

import os  
 import skimage  
 import IPython.display  
 import matplotlib.pyplot as plt  
 from PIL import Image  
 import numpy as np  

 from collections import OrderedDict  
 import torch  

 \%matplotlib inline  
 \%config InlineBackend.figure\_format = 'retina'  

 \# images in skimage to use and their textual descriptions  
 descriptions = \{  
    "page""a page of text about segmentation",  
    "chelsea""a facial photo of a tabby cat",  
    "astronaut""a portrait of an astronaut with the American flag",  
    "rocket""a rocket standing on a launchpad",  
    "motorcycle\_right""a red motorcycle standing in a garage",  
    "camera""a person looking at a camera on a tripod",  
    "horse""a black-and-white silhouette of a horse",  
    "coffee""a cup of coffee on a saucer"  
 \}original\_images = \[\]  
 images = \[\]  
 texts = \[\]  
 plt.figure\(figsize=\(16, 5\)\)  

 for filename in \[filename for filename in os.listdir\(skimage.data\_dir\) if filename.endswith\(".png"\) or filename.endswith\(".jpg"\)\]:  
    name = os.path.splitext\(filename\)\[0\]  
    if name not in descriptions:  
        continue  

    image = Image.open\(os.path.join\(skimage.data\_dir, filename\)\).convert\("RGB"\)  
       
    plt.subplot\(2, 4, len\(images\) + 1\)  
    plt.imshow\(image\)  
    plt.title\(f"\{filename\}\\n\{descriptions\[name\]\}"\)  
    plt.xticks\(\[\]\)  
    plt.yticks\(\[\]\)  
       
    original\_images.append\(image\)  
    images.append\(preprocess\(image\)\)  
    texts.append\(descriptions\[name\]\)  

 plt.tight\_layout\(\)

结果的可视化如下:

我们对图像进行规范化,对每个文本输入进行标记,并运行模型的正传播获得图像和文本的特征。

 image\_input = torch.tensor\(np.stack\(images\)\).cuda\(\)  
 text\_tokens = clip.tokenize\(\["This is " + desc for desc in texts\]\).cuda\(\)  

 with torch.no\_grad\(\):  
    image\_features = model.encode\_image\(image\_input\).float\(\)  
    text\_features = model.encode\_text\(text\_tokens\).float\(\)

我们将特征归一化,并计算每一对的点积,进行余弦相似度计算

 image\_features /= image\_features.norm\(dim=-1, keepdim=True\)  
 text\_features /= text\_features.norm\(dim=-1, keepdim=True\)  
 similarity = text\_features.cpu\(\).numpy\(\) \@ image\_features.cpu\(\).numpy\(\).T  

 count = len\(descriptions\)  

 plt.figure\(figsize=\(20, 14\)\)  
 plt.imshow\(similarity, vmin=0.1, vmax=0.3\)  
 \# plt.colorbar\(\)  
 plt.yticks\(range\(count\), texts, fontsize=18\)  
 plt.xticks\(\[\]\)  
 for i, image in enumerate\(original\_images\):  
    plt.imshow\(image, extent=\(i - 0.5, i + 0.5, -1.6, -0.6\), origin="lower"\)  
 for x in range\(similarity.shape\[1\]\):  
    for y in range\(similarity.shape\[0\]\):  
        plt.text\(x, y, f"\{similarity\[y, x\]:.2f\}", ha="center", va="center", size=12\)  

 for side in \["left""top""right""bottom"\]:  
  plt.gca\(\).spines\[side\].set\_visible\(False\)  

 plt.xlim\(\[-0.5, count - 0.5\]\)  
 plt.ylim\(\[count + 0.5, -2\]\)  

 plt.title\("Cosine similarity between text and image features", size=20\)

零样本的图像分类

from torchvision.datasets import CIFAR100  
 cifar100 = CIFAR100\(os.path.expanduser\("\~/.cache"\), transform=preprocess, download=True\)  
 text\_descriptions = \[f"This is a photo of a \{label\}" for label in cifar100.classes\]  
 text\_tokens = clip.tokenize\(text\_descriptions\).cuda\(\)  
 with torch.no\_grad\(\):  
    text\_features = model.encode\_text\(text\_tokens\).float\(\)  
    text\_features /= text\_features.norm\(dim=-1, keepdim=True\)  

 text\_probs = \(100.0 \* image\_features \@ text\_features.T\).softmax\(dim=-1\)  
 top\_probs, top\_labels = text\_probs.cpu\(\).topk\(5, dim=-1\)  
 plt.figure\(figsize=\(16, 16\)\)  
 for i, image in enumerate\(original\_images\):  
    plt.subplot\(4, 4, 2 \* i + 1\)  
    plt.imshow\(image\)  
    plt.axis\("off"\)  

    plt.subplot\(4, 4, 2 \* i + 2\)  
    y = np.arange\(top\_probs.shape\[-1\]\)  
    plt.grid\(\)  
    plt.barh\(y, top\_probs\[i\]\)  
    plt.gca\(\).invert\_yaxis\(\)  
    plt.gca\(\).set\_axisbelow\(True\)  
    plt.yticks\(y, \[cifar100.classes\[index\] for index in top\_labels\[i\].numpy\(\)\]\)  
    plt.xlabel\("probability"\)  

 plt.subplots\_adjust\(wspace=0.5\)  
 plt.show\(\)

可以看到,分类的效果还是非常好的

推荐阅读

1、加入AIGCmagic社区知识星球

AIGCmagic社区知识星球不同于市面上其他的AI知识星球,AIGCmagic社区知识星球是国内首个以AIGC全栈技术与商业变现为主线的学习交流平台,涉及AI绘画、AI视频、大模型、AI多模态、数字人、全行业AIGC赋能等50+应用方向,内部包含海量学习资源、专业问答、前沿资讯、内推招聘、AI课程、AIGC模型、AIGC数据集和源码等

那该如何加入星球呢?很简单,我们只需要扫下方的二维码即可。知识星球原价:299元/年,前200名限量活动价,终身优惠只需199元/年。大家只需要扫描下面的星球优惠卷即可享受初始居民的最大优惠:

2、《三年面试五年模拟》算法工程师面试秘籍

《三年面试五年模拟》面试秘籍旨在整理&挖掘AI算法工程师在实习/校招/社招时所需的干货知识点与面试方法,力求让读者在获得心仪offer的同时,增强技术基本面。

Rocky已经将《三年面试五年模拟》面试秘籍的完整版构建在Github上:https://github.com/WeThinkIn/Interview-for-Algorithm-Engineer/tree/main,欢迎大家star!

《三年面试五年模拟》面试秘籍的内容框架

想要一起进行项目共建的朋友,欢迎点击链接加入项目团队:《三年面试五年模拟》版本更新白皮书,迎接AIGC时代

3、Sora等AI视频大模型的核心原理,核心基础知识,网络结构,经典应用场景,从0到1搭建使用AI视频大模型,从0到1训练自己的AI视频大模型,AI视频大模型性能测评,AI视频领域未来发展等全维度解析文章正式发布!

码字不易,欢迎大家多多点赞:

Sora等AI视频大模型文章地址:https://zhuanlan.zhihu.com/p/706722494

4、Stable Diffusion 3和FLUX.1核心原理,核心基础知识,网络结构,从0到1搭建使用Stable Diffusion 3和FLUX.1进行AI绘画,从0到1上手使用Stable Diffusion 3和FLUX.1训练自己的AI绘画模型,Stable Diffusion 3和FLUX.1性能优化等全维度解析文章正式发布!

码字不易,欢迎大家多多点赞:

Stable Diffusion 3和FLUX.1文章地址:https://zhuanlan.zhihu.com/p/684068402

5、Stable Diffusion XL核心基础知识,网络结构,从0到1搭建使用Stable Diffusion XL进行AI绘画,从0到1上手使用Stable Diffusion XL训练自己的AI绘画模型,AI绘画领域的未来发展等全维度解析文章正式发布!

码字不易,欢迎大家多多点赞:

Stable Diffusion XL文章地址:https://zhuanlan.zhihu.com/p/643420260

6、Stable Diffusion 1.x-2.x核心原理,核心基础知识,网络结构,经典应用场景,从0到1搭建使用Stable Diffusion进行AI绘画,从0到1上手使用Stable Diffusion训练自己的AI绘画模型,Stable Diffusion性能优化等全维度解析文章正式发布!

码字不易,欢迎大家多多点赞:

Stable Diffusion文章地址:https://zhuanlan.zhihu.com/p/632809634

7、ControlNet核心基础知识,核心网络结构,从0到1使用ControlNet进行AI绘画,从0到1训练自己的ControlNet模型,从0到1上手构建ControlNet商业变现应用等全维度解析文章正式发布!

码字不易,欢迎大家多多点赞:

ControlNet文章地址:https://zhuanlan.zhihu.com/p/660924126

8、LoRA系列模型核心原理,核心基础知识,从0到1使用LoRA模型进行AI绘画,从0到1上手训练自己的LoRA模型,LoRA变体模型介绍,优质LoRA推荐等全维度解析文章正式发布!

码字不易,欢迎大家多多点赞:

LoRA文章地址:https://zhuanlan.zhihu.com/p/639229126

9、Transformer核心基础知识,核心网络结构,AIGC时代的Transformer新内涵,各AI领域Transformer的应用落地,Transformer未来发展趋势等全维度解析文章正式发布!

码字不易,欢迎大家多多点赞:

Transformer文章地址:https://zhuanlan.zhihu.com/p/709874399

10、最全面的AIGC面经《手把手教你成为AIGC算法工程师,斩获AIGC算法offer!(2024年版)》文章正式发布!

码字不易,欢迎大家多多点赞:

AIGC面经文章地址:https://zhuanlan.zhihu.com/p/651076114

11、50万字大汇总《“三年面试五年模拟”之算法工程师的求职面试“独孤九剑”秘籍》文章正式发布!

码字不易,欢迎大家多多点赞:

算法工程师三年面试五年模拟文章地址:https://zhuanlan.zhihu.com/p/545374303

《三年面试五年模拟》github项目地址(希望大家能多多star):https://github.com/WeThinkIn/Interview-for-Algorithm-Engineer

12、Stable Diffusion WebUI、ComfyUI、Fooocus三大主流AI绘画框架核心知识,从0到1搭建AI绘画框架,从0到1使用AI绘画框架的保姆级教程,深入浅出介绍AI绘画框架的各模块功能,深入浅出介绍AI绘画框架的高阶用法等全维度解析文章正式发布!

码字不易,欢迎大家多多点赞:

AI绘画框架文章地址:https://zhuanlan.zhihu.com/p/673439761

13、GAN网络核心基础知识,网络架构,GAN经典变体模型,经典应用场景,GAN在AIGC时代的商业应用等全维度解析文章正式发布!

码字不易,欢迎大家多多点赞:

GAN网络文章地址:https://zhuanlan.zhihu.com/p/663157306

14、其他

Rocky将YOLOv1-v7全系列大解析文章也制作成相应的pdf版本,大家可以关注公众号WeThinkIn,并在后台 【精华干货】菜单或者回复关键词“YOLO” 进行取用。

WeThinkIn
Rocky相信人工智能,数据科学,商业逻辑,金融工具,终身成长,以及顺应时代的潮流会赋予我们超能力。
 最新文章