Zhangzhe's Blog

The projection of my life.

0%

用 langchain 实现 RAG 的 demo

URL

TL;DR

  • langchainSemi-structured RAGdemo,包含 PDF 的 解析 -> 向量化 -> 存储 -> 检索 -> 回答 的全流程
  • 来自 langchain 官方的 cookbook,值得参考

总体流程

langchain_rag_demo.png

  1. 解析 PDF 文件。用 partition_pdf 工具,将 PDF 文件解析为 chunks,分成文本和表格两种
  2. 总结 chunks。用大模型 API 总结 chunks,得到 summary,用哪家的模型都行
  3. 向量化 summary。调用 embedding 模型 APIsummary 进行向量化,作为索引的 keyvalue 是原始文本/表格),存到 Chroma 数据库中
  4. 问答时自动检索。在问答时,会自动根据问题向量在 Chroma 数据库中检索出最相似的 summary,将 summary 向量对应的原始文本/表格作为 prompt 中的 context 域传给大模型,得到回答

具体实现代码

1. 解析 PDF 文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from typing import Any

from pydantic import BaseModel
from unstructured.partition.pdf import partition_pdf

path = "/Users/rlm/Desktop/Papers/LLaMA2/"
# Get elements
raw_pdf_elements = partition_pdf(
filename=path + "LLaMA2.pdf",
# Unstructured first finds embedded image blocks
extract_images_in_pdf=False,
# Use layout model (YOLOX) to get bounding boxes (for tables) and find titles
# Titles are any sub-section of the document
infer_table_structure=True,
# Post processing to aggregate text once we have the title
chunking_strategy="by_title",
# Chunking params to aggregate text blocks
# Attempt to create a new chunk 3800 chars
# Attempt to keep chunks > 2000 chars
max_characters=4000,
new_after_n_chars=3800,
combine_text_under_n_chars=2000,
image_output_dir_path=path,
)

2. 总结 chunks 得到 summary

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

# Prompt
prompt_text = """You are an assistant tasked with summarizing tables and text. \
Give a concise summary of the table or text. Table or text chunk: {element} """
prompt = ChatPromptTemplate.from_template(prompt_text)

# Summary chain
model = ChatOpenAI(temperature=0, model="gpt-4")
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()

# Apply to tables
tables = [i.text for i in table_elements]
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5})

# Apply to texts
texts = [i.text for i in text_elements]
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})

langchain| 符号就可以把多个工具串成一个 chain,非常方便

3. 向量化 summary 并存到 Chroma 数据库中

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import uuid

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings

# The vectorstore to use to index the child chunks
vectorstore = Chroma(collection_name="summaries", embedding_function=OpenAIEmbeddings())

# The storage layer for the parent documents
store = InMemoryStore()
id_key = "doc_id"

# The retriever (empty to start)
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=store,
id_key=id_key,
)

# Add texts
doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_texts = [
Document(page_content=s, metadata={id_key: doc_ids[i]})
for i, s in enumerate(text_summaries)
]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))

# Add tables
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [
Document(page_content=s, metadata={id_key: table_ids[i]})
for i, s in enumerate(table_summaries)
]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))

4. 问答时自动检索

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from langchain_core.runnables import RunnablePassthrough

# Prompt template
template = """Answer the question based only on the following context, which can include text and tables:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

# LLM
model = ChatOpenAI(temperature=0, model="gpt-4")

# RAG pipeline
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| model
| StrOutputParser()
)

# run the chain
chain.invoke("What is the number of training tokens for LLaMA2?")

总结

  1. langchain 做了非常多的工具,并给出一种将这些工具非常容易组合使用的方法
  2. 更重要的是,langsmith 提供了非常方便的 trace 功能,可以非常方便地追踪一次问答过程中经过了哪些模型/工具/行为,以及这些模型/工具/行为的 input / output / 耗时等,非常方便