LoRA 的原理和用 PyTorch 从零到一的代码实现

教育   2025-01-12 15:42   江苏  


 

备注:本文首发于 https://bruceyuan.com/hands-on-code/hands-on-lora.html 后续有修改会更新于 博客中(实在是博客没人看啊)

来自:chaofa用代码打点酱油

LLM所有细分领域群、投稿群从这里进入!

背景

无论是火热的大模型(LLM)还是文生图模型(Stable Diffusion)微调的时候,都需要大量的GPU显存,个人的显卡上很难实现,因此各种参数高效(Parameter-Efficient)的方法层出不穷,最受大家欢迎的就是 LoRA《LoRA: Low-Rank Adaptation of Large Language Models》[1]

LoRA 有很多的优点,节约显存,训练快,效果损失较小(相对于全参数微调),推理的时候不增加耗时,可以做一个插入式组件使用。缺点当然也有,那就是还是会有一些效果的损失(笑)。

减少显存占用的主要原因是训练参数变小了(比如只对 qkv 层做 LoRA)

不喜欢看文字的同学可以看 B站视频-chaofa用代码打点酱油[2],


核心原理

核心原理非常的简单,任意一个矩阵 ,都可以对它进行低秩分解,把一个很大的矩阵拆分成两个小矩矩阵[^1](),在训练的过程中不去改变参数,而是去改变。具体可以表示为最终在训练计算的时候是其中 甚至可以设置成 1。

  • • 为什么说只优化 AB 两个矩阵就可以了呢?这里面的假设是什么?
  • •  不是满秩的,里面有大量参数是冗余的,那么其实可以用更接近满秩的矩阵 AB 代替。

    矩阵都可以表示为若干个线性无关向量,最大的线性无关向量个数就是秩

PyTorch 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


classLinearLoRALayer(nn.Module):
    def__init__(self, 
        in_features, 
        out_features,
        merge=False,
        rank=8,
        lora_alpha=16,
        dropout=0.1,
    
):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.merge = merge
        self.rank = rank

        # linear weight 的 Shape 是 (out_features, in_features), 正确的做法是 xW^T
        self.linear = nn.Linear(in_features, out_features)
        # 这里非常的重要,这里是实现的小细节
        
        if rank > 0:
            # 这里是为了标记 lora_a 和 lora_b 是可训练的参数
            self.lora_a = nn.Parameter(
                torch.zeros(out_features, rank)
            )
            # lora_a 需要初始化为 高斯分布
            # @春风归无期 提醒我 @用代码打点酱油的chaofa : 在调用凯明初始化的时候注释里写的高斯分布,调用的却是均匀分布,而且参数a的值设置的是根号5,但a表示的是leaky relu的负斜率系数,一般是0.01这样的小值,不可能超过1
            nn.init.kaiming_normal_(self.lora_a, a=0.01)

            self.lora_b = nn.Parameter(
                torch.zeros(rank, in_features)
            )
            self.scale = lora_alpha / rank

            # linear 需要设置为不可以训练
            self.linear.weight.requires_grad = False
        
        self.dropout = nn.Dropout(
            dropout
        ) if dropout > 0else nn.Identity()

        # 如果采用 merge 进行推理,
        # 那么会把 lora_a 和 lora_b 两个小矩阵的参数直接放到 linear.weight 中
        if merge:
            self.merge_weight()

    
    defforward(self, X):
        # X shape is (batch, seq_len, in_feature)
        # lora_a 是 out_features * rank
        ifself.rank > 0andnotself.merge:
            output = self.linear(X) + self.scale * ( X @ (self.lora_a @ self.lora_b).T )
        elifself.rank > 0andself.merge:
            output = self.linear(X)
        else:
            output = self.linear(X)
        
        returnself.dropout(output)

    defmerge_weight(self, ):
        ifself.merge andself.rank > 0:
            self.linear.weight.data += self.scale * (self.lora_a @ self.lora_b)
    
    defunmerge_weight(self, ):
        ifself.rank > 0:
            self.linear.weight.data -= self.scale * (self.lora_a @ self.lora_b)


# 写一段测试代码
# Test the LoRALinear layer
batch_size = 32
seq_len = 128
in_features = 768
out_features = 512
rank = 8
lora_alpha = 16
dropout = 0.1

# Create a test input
x = torch.randn(batch_size, seq_len, in_features)

# Test regular mode (no merge)
lora_layer = LinearLoRALayer(
    in_features=in_features,
    out_features=out_features,
    rank=rank,
    lora_alpha=lora_alpha,
    dropout=dropout,
    merge=False
)

# Forward pass
output = lora_layer(x)
print(f"Output shape (no merge): {output.shape}")  # Should be [batch_size, seq_len, out_features]

# Test merged mode
lora_layer_merged = LinearLoRALayer(
    in_features=in_features,
    out_features=out_features,
    rank=rank,
    lora_alpha=lora_alpha,
    dropout=dropout,
    merge=True
)

# Forward pass with merged weights
output_merged = lora_layer_merged(x)
print(f"Output shape (merged): {output_merged.shape}")  # Should be [batch_size, seq_len, out_features]

# Test weight merging/unmerging
lora_layer.merge_weight()
output_after_merge = lora_layer(x)
lora_layer.unmerge_weight()
output_after_unmerge = lora_layer(x)

print("Max difference after merge/unmerge cycle:"
      torch.max(torch.abs(output - output_after_unmerge)).item())

  • • Q: 大模型的 LoRA 实现真的这么简单吗?
  • • A: 原理是这么简单,但是实际实现过程中因为层很多,会有一些配置,比如 QKV 层做 LoRA 还是 FFN 层做 LoRA,这些都会增加代码的复杂性,但是核心原理就是上面的代码。

References

[^1]: 这里和PCA,SVD 有一些差别。前者是为了据降维/压缩,后者仅仅是为了学习低秩的矩阵(参数可以更新改变)

引用链接

[1] 《LoRA: Low-Rank Adaptation of Large Language Models》: https://papers.cool/arxiv/2106.09685
[2] B站视频-chaofa用代码打点酱油: https://www.bilibili.com/video/BV1fHmkYyE2w/
[3] 从 self-attention 到 multi-head self-attention: https://bruceyuan.com/hands-on-code/from-self-attention-to-multi-head-self-attention.html
[4] 手写 transformer decoder(CausalLM): https://bruceyuan.com/hands-on-code/hands-on-causallm-decoder.html
[5] LLM 大模型训练-推理显存占用分析: https://bruceyuan.com/post/llm-train-infer-memoery-usage-calculation.html

 



备注:昵称-学校/公司-方向/会议(eg.ACL),进入技术/投稿群


id:DLNLPer,记得备注呦

深度学习自然语言处理
一个热衷于深度学习与NLP前沿技术的平台,期待在知识的殿堂与你相遇~
 最新文章