大型语言模型(LLMs)在生成文本时不可避免地会出现幻觉现象,因为其生成内容的准确性无法单靠模型参数中的知识来保证。尽管检索增强生成(RAG)是 LLMs 的一种实用补充,但其效果在很大程度上取决于检索到的文档的相关性,这也引发了人们对检索出错时模型表现的担忧。
为此,有学者提出了一种名为 Corrective Retrieval Augmented Generation(CRAG)的策略,以提升生成的鲁棒性。
具体而言,CRAG 包括一个轻量级的检索评估器,用于评估查询结果的整体文档质量,并返回一个置信度评分,根据该评分触发不同的知识检索操作。由于从静态、有限的语料库中检索到的文档可能并不理想,CRAG 还通过大规模 Web 搜索来扩展和增强检索结果。
此外,论文设计了一种“分解-重组”(decompose-then-recompose)算法,能够对检索到的文档进行选择性处理,聚焦于关键信息并过滤掉无关内容。
CRAG 具备即插即用的特性,可与各种基于 RAG 的方法无缝结合。实验结果表明,在涵盖短文本和长文本生成任务的四个数据集上,CRAG 显著提升了 RAG 方法的性能。
CRAG通过纠正策略来提升生成的鲁棒性,其工作流程如下图所示。
这个过程展示了 CRAG(Corrective Retrieval Augmented Generation)在推理阶段的操作流程。首先,给定一个查询(如“谁是《死亡蝙蝠侠》的编剧?”),系统会进行初步的文档检索,返回一组检索到的文档(如
接着,检索评估器会对这些检索到的文档与查询的相关性进行评估,判断它们是否能正确回答查询问题,并估计出一个置信度等级。根据评估结果,系统会触发不同的知识检索操作,分为三种情况:正确(Correct)、模糊(Ambiguous) 和 错误(Incorrect)。
对于评估为正确的文档,系统会直接将检索到的文档及其相关知识传递给生成器进行生成。
如果评估为模糊,系统会进入知识细化阶段(Knowledge Refinement)。在此阶段,首先对文档进行分解和清理,提取出可能有用的片段,然后经过过滤过程筛除无关信息,再将提炼后的信息重新组合成新的知识项,传递给生成器进行生成。
当文档评估为错误时,系统会启动知识搜索阶段(Knowledge Searching)。在这个阶段,会对原始查询进行重写,添加更多的上下文信息,并使用扩展后的查询进行大规模的 Web 搜索,以找到更相关的文档。通过对搜索结果进行筛选,最终选出更符合需求的文档传递给生成器。
在整个流程中,生成器会根据不同的知识来源(正确、模糊、错误)生成最终的响应,以提供更加准确和可靠的答案。
算法伪代码:
Corrective Retrieval Augmented Generation (CRAG) 旨在提升生成的鲁棒性,其核心是通过轻量级检索评估器来区分和触发三种不同的知识检索操作。借助 Web 搜索的扩展和优化知识利用,CRAG 显著增强了自动自我纠正的能力,并有效地利用检索到的文档信息。实验结果广泛证明了 CRAG 对 RAG 方法的适应性,以及在短格式和长格式生成任务中的泛化能力。
虽然 CRAG 主要从纠错的角度对 RAG 框架进行改进,并能与各种 RAG 方法无缝结合,但仍需要对外部检索评估器进行微调。未来的研究将集中于如何淘汰这一外部评估器,为 LLMs 配备更强大的检索评估能力,以进一步提升系统的智能性和性能。
论文开源项目CRAG(https://github.com/HuskyInSalt/CRAG)实现了纠正RAG。该项目运行需要Python 3.11环境,其他更多内容,感兴趣的读者可访问该项目自行阅读。
此外,LangChain框架也实现了CRAG的应用,其中LangChain框架被用来处理检索增强生成(RAG)的复杂流程,LangGraph则是用于从头构建图工作流的工具。
在 CRAG 的操作流程中,如果至少有一篇文档的相关性超过预设阈值,那么系统就会继续生成响应。在生成之前,还会执行知识细化步骤,将文档分割为“知识片段”,对每个片段进行评分,并过滤掉不相关的内容。
如果所有文档的相关性都低于阈值,或者评估器无法确定相关性,系统将寻求额外的数据源进行补充。这时,CRAG 会使用网络搜索来增强原有的检索结果,从而提高信息的全面性和准确性。
在实现过程中,一些步骤可以被简化或调整。例如,初次尝试时可以跳过知识细化阶段,如果需要,可以在后续版本中作为独立节点添加回去。当某些文档被判定为不相关时,可以选择通过网络搜索来补充检索,优化查询以获得更相关的结果。
关键代码如下:
from langchain.schema import Document
def retrieve(state):
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
question = state["question"]
# Retrieval
documents = retriever.get_relevant_documents(question)
return {"documents": documents, "question": question}
def generate(state):
"""
Generate answer
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
# RAG generation
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(state):
"""
Determines whether the retrieved documents are relevant to the question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates documents key with only filtered relevant documents
"""
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]
# Score each doc
filtered_docs = []
web_search = "No"
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
web_search = "Yes"
continue
return {"documents": filtered_docs, "question": question, "web_search": web_search}
def transform_query(state):
"""
Transform the query to produce a better question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates question key with a re-phrased question
"""
print("---TRANSFORM QUERY---")
question = state["question"]
documents = state["documents"]
# Re-write question
better_question = question_rewriter.invoke({"question": question})
return {"documents": documents, "question": better_question}
def web_search(state):
"""
Web search based on the re-phrased question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates documents key with appended web results
"""
print("---WEB SEARCH---")
question = state["question"]
documents = state["documents"]
# Web search
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
documents.append(web_results)
return {"documents": documents, "question": question}
### Edges
def decide_to_generate(state):
"""
Determines whether to generate an answer, or re-generate a question.
Args:
state (dict): The current graph state
Returns:
str: Binary decision for next node to call
"""
print("---ASSESS GRADED DOCUMENTS---")
state["question"]
web_search = state["web_search"]
state["documents"]
if web_search == "Yes":
# All documents have been filtered check_relevance
# We will re-generate a new query
print(
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
)
return "transform_query"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "generate"
CRAG 通过这些步骤,显著提升了系统在复杂信息环境下的鲁棒性和灵活性,增强了对检索文档的有效利用。
完整代码参考资料:https://github.com/langchain-ai/langgraph/blob/main/examples/rag/langgraph_crag.ipynb