Self-RAG(Self-Reflective Retrieval-Augmented Generation)是一种新型的检索增强生成框架,旨在提高大型语言模型(LLM)的生成质量和准确性。
Self-RAG通过引入“反思标记”(reflection tokens),使得模型能够根据具体需求动态决定是否进行信息检索。这种方法不仅减少了不必要的检索操作,还提高了生成内容的准确性和相关性。
与传统RAG相比,Self-RAG不仅增强了信息过滤和检索策略的灵活性,还能够更好地处理长尾问题和多样化信息需求,使得生成结果更加精准和高效。
Self-RAG的工作流程主要包括以下几个步骤:
检索决策:
模型首先生成一个检索标记,以评估是否需要从外部资源检索信息。如果不需要检索,模型将直接生成答案;如果需要,则进行相关文档的检索。
生成内容:
在检索到相关文档后,模型会生成基于这些文档的内容,并使用批判标记(critique tokens)来评估生成的答案是否准确和有用。
自我评估:
Self-RAG还会对生成的内容进行自我评估,确保输出的质量和事实准确性。这种自我反思的机制使得模型能够在生成过程中不断优化其输出。
论文开源项目self-rag(https://github.com/AkariAsai/self-rag)实现了自反射RAG。可以从HuggingFace Hub下载Self-RAG。对于推理,该项目建议使用VLLM来提高推理的效率。其他更多内容,读者可访问该项目自行阅读。
from vllm import LLM, SamplingParams
model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half")
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)
def format_prompt(input, paragraph=None):
prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
if paragraph is not None:
prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
return prompt
query_1 = "Leave odd one out: twitter, instagram, whatsapp."
query_2 = "Can you tell me the difference between llamas and alpacas?"
queries = [query_1, query_2]
# for a query that doesn't require retrieval
preds = model.generate([format_prompt(query) for query in queries], sampling_params)
for pred in preds:
print("Model prediction: {0}".format(pred.outputs[0].text))
此外,LangChain框架也实现了Self-RAG的应用,其中LangChain框架被用来处理检索增强生成(RAG)的复杂流程,LangGraph则是用于从头构建图工作流的工具。
这种工作流的设计使得各节点可以独立处理不同的任务,并根据上下文动态调整执行顺序,提高了模型生成回答的质量和适应性。每个节点代表一个特定的处理步骤(如检索、生成、评估等),而边则定义了各步骤之间的依赖关系和数据流向。
关键代码:
def retrieve(state):
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
question = state["question"]
# Retrieval
documents = retriever.get_relevant_documents(question)
return {"documents": documents, "question": question}
def generate(state):
"""
Generate answer
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
# RAG generation
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
def grade_documents(state):
"""
Determines whether the retrieved documents are relevant to the question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates documents key with only filtered relevant documents
"""
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]
# Score each doc
filtered_docs = []
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
continue
return {"documents": filtered_docs, "question": question}
def transform_query(state):
"""
Transform the query to produce a better question.
Args:
state (dict): The current graph state
Returns:
state (dict): Updates question key with a re-phrased question
"""
print("---TRANSFORM QUERY---")
question = state["question"]
documents = state["documents"]
# Re-write question
better_question = question_rewriter.invoke({"question": question})
return {"documents": documents, "question": better_question}
### Edges
def decide_to_generate(state):
"""
Determines whether to generate an answer, or re-generate a question.
Args:
state (dict): The current graph state
Returns:
str: Binary decision for next node to call
"""
print("---ASSESS GRADED DOCUMENTS---")
state["question"]
filtered_documents = state["documents"]
if not filtered_documents:
# All documents have been filtered check_relevance
# We will re-generate a new query
print(
"---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
)
return "transform_query"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "generate"
def grade_generation_v_documents_and_question(state):
"""
Determines whether the generation is grounded in the document and answers question.
Args:
state (dict): The current graph state
Returns:
str: Decision for next node to call
"""
print("---CHECK HALLUCINATIONS---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]
score = hallucination_grader.invoke(
{"documents": documents, "generation": generation}
)
grade = score.binary_score
# Check hallucination
if grade == "yes":
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
# Check question-answering
print("---GRADE GENERATION vs QUESTION---")
score = answer_grader.invoke({"question": question, "generation": generation})
grade = score.binary_score
if grade == "yes":
print("---DECISION: GENERATION ADDRESSES QUESTION---")
return "useful"
else:
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
return "not useful"
else:
pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
return "not supported"
完整代码参考:https://github.com/langchain-ai/langgraph/blob/main/examples/rag/langgraph_self_rag.ipynb