微调实战项目-01

文摘   2025-01-04 08:00   四川  
很多初学者学习了东西,苦于没有实际的动手机会去真正体验一把,一直处于“纸上谈兵”的痛苦状态。
AI 怎么能如此无聊呢?!我们需要发挥自己的想象并且结合实际的应用场景做点有意思的东西出来!
项目背景
我们其实很多开发者都有这样的一个实际诉求:
本地有一个公司的内部项目代码,这是公司内部人员编写的业务系统代码。代码分了很多包,代码也很多,各种牛鬼蛇神的业务逻辑都在里面。
作为一个新人往往不知道代码的逻辑,如果是陈年老代码,估计有很大一部分都不知道它的内部逻辑了。
而且这种通过现成的基座大模型无法解决,局部的代码交给 ChatGPT,它往往无法很好地处理,或者说基本上无法处理。
基于这种情况,我想将这个项目代码的内容作为一种知识通过 lora 微调的方式让本地的 model 学习。

当然了,理论上这种使用RAG也是可以实现的,但是我不太确定这种情况,RAG能做到什么地步,实际上RAG的上限不在于基座大模型本身,它的上限在于开发者的上层RAG的构建和组织能力。我们后面会使用最新版 langchain 0.3 版本的方式实现一遍,对比一下二者的效果。

有兴趣的读者点个关注不要错过哦。


微调方案分析

1. 通过“无监督”微调(仅用代码)

这种方法的微调dataset格式是:

[  {"text""code"},  {"text""code"}]
  • 优点

    • 简单直接:这种方式不需要为每段代码提供额外的描述或上下文,直接将代码作为输入进行文字填空训练,能让模型从大量代码中提取出模式。

    • 能够处理大量代码:如果项目代码量非常庞大,可以直接将代码片段作为训练样本,不需要人工标注功能描述。

    • 减少标注工作量:不需要为每个代码片段编写功能描述,适合大规模代码的学习。

  • 缺点

    • 缺乏上下文理解:模型会学习到代码的语法和某些模式,但不一定能够理解每段代码的业务逻辑或功能,尤其是对于复杂的、跨多个模块或包的逻辑,模型的表现可能不如期望。

    • 难以捕捉复杂功能:如果代码背后的业务逻辑比较复杂,仅仅通过代码的形式,模型难以推测出其真正的目的或应用场景。

2. 通过“有监督指令微调”(代码+功能描述)

这种方法的微调dataset格式是:

[  {"instruction""code 功能描述","output""code"},  {"instruction""code 功能描述","output""code"},]
  • 优点

    • 明确的上下文:通过明确的功能描述,模型能够学到每段代码背后的目的和作用,能够更加精确地理解和生成代码。

    • 提高模型理解度:通过描述功能,模型能将代码与实际的业务需求联系起来,从而生成更加符合预期的代码。

    • 适用于复杂业务:对于复杂的业务系统,功能描述能够帮助模型更好地捕捉跨模块的逻辑关系。

  • 缺点

    • 人工标注工作量大:需要为每段代码编写详细的功能描述,增加了数据准备的复杂性和时间成本。

    • 数据量可能有限:如果没有足够的功能描述样本,训练效果可能会受到限制。

3. 综合考虑:两种方式结合使用

在我们的场景中,最理想的方式其实将两者结合使用,结合能想到的所有方案的优点,通常能“堆出”一个好的系统,这也是大模型model的设计精髓。具体做法如下:

  • 基础微调:首先使用无监督的方式,通过大量的代码片段来让模型学习项目中常见的模块结构,业务属性、业务逻辑和常见函数。这样模型能够从广泛的代码片段中提取基础的业务知识。

  • 增强微调:接着通过有监督的方式对模型进行进一步微调,将每段代码与其功能描述配对。这样可以帮助模型更好地理解每段代码的上下文和具体功能,从而提升模型的准确性和可解释性,特别是在需要模型生成或理解复杂代码时。


无监督训练dataset构建

这节,我们先来尝试构建无监督训练的 dataset,也就是从某个代码项目的根目录下找到所有的项目代码文件比如 .java 或者 .py 的文件,然后进行text split。

这里我计划使用Langchain提供的便捷的语义分割器: RecursiveCharacterTextSplitter。基于这个工具我们可以定义出多样的代码切割工具,比如下面是我定义的 Java 代码的 split class :

from langchain_text_splitters import RecursiveCharacterTextSplitter, Language
class JavaCodeTextSplitter(RecursiveCharacterTextSplitter):    def __init__(self, **kwargs: Any) -> None:        separators = self.get_separators_for_language(Language.JAVA)        super().__init__(separators=separators, **kwargs)

拿Java来说,基本上它文件最上面的 package 导入就能提供模块与模块或者class 和class之间的隐性关联。即使这样简单地拆分,在无监督训练应该也是能学习到很多东西的,下面是完整的构建无监督训练dataset的代码

from typing import AnyListimport osimport jsonfrom langchain_text_splitters import RecursiveCharacterTextSplitter, Languagefrom langchain_text_splitters.python import PythonCodeTextSplitter
class JavaCodeTextSplitter(RecursiveCharacterTextSplitter):    def __init__(self, **kwargs: Any) -> None:        separators = self.get_separators_for_language(Language.JAVA)        super().__init__(separators=separators, **kwargs)
def save_docs(docs: List[str], path: str = './ca_pre_train_data.json') -> None:    """将文档对象保存到 JSON 文件"""    with open(path, 'a', encoding="utf-8"as f:        for doc in docs:            json.dump({"text": doc}, f, ensure_ascii=False)            f.write(',\n')
class PreTrainDataBuilder:    def __init__(            self,            root_path: str = '../data',            file_extension: str = '.py') -> None:        self.root_path = root_path        self.file_extension = file_extension        self.files = self._get_all_files()        def _get_all_files(self) -> List[str]:        """递归获取指定目录下的所有文件"""        return [            os.path.join(dir_path, filename)            for dir_path, _, filenames in os.walk(self.root_path)            for filename in filenames if filename.endswith(self.file_extension)        ]        def _code_to_docs(self, file_path: str) -> List[str]:        """将代码文件转换为文档对象列表"""        with open(file_path, 'r', encoding='utf-8'as f:            text = f.read()        splitter = (            PythonCodeTextSplitter(chunk_size=1024, chunk_overlap=100)            if self.file_extension == '.py' else            JavaCodeTextSplitter(chunk_size=1024, chunk_overlap=100)        )        return splitter.split_text(text)        def run(self,            output_path: str = './ca_pre_train_data.json') -> None:        """处理所有文件并保存提取的文档"""        for file_path in self.files:            docs = self._code_to_docs(file_path)            save_docs(docs, path=output_path)
if __name__ == "__main__":    builder = PreTrainDataBuilder(file_extension='.java')    builder.run()
然后你就可以拿着这个dataset 去llamafactory(关于llamafactory,请参照这篇)上面去做预训练啦!
测试结果
经过pre_train的微调,我实际测试了一把,它确实学到了我本地代码的一些知识,我们的预训练方式的微调是生效的!
注意,当你向微调好的model提问的时候给出的prompt至关重要,不然它还是会有很多幻觉的,这些幻觉会让你产生一种挫败感:为什么我的微调没有起作用
因为是我们自己拿着准备好的数据微调,因此,我们知道数据的分布,还是比较容易构造高质量的prompt的。经过我的尝试,你可以尝试添加 如下格式的 system message:
请从 xxxx package中获取相关问题的信息
这个约束相当于让它从我们自己的dataset中找答案,而不是它之前的model中的知识中找答案。因为我这里的项目背景是对公司内部的Java项目熟悉,因此xxxx package可以是项目的根包名,这个包名应该在所有的Java代码文件中都存在。
参考链接:
https://python.langchain.com/docs/how_to/code_splitter

半夏决明
读书,摄影,随笔
 最新文章