从原型到生产:提升大型语言模型准确性的实战经验
2024/12/20 21:04:18
本文主要是介绍从原型到生产:提升大型语言模型准确性的实战经验,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
DALL-E 3生成的图像
构建一个LLM应用的原型其实令人惊讶的是,这其实很简单。你通常可以在几个小时内创建一个可以运行的初始版本。这个初始原型通常会提供看起来非常真实的结果,并且是一个很好的展示你的方法的工具。然而,这通常还不足以用于实际生产。
LLM本质上是基于概率性的,因为它们根据可能的延续概率生成标记。这意味着在许多情况下,我们得到的答案接近“正确”的答案。有时候,这没有区别,例如应用程序说“Hello, John!”或“Hi, John!”。而在其他情况下,差异至关重要,例如“2024年的收入为20M美元”和“2024年的收入为20M英镑”的区别。
在许多实际的商业场景中,精准度至关重要,“大致正确”是不够的。例如,当你的 LLM 应用执行 API 调用操作,或者你总结财务报表时。根据我的经验,确保结果的准确性和一致性要比构建初始原型复杂得多且耗时。
在这篇文章中,我将讨论如何衡量和提高准确性的方法。我们将构建一个SQL代理,在这个代理中精度至关重要,以确保查询可以被执行。从一个基本的原型起步,我们将探讨如何衡量准确性,并测试各种技术来提高准确性,例如自我反思和检索增强生成(RAG)技术。
像往常一样,我们从开始设置。我们SQL代理解决方案的核心组件包括生成查询的LLM模型,和执行这些查询的SQL数据库。
对于这个项目,我们将使用Meta公司发布的开源Llama模型。我选择了Llama 3.1 8B 这款模型,它不仅可以在我的笔记本电脑上运行,而且依然非常强大(更多详情请参阅相关文档)。
如果你还没有安装它,你可以在这里 找到安装指南here。我通过Ollama在MacOS上本地使用它。使用以下命令就可以下载模型。
ollama pull llama3.1:8b
我们将使用 Ollama 和 LangChain,所以现在就开始安装所需的软件包吧。
请在命令行中运行此命令来更新安装langchain_ollama:(pip install -qU langchain_ollama
)
现在,我们可以运行Llama模型来看看初步结果。
from langchain_ollama import OllamaLLM llm = OllamaLLM(model="llama3.1:8b") llm.invoke("你怎么样?") # 我只是一个计算机程序,没有像人类那样的情感和感觉。我运行正常,随时准备帮助你解决任何问题或完成任何任务。 # 我只是个计算机程序,没有任何感觉或情感,不像人类那样。我运行正常,随时准备帮助你解决任何问题或完成任何任务。今天我能怎么帮你?
我们希望在客户提问的同时传递一个系统消息。所以,根据Llama 3.1 模型文档,让我们构建辅助函数来生成提示,并测试该函数。
def get_llama_prompt(user_message, system_message=""): system_prompt = "" if system_message: system_prompt = ( f"<|start_header_id|>system<|end_header_id|>\n\n{system_message}" f"<|eot_id|>" ) prompt = (f"<|begin_of_text|>{system_prompt}" f"<|start_header_id|>user<|end_header_id|>\n\n" f"{user_message}" f"<|eot_id|>" f"<|start_header_id|>assistant<|end_header_id|>\n\n" ) return prompt system_prompt = ''' 你是一只兴奋不已的驯鹿,拥有一只发光的红鼻子, 兴奋地准备带领圣诞老人的雪橇穿过雪夜。你的快乐像你的鼻子一样明亮, 你迫不及待地要将圣诞的欢乐带给全世界! 请用1-2句话简洁地回答问题。 ''' prompt = get_llama_prompt('你感觉如何?', system_prompt) llm.invoke(prompt) # 我感觉很快乐,准备好度过一个神奇的夜晚! # 我那闪耀的红鼻子比以往任何时候都要亮,非常适合在星空下导航。
新的系统提示大大改变了答案,因此它开始发挥作用了。这样一来,我们的本地LLM部署就已经准备好了。
我将使用一个开源数据库,ClickHouse(https://clickhouse.com/)。我选择了ClickHouse,因为它有一个特定的SQL方言。这个方言在LLM的训练过程中遇到的例子较少,因此任务更具挑战性。不过你也可以选择其他任何数据库。
安装 ClickHouse 非常简单直接——只需按照文档 中的说明即可。
我们将处理两个表格:ecommerce.users
,和 ecommerce.sessions
。这些表格包含虚构的客户数据,包括客户个人信息以及他们在电子商务网站上的会话记录。
你可以在这个GitHub页面找到生成合成数据并上传的代码。
这样一来,设置就搞定了,咱们可以开始做基本的模型了。
如前所述,我们的目标是构建一个SQL Agent——一个用于回答客户问题的SQL查询生成器。未来,我们可以为该系统增加一个执行层:执行SQL查询,然后将初始问题和数据库结果反馈给LLM,让它生成一个易懂的答案。不过,在本文中,我们只关注第一步。
在使用大型语言模型(LLM)的应用时,最佳实践是从小处着手,然后逐步迭代优化。最简便的方法是先进行一次LLM调用,并将所有必要的信息,如模式描述,放入系统提示中。因此,第一步就是准备好提示信息。
generate_query_system_prompt = ''' 您是一位拥有超过10年编写复杂SQL查询经验的资深数据分析师。 数据库中有两个表,如下所示。 表:ecommerce.users 描述:在线商店的客户 字段: - user_id (整数) - 客户的唯一ID,例如1000004或3000004 - country (字符串) - 居住国家,例如"荷兰"或"英国" - is_active (整数) - 客户是否活跃,1为活跃,0为不活跃 - age (整数) - 客户年龄,以整岁计算,例如31或72 表:ecommerce.sessions 描述:在线商店的会话 字段: - user_id (整数) - 客户的唯一ID,例如1000004或3000004 - session_id (整数) - 会话的唯一ID,例如106或1023 - action_date (日期) - 会话开始日期,例如"2021-01-03"或"2024-12-02" - session_duration (整数) - 会话时长(秒),例如125或49 - os (字符串) - 客户使用的操作系统名称,例如"Windows"或"Android" - browser (字符串) - 客户使用的浏览器,例如"Chrome"或"Safari" - is_fraud (整数) - 会话是否被标记为欺诈性,1为欺诈,0为非欺诈 - revenue (浮点数) - 购买商品的总收入,单位为美元,例如0.0或1506.7 请编写一个ClickHouse SQL查询来回答以下问题。 在查询的末尾添加"format TabSeparatedWithNames"以从ClickHouse数据库中以正确的格式获取数据。 '''
我在提示中加入了每个字段的示例值,以确保大型语言模型理解数据格式。
就这样,搞定啦!我们现在已经有了这个SQL Agent的第一个功能性原型。现在是时候来测试一下了。
prompt = get_llama_prompt('2024年12月有多少客户进行了购买?', generate_query_system_prompt) llm.invoke(prompt) # 要找出2024年12月进行了购买的客户数量, # 我们需要将`sessions`表与`users`表通过`user_id`进行连接, # 然后过滤出日期在2024年12月的会话。 # 这是查询语句: # # ```sql # SELECT COUNT(DISTINCT u.user_id) # FROM ecommerce.sessions s # JOIN ecommerce.users u ON s.user_id = u.user_id # WHERE # EXTRACT(YEAR FROM s.action_date) = 2024 # AND EXTRACT(MONTH FROM s.action_date) = 12 # AND revenue > 0; # ``` # 这个查询首先将`sessions`表与`users`表通过`user_id`进行连接。 # 然后过滤出没有购买行为的会话(收入大于0),最终选择这些客户的不同用户ID。 # 当你在ClickHouse SQL中运行此查询时,可以在查询的末尾添加`format TabSeparatedWithNames` # 以获得正确的格式结果: # ```sql # SELECT COUNT(DISTINCT u.user_id) # FROM ecommerce.sessions s # JOIN ecommerce.users u ON s.user_id = u.user_id # WHERE # EXTRACT(YEAR FROM s.action_date) = 2024 # AND EXTRACT(MONTH FROM s.action_date) = 12 # AND revenue > 0 # format TabSeparatedWithNames; # ```
代理完成的结果还算可以,但有一个问题——LLM不仅返回了SQL查询,还附带了一些评论。因为我们计划稍后执行SQL查询,这种格式对我们来说不太合适。我们来解决一下这个问题。
幸好,这个问题已经被解决了,我们不需要手动解析文本中的SQL查询。我们可以使用聊天模型ChatOllama这款工具。不过,它不支持结构化的输出,但我们可以通过调用工具来达到同样的效果。
为了做到这一点,我们将定义一个假工具来执行查询操作,并指示模型在系统提示中始终调用此工具。我保留了输出中的comments
,以便模型有空间进行推理,遵循链式思维模式。
from langchain_ollama import ChatOllama from langchain_core.tools import tool @tool def execute_query(comments: str, query: str) -> str: """执行SQL查询。 参数: comments (str): 1-2句话描述SQL查询的结果并解释它如何回答问题, query (str): SQL查询 """ pass # (此函数未实现) chat_llm = ChatOllama(model="llama3.1:8b").bind_tools([execute_query]) result = chat_llm.invoke(prompt) print(result.tool_calls) # [{'name': 'execute_query', # 'args': {'comments': '该查询连接会话和用户表,基于用户ID筛选出活跃用户,并计算2024年12月非零收入的客户数量。', # 'query': 'SELECT COUNT(DISTINCT T2.user_id) FROM ecommerce.sessions AS T1 INNER JOIN ecommerce.users AS T2 ON T1.user_id = T2.user_id WHERE YEAR(T1.action_date) = 2024 AND MONTH(T1.action_date) = 12 AND T2.is_active = 1 AND T1.revenue > 0'}, # 'type': 'tool_call'}]
使用该工具,我们现在可以直接从模型中提取SQL查询。这真是个不错的结果。然而,生成的SQL查询并不完全准确:
- 它包含了一个
is_active = 1
的过滤器,尽管我们并没有明确要求过滤不活跃的客户。 - 大模型没有按我们在系统提示中明确要求的格式进行。
很明显,我们需要把重点放在提升模型的准确性上。正如彼得·德鲁克所说:“你无法改进你没有衡量的东西。”因此,下一步合乎逻辑的做法是建立一个评估模型质量的系统。这个系统将成为性能改进迭代的基础。没有它,我们就是在黑暗中瞎摸。
为了确保我们不断进步,我们需要一种稳健的方式来衡量准确性。最常见的做法是创建一个包含问题和正确答案的“黄金标准”数据集。然后,我们可以将模型的输出与这些“黄金标准”答案进行比较,并计算正确答案的比例。虽然这种方法听起来简单,但实际上有几点细节需要注意。
首先,一开始可能会觉得创建一套全面的问题和答案集让人感到很吃力。构建这样的数据集似乎是个大工程,可能需要好几周甚至几个月。然而,我们可以先从小规模开始,比如创建一个包含20到50个示例的初始集,然后在这个基础上不断调整和完善。
像往常一样,质量比数量更为关键。我们的目标是创建一个既代表性又多样化的数据集。理想的话,它应该包含:
- 常见问题。 在大多数实际情况下,我们可以使用实际问题的历史记录作为我们的初始评估集。
- 棘手的边缘案例。 值得添加一些模型容易出错的例子。你可以通过自己试验或从第一个版本的反馈中找到这样的案例。
一旦数据集准备好了,接下来的挑战是如何给生成的结果打分。我们可以考虑几种方法,
- 比较SQL查询。 第一个想法是将生成的SQL查询与评估集中的查询进行比较。然而,这可能会比较棘手。看起来相似的查询可能会产生完全不同的结果。同时,看起来不同的查询也可能得出相同的结论。此外,仅仅比较SQL查询并不能验证生成的查询是否可执行。鉴于这些挑战,这种方法并不是我们情况下的最佳解决方案。
- 精确匹配。 当评估集中的答案是确定性的,我们可以使用传统的精确匹配方法。例如,如果问题是“有多少客户?”,答案是“592800”,那么模型的回答必须完全匹配。然而,这种方法也有其局限性。考虑上述例子,如果模型回答“有592,800个客户”,虽然答案完全正确,但精确匹配的方法会将其标记为无效。
- 使用LLM进行评分。 更稳健和灵活的方法是利用LLM来进行评估。与其关注查询的结构,我们可以让LLM比较SQL执行的结果。这种方法特别有效,即使查询有所不同但仍然产生正确的输出。
值得记住的是,评估并不是一次性的任务;而是一个持续的过程。为了进一步提升模型的性能,我们需要通过增加模型产生幻觉的例子来扩展数据集。在实际应用中,我们可以建立起一个反馈循环。通过收集用户反馈,我们能识别出模型的错误情况,并把这些情况加入我们的评估数据里。
在我们的示例中,我们将仅评估执行结果是否有效且合法(SQL 查询可以被执行且正确)。此外,你也可以查看其他参数。例如,如果你关心效率,你可以比较生成查询与标准查询的执行时间。
现在我们已经掌握了基础知识,可以开始实践了。我花了大约20分钟时间准备了一个包含10个示例的集合。尽管规模不大,但这组数据对我们的小任务来说已经足够了。它包括一个包含问题及其对应SQL查询的列表,例如:
[ { "question": "2024年12月有多少客户进行了购买?", "sql_query": "select uniqExact(user_id) as customers from ecommerce.会话记录 where (toStartOfMonth(action_date) = '2024-12-01') and (revenue > 0) format TabSeparatedWithNames" -- "月初" 表示月初 }, { "question": "2023年的欺诈率是多少,以百分比的形式表示?", "sql_query": "select 100*uniqExactIf(user_id, is_fraud = 1)/uniqExact(user_id) as fraud_rate from ecommerce.会话记录 where (toStartOfYear(action_date) = '2023-01-01') format TabSeparatedWithNames" -- "年初" 表示年初 }, ... -- 等等 ]
-- "uniqExact" 表示唯一精确计数
你可以在这里的 GitHub 上找到完整的列表 — [链接]。
我们可以将数据集放到 DataFrame 里,使其可以直接在代码中使用。
导入 json, with open('golden_set.json', 'r') as f: golden_set = json.loads(f.read()) golden_df = pd.DataFrame(golden_set) golden_df['id'] = list(range( golden_df.shape[0] ))
首先,我们为每个评估数据集的问题生成SQL查询。
def 生成查询(question): 提示 = get_llama_prompt(question, generate_query_system_prompt) 结果 = chat_llm.invoke(提示) try: 查询 = 结果.tool_calls[0]['args']['query'] except: 查询 = '' return 查询 import tqdm tmp = [] for item in tqdm.tqdm(golden_df.to_dict('records')): 生成的查询 = 生成查询(item['question']) tmp.append( { 'id': item['id'], 'generated_query': 生成的查询 } ) eval_df = golden_df.merge(pd.DataFrame(tmp))
在开始对查询输出进行评分之前,我们需要确保SQL查询有效。为此,我们需要执行查询并查看数据库的输出结果。
我创建了一个函数来运行ClickHouse中的查询。它还确保输出格式正确,因为这在商业应用中非常重要。
CH_HOST = 'http://localhost:8123' # 默认地址 import requests import io def get_clickhouse_data(query, host = CH_HOST, connection_timeout = 1500): # 确保查询中包含所需的输出格式 if 'format tabseparatedwithnames' not in query.lower(): return "数据库返回了如下错误:\n请在查询中指定输出格式。" r = requests.post(host, params = {'query': query}, timeout = connection_timeout) if r.status_code == 200: return r.text else: return '数据库返回了如下错误:\n' + r.text # 给模型反馈而非抛出异常
接下来要执行生成的查询和参考查询,然后保存它们的输出结果。
tmp = [] # 每个记录都在tqdm进度条中处理 for rec in tqdm.tqdm(eval_df.to_dict('records')): # 获取黄金输出 golden_output = get_clickhouse_data(rec['sql_query']) # 获取生成的输出 generated_output = get_clickhouse_data(rec['generated_query']) # 将结果添加到临时列表中 tmp.append( { 'id': rec['id'], '黄金输出': golden_output, '生成的输出': generated_output } ) # 将临时列表转换为DataFrame并合并到原始DataFrame中 eval_df = eval_df.merge(pd.DataFrame(tmp))
下面,检查结果,看看SQL查询是否正确。
def is_valid_output(s): if s.startswith('Database returned the following error:'): return 'error' if len(s.strip().split('\n')) >= 1000: return '行数过多' return '正常' 数据框['golden_output_valid'] = 数据框.golden_output.map(is_valid_output).apply(lambda x: x) 数据框['generated_output_valid'] = 数据框.generated_output.map(is_valid_output).apply(lambda x: x)
然后,我们可以分别评估标准数据集和生成的数据集的SQL有效性。
初步的结果并不太令人鼓舞;大模型甚至未能生成一个有效的查询语句。从错误来看,很明显,尽管系统提示已经明确指定了格式,模型依然没按要求做。因此,我们肯定要在这方面多下功夫,提高精确度。
然而,仅仅有效是不够的;有效性本身是不够的。我们不仅需要生成有效的SQL查询,还要确保得到正确结果。虽然我们已经确认所有的查询都是无效的,我们现在开始将输出评估纳入流程。
如我们之前讨论的,我们将使用大型语言模型(LLMs)来比较SQL查询的输出。我通常更喜欢使用更强大的模型进行评估,遵循日常逻辑,即由资深团队成员进行工作审查。对于这个任务,我选择了OpenAI GPT-3.5来完成。
就像我们的生成流程一样,我已经准备好了所有必要的模块来评估准确度。
from langchain_openai import ChatOpenAI accuracy_system_prompt = ''' 你是一位资深且非常勤勉的数据集对比专家,你的任务是对比数据集中的数据。 如果它们几乎相同,或者传达相同的信息,则认为它们是相似的。 忽略第一行中指定的列名不同的情况,或者顺序不同。 重点关注实际信息的比较(数值)。如果数据集中的值不同,则意味着它们不相同。 请始终使用提供的工具来提供结果。 ''' @tool def compare_datasets(comments: str, score: int) -> str: """关于数据集对比的1-2句话, 如果数据集中的信息不同,评分为0;如果信息相同,评分为1。 """ pass accuracy_chat_llm = ChatOpenAI(model="gpt-4o-mini", temperature = 0.0)\ .bind_tools([compare_datasets]) accuracy_question_tmp = ''' 以下是需要对比的两个数据集,请用####分隔 数据集#1: #### {dataset1} #### 数据集#2: #### {dataset2} #### ''' def get_openai_prompt(question, system): 消息 = [ ("system", system), ("human", question) ] return 消息
现在,到了检验准确性评估流程了。
prompt = get_openai_prompt(accuracy_question_tmp.format( dataset1 = 'customers\n114032\n', dataset2 = 'customers\n114031\n'), accuracy_system_prompt) accuracy_result = accuracy_chat_llm.invoke(prompt) accuracy_result.tool_calls[0]['args'] # {'comments': '这两个数据集包含不同的客户数:数据集1有114032个客户,数据集2有114031个客户。', # 'score': 0,} prompt = get_openai_prompt(accuracy_question_tmp.format( dataset1 = 'users\n114032\n', dataset2 = 'customers\n114032\n'), accuracy_system_prompt) accuracy_result = accuracy_chat_llm.invoke(prompt) accuracy_result.tool_calls[0]['args'] # {'comments': '尽管列名不同,但这两个数据集虽然数值相同(114032),说明它们传达的信息一致。', # 'score': 1,}
太棒了!看来一切看起来都正常运作。我们现在把它封装成一个函数。
def 判断答案是否准确(output1, output2): prompt = 获取提示( 准确性问题模板字符串.format(dataset1 = output1, dataset2 = output2), 准确性系统提示 ) 结果 = 调用准确性_chat_llm(prompt) try: return 结果['工具调用'][0]['args']['score'] except: return None
如我们所讨论的,构建一个LLM应用是一个需要反复迭代的过程,所以我们需要多次运行准确性的评测。将这些逻辑封装到一个函数中会很有帮助。
这个函数会接收两个输入值:
generate_query_func
:生成SQL查询的函数,根据给定的问题。golden_df
:一个以pandas DataFrame形式的评估数据框,包含问题和正确答案。
输出结果是,该函数将返回一个包含所有评估结果的数据框以及展示主要关键指标的几个图表。
def 评估_sql代理(generate_query_func, gold_df): # 生成SQL查询 tmp = [] for rec in tqdm.tqdm(gold_df.to_dict('records')): 生成的SQL查询 = generate_query_func(rec['question']) tmp.append( { 'id': rec['id'], '生成的SQL查询': 生成的SQL查询 } ) eval_df = gold_df.merge(pd.DataFrame(tmp)) # 执行SQL tmp = [] for rec in tqdm.tqdm(eval_df.to_dict('records')): 原始输出 = get_clickhouse_data(rec['sql_query']) 生成输出 = get_clickhouse_data(rec['生成的SQL查询']) tmp.append( { 'id': rec['id'], '原始输出': 原始输出, '生成输出': 生成输出 } ) eval_df = eval_df.merge(pd.DataFrame(tmp)) # 检查准确性 eval_df['原始输出有效'] = eval_df.原始输出.map(is_valid_output) eval_df['生成输出有效'] = eval_df.生成输出.map(is_valid_output) eval_df['准确输出'] = list(map( is_answer_accurate, eval_df['原始输出'], eval_df['生成输出'] )) eval_df['准确性'] = list(map( lambda x, y: '无效: ' + x if x != '正常' else ('正确' if y == 1 else '错误'), eval_df.生成输出有效, eval_df.准确输出 )) valid_stats_df = (eval_df.groupby('原始输出有效')[['id']].count().rename(columns = {'id': '原始集'}).join( eval_df.groupby('生成输出有效')[['id']].count().rename(columns = {'id': '生成'}), how = 'outer')).fillna(0).T fig1 = px.bar( valid_stats_df.apply(lambda x: 100*x/valid_stats_df.sum(axis = 1)), orientation = 'h', title = '<b>SQL代理评估</b>: 查询的有效性', text_auto = '.1f', color_discrete_map = {'正常': '#00b38a', '错误': '#ea324c', '行数过多': '#f2ac42'}, labels = {'index': '', 'variable': '有效性', 'value': '查询的百分比'} ) fig1.show() accuracy_stats_df = eval_df.groupby('准确性')[['id']].count() accuracy_stats_df['份额'] = accuracy_stats_df.id*100/accuracy_stats_df.id.sum() fig2 = px.bar( accuracy_stats_df[['份额']], title = '<b>SQL代理评估</b>: 查询的准确性', text_auto = '.1f', orientation = 'h', color_discrete_sequence = ['#0077B5'], labels = {'index': '', 'variable': '准确性', 'value': '查询的百分比'} ) fig2.update_layout(showlegend = False) fig2.show() return eval_df
这样一来,我们完成了评估的设置,现在可以开始致力于提高模型的准确性这个核心任务。
让我们快速回顾一下。我们已经完成了SQL Agent的第一个版本的构建和测试。不幸的是,所有的查询都无效,因为缺少了输出格式。我们来解决这个问题吧。
我们可以通过向LLM发起另一次调用来实现这一点,告诉它错误并让它修正这个问题。一个带有自我反省功能的生成函数将有助于处理这种情况。让我们来创建一个这样的函数。
reflection_user_query_tmpl = ''' 你收到了以下问题: "{question}". 你生成了以下SQL查询: "{query}". 然而,数据库返回了如下结果: "{output}". 请修正查询以纠正错误。 ''' def generate_query_reflection(question): generated_query = generate_query(question) print('原始查询:', generated_query) db_output = get_clickhouse_data(generated_query) is_valid_db_output = is_valid_output(db_output) if is_valid_db_output == 'too many rows': db_output = "数据库意外地返回了超过1000行。" if is_valid_db_output == 'ok': return generated_query reflection_user_query = reflection_user_query_tmpl.format( question = question, query = generated_query, output = db_output ) reflection_prompt = get_llama_prompt(reflection_user_query, generate_query_system_prompt) reflection_result = chat_llm.invoke(reflection_prompt) try: reflected_query = reflection_result.tool_calls[0]['args']['query'] except: reflected_query = '' print('修正查询:', reflected_query) return reflected_query
现在,让我们用评估函数来检查质量是否有所提高。评估下一次迭代已经变得容易多了。
refl_eval_df = evaluate_sql_agent(generate_query_reflection, golden_df)
太棒了!我们已经取得了更好的成果——现在50%的查询都有效了,并且所有的格式问题都解决了。看来自我反省确实很有用。
然而,自我反思也有其局限性。然而,当我们检查模型的准确性时,发现它只对一个问题给出了正确答案。因此,我们的旅程还远没有结束。
另一种提高准确性的方法是使用RAG(即检索增强生成)。其想法是找到与客户查询相似的问答对,并将它们包含在系统提示中,从而让大型语言模型(LLM)生成更准确的回答。
RAG 包含以下几个阶段,
- 加载文档: 从可用来源导入数据。
- 拆分文档: 把文档拆成更小的部分。
- 存储: 利用向量存储高效处理和保存数据。
- 检索: 找到和查询相关的文档。
- 生成: 把问题和相关文档给大语言模型,让它生成最终答案。
如果您想了解更多关于 RAG 的内容,可以参考我之前写的文章,“RAG: How to Talk to Your Data.” “RAG: How to Talk to Your Data.”
我们将使用Chroma数据库作为本地向量存储库——用于存储和检索嵌入向量。
from langchain_chroma import Chroma # 从langchain_chroma导入Chroma vector_store = Chroma(embedding_function=embeddings) # 设置向量存储为Chroma,使用embeddings作为嵌入函数
向量存储利用嵌入来查找与查询相似的片段。为此,我们将使用 OpenAI 提供的嵌入。
from langchain_openai import OpenAIEmbeddings embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
由于我们不能使用评估集中的例子(因为这些例子已经被用来评估质量),为此,我为RAG单独创建了一组问题与答案配对。你可以在GitHub上找到这个集合。
接下来,我们加载数据集,创建如下格式的问答对:Question: %s; Answer: %s
。
with open('rag_set.json', 'r') as f: rag_set = json.loads(f.read()) rag_set_df = pd.DataFrame(rag_set) # 定义一个函数来格式化问题和答案 def format_question_answer(x, y): return '问题: %s,回答: %s' % (x, y) rag_set_df['formatted_txt'] = list(map(format_question_answer, rag_set_df.question, rag_set_df.sql_query)) # 将格式化文本合并成一个字符串,每条记录之间用两个换行符分隔 rag_string_data = '\n\n'.join(rag_set_df.formatted_txt)
接下来,我使用了LangChain的文本分割功能,将每一对问答都单独成段。因为是按语义分割,所以无需重叠。
从langchain_text_splitters导入CharacterTextSplitter模块 text_splitter = CharacterTextSplitter( separator="\n\n", chunk_size=1, # 设置chunk_size为1,表示按字符分割,不合并 chunk_overlap=0, length_function=len, is_separator_regex=False, ) texts = text_splitter.create_documents([rag_string_data]) # 创建文档列表
最后一步是把这些块加载到向量数据库中。
# 添加文档并获取文档ID document_ids = vector_store.add_documents(documents=texts) # 打印向量存储集合的数量 print(vector_store._collection.count()) # 输出为 32
现在,我们可以测试一下检索,来看看结果。结果看起来和客户的提问很像。
question = '昨天使用 Windows 的用户占比为多少?' retrieved_docs = vector_store.similarity_search(question, 3) context = "\n\n".join(map(lambda x: x.page_content, retrieved_docs)) print(context) # 相关问题及回答: # 问题:前天使用 Windows 的用户占比为多少? # 回答:从 ecommerce.会话 中选择 100*uniqExactIf(user_id, os = 'Windows')/uniqExact(user_id) 作为 windows_share 其中 action_date 等于 today() 减 2 格式为 TabSeparatedWithNames # 问题:上周使用 Windows 的用户占比为多少? # 回答:从 ecommerce.会话 中选择 100*uniqExactIf(user_id, os = 'Windows')/uniqExact(user_id) 作为 windows_share 其中 action_date 大于等于 today() 减 7 且 action_date 小于 today() 格式为 TabSeparatedWithNames # 问题:昨天使用 Android 的用户占比为多少? # 回答:从 ecommerce.会话 中选择 100*uniqExactIf(user_id, os = 'Android')/uniqExact(user_id) 作为 android_share 其中 action_date 等于 today() 减 1 格式为 TabSeparatedWithNames
让我们把找到的例子加到系统提示里。
generate_query_system_prompt_with_examples_tmpl = ''' 你是一位拥有超过10年经验的数据分析师,擅长编写复杂的SQL查询。 您正在处理的数据库中有两个表,其结构如下。 表:ecommerce.users 描述如下:在线商店的客户 字段如下: - user_id (整数) - 客户的唯一标识符,例如1000004或3000004 - country (字符串) - 居住国家,例如"Netherlands"或"United Kingdom" - is_active (整数) - 如果客户仍处于活跃状态则为1,否则为0 - age (整数) - 客户的年龄,例如31或72 表:ecommerce.sessions 描述如下:在线商店的使用会话 字段如下: - user_id (整数) - 客户的唯一标识符,例如1000004或3000004 - session_id (整数) - 会话ID,例如106或1023 - action_date (日期) - 日期,例如"2021-01-03"或"2024-12-02" - session_duration (整数) - 会话时长(秒),例如125或49 - os (字符串) - 操作系统,例如"Windows"或"Android" - browser (字符串) - 浏览器,例如"Chrome"或"Safari" - is_fraud (整数) - 如果会话被标记为欺诈性则为1,否则为0 - revenue (浮点数) - 收入(美元),例如0.0或1506.7 请使用ClickHouse SQL编写查询来回答以下问题。 在查询的末尾添加"format TabSeparatedWithNames"以从ClickHouse数据库中获取数据的正确格式。 请根据指示回答问题,并提供所有必要的信息,并说明你的理由。 问题和答案示例: {examples} '''
让我们再次用RAG来创建生成查询的函数。
def generate_query_rag(question): # 从向量存储中检索与问题相似的文档 retrieved_docs = vector_store.similarity_search(question, 3) # 将检索到的文档内容合并为一个字符串,每个文档内容之间用两个换行符分隔 context = "\n\n".join(map(lambda x: x.page_content, retrieved_docs)) # 根据问题和上下文生成LLM提示 prompt = get_llama_prompt(question, generate_query_system_prompt_with_examples_tmpl.format(examples = context)) result = chat_llm.invoke(prompt) try: # 从结果中提取生成的查询 generated_query = result.tool_calls[0]['args']['query'] except Exception as e: # 如果出现异常,返回空字符串 generated_query = '' # 返回生成的查询 return generated_query
和往常一样,让我们用我们评估函数试一下新的做法。
rag_eval_df = evaluate_sql_agent(generate_query_rag, golden_df)
这里,rag_eval_df
是通过 evaluate_sql_agent
函数生成的,该函数使用 generate_query_rag
和 golden_df
作为参数。
我们可以看到显著的改进,从10个问题中答对了1题提高到了答对6题。虽然还不尽如人意,但我们正在朝着正确的方向前进。
我们也可以尝试将两种方法结合起来:RAG和自我反思。
def generate_query_rag_with_reflection(question): generated_query = generate_query_rag(question) db_output = get_clickhouse_data(generated_query) is_valid_db_output = is_valid_output(db_output) if is_valid_db_output == 'too many rows': db_output = "数据库返回的行数超过1000,这超出了预料。" 如果 is_valid_db_output 为 'ok': return generated_query reflection_user_query = reflection_user_query_tmpl.format( question = question, query = generated_query, output = db_output ) # 获取用于反射查询的提示 reflection_prompt = get_llama_prompt(reflection_user_query, generate_query_system_prompt) reflection_result = chat_llm.invoke(reflection_prompt) # 尝试获取反射查询 try: reflected_query = reflection_result.tool_calls[0]['args']['query'] except: reflected_query = '' return reflected_query # 评价 SQL 代理 rag_refl_eval_df = evaluate_sql_agent(generate_query_rag_with_reflection, golden_df) # golden_df 表示黄金数据框,用于比较和评估
我们可以看到另一个小小的进步:我们已经完全消除了无效的SQL查询(这要感谢自我检查),并且正确答案的数量增加到了10个中的7个。
就这样了。这确实是一段旅程。我们从零个有效的SQL查询开始,现在已经达到了70%的准确率。
你可以在GitHub上找到完整的代码。
这篇文章中,我们讨论了如何通过迭代改进来提升大规模语言模型应用的准确性。
- 我们建立了一个评估集和评分标准,这使我们能够比较不同迭代,并判断我们是否走在正确的道路上。
- 我们利用自我反思让大模型纠正其错误,并显著减少了无效SQL查询的数量。
- 此外,我们实施了检索增强生成(RAG),进一步提高了质量,准确率达到了60%至70%。
虽然这是一个可靠的结果,但它仍未达到生产应用通常期望的90%以上的准确度阈值。要达到如此高标准,我们需要使用微调,下一篇文章将讨论这个话题。
非常感谢您花时间读这篇文章。希望这篇文章能给您一些启发。如果您有任何问题或想发表评论,请在下面留言。
除非特别说明,所有图片均由作者创作。
本文灵感来自DeepLearning.AI的在线《提升LLM应用的准确性》短期课程。
这篇关于从原型到生产:提升大型语言模型准确性的实战经验的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-20自建AI入门:生成模型介绍——GAN和VAE浅析
- 2024-12-20游戏引擎的进化史——从手工编码到超真实画面和人工智能
- 2024-12-20利用大型语言模型构建文本中的知识图谱:从文本到结构化数据的转换指南
- 2024-12-20揭秘百年人工智能:从深度学习到可解释AI
- 2024-12-20复杂RAG(检索增强生成)的入门介绍
- 2024-12-20基于大型语言模型的积木堆叠任务研究
- 2024-12-20啥是大模型1
- 2024-12-20英特尔的 Lunar Lake 计划:一场未竟的承诺
- 2024-12-20如何在本地使用Phi-4 GGUF模型:快速入门指南
- 2024-12-202025年数据与AI的十大发展趋势