自大数据时代的到来以来,大型语言模型(LLMs)取得了显著进展,展现了前所未有的应用场景和出色的泛化能力。这些进展为各类智能应用奠定了基础,涵盖从自然语言处理到复杂的推理任务等多个领域。
为了进一步提升模型的能力,研究者们开始引入视觉图像作为输入,推动了多模态大型语言模型(MLLMs)的发展。这类模型不仅能生成具有连贯性的语言响应,还能在跨模态理解方面展现出卓越的能力,能够处理诸如图像标题生成、视觉问题回答以及图像中不同对象的定位等任务。
在现有的多模态语言模型中,研究者们探索了不同的策略,以提升LLMs对视觉指令的响应能力。首先,有的研究通过在预训练阶段冻结LLMs,仅使用一个投影网络来进行视觉语言对齐。
例如,LLaMA-Adapter V2通过引入一个简单的MLP层,而mPLUG-Owl则基于注意力机制设计了视觉摘要器。其次,部分方法通过构建新的训练任务数据,赋予模型新的视觉理解能力。例如,Kosmos-2引入了指称对话任务,而Shikra则通过区域级定位来增强模型的视觉理解能力。另外,也有模型通过引入高级图像编码器来提取视觉嵌入,如LLaVA使用了CLIP编码器,而MiniGPT-4则采用了Q-Former。
在视觉语言模型(VLM)的丛林中,SPHINX 的诞生方式相当令人印象深刻。想象一下,像阿尔伯特·爱因斯坦、艾萨克·牛顿和尼古拉·特斯拉这样的专家一起合作在一个项目上!这正是 SPHINX 所做的事情——它将顶级 AI 模型的力量结合到一个“篮子”中,以实现多任务处理的流畅性。
本文提出了一种创新的多模态语言模型——SPHINX,它结合了四个关键元素:模型权重、调优任务、视觉嵌入和高分辨率子图像。通过这种多维度的融合,SPHINX展示了在多个应用场景中的强大表现。接下来,我们将详细介绍这一方法的主要特点和实验发现。
SPHINX模型的整体混合范式,采用了两阶段训练流程:第一阶段是视觉-语言对齐的预训练,第二阶段是视觉指令跟随的微调。每个阶段都应用了提出的模型权重混合策略和调优任务。整个模型由一个大型语言模型(如LLaMA-2)、一个视觉编码器的混合结构以及两个线性投影层组成。
在阶段一的预训练过程中,我们解冻了LLM以进行视觉-语言对齐。与现有的多模态大型语言模型(如Zhu等,2023;Li等,2023d;Dai等,2023)通常采用的冻结LLM方法不同,后者在预训练阶段通常冻结整个LLM,仅训练中间的投影层来实现视觉-语言对齐。
这种策略虽然能够避免LLMs过度拟合生成简短的句子,但也限制了其在大规模视觉-语言数据上的跨模态学习潜力。为了解决这一问题,SPHINX解冻了整个LLM以及可学习的线性投影层,从而实现了更充分的视觉-语言适应。同时,为了保持高质量的图像表示,视觉编码器在这一阶段保持冻结状态。
与LLMs不同,它只处理文本数据,视觉语言 MLLMs 还必须处理图像,这意味着输入到LLM解码器的不仅仅是文本,还包括图像。然而,这是LLMs本身并不理解的信息类型。
这既有优点也有缺点:
优点:它可以防止LLMs过度拟合,仅生成简短文本,因为图像标题数据集中的文本通常非常简洁。
缺点:LLMs无法从图像中学习信息或理解文本与图像之间的关系,这意味着在生成文本时,它们会错失图像中的大量信息。
如何兼顾利弊?
解决方案
答案在于这个思路:解冻并预训练LLM,在这一过程中让模型从额外的数据集中学习(这是SPHINX的独特功能)。
该阶段的详细步骤如下:
我们将使用真实世界的数据集(LAION-400M)对LLM(具体为LLama2)进行预训练。此时,输入将是视觉嵌入(来自视觉编码器模块的输出)。
在初步的LLM预训练之后,我们将以这些预训练权重为起点,在合成数据集(LAION-COCO)上进一步训练相同的LLM。这个步骤的原因将在后面解释。
此外,我们还将使用另一个仅包含文本的数据集——RefinedWeb,对LLM进行预训练。
在每次训练过程中,我们将从RefinedWeb中采样一个数据点,并从图像标题数据集中采样几个数据点。然后,我们使用两个损失函数依次进行训练,一个损失函数要求模型保持生成长文本的能力,另一个则确保模型能够准确描述图像内容。
这种方法非常有效。没有与RefinedWeb一起训练时,模型会严重过拟合(如橙色所示)。相反,采用RefinedWeb进行训练时,模型不会过度拟合生成短文本(如绿色所示)。
通过组合LLM权重来综合知识
理由
一个合成数据集可能包含真实数据集所缺乏的独特信息。因此,作者希望LLM能够同时处理这两种类型的信息。
解决方案
我们可以在两个数据集上同时训练模型,但作者认为这种做法会使模型难以收敛,并且可能过于苛刻。
from SPHINX import SPHINXModel
from PIL import Image
import torch
# Besides loading the `consolidated.*.pth` model weights, from_pretrained will also try to
# use `tokenizer.model', 'meta.json', and 'config.json' under `pretrained_path` to configure
# the `tokenizer_path`, `llama_type`, and `llama_config` of the model. You may also override
# the configurations by explicitly specifying the arguments
model = SPHINXModel.from_pretrained(pretrained_path="path/to/checkpoint", with_visual=True)
image = Image.open("examples/1.jpg")
qas = [["What's in the image?", None]]
response = model.generate_response(qas, image, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0)
print(response)
# if you wanna continue
qas[-1][-1] = response
qas.append(["Then how does it look like?", None])
response2 = model.generate_response(qas, image, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0)
print(response2)
from SPHINX import SPHINXModel
from PIL import Image
import torch
import torch.distributed as dist
import multiprocessing as mp
def main(world_size, rank) -> None:
dist.init_process_group(
backend="nccl", rank=rank, world_size=world_size,
init_method=f"tcp://127.0.0.1:23560",
)
torch.cuda.set_device(rank)
# mp_group tells the model which ranks will work together
# through model parallel to compose a complete model.
# When mp_group is None, a single-rank process group will
# be created and used, which means model parallel size = 1 (not enabled)
model = SPHINXModel.from_pretrained(
pretrained_path="path/to/checkpoint", with_visual=True,
mp_group=dist.new_group(ranks=list(range(world_size)))
)
# it's important to make sure that ranks within the same
# model parallel group should always receive the same input simultaneously
image = Image.open("examples/1.jpg")
qas = [["What's in the image?", None]]
response = model.generate_response(qas, image, max_gen_len=1024, temperature=0.9, top_p=0.5, seed=0)
if __name__ == "__main__":
N_GPU = 2
assert N_GPU in [1, 2, 4, 8]
if N_GPU == 1:
main(world_size=1, rank=0)
else:
# You can use whatever method, e.g. torchrun, slurm, etc. for distributed launch
# Just be sure to initialize torch distributed (by invoking dist.init_process_group)
# before creating the SPHINX model if model parallel size > 1 is used
mp.set_start_method("spawn")
for rank in range(N_GPU):
process = mp.Process(target=main, args=(N_GPU, rank))
process.start()