揭秘Self-RAG:引领大型语言模型生成质量的新潮流!

文摘   2024-10-16 07:40   北京  

Self-RAG(Self-Reflective Retrieval-Augmented Generation)是一种新型的检索增强生成框架,旨在提高大型语言模型(LLM)的生成质量和准确性。

Self-RAG通过引入“反思标记”(reflection tokens),使得模型能够根据具体需求动态决定是否进行信息检索。这种方法不仅减少了不必要的检索操作,还提高了生成内容的准确性和相关性。

与传统RAG相比,Self-RAG不仅增强了信息过滤和检索策略的灵活性,还能够更好地处理长尾问题和多样化信息需求,使得生成结果更加精准和高效。

Self-RAG的工作流程主要包括以下几个步骤:

  1. 检索决策:

  • 模型首先生成一个检索标记,以评估是否需要从外部资源检索信息。如果不需要检索,模型将直接生成答案;如果需要,则进行相关文档的检索。

  • 生成内容:

    • 在检索到相关文档后,模型会生成基于这些文档的内容,并使用批判标记(critique tokens)来评估生成的答案是否准确和有用。

  • 自我评估:

    • Self-RAG还会对生成的内容进行自我评估,确保输出的质量和事实准确性。这种自我反思的机制使得模型能够在生成过程中不断优化其输出。


    论文开源项目self-rag(https://github.com/AkariAsai/self-rag)实现了自反射RAG。可以从HuggingFace Hub下载Self-RAG。对于推理,该项目建议使用VLLM来提高推理的效率。其他更多内容,读者可访问该项目自行阅读。

    from vllm import LLM, SamplingParams
    model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half")sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)
    def format_prompt(input, paragraph=None): prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input) if paragraph is not None: prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph) return prompt
    query_1 = "Leave odd one out: twitter, instagram, whatsapp."query_2 = "Can you tell me the difference between llamas and alpacas?"queries = [query_1, query_2]
    # for a query that doesn't require retrievalpreds = model.generate([format_prompt(query) for query in queries], sampling_params)for pred in preds: print("Model prediction: {0}".format(pred.outputs[0].text))

    此外,LangChain框架也实现了Self-RAG的应用,其中LangChain框架被用来处理检索增强生成(RAG)的复杂流程,LangGraph则是用于从头构建图工作流的工具。

    这种工作流的设计使得各节点可以独立处理不同的任务,并根据上下文动态调整执行顺序,提高了模型生成回答的质量和适应性。每个节点代表一个特定的处理步骤(如检索、生成、评估等),而边则定义了各步骤之间的依赖关系和数据流向。

    关键代码

    def retrieve(state):    """    Retrieve documents
    Args: state (dict): The current graph state
    Returns: state (dict): New key added to state, documents, that contains retrieved documents """ print("---RETRIEVE---") question = state["question"]
    # Retrieval documents = retriever.get_relevant_documents(question) return {"documents": documents, "question": question}

    def generate(state): """ Generate answer
    Args: state (dict): The current graph state
    Returns: state (dict): New key added to state, generation, that contains LLM generation """ print("---GENERATE---") question = state["question"] documents = state["documents"]
    # RAG generation generation = rag_chain.invoke({"context": documents, "question": question}) return {"documents": documents, "question": question, "generation": generation}

    def grade_documents(state): """ Determines whether the retrieved documents are relevant to the question.
    Args: state (dict): The current graph state
    Returns: state (dict): Updates documents key with only filtered relevant documents """
    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---") question = state["question"] documents = state["documents"]
    # Score each doc filtered_docs = [] for d in documents: score = retrieval_grader.invoke( {"question": question, "document": d.page_content} ) grade = score.binary_score if grade == "yes": print("---GRADE: DOCUMENT RELEVANT---") filtered_docs.append(d) else: print("---GRADE: DOCUMENT NOT RELEVANT---") continue return {"documents": filtered_docs, "question": question}

    def transform_query(state): """ Transform the query to produce a better question.
    Args: state (dict): The current graph state
    Returns: state (dict): Updates question key with a re-phrased question """
    print("---TRANSFORM QUERY---") question = state["question"] documents = state["documents"]
    # Re-write question better_question = question_rewriter.invoke({"question": question}) return {"documents": documents, "question": better_question}

    ### Edges

    def decide_to_generate(state): """ Determines whether to generate an answer, or re-generate a question.
    Args: state (dict): The current graph state
    Returns: str: Binary decision for next node to call """
    print("---ASSESS GRADED DOCUMENTS---") state["question"] filtered_documents = state["documents"]
    if not filtered_documents: # All documents have been filtered check_relevance # We will re-generate a new query print( "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---" ) return "transform_query" else: # We have relevant documents, so generate answer print("---DECISION: GENERATE---") return "generate"

    def grade_generation_v_documents_and_question(state): """ Determines whether the generation is grounded in the document and answers question.
    Args: state (dict): The current graph state
    Returns: str: Decision for next node to call """
    print("---CHECK HALLUCINATIONS---") question = state["question"] documents = state["documents"] generation = state["generation"]
    score = hallucination_grader.invoke( {"documents": documents, "generation": generation} ) grade = score.binary_score
    # Check hallucination if grade == "yes": print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") # Check question-answering print("---GRADE GENERATION vs QUESTION---") score = answer_grader.invoke({"question": question, "generation": generation}) grade = score.binary_score if grade == "yes": print("---DECISION: GENERATION ADDRESSES QUESTION---") return "useful" else: print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---") return "not useful" else: pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---") return "not supported"

    完整代码参考:https://github.com/langchain-ai/langgraph/blob/main/examples/rag/langgraph_self_rag.ipynb

    AI技术研习社
    专注分享人工智能、大模型、算法、大数据开发、数据分析领域的技术干货和落地实践!
     最新文章