经过多方努力,终于timm可以深度集成进transformers生态:带来闪电般的推理速度,快速的模型量化,torch.compile提速,以及transformers Trainer API轻松微调任何timm模型。
Github: https://github.com/huggingface/pytorch-image-models
Example: https://github.com/ariG23498/timm-wrapper-examples
pip install -Uq transformers timm
timm是什么
PyTorch Image Models(timm)开源库,提供了极其丰富的计算机视觉预训练模型(Vision Backbone),以及各种有用的自定义Layer、优化器、数据增强以及实用的训练工具。在撰写本文时,它拥有超过32K的GitHub star和超过20万的每日下载量,是计算机视觉领域最受欢迎的开源库,被用于图像分类、目标检测、分割、图像检索和其他众多下游任务。 由于覆盖了超多流行的视觉预训练模型,timm极大简化了计算机视觉从业者的工作流程。笔者也是timm的Contributor之一,关于更多timm的信息,可以参考本号先前推送: 调包侠系列—— timm保姆级使用指南
为何transformers需要timm集成
虽然transformers支持了多种视觉模型,但timm提供了更广泛的模型结合,尤其是包含了很多端侧友好的小模型(Efficient Models)。
而对于timm来说,加入当前最流行的transformers生态,意味着:
✅ Pipeline API支持:轻松地将任何timm模型插入transformers的pipeline,以实现更简单的推理运行。
🧩 与Auto Class的兼容性:让timm模型与transformers的Auto Class API无缝连接。
⚡ 快速量化:只需约5行代码,您就可以量化任何timm模型以进行高效推理。
🎯 支持微调Trainer API:由于transformers Trainer API已经成为主流,使用Trainer微调timm模型将会对transformers使用者非常友好,并且能够无缝支持LoRA等先进微调方法,或应用上各种训练tricks或工具。
🔁 相互兼容:可以在timm中使用transformers微调后的模型。
🚀 Torch Compile:作为Pytorch2.x的重要特性,使用Torch.Compile工具来优化推理时间也将被很好地支持。
Pipeline API支持
timm集成允许利用 🤗transformers的 Pipeline API抽象。使得加载预先训练的模型、执行推理以及查看结果变得非常容易,仅需几行代码即可搞定一切!
比如载入MobileNetV4模型,仅需一行代码。对一张图片运行推理也只需一行代码。这就是transformers pipeline API的简单无脑,无需担心各种例如数据预处理等细节无法对齐。
from transformers import pipeline
# 载入分类器
image_classifier = pipeline(model="timm/mobilenetv4_conv_medium.e500_r256_in1k")
# 图像url
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"
# 运行推理
outputs = image_classifier(url)
# 展示结果
for output in outputs:
print(f"Label: {output['label'] :20} Score: {output['score'] :0.2f}")
Gradio集成
通过gradio,可以快速构建在线demo,可以提供给工程师方便地进行可视化debug或测试。
import gradio as gr
demo = gr.Interface(
fn=classify,
inputs=gr.Image(type="pil"),
outputs="text",
examples=[["./sushi.png", "sushi"]]
)
demo.launch()
Auto Class集成
通过🤗 transformers 库提供Auto Classes来抽象模型加载和预处理,已经成为了大家十分常用的操作。使用TimmWrapper,您可以使用 AutoImageProcessor 和 AutoModelForImageClassification 轻松加载任何timm模型及其对应的预处理,并用于训练或推理流程。
from transformers import (
AutoModelForImageClassification,
AutoImageProcessor,
)
from transformers.image_utils import load_image
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"
image = load_image(image_url)
# Use Auto classes to load a timm model
checkpoint = "timm/mobilenetv4_conv_medium.e500_r256_in1k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
timm模型量化
量化是一种强大的技术,可以减小模型大小并加快推理速度,特别是在资源受限的设备上部署时。通过timm集成,您可以使用大模型训练中最常用的bitsandbytes库,使用其中的BitsAndBytesConfig,只需几行代码即可动态量化任何timm模型。对于常训大模型的读者来说,是不是非常熟悉?同理,timm模型的QLoRA微调似乎也可以通过bitsandbytes的量化来进行尝试(但笔者还未验证QLoRA是否正常支持,毕竟小模型一般也用不到)。
from transformers import TimmWrapperForImageClassification, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
checkpoint = "timm/vit_base_patch16_224.augreg2_in21k_ft_in1k"
model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to("cuda")
model_8bit = TimmWrapperForImageClassification.from_pretrained(
checkpoint,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
)
并且精度几乎相同
监督微调(SFT)
使用来自 🤗 transformers的Trainer API对timm模型进行微调,简单且高度灵活。您可以使用Trainer类在自定义数据集上微调模型,Trainer会自动处理训练循环、日志记录和评估。此外,您可以使用LoRA等多种工具或如cosine learning rate decay等tricks进行微调。以获得更高的训练精度、更快的训练速度或是更小的显存占用。transformers的Trainer API简化了训练的一切,算法研究者不再需要撰写复杂的训练工具库,简直是调包侠的最爱。
from transformers import TrainingArguments, Trainer
# 定义训练参数
training_args = TrainingArguments(
output_dir="my_model_output",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
num_train_epochs=3,
load_best_model_at_end=True,
push_to_hub=True,
)
# 定义Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
# 原码 启动!
trainer.train()
对于LoRA微调,也和LLM的LoRA微调使用方法保持完全一致
from peft import LoraConfig, get_peft_model
model = AutoModelForImageClassification.from_pretrained(checkpoint, num_labels=num_labels)
lora_config = LoraConfig(
r=16,
lora_alpha=16,
target_modules=["qkv"],
lora_dropout=0.1,
bias="none",
modules_to_save=["head"],
)
# Wrap the model with PEFT
lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()
支持互转
使用timm同样能载入transformers训练出来的timm模型,如下为一个样例:
checkpoint = "ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101"
config = AutoConfig.from_pretrained(checkpoint)
model = timm.create_model(f"hf_hub:{checkpoint}", pretrained=True) # Load the model with timm
model = model.eval()
image = load_image("https://cdn.britannica.com/52/128652-050-14AD19CA/Maki-zushi.jpg")
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(image).unsqueeze(0))
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
for prob, idx in zip(top5_probabilities[0], top5_class_indices[0]):
print(f"Label: {config.id2label[idx.item()] :20} Score: {prob/100 :0.2f}%")
Torch.Compile
使用PyTorch 2.x中的torch.compile,只需一行代码即可编译模型,从而实现更快的推理。TimmWrapper与torch.compile完全兼容。
点击👇关注 “思源数据科学”
👇点个“赞”和“在看”吧