拥抱新时代:transformers深度集成timm

文摘   科技   2025-01-21 16:00   北京  



思源Source报道
翻译&编辑:seefun
经过多方努力,终于timm可以深度集成进transformers生态:带来闪电般的推理速度,快速的模型量化,torch.compile提速,以及transformers Trainer API轻松微调任何timm模型。

Github: https://github.com/huggingface/pytorch-image-models

Example: https://github.com/ariG23498/timm-wrapper-examples


pip install -Uq transformers timm


timm是什么

PyTorch Image Models(timm)开源库,提供了极其丰富的计算机视觉预训练模型(Vision Backbone),以及各种有用的自定义Layer、优化器、数据增强以及实用的训练工具。在撰写本文时,它拥有超过32K的GitHub star和超过20万的每日下载量,是计算机视觉领域最受欢迎的开源库,被用于图像分类、目标检测、分割、图像检索和其他众多下游任务。 由于覆盖了超多流行的视觉预训练模型,timm极大简化了计算机视觉从业者的工作流程。笔者也是timm的Contributor之一,关于更多timm的信息,可以参考本号先前推送:  调包侠系列—— timm保姆级使用指南


为何transformers需要timm集成

虽然transformers支持了多种视觉模型,但timm提供了更广泛的模型结合,尤其是包含了很多端侧友好的小模型(Efficient Models)。

而对于timm来说,加入当前最流行的transformers生态,意味着:

✅ Pipeline API支持:轻松地将任何timm模型插入transformers的pipeline,以实现更简单的推理运行。

🧩 与Auto Class的兼容性:让timm模型与transformers的Auto Class API无缝连接。

⚡ 快速量化:只需约5行代码,您就可以量化任何timm模型以进行高效推理。

🎯 支持微调Trainer API:由于transformers Trainer API已经成为主流,使用Trainer微调timm模型将会对transformers使用者非常友好,并且能够无缝支持LoRA等先进微调方法,或应用上各种训练tricks或工具。

🔁 相互兼容:可以在timm中使用transformers微调后的模型。

🚀 Torch Compile:作为Pytorch2.x的重要特性,使用Torch.Compile工具来优化推理时间也将被很好地支持。


Pipeline API支持

timm集成允许利用 🤗transformers的 Pipeline API抽象。使得加载预先训练的模型、执行推理以及查看结果变得非常容易,仅需几行代码即可搞定一切!

比如载入MobileNetV4模型,仅需一行代码。对一张图片运行推理也只需一行代码。这就是transformers pipeline API的简单无脑,无需担心各种例如数据预处理等细节无法对齐。

from transformers import pipeline
# 载入分类器image_classifier = pipeline(model="timm/mobilenetv4_conv_medium.e500_r256_in1k")# 图像urlurl = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"# 运行推理outputs = image_classifier(url)# 展示结果for output in outputs:    print(f"Label: {output['label'] :20} Score: {output['score'] :0.2f}")

Gradio集成

通过gradio,可以快速构建在线demo,可以提供给工程师方便地进行可视化debug或测试。

import gradio as gr
demo = gr.Interface(    fn=classify,    inputs=gr.Image(type="pil"),    outputs="text",    examples=[["./sushi.png""sushi"]])
demo.launch()


Auto Class集成

通过🤗 transformers 库提供Auto Classes来抽象模型加载和预处理,已经成为了大家十分常用的操作。使用TimmWrapper,您可以使用 AutoImageProcessor 和 AutoModelForImageClassification 轻松加载任何timm模型及其对应的预处理,并用于训练或推理流程。

from transformers import (    AutoModelForImageClassification,    AutoImageProcessor,)from transformers.image_utils import load_image
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/timm/cat.jpg"image = load_image(image_url)
# Use Auto classes to load a timm modelcheckpoint = "timm/mobilenetv4_conv_medium.e500_r256_in1k"image_processor = AutoImageProcessor.from_pretrained(checkpoint)model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()


timm模型量化

量化是一种强大的技术,可以减小模型大小并加快推理速度,特别是在资源受限的设备上部署时。通过timm集成,您可以使用大模型训练中最常用的bitsandbytes库,使用其中的BitsAndBytesConfig,只需几行代码即可动态量化任何timm模型。对于常训大模型的读者来说,是不是非常熟悉?同理,timm模型的QLoRA微调似乎也可以通过bitsandbytes的量化来进行尝试(但笔者还未验证QLoRA是否正常支持,毕竟小模型一般也用不到)。

from transformers import TimmWrapperForImageClassification, BitsAndBytesConfigquantization_config = BitsAndBytesConfig(load_in_8bit=True)checkpoint = "timm/vit_base_patch16_224.augreg2_in21k_ft_in1k"model = TimmWrapperForImageClassification.from_pretrained(checkpoint).to("cuda")model_8bit = TimmWrapperForImageClassification.from_pretrained(    checkpoint,    quantization_config=quantization_config,    low_cpu_mem_usage=True,)
经过8bit量化后,模型体积明显小了很多:

并且精度几乎相同


监督微调(SFT)

使用来自 🤗 transformers的Trainer API对timm模型进行微调,简单且高度灵活。您可以使用Trainer类在自定义数据集上微调模型,Trainer会自动处理训练循环、日志记录和评估。此外,您可以使用LoRA等多种工具或如cosine learning rate decay等tricks进行微调。以获得更高的训练精度、更快的训练速度或是更小的显存占用。transformers的Trainer API简化了训练的一切,算法研究者不再需要撰写复杂的训练工具库,简直是调包侠的最爱。

from transformers import TrainingArguments, Trainer# 定义训练参数training_args = TrainingArguments(    output_dir="my_model_output",    evaluation_strategy="epoch",    save_strategy="epoch",    learning_rate=5e-5,    per_device_train_batch_size=16,    num_train_epochs=3,    load_best_model_at_end=True,    push_to_hub=True,)# 定义Trainertrainer = Trainer(    model=model,    args=training_args,    train_dataset=train_ds,    eval_dataset=val_ds,    data_collator=data_collator,    compute_metrics=compute_metrics,)# 原码 启动!trainer.train()

对于LoRA微调,也和LLM的LoRA微调使用方法保持完全一致

from peft import LoraConfig, get_peft_modelmodel = AutoModelForImageClassification.from_pretrained(checkpoint, num_labels=num_labels)lora_config = LoraConfig(    r=16,    lora_alpha=16,    target_modules=["qkv"],    lora_dropout=0.1,    bias="none",    modules_to_save=["head"],)# Wrap the model with PEFTlora_model = get_peft_model(model, lora_config)lora_model.print_trainable_parameters()

支持互转

使用timm同样能载入transformers训练出来的timm模型,如下为一个样例:

checkpoint = "ariG23498/vit_base_patch16_224.augreg2_in21k_ft_in1k.ft_food101"config = AutoConfig.from_pretrained(checkpoint)model = timm.create_model(f"hf_hub:{checkpoint}", pretrained=True) # Load the model with timmmodel = model.eval()image = load_image("https://cdn.britannica.com/52/128652-050-14AD19CA/Maki-zushi.jpg")data_config = timm.data.resolve_model_data_config(model)transforms = timm.data.create_transform(**data_config, is_training=False)output = model(transforms(image).unsqueeze(0))top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1* 100, k=5)for prob, idx in zip(top5_probabilities[0], top5_class_indices[0]):    print(f"Label: {config.id2label[idx.item()] :20} Score: {prob/100 :0.2f}%")

Torch.Compile

使用PyTorch 2.x中的torch.compile,只需一行代码即可编译模型,从而实现更快的推理。TimmWrapper与torch.compile完全兼容。

timm与transformers的集成,以最小的代价让transformers使用上timm中最先进的视觉模型。也为微调、量化、推理都提供了统一的API来简化CV工程师的工作流程。让我们以更低的上手难度,开启计算机视觉新的可能!

点击👇关注 “思源数据科学”

👇点个“赞”和“在看”吧

思源数据科学
Towards AGI
 最新文章