From e80f0ec0587c644b85dfc495be9afb583e7e974b Mon Sep 17 00:00:00 2001 From: Ethan-Zhang Date: Thu, 6 Nov 2025 15:49:37 +0800 Subject: [PATCH] =?UTF-8?q?Fix:=20chat=E6=96=B0=E5=A2=9Erecord=5Fid?= =?UTF-8?q?=E4=BA=8B=E4=BB=B6&=E9=81=BF=E5=85=8DObjectId&=E6=97=A0?= =?UTF-8?q?=E5=BA=94=E7=94=A8=E5=AF=B9=E8=AF=9DFunctionCall=E9=99=8D?= =?UTF-8?q?=E7=BA=A7=E9=A1=BA=E5=BA=8F=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/chat.py | 6 +- apps/scheduler/scheduler/context.py | 15 ++-- apps/services/llm.py | 104 +++++----------------------- apps/services/rag.py | 7 +- 4 files changed, 36 insertions(+), 96 deletions(-) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 80354acf..e52905ad 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -266,7 +266,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) return # 创建新Record,存入数据库 - await save_data(task, user_sub, post_body) + record_id = await save_data(task, user_sub, post_body) if post_body.app and post_body.app.flow_id: await FlowManager.update_flow_debug_by_app_and_flow_id( @@ -275,6 +275,10 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) debug=True, ) + # 发送 record_id 给前端 + if record_id: + yield f"data: [RECORD_ID]{record_id}\n\n" + yield "data: [DONE]\n\n" except Exception: diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 38207736..e444fb91 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -106,8 +106,12 @@ async def generate_facts(task: Task, question: str) -> tuple[Task, list[str]]: return task, facts -async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: - """保存当前Executor、Task、Record等的数据""" +async def save_data(task: Task, user_sub: str, post_body: RequestData) -> str: + """保存当前Executor、Task、Record等的数据 + + Returns: + str: record_id + """ # 构造RecordContent used_docs = [] order_to_id = {} @@ -164,7 +168,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: encrypt_data, encrypt_config = Security.encrypt(record_content.model_dump_json(by_alias=True)) except Exception: logger.exception("[Scheduler] 问答对加密错误") - return + return "" # 保存Flow信息 if task.state: @@ -204,7 +208,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: ) if not record_group_id: logger.error("[Scheduler] 创建问答组失败") - return + return "" else: record_group_id = task.ids.group_id @@ -224,3 +228,6 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: await TaskManager.delete_task_by_task_id(task.id) else: await TaskManager.save_task(task.id, task) + + # 返回 record_id + return task.ids.record_id diff --git a/apps/services/llm.py b/apps/services/llm.py index 70608e06..1d150491 100644 --- a/apps/services/llm.py +++ b/apps/services/llm.py @@ -65,32 +65,15 @@ class LLMManager: :param llm_id: 大模型ID :return: 大模型对象 """ - from bson import ObjectId - llm_collection = MongoDB().get_collection("llm") - # 尝试同时使用字符串和ObjectId查询,以兼容不同的存储格式 - result = None - try: - # 首先尝试作为字符串查询 - result = await llm_collection.find_one({"_id": llm_id}) - - # 如果字符串查询失败,尝试转换为ObjectId查询 - if not result and ObjectId.is_valid(llm_id): - result = await llm_collection.find_one({"_id": ObjectId(llm_id)}) - - except Exception as e: - logger.warning(f"[LLMManager] 查询LLM时发生错误: {e}") + result = await llm_collection.find_one({"_id": llm_id}) if not result: err = f"[LLMManager] LLM {llm_id} 不存在" logger.error(err) raise ValueError(err) - # 将ObjectId转换为字符串,以兼容LLM模型的验证 - if isinstance(result.get("_id"), ObjectId): - result["_id"] = str(result["_id"]) - return LLM.model_validate(result) @staticmethod @@ -102,32 +85,15 @@ class LLMManager: :param llm_id: 大模型ID :return: 大模型对象 """ - from bson import ObjectId - llm_collection = MongoDB().get_collection("llm") - # 尝试同时使用字符串和ObjectId查询,以兼容不同的存储格式 - result = None - try: - # 首先尝试作为字符串查询 - result = await llm_collection.find_one({"_id": llm_id, "user_sub": user_sub}) - - # 如果字符串查询失败,尝试转换为ObjectId查询 - if not result and ObjectId.is_valid(llm_id): - result = await llm_collection.find_one({"_id": ObjectId(llm_id), "user_sub": user_sub}) - - except Exception as e: - logger.warning(f"[LLMManager] 查询LLM时发生错误: {e}") + result = await llm_collection.find_one({"_id": llm_id, "user_sub": user_sub}) if not result: err = f"[LLMManager] LLM {llm_id} 不存在" logger.error(err) raise ValueError(err) - # 将ObjectId转换为字符串,以兼容LLM模型的验证 - if isinstance(result.get("_id"), ObjectId): - result["_id"] = str(result["_id"]) - return LLM.model_validate(result) @staticmethod @@ -178,7 +144,7 @@ class LLMManager: llm_type = [llm_type] llm_item = LLMProviderInfo( - llmId=str(llm["_id"]), # 转换ObjectId为字符串 + llmId=llm["_id"], # _id已经是UUID字符串 icon=llm["icon"], openaiBaseUrl=llm["openai_base_url"], openaiApiKey=llm["openai_api_key"], @@ -436,7 +402,7 @@ class LLMManager: llm_type = [llm_type] llm_item = LLMProviderInfo( - llmId=str(llm["_id"]), # 转换ObjectId为字符串 + llmId=llm["_id"], # _id已经是UUID字符串 icon=llm["icon"], openaiBaseUrl=llm["openai_base_url"], openaiApiKey=llm["openai_api_key"], @@ -475,7 +441,7 @@ class LLMManager: llm_type = [llm_type] llm_item = LLMProviderInfo( - llmId=str(llm["_id"]), # 转换ObjectId为字符串 + llmId=llm["_id"], # _id已经是UUID字符串 icon=llm["icon"], openaiBaseUrl=llm["openai_base_url"], openaiApiKey=llm["openai_api_key"], @@ -510,7 +476,7 @@ class LLMManager: llm_type = [llm_type] llm_item = LLMProviderInfo( - llmId=str(llm["_id"]), # 转换ObjectId为字符串 + llmId=llm["_id"], # _id已经是UUID字符串 icon=llm["icon"], openaiBaseUrl=llm["openai_base_url"], openaiApiKey=llm["openai_api_key"], @@ -549,7 +515,7 @@ class LLMManager: llm_type = [llm_type] llm_item = LLMProviderInfo( - llmId=str(llm["_id"]), # 转换ObjectId为字符串 + llmId=llm["_id"], # _id已经是UUID字符串 icon=llm["icon"], openaiBaseUrl=llm["openai_base_url"], openaiApiKey=llm["openai_api_key"], @@ -596,8 +562,8 @@ class LLMManager: notes=model_info.notes if model_info else "", ) - # 排除_id字段让MongoDB自动生成_id,避免冲突 - insert_data = chat_llm.model_dump(exclude={"_id"}) + # 使用by_alias=True将id字段作为_id插入,保持UUID字符串格式 + insert_data = chat_llm.model_dump(by_alias=True) await llm_collection.insert_one(insert_data) logger.info(f"已初始化系统chat模型: {config.llm.model}") @@ -641,26 +607,10 @@ class LLMManager: notes=model_info.notes if model_info else "", ) - # 使用upsert模式:如果model_name已存在就更新,否则插入 - filter_query = { - "user_sub": "", - "model_name": model_config.model - } - - # 排除id和_id字段以避免MongoDB的不可变_id字段错误 - model_data = system_llm.model_dump(by_alias=True, exclude={"id", "_id"}) - - # 使用update_one替代replace_one,更安全 - result = await llm_collection.update_one( - filter_query, - {"$set": model_data}, - upsert=True - ) - - if result.upserted_id: - logger.info(f"[LLMManager] 创建系统{model_type}模型: {model_config.model}") - else: - logger.info(f"[LLMManager] 更新系统{model_type}模型: {model_config.model}") + # 使用by_alias=True将id字段作为_id插入,保持UUID字符串格式 + insert_data = system_llm.model_dump(by_alias=True) + await llm_collection.insert_one(insert_data) + logger.info(f"[LLMManager] 创建系统{model_type}模型: {model_config.model}") @staticmethod async def get_function_call_model_id(user_sub: str, app_llm_id: str | None = None) -> str | None: @@ -821,36 +771,16 @@ class LLMManager: mongo = MongoDB() llm_collection = mongo.get_collection("llm") - # 尝试将llm_id转换为ObjectId(如果适用) - from bson import ObjectId - from bson.errors import InvalidId - - # 先尝试作为字符串查询,如果失败再尝试ObjectId - try: - # 首先尝试作为字符串ID查询(UUID格式) - result = await llm_collection.find_one({ - "_id": llm_id, - "$or": [{"user_sub": user_sub}, {"user_sub": ""}] - }) - - # 如果没找到且llm_id可以转换为ObjectId,则尝试作为ObjectId查询 - if not result and ObjectId.is_valid(llm_id): - result = await llm_collection.find_one({ - "_id": ObjectId(llm_id), - "$or": [{"user_sub": user_sub}, {"user_sub": ""}] - }) - except InvalidId: - result = None + result = await llm_collection.find_one({ + "_id": llm_id, + "$or": [{"user_sub": user_sub}, {"user_sub": ""}] + }) if not result: err = f"[LLMManager] LLM {llm_id} 不存在或无权限访问" logger.error(err) raise ValueError(err) - # 将ObjectId转换为字符串,以兼容LLM模型的验证 - if isinstance(result.get("_id"), ObjectId): - result["_id"] = str(result["_id"]) - llm = LLM.model_validate(result) # 从注册表获取模型能力 diff --git a/apps/services/rag.py b/apps/services/rag.py index ac9ac0db..881ea491 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -374,13 +374,12 @@ Please generate a detailed, well-structured, and clearly formatted answer based # 用于问题改写的LLM if history: try: - # 获取function call场景使用的模型ID用于问题改写 - # 在RAG场景中,将对话模型ID作为应用配置模型传递(最高优先级) - # 降级顺序:对话模型 -> 用户偏好的function call模型 -> 系统默认function call模型 -> 系统默认chat模型 + # 🔑 对于无应用对话(RAG场景),使用用户preference或系统默认的函数调用模型 + # 降级顺序:用户偏好的function call模型 -> 系统默认function call模型 -> 系统默认chat模型 from apps.services.llm import LLMManager function_call_model_id = await LLMManager.get_function_call_model_id( user_sub, - app_llm_id=llm.id # 传递对话模型ID作为应用配置(最高优先级) + app_llm_id=None # 无应用对话不传递应用模型ID,使用用户preference或系统默认 ) # 如果没有找到函数调用模型,使用对话模型作为降级方案 -- Gitee