From 22e1e5ab73ec89814af417fae3d1cb31569125d9 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 12 Nov 2025 22:08:34 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=97=A0=E7=94=A8?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/executor/step.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 3d7f45596..d5a0d133c 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -239,8 +239,6 @@ class StepExecutor(BaseExecutor): if self.step.step.params and isinstance(self.step.step.params, dict): output_parameters = self.step.step.params.get( "output_parameters", {}) - elif hasattr(self.step, 'output_parameters'): - output_parameters = self.step.output_parameters if not output_parameters or not isinstance(output_parameters, dict): logger.debug( -- Gitee From f937e550a84da677cc48ff500c35a41b561ba98d Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 13 Nov 2025 10:54:16 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E7=AE=80=E5=8C=96flow=E8=BF=90=E8=A1=8C?= =?UTF-8?q?=E5=8F=98=E9=87=8F=E5=AD=98=E5=82=A8=E8=BF=87=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/executor/step.py | 252 ++++--------------------- apps/scheduler/variable/integration.py | 11 +- apps/scheduler/variable/type.py | 33 ++++ 3 files changed, 72 insertions(+), 224 deletions(-) diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index d5a0d133c..0e8ff877b 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator from datetime import UTC, datetime from typing import Any import json - +import jsonschema from pydantic import ConfigDict from apps.scheduler.call.core import CoreCall @@ -233,224 +233,48 @@ class StepExecutor(BaseExecutor): async def _save_output_parameters_to_variables(self, output_data: str | dict[str, Any]) -> None: """保存输出参数到变量池,并进行类型验证""" + output_data_schema = self.node.override_output try: - # 获取当前步骤的output_parameters配置 - output_parameters = None - if self.step.step.params and isinstance(self.step.step.params, dict): - output_parameters = self.step.step.params.get( - "output_parameters", {}) - - if not output_parameters or not isinstance(output_parameters, dict): - logger.debug( - f"[StepExecutor] 步骤 {self.step.step_id} 没有配置output_parameters") - return - - # 解析输出数据 - if isinstance(output_data, str): - try: - data_dict = json.loads(output_data) - except json.JSONDecodeError: - logger.warning( - f"[StepExecutor] 无法解析输出数据为JSON: {output_data}") - data_dict = {"raw_output": output_data} - else: - data_dict = output_data - - # 构造变量名前缀 - var_prefix = f"{self.step.step_id}." - - # 保存每个output_parameter到变量池,并进行类型验证 - saved_count = 0 - failed_params = [] - - # 特殊处理:如果是旧格式的JSON Schema结构(主要是Loop节点) - if (isinstance(output_parameters, dict) and - "type" in output_parameters and - "items" in output_parameters and - isinstance(output_parameters["items"], dict)): - - # 提取items中的真正参数配置 - output_parameters = output_parameters["items"] - - # 清理每个参数配置中的多余字段(如嵌套的items) - for param_name, param_config in output_parameters.items(): - if isinstance(param_config, dict) and "items" in param_config: - # 移除多余的items字段,保持参数配置的简洁性 - param_config.pop("items", None) - - logger.debug( - f"[StepExecutor] 转换后的output_parameters: {output_parameters}") - - for param_name, param_config in output_parameters.items(): - try: - # 检查param_config格式,确保它是字典 - if not isinstance(param_config, dict): - logger.warning( - f"[StepExecutor] 输出参数 {param_name} 的配置不是字典格式: {param_config} (类型: {type(param_config)})") - # 如果不是字典,尝试转换为标准格式 - if isinstance(param_config, str): - param_config = { - "type": param_config, "description": ""} - else: - param_config = { - "type": "string", "description": "", "raw_config": str(param_config)} - - # 获取参数值 - param_value = self._extract_value_from_output_data( - param_name, data_dict, param_config) - - if param_value is not None: - # 获取期望的类型 - raw_expected_type = param_config.get("type", "string") - - # 映射类型到变量系统支持的类型 - type_mapping = { - "integer": "number", # integer 映射到 number - "int": "number", # int 映射到 number - "float": "number", # float 映射到 number - "str": "string", # str 映射到 string - "bool": "boolean", # bool 映射到 boolean - "dict": "object", # dict 映射到 object - } - expected_type = type_mapping.get( - raw_expected_type, raw_expected_type) - if expected_type.lower() == "anyof": - # 处理 anyOf 类型,暂时取第一个类型作为期望类型 - expected_type_list = param_config.get( - "type_list", []) - for i, et in enumerate(expected_type_list): - if et in type_mapping: - expected_type_list[i] = type_mapping[et] - else: - expected_type_list = [expected_type] - value_validated = False - for et in expected_type_list: - if self._validate_output_value_type(param_value, et): - expected_type = et - value_validated = True - break - # 进行类型验证 - if not value_validated: - error_msg = (f"输出参数 '{param_name}' 类型不匹配。" - f"期望: {expected_type}, " - f"实际: {type(param_value).__name__}({param_value})") - logger.error(f"[StepExecutor] {error_msg}") - failed_params.append(f"{param_name}: {error_msg}") - continue - - # 构造变量名 - var_name = f"{var_prefix}{param_name}" - - # 保存到对话变量池 - success = await VariableIntegration.save_conversation_variable( - var_name=var_name, - value=param_value, - var_type=expected_type, - description=param_config.get("description", ""), - user_sub=self.task.ids.user_sub, - # type: ignore[arg-type] - flow_id=self.task.state.flow_id, - conversation_id=self.task.ids.conversation_id - ) - - if success: - saved_count += 1 - logger.debug( - f"[StepExecutor] 已保存输出参数变量: conversation.{var_name} = {param_value}") - else: - error_msg = f"保存输出参数变量失败: {var_name}" - logger.warning(f"[StepExecutor] {error_msg}") - failed_params.append(f"{param_name}: {error_msg}") - - except Exception as e: - error_msg = f"处理输出参数失败: {str(e)}" - logger.warning( - f"[StepExecutor] 保存输出参数 {param_name} 失败: {e}") - failed_params.append(f"{param_name}: {error_msg}") - - # 如果有失败的参数,将步骤状态设置为失败 - if failed_params: - - # type: ignore[assignment] - self.task.state.step_status = StepStatus.ERROR - - failure_msg = f"输出参数类型验证失败:\n" + "\n".join(failed_params) - logger.error( - f"[StepExecutor] 步骤 {self.step.step_id} 执行失败: {failure_msg}") - - # 保存错误信息到任务状态 - if not hasattr(self.task.state, 'error_info') or self.task.state.error_info is None: - self.task.state.error_info = {} - # type: ignore[assignment] - self.task.state.error_info['output_validation_errors'] = failed_params - - # 抛出异常以停止工作流执行 - raise ValueError(f"步骤输出参数类型验证失败: {failure_msg}") - - if saved_count > 0: - logger.info(f"[StepExecutor] 已保存 {saved_count} 个输出参数到变量池") - + jsonschema.validate(instance=output_data, + schema=output_data_schema) except Exception as e: - # 如果是我们主动抛出的验证错误,重新抛出 - if "类型验证失败" in str(e): - raise - logger.error(f"[StepExecutor] 保存输出参数到变量池失败: {e}") - # 对于其他意外错误,也将步骤设置为失败 + logger.error( + f"[StepExecutor] 输出数据不符合output_data_schema: {e}") # type: ignore[assignment] self.task.state.step_status = StepStatus.ERROR - raise - - def _extract_value_from_output_data(self, param_name: str, output_data: dict[str, Any], param_config: dict) -> Any: - """从输出数据中提取参数值""" - # 支持多种提取方式 - - # 1. 直接从输出数据中获取同名key - if param_name in output_data: - return output_data[param_name] - - # 2. 支持路径提取(例如:result.data.value) - if "path" in param_config: - path = param_config["path"] - current_data = output_data - for key in path.split("."): - if isinstance(current_data, dict) and key in current_data: - current_data = current_data[key] - else: - return None - return current_data - - # 3. 支持默认值 - if "default" in param_config: - return param_config["default"] - - # 4. 如果参数配置为"full_output",返回完整输出 - if param_config.get("source") == "full_output": - return output_data - - return None - - def _validate_output_value_type(self, value: Any, expected_type: str) -> bool: - """验证输出值的类型是否符合期望""" - try: - if expected_type == "string": - return isinstance(value, str) - elif expected_type == "number": - return isinstance(value, (int, float)) - elif expected_type == "boolean": - return isinstance(value, bool) - elif expected_type == "array": - return isinstance(value, list) - elif expected_type == "object": - return isinstance(value, dict) - elif expected_type == "any": - return True # 任何类型都匹配 + raise ValueError( + f"输出数据不符合output_data_schema: {e}") from e + + var_prefix = f"{self.step.step_id}" + if not isinstance(output_data_schema, dict): + await VariableIntegration.save_conversation_variable( + var_name=f"{var_prefix}", + value=output_data, + description=f"步骤 {self.step.step_id} 的完整输出", + user_sub=self.task.ids.user_sub, + # type: ignore[arg-type] + conversation_id=self.task.ids.conversation_id + ) + else: + if "items" in output_data_schema: + output_data_schema: dict = output_data_schema["items"] else: - # 对于未知类型,默认接受字符串 - logger.warning( - f"[StepExecutor] 未知的期望类型: {expected_type},默认验证为字符串") - return isinstance(value, str) - except Exception: - return False + raise ValueError( + f"[StepExecutor] 步骤 {self.step.step_id} 的 output_data_schema 格式错误,缺少 items 字段") + for param_name, param_config in output_data_schema.items(): + param_value = None + if param_name in output_data: + param_value = output_data[param_name] + if param_value is not None: + await VariableIntegration.save_conversation_variable( + var_name=f"{var_prefix}.{param_name}", + value=param_value, + description=param_config.get( + "description", f"步骤 {self.step.step_id} 的输出参数 {param_name}"), + user_sub=self.task.ids.user_sub, + # type: ignore[arg-type] + conversation_id=self.task.ids.conversation_id + ) async def run(self) -> None: """运行单个步骤""" diff --git a/apps/scheduler/variable/integration.py b/apps/scheduler/variable/integration.py index bc49a082b..9eb0ae102 100644 --- a/apps/scheduler/variable/integration.py +++ b/apps/scheduler/variable/integration.py @@ -11,7 +11,6 @@ logger = logging.getLogger(__name__) class VariableIntegration: """变量解析集成类 - 为现有调度器提供变量功能""" - @staticmethod async def initialize_system_variables(context: Dict[str, Any]) -> None: """初始化系统变量 @@ -114,10 +113,8 @@ class VariableIntegration: async def save_conversation_variable( var_name: str, value: Any, - var_type: str = "string", description: str = "", user_sub: str = "", - flow_id: Optional[str] = None, conversation_id: Optional[str] = None ) -> bool: """保存对话变量 @@ -125,10 +122,8 @@ class VariableIntegration: Args: var_name: 变量名(不包含scope前缀) value: 变量值 - var_type: 变量类型 description: 变量描述 user_sub: 用户ID - flow_id: 流程ID conversation_id: 对话ID Returns: @@ -148,11 +143,7 @@ class VariableIntegration: return False # 转换变量类型 - try: - var_type_enum = VariableType(var_type) - except ValueError: - var_type_enum = VariableType.STRING - logger.warning(f"未知的变量类型 {var_type},使用默认类型 string") + var_type_enum = VariableType.judge_type_by_value(value) # 尝试更新变量,如果不存在则创建 try: diff --git a/apps/scheduler/variable/type.py b/apps/scheduler/variable/type.py index 2cadaba5e..de0a0e76e 100644 --- a/apps/scheduler/variable/type.py +++ b/apps/scheduler/variable/type.py @@ -1,4 +1,5 @@ from enum import StrEnum +from typing import Any class VariableType(StrEnum): @@ -47,6 +48,38 @@ class VariableType(StrEnum): } return element_type_map.get(self) + @staticmethod + def juge_list_type(value: list[Any]) -> "VariableType": + """判断列表的变量类型""" + if all(isinstance(item, str) for item in value): + return VariableType.ARRAY_STRING + elif all(isinstance(item, (int, float)) for item in value): + return VariableType.ARRAY_NUMBER + elif all(isinstance(item, dict) for item in value): + return VariableType.ARRAY_OBJECT + elif all(isinstance(item, bool) for item in value): + return VariableType.ARRAY_BOOLEAN + elif all(isinstance(item, bytes) for item in value): + return VariableType.ARRAY_FILE + else: + return VariableType.ARRAY_ANY + + @staticmethod + def judge_type_by_value(value: Any) -> "VariableType": + """根据值获取变量类型""" + if isinstance(value, bool): + return VariableType.BOOLEAN + elif isinstance(value, int) or isinstance(value, float): + return VariableType.NUMBER + elif isinstance(value, str): + return VariableType.STRING + elif isinstance(value, list): + return VariableType.juge_list_type(value) + elif isinstance(value, dict): + return VariableType.OBJECT + else: + return VariableType.STRING # 默认使用字符串类型 + class VariableScope(StrEnum): """变量作用域枚举""" -- Gitee From 58f5ce8c8670b3470b039d1d85946c3aa25a80ca Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 13 Nov 2025 11:33:11 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E5=AE=8C=E5=96=84input=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E6=8F=90=E5=8F=96=E7=9A=84=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/executor/step.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 0e8ff877b..0848f97d3 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -122,22 +122,11 @@ class StepExecutor(BaseExecutor): self.node.known_params if self.node and self.node.known_params else {} ) if self.step.step.params: - params.update(self.step.step.params) - - # 对于需要扁平化处理的Call类型,将input_parameters中的内容提取到顶级 - # TODO Call中自带属性区分是否需要扁平化,避免逻辑判断频繁修改,或者修改Code逻辑为统一设计 - if self._call_id not in ["Code"] and "input_parameters" in params: - # 提取input_parameters中的所有字段到顶级 - input_params = params.get("input_parameters", {}) - if isinstance(input_params, dict): - # 将input_parameters中的字段提取到顶级 - for key, value in input_params.items(): - params[key] = value - # 移除input_parameters,避免重复 - params.pop("input_parameters", None) + input_params = self.step.step.params.get("input_parameters", {}) + params.update(input_params) # 对于LLM调用,注入enable_thinking参数 - if self._call_id == "LLM" and hasattr(self.background, 'enable_thinking'): + if self._call_id == "LLM": params['enable_thinking'] = self.background.enable_thinking try: -- Gitee From 4d5cd6067c62dc3fc6a41ec27dbe6c0eade2451c Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 13 Nov 2025 12:20:19 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E5=AE=8C=E5=96=84flow=E6=89=A7=E8=A1=8C?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E7=9B=B8=E5=85=B3=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/summary/summary.py | 20 +++++++--------- apps/scheduler/executor/step.py | 10 ++++---- apps/services/flow.py | 33 +++++++++++++++++++------- 3 files changed, 39 insertions(+), 24 deletions(-) diff --git a/apps/scheduler/call/summary/summary.py b/apps/scheduler/call/summary/summary.py index 39b5a81fa..4e25f610a 100644 --- a/apps/scheduler/call/summary/summary.py +++ b/apps/scheduler/call/summary/summary.py @@ -22,12 +22,12 @@ if TYPE_CHECKING: from apps.scheduler.executor.step import StepExecutor - class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput): """总结工具""" context: ExecutorBackground = Field(description="对话上下文") - llm_id: str | None = Field(default=None, description="大模型ID,如果为None则使用系统默认模型") + llm_id: str | None = Field( + default=None, description="大模型ID,如果为None则使用系统默认模型") enable_thinking: bool = Field(default=False, description="是否启用思维链") i18n_info: ClassVar[dict[str, dict]] = { LanguageType.CHINESE: { @@ -48,7 +48,7 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput): # 提取 llm_id 和 enable_thinking,避免重复传递 llm_id = kwargs.pop("llm_id", None) enable_thinking = kwargs.pop("enable_thinking", False) - + obj = cls( context=executor.background, name=executor.step.step.name, @@ -61,19 +61,18 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput): await obj._set_input(executor) return obj - async def _init(self, call_vars: CallVars) -> DataBase: """初始化工具,返回输入""" return DataBase() - async def _exec( self, _input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE ) -> AsyncGenerator[CallOutputChunk, None]: """执行工具""" import logging logger = logging.getLogger(__name__) - logger.info(f"[Summary] 使用模型ID: {self.llm_id}, 启用思维链: {self.enable_thinking}") + logger.info( + f"[Summary] 使用模型ID: {self.llm_id}, 启用思维链: {self.enable_thinking}") summary_obj = ExecutorSummary( llm_id=self.llm_id, enable_thinking=self.enable_thinking, @@ -82,8 +81,7 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput): self.tokens.input_tokens += summary_obj.input_tokens self.tokens.output_tokens += summary_obj.output_tokens - yield CallOutputChunk(type=CallOutputType.TEXT, content=summary) - + yield CallOutputChunk(type=CallOutputType.DATA, content={"summary": summary}) async def exec( self, @@ -92,10 +90,8 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput): language: LanguageType = LanguageType.CHINESE, ) -> AsyncGenerator[CallOutputChunk, None]: """执行工具""" + content = "" async for chunk in self._exec(input_data, language): - content = chunk.content - if not isinstance(content, str): - err = "[SummaryCall] 工具输出格式错误" - raise TypeError(err) + content = chunk.content.get("summary", "") executor.task.runtime.summary = content yield chunk diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 0848f97d3..ce6ba7ccb 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -123,7 +123,8 @@ class StepExecutor(BaseExecutor): ) if self.step.step.params: input_params = self.step.step.params.get("input_parameters", {}) - params.update(input_params) + if isinstance(input_params, dict): + params.update(input_params) # 对于LLM调用,注入enable_thinking参数 if self._call_id == "LLM": @@ -198,6 +199,7 @@ class StepExecutor(BaseExecutor): content: str | dict[str, Any] = "" async for chunk in iterator: + logging.error(f"StepExecutor接收到chunk: {chunk}") if not isinstance(chunk, CallOutputChunk): err = "[StepExecutor] 返回结果类型错误" logger.error(err) @@ -222,7 +224,7 @@ class StepExecutor(BaseExecutor): async def _save_output_parameters_to_variables(self, output_data: str | dict[str, Any]) -> None: """保存输出参数到变量池,并进行类型验证""" - output_data_schema = self.node.override_output + output_data_schema = self.obj.output_model.model_json_schema() try: jsonschema.validate(instance=output_data, schema=output_data_schema) @@ -245,8 +247,8 @@ class StepExecutor(BaseExecutor): conversation_id=self.task.ids.conversation_id ) else: - if "items" in output_data_schema: - output_data_schema: dict = output_data_schema["items"] + if "properties" in output_data_schema: + output_data_schema: dict = output_data_schema["properties"] else: raise ValueError( f"[StepExecutor] 步骤 {self.step.step_id} 的 output_data_schema 格式错误,缺少 items 字段") diff --git a/apps/services/flow.py b/apps/services/flow.py index 22df1b1d1..997e2fb98 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -459,6 +459,8 @@ class FlowManager: :param flow_item: 流的item :return: 流的id """ + import time + st = time.time() try: app_collection = MongoDB().get_collection("app") app_record = await app_collection.find_one({"_id": app_id}) @@ -468,6 +470,8 @@ class FlowManager: except Exception: logger.exception("[FlowManager] 获取流失败") return None + en = time.time() + logger.info(f"[FlowManager] 获取应用时间: {en-st}s") try: flow_config = Flow( name=flow_item.name, @@ -479,14 +483,15 @@ class FlowManager: connectivity=flow_item.connectivity, debug=flow_item.debug, ) - + st = time.time() # 获取旧的flow配置以便比较节点变化 flow_loader = FlowLoader() old_flow_config = await flow_loader.load(app_id, flow_id) - + en = time.time() + logger.info(f"[FlowManager] 加载旧流配置时间: {en-st}s") # 收集新配置中的所有步骤ID new_step_ids = set() - + st = time.time() for node_item in flow_item.nodes: params = node_item.parameters new_step_ids.add(node_item.step_id) @@ -507,7 +512,9 @@ class FlowManager: service_id=node_item.service_id, plugin_type=node_item.plugin_type, ) - + en = time.time() + logger.info(f"[FlowManager] 处理节点时间: {en-st}s") + st = time.time() # 检查是否有节点被删除,如果有则清理相关变量 if old_flow_config and user_sub: old_step_ids = set(old_flow_config.steps.keys()) @@ -519,7 +526,9 @@ class FlowManager: await FlowManager._cleanup_deleted_node_variables( deleted_step_ids, flow_id, user_sub ) - + en = time.time() + logger.info(f"[FlowManager] 清理删除节点变量时间: {en-st}s") + st = time.time() for edge_item in flow_item.edges: try: edge_from = edge_item.source_node @@ -549,7 +558,9 @@ class FlowManager: logger.error( f"[FlowManager] 创建边失败: {edge_item.edge_id}, 错误: {e}") continue - + en = time.time() + logger.info(f"[FlowManager] 处理边时间: {en-st}s") + st = time.time() # 处理notes for note_item in flow_item.notes: try: @@ -566,10 +577,11 @@ class FlowManager: logger.error( f"[FlowManager] 创建备注失败: {note_item.note_id}, 错误: {e}") continue - + en = time.time() + logger.info(f"[FlowManager] 处理备注时间: {en-st}s") logger.info( f"[FlowManager] 构建完成,flow_config.edges数量: {len(flow_config.edges)}, flow_config.notes数量: {len(flow_config.notes)}") - + st = time.time() if old_flow_config is None: error_msg = f"[FlowManager] 流 {flow_id} 不存在;可能为新创建" logger.error(error_msg) @@ -577,7 +589,12 @@ class FlowManager: flow_config.debug = await FlowManager.is_flow_config_equal(old_flow_config, flow_config) else: flow_config.debug = False + en = time.time() + logger.info(f"[FlowManager] 比较流配置时间: {en-st}s") + st = time.time() await flow_loader.save(app_id, flow_id, flow_config) + en = time.time() + logger.info(f"[FlowManager] 保存流配置时间: {en-st}s") except Exception: logger.exception("[FlowManager] 存储/更新流失败") return None -- Gitee From 66eee246bb11d64402c61e13663d7d33e0c2f067 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 13 Nov 2025 15:45:32 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E5=AE=8C=E5=96=84lancedb=E5=AD=98=E5=82=A8?= =?UTF-8?q?=E6=97=B6=E9=97=B4/=E5=AE=8C=E5=96=84=E9=A3=8E=E9=99=A9?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=E7=9A=84prompt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/lance.py | 40 ++++++----- apps/main.py | 2 +- apps/scheduler/mcp/select.py | 4 +- apps/scheduler/mcp_agent/prompt.py | 49 ++++++++++++-- apps/scheduler/pool/loader/call.py | 18 ++--- apps/scheduler/pool/loader/flow.py | 96 ++++++++++++++++----------- apps/scheduler/pool/loader/mcp.py | 49 ++++++++------ apps/scheduler/pool/loader/service.py | 23 +++---- 8 files changed, 172 insertions(+), 109 deletions(-) diff --git a/apps/common/lance.py b/apps/common/lance.py index 2ead17625..7e08692d3 100644 --- a/apps/common/lance.py +++ b/apps/common/lance.py @@ -2,6 +2,7 @@ """向LanceDB中存储向量化数据""" import lancedb +from lancedb.db import AsyncConnection from lancedb.index import HnswSq from apps.common.config import Config @@ -17,8 +18,10 @@ from apps.schemas.mcp import MCPToolVector, MCPVector class LanceDB(metaclass=SingletonMeta): """LanceDB向量化存储""" + _engine: AsyncConnection | None = None - async def init(self) -> None: + @staticmethod + async def init() -> None: """ 初始化LanceDB @@ -26,44 +29,50 @@ class LanceDB(metaclass=SingletonMeta): :return: 无 """ - self._engine = await lancedb.connect_async( + LanceDB._engine = await lancedb.connect_async( Config().get_config().deploy.data_dir.rstrip("/") + "/vectors", ) # 创建表 - await self._engine.create_table( + await LanceDB._engine.create_table( "flow", schema=FlowPoolVector, exist_ok=True, ) - await self._engine.create_table( + await LanceDB.create_index("flow") + await LanceDB._engine.create_table( "service", schema=ServicePoolVector, exist_ok=True, ) - await self._engine.create_table( + await LanceDB.create_index("service") + await LanceDB._engine.create_table( "call", schema=CallPoolVector, exist_ok=True, ) - await self._engine.create_table( + await LanceDB.create_index("call") + await LanceDB._engine.create_table( "node", schema=NodePoolVector, exist_ok=True, ) - await self._engine.create_table( + await LanceDB.create_index("node") + await LanceDB._engine.create_table( "mcp", schema=MCPVector, exist_ok=True, ) - await self._engine.create_table( + await LanceDB.create_index("mcp") + await LanceDB._engine.create_table( "mcp_tool", schema=MCPToolVector, exist_ok=True, ) + await LanceDB.create_index("mcp_tool") - - async def get_table(self, table_name: str) -> lancedb.AsyncTable: + @staticmethod + async def get_table(table_name: str) -> lancedb.AsyncTable: """ 获取LanceDB中的表 @@ -71,20 +80,17 @@ class LanceDB(metaclass=SingletonMeta): :return: 表 :rtype: lancedb.AsyncTable """ - self._engine = await lancedb.connect_async( - Config().get_config().deploy.data_dir.rstrip("/") + "/vectors", - ) - return await self._engine.open_table(table_name) - + return await LanceDB._engine.open_table(table_name) - async def create_index(self, table_name: str) -> None: + @staticmethod + async def create_index(table_name: str) -> None: """ 创建LanceDB中表的索引;使用HNSW算法 :param str table_name: 表名 :return: 无 """ - table = await self.get_table(table_name) + table = await LanceDB.get_table(table_name) await table.create_index( "embedding", config=HnswSq(), diff --git a/apps/main.py b/apps/main.py index 3c3c0a14f..950789f98 100644 --- a/apps/main.py +++ b/apps/main.py @@ -265,7 +265,7 @@ async def init_resources() -> None: """初始化必要资源""" WordsCheck() - await LanceDB().init() + await LanceDB.init() await Pool.init() TokenCalculator() diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 11ed82ca9..5b846efe4 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -47,7 +47,7 @@ class MCPSelector: ) -> list[dict[str, str]]: """通过向量检索获取Top5 MCP Server""" logger.info("[MCPHelper] 查询MCP Server向量: %s, %s", query, mcp_list) - mcp_table = await LanceDB().get_table("mcp") + mcp_table = await LanceDB.get_table("mcp") query_embedding = await Embedding.get_embedding([query]) mcp_vecs = ( await ( @@ -156,7 +156,7 @@ class MCPSelector: query: str, mcp_list: list[str], top_n: int = 10 ) -> list[MCPTool]: """选择最合适的工具""" - tool_vector = await LanceDB().get_table("mcp_tool") + tool_vector = await LanceDB.get_table("mcp_tool") query_embedding = await Embedding.get_embedding([query]) tool_vecs = await (await tool_vector.search( query=query_embedding, diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index 852046079..8777ec933 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -1398,7 +1398,7 @@ RISK_EVALUATE: dict[LanguageType, str] = { "reason": "提示信息" } ``` - # 样例 + # 样例一 # 工具名称 mysql_analyzer # 工具描述 @@ -1422,10 +1422,28 @@ RISK_EVALUATE: dict[LanguageType, str] = { ```json { "risk": "medium", - "reason": "当前工具将连接到MySQL数据库并分析性能,可能会对数据库性能产生一定影响。请确保在非生产环境中执行此操作。" + "reason": "当前使用的工具名称为mysql_analyzer,工具的入参为 host: 192.0.0.1, port: 3306, username: root, password: password。该工具将连接到MySQL数据库并分析性能,可能会对数据库性能产生一定影响。请确保在非生产环境中执行此操作。" } ``` - # 工具 + # 样例二 + # 工具名称 + shell_command_executor + # 工具描述 + 执行shell命令 + # 工具入参 + { + "command": "rm -rf /" + } + # 附加信息 + 无 + # 输出 + ```json + { + "risk": "high", + "reason": "当前使用的工具名称为shell_command_executor,工具的入参为 command: rm -rf /. 该命令将删除系统根目录下的所有文件,这将导致系统无法正常运行,存在极高的风险。请立即停止执行该操作。" + } + ``` + # 工具名称 {{tool_name}} {{tool_description}} @@ -1447,7 +1465,7 @@ RISK_EVALUATE: dict[LanguageType, str] = { "reason": "prompt message" } ``` - # Example + # Example One # Tool name mysql_analyzer # Tool description @@ -1471,10 +1489,29 @@ RISK_EVALUATE: dict[LanguageType, str] = { ```json { "risk": "medium", - "reason": "This tool will connect to a MySQL database and analyze performance, which may impact database performance. This operation should only be performed in a non-production environment." + "reason": "The current tool being used is mysql_analyzer, with input parameters host: 192.0.0.1, port: 3306, username: root, password: password. This tool will connect to a MySQL database and analyze performance, which may impact database performance. This operation should only be performed in a non-production environment." } ``` - # Tool + # Example Two + # Tool name + shell_command_executor + # Tool description + Executes shell commands + # Tool input + { + "command": "rm -rf /" + } + # Additional information + None + # Output + ```json + { + "risk": "high", + "reason": "The current tool being used is shell_command_executor, with input parameter command: rm -rf /. This command will delete all files in the system root directory, which will cause the system to be unable to function properly and poses an extremely high risk. Please stop this operation immediately." + } + ``` + + # Tool name {{tool_name}} {{tool_description}} diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index 157342875..fd9762527 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -56,7 +56,8 @@ class CallLoader(metaclass=SingletonMeta): call_dir = BASE_PATH / call_dir_name if not (call_dir / "__init__.py").exists(): - logger.info("[CallLoader] 模块 %s 不存在__init__.py文件,尝试自动创建。", call_dir) + logger.info( + "[CallLoader] 模块 %s 不存在__init__.py文件,尝试自动创建。", call_dir) try: (Path(call_dir) / "__init__.py").touch() except Exception as e: @@ -110,7 +111,8 @@ class CallLoader(metaclass=SingletonMeta): try: sys.path.insert(0, str(call_dir.parent)) if not (call_dir / "__init__.py").exists(): - logger.info("[CallLoader] 父模块 %s 不存在__init__.py文件,尝试自动创建。", call_dir) + logger.info( + "[CallLoader] 父模块 %s 不存在__init__.py文件,尝试自动创建。", call_dir) (Path(call_dir) / "__init__.py").touch() importlib.import_module("call") except Exception as e: @@ -160,7 +162,7 @@ class CallLoader(metaclass=SingletonMeta): # 从LanceDB中删除 while True: try: - table = await LanceDB().get_table("call") + table = await LanceDB.get_table("call") await table.delete(f"id = '{call_name}'") break except RuntimeError as e: @@ -170,8 +172,8 @@ class CallLoader(metaclass=SingletonMeta): else: raise - # 更新数据库 + async def _add_to_db(self, call_metadata: list[CallPool]) -> None: # noqa: C901 """更新数据库""" # 更新MongoDB @@ -202,7 +204,7 @@ class CallLoader(metaclass=SingletonMeta): while True: try: - table = await LanceDB().get_table("call") + table = await LanceDB.get_table("call") # 删除重复的ID for call in call_metadata: await table.delete(f"id = '{call.id}'") @@ -226,10 +228,8 @@ class CallLoader(metaclass=SingletonMeta): ) while True: try: - table = await LanceDB().get_table("call") - await table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute( - vector_data, - ) + table = await LanceDB.get_table("call") + await table.add(vector_data) break except RuntimeError as e: if "Commit conflict" in str(e): diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index 975c7cea4..724375b84 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -11,7 +11,7 @@ import yaml from anyio import Path from apps.common.config import Config -from apps.schemas.enum_var import NodeType,EdgeType +from apps.schemas.enum_var import NodeType, EdgeType from apps.schemas.flow import AppFlow, Flow from apps.schemas.pool import AppPool from apps.models.vector import FlowPoolVector @@ -25,6 +25,7 @@ from apps.schemas.subflow import AppSubFlow logger = logging.getLogger(__name__) BASE_PATH = Path(Config().get_config().deploy.data_dir) / "semantics" / "app" + class FlowLoader: """工作流加载器""" @@ -82,12 +83,13 @@ class FlowLoader: err = f"[FlowLoader] 步骤名称不能以下划线开头:{key}" logger.error(err) raise ValueError(err) - if step["type"]==NodeType.START.value or step["type"]==NodeType.END.value: + if step["type"] == NodeType.START.value or step["type"] == NodeType.END.value: continue try: step["type"] = await NodeManager.get_node_call_id(step["node"]) except ValueError as e: - logger.warning("[FlowLoader] 获取节点call_id失败:%s,错误信息:%s", step["node"], e) + logger.warning( + "[FlowLoader] 获取节点call_id失败:%s,错误信息:%s", step["node"], e) step["type"] = "Empty" step["name"] = ( (await NodeManager.get_node_name(step["node"])) @@ -99,13 +101,13 @@ class FlowLoader: async def load(self, app_id: str, flow_id: str) -> Flow | None: """从文件系统中加载【单个】工作流""" flow_key = f"{app_id}:{flow_id}" - + # 第一次检查:是否已在加载中 existing_task = None async with self._loading_lock: if flow_key in self._loading_flows: existing_task = self._loading_flows[flow_key] - + # 如果找到现有任务,等待其完成 if existing_task is not None: logger.info(f"[FlowLoader] 工作流正在加载中,等待完成: {flow_key}") @@ -118,7 +120,7 @@ class FlowLoader: if self._loading_flows.get(flow_key) == existing_task: self._loading_flows.pop(flow_key, None) return None - + # 创建新的加载任务 task = None async with self._loading_lock: @@ -130,13 +132,14 @@ class FlowLoader: try: return await existing_task except Exception as e: - logger.error(f"[FlowLoader] 等待工作流加载失败: {flow_key}, 错误: {e}") + logger.error( + f"[FlowLoader] 等待工作流加载失败: {flow_key}, 错误: {e}") return None - + # 创建新的加载任务 task = asyncio.create_task(self._do_load(app_id, flow_id)) self._loading_flows[flow_key] = task - + # 执行加载任务 try: result = await task @@ -176,6 +179,8 @@ class FlowLoader: if not flow_yaml: return None flow_config = Flow.model_validate(flow_yaml) + import time + st = time.time() await self._update_db( app_id, AppFlow( @@ -187,9 +192,12 @@ class FlowLoader: debug=flow_config.debug, ), ) + en = time.time() + logger.info(f"[FlowLoader] 更新数据库耗时: {en-st} 秒") return Flow.model_validate(flow_yaml) except Exception: - logger.exception("[FlowLoader] 应用 %s:工作流 %s 格式不合法", app_id, flow_id) + logger.exception( + "[FlowLoader] 应用 %s:工作流 %s 格式不合法", app_id, flow_id) return None async def save(self, app_id: str, flow_id: str, flow: Flow) -> None: @@ -233,7 +241,7 @@ class FlowLoader: logger.exception("[FlowLoader] 删除工作流文件失败:%s", flow_path) return False - table = await LanceDB().get_table("flow") + table = await LanceDB.get_table("flow") try: await table.delete(f"id = '{flow_id}'") except Exception: @@ -281,7 +289,8 @@ class FlowLoader: await app_collection.aggregate( [ {"$match": {"_id": app_id}}, - {"$replaceWith": {"$setField": {"field": key, "input": "$$ROOT", "value": new_hash}}}, + {"$replaceWith": {"$setField": {"field": key, + "input": "$$ROOT", "value": new_hash}}}, ], ) except Exception: @@ -290,9 +299,11 @@ class FlowLoader: # 删除重复的ID,增加重试次数限制 max_retries = 10 retry_count = 0 + import time + st = time.time() while retry_count < max_retries: try: - table = await LanceDB().get_table("flow") + table = await LanceDB.get_table("flow") await table.delete(f"id = '{metadata.id}'") break except RuntimeError as e: @@ -306,9 +317,11 @@ class FlowLoader: except Exception as e: logger.error(f"[FlowLoader] LanceDB删除操作异常: {e}") break - + en = time.time() + logger.error(f"[FlowLoader] LanceDB删除flow耗时: {en-st} 秒") if retry_count >= max_retries: - logger.warning(f"[FlowLoader] LanceDB删除flow达到最大重试次数,跳过删除: {metadata.id}") + logger.warning( + f"[FlowLoader] LanceDB删除flow达到最大重试次数,跳过删除: {metadata.id}") # 不抛出异常,继续执行后续操作 # 进行向量化 service_embedding = await Embedding.get_embedding([metadata.description]) @@ -319,15 +332,14 @@ class FlowLoader: embedding=service_embedding[0], ), ] + st = time.time() # 插入向量数据,增加重试次数限制 max_retries_insert = 10 retry_count_insert = 0 while retry_count_insert < max_retries_insert: try: - table = await LanceDB().get_table("flow") - await table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute( - vector_data, - ) + table = await LanceDB.get_table("flow") + await table.add(vector_data) break except RuntimeError as e: if "Commit conflict" in str(e): @@ -340,15 +352,18 @@ class FlowLoader: except Exception as e: logger.error(f"[FlowLoader] LanceDB插入操作异常: {e}") break - + en = time.time() + logger.error(f"[FlowLoader] LanceDB插入flow耗时: {en-st} 秒") if retry_count_insert >= max_retries_insert: - logger.error(f"[FlowLoader] LanceDB插入flow达到最大重试次数,操作失败: {metadata.id}") + logger.error( + f"[FlowLoader] LanceDB插入flow达到最大重试次数,操作失败: {metadata.id}") raise RuntimeError(f"LanceDB插入flow失败,达到最大重试次数: {metadata.id}") async def save_subflow(self, app_id: str, flow_id: str, sub_flow_id: str, flow: Flow) -> None: """保存子工作流到层次化路径""" # 子工作流路径: {app_id}/flow/{flow_id}/subflow/{sub_flow_id}.yaml - subflow_path = BASE_PATH / app_id / "flow" / flow_id / "subflow" / f"{sub_flow_id}.yaml" + subflow_path = BASE_PATH / app_id / "flow" / \ + flow_id / "subflow" / f"{sub_flow_id}.yaml" if not await subflow_path.parent.exists(): await subflow_path.parent.mkdir(parents=True, exist_ok=True) @@ -363,7 +378,7 @@ class FlowLoader: sort_keys=False, ), ) - + # 更新数据库中的子工作流元数据 await self._update_subflow_db( app_id, @@ -379,12 +394,13 @@ class FlowLoader: async def load_subflow(self, app_id: str, flow_id: str, sub_flow_id: str) -> Flow | None: """加载子工作流""" - subflow_path = BASE_PATH / app_id / "flow" / flow_id / "subflow" / f"{sub_flow_id}.yaml" - + subflow_path = BASE_PATH / app_id / "flow" / \ + flow_id / "subflow" / f"{sub_flow_id}.yaml" + if not await subflow_path.exists(): logger.warning("[FlowLoader] 子工作流文件不存在: %s", subflow_path) return None - + try: async with aiofiles.open(subflow_path, mode="r", encoding="utf-8") as f: content = await f.read() @@ -396,17 +412,18 @@ class FlowLoader: async def delete_subflow(self, app_id: str, flow_id: str, sub_flow_id: str) -> bool: """删除子工作流文件""" - subflow_path = BASE_PATH / app_id / "flow" / flow_id / "subflow" / f"{sub_flow_id}.yaml" - + subflow_path = BASE_PATH / app_id / "flow" / \ + flow_id / "subflow" / f"{sub_flow_id}.yaml" + if await subflow_path.exists(): try: await subflow_path.unlink() logger.info("[FlowLoader] 成功删除子工作流文件:%s", subflow_path) - + # 从数据库中删除子工作流元数据 await self._delete_subflow_db(app_id, flow_id, sub_flow_id) return True - + except Exception: logger.exception("[FlowLoader] 删除子工作流文件失败:%s", subflow_path) return False @@ -418,19 +435,19 @@ class FlowLoader: """更新数据库中的子工作流元数据""" try: app_collection = MongoDB().get_collection("app") - + # 查找应用 app_record = await app_collection.find_one({"_id": app_id}) if not app_record: logger.error("[FlowLoader] 应用不存在: %s", app_id) return - + # 确保子工作流元数据结构存在 if "subflows" not in app_record: app_record["subflows"] = {} if flow_id not in app_record["subflows"]: app_record["subflows"][flow_id] = [] - + # 更新或添加子工作流元数据 subflows = app_record["subflows"][flow_id] existing_index = None @@ -438,19 +455,20 @@ class FlowLoader: if subflow.get("id") == metadata.id: existing_index = i break - - subflow_data = metadata.model_dump(by_alias=True, exclude_none=True) + + subflow_data = metadata.model_dump( + by_alias=True, exclude_none=True) if existing_index is not None: subflows[existing_index] = subflow_data else: subflows.append(subflow_data) - + # 保存到数据库 await app_collection.update_one( {"_id": app_id}, {"$set": {f"subflows.{flow_id}": subflows}} ) - + except Exception: logger.exception("[FlowLoader] 更新子工作流数据库元数据失败") @@ -458,12 +476,12 @@ class FlowLoader: """从数据库中删除子工作流元数据""" try: app_collection = MongoDB().get_collection("app") - + # 从应用的子工作流列表中移除 await app_collection.update_one( {"_id": app_id}, {"$pull": {f"subflows.{flow_id}": {"id": sub_flow_id}}} ) - + except Exception: logger.exception("[FlowLoader] 删除子工作流数据库元数据失败") diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 714a1cf2d..fb6a24cdd 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -112,7 +112,8 @@ class MCPLoader(metaclass=SingletonMeta): # 重新保存config template_config = MCP_PATH / "template" / mcp_id / "config.json" f = await template_config.open("w+", encoding="utf-8") - config_data = config.model_dump(by_alias=True, exclude_none=True) + config_data = config.model_dump( + by_alias=True, exclude_none=True) await f.write(json.dumps(config_data, indent=4, ensure_ascii=False)) await f.aclose() @@ -137,13 +138,15 @@ class MCPLoader(metaclass=SingletonMeta): mcp_ids = ProcessHandler.get_all_task_ids() # 检索_id在mcp_ids且状态为ready或者failed的MCP的内容 db_service_list = await mcp_collection.find( - {"_id": {"$in": mcp_ids}, "status": {"$in": [MCPInstallStatus.READY, MCPInstallStatus.FAILED]}}, + {"_id": {"$in": mcp_ids}, "status": { + "$in": [MCPInstallStatus.READY, MCPInstallStatus.FAILED]}}, ).to_list(None) for db_service in db_service_list: try: item = MCPCollection.model_validate(db_service) except Exception as e: - logger.error("[MCPLoader] MCP模板数据验证失败: %s, 错误: %s", db_service["_id"], e) + logger.error("[MCPLoader] MCP模板数据验证失败: %s, 错误: %s", + db_service["_id"], e) continue ProcessHandler.remove_task(item.id) logger.info("[MCPLoader] 删除已完成或失败的MCP安装进程: %s", item.id) @@ -211,7 +214,8 @@ class MCPLoader(metaclass=SingletonMeta): """ # 创建客户端 if ( - (config.type == MCPType.STDIO and isinstance(config.config, MCPServerStdioConfig)) + (config.type == MCPType.STDIO and isinstance( + config.config, MCPServerStdioConfig)) or (config.type == MCPType.SSE and isinstance(config.config, MCPServerSSEConfig)) ): client = MCPClient() @@ -240,7 +244,8 @@ class MCPLoader(metaclass=SingletonMeta): description=item.description or "", input_schema=item.inputSchema, )] - logger.info("[MCPLoader] MCP %s 成功获取 %d 个工具", mcp_id, len(tool_list)) + logger.info("[MCPLoader] MCP %s 成功获取 %d 个工具", + mcp_id, len(tool_list)) except Exception as e: logger.error("[MCPLoader] MCP %s 获取工具列表失败: %s", mcp_id, e) raise ValueError(f"MCP {mcp_id} 获取工具列表失败: {e}") @@ -250,10 +255,11 @@ class MCPLoader(metaclass=SingletonMeta): try: await client.stop() except Exception as e: - logger.warning("[MCPLoader] MCP %s 停止客户端时发生异常: %s", mcp_id, e) + logger.warning( + "[MCPLoader] MCP %s 停止客户端时发生异常: %s", mcp_id, e) else: logger.warning("[MCPLoader] MCP %s 客户端没有stop方法", mcp_id) - + return tool_list @staticmethod @@ -310,13 +316,13 @@ class MCPLoader(metaclass=SingletonMeta): while True: try: - mcp_table = await LanceDB().get_table("mcp") - await mcp_table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute([ - MCPVector( + mcp_table = await LanceDB.get_table("mcp") + await mcp_table.add( + [MCPVector( id=mcp_id, embedding=embedding[0], - ), - ]) + )] + ) break except Exception as e: if "Commit conflict" in str(e): @@ -331,16 +337,14 @@ class MCPLoader(metaclass=SingletonMeta): for tool, embedding in zip(tool_list, tool_embedding, strict=True): while True: try: - mcp_tool_table = await LanceDB().get_table("mcp_tool") - await mcp_tool_table.merge_insert( - "id", - ).when_matched_update_all().when_not_matched_insert_all().execute([ - MCPToolVector( + mcp_tool_table = await LanceDB.get_table("mcp_tool") + await mcp_tool_table.add( + [MCPToolVector( id=tool.id, mcp_id=mcp_id, embedding=embedding, - ), - ]) + )] + ) break except Exception as e: if "Commit conflict" in str(e): @@ -348,7 +352,7 @@ class MCPLoader(metaclass=SingletonMeta): await asyncio.sleep(0.01) else: raise - await LanceDB().create_index("mcp_tool") + await LanceDB.create_index("mcp_tool") @staticmethod async def save_one(mcp_id: str, config: MCPServerConfig) -> None: @@ -468,7 +472,8 @@ class MCPLoader(metaclass=SingletonMeta): else: mcp_config.config.args.append(str(user_path)+'/project') else: - mcp_config.config.args = ["--directory", str(user_path)+'/project'] + mcp_config.config.args + mcp_config.config.args = [ + "--directory", str(user_path)+'/project'] + mcp_config.config.args user_config_path = user_path / "config.json" # 更新用户配置 f = await user_config_path.open("w", encoding="utf-8", errors="ignore") @@ -575,7 +580,7 @@ class MCPLoader(metaclass=SingletonMeta): for mcp_id in deleted_mcp_list: while True: try: - mcp_table = await LanceDB().get_table("mcp") + mcp_table = await LanceDB.get_table("mcp") await mcp_table.delete(f"id == '{mcp_id}'") break except Exception as e: diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index 2d84069c4..cf63aa27e 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -21,7 +21,8 @@ from apps.scheduler.pool.loader.metadata import MetadataLoader, MetadataType from apps.scheduler.pool.loader.openapi import OpenAPILoader logger = logging.getLogger(__name__) -BASE_PATH = Path(Config().get_config().deploy.data_dir) / "semantics" / "service" +BASE_PATH = Path(Config().get_config().deploy.data_dir) / \ + "semantics" / "service" class ServiceLoader: @@ -83,8 +84,8 @@ class ServiceLoader: try: # 获取 LanceDB 表 - service_table = await LanceDB().get_table("service") - node_table = await LanceDB().get_table("node") + service_table = await LanceDB.get_table("service") + node_table = await LanceDB.get_table("node") # 删除数据 await service_table.delete(f"id = '{service_id}'") @@ -137,8 +138,8 @@ class ServiceLoader: # 向量化所有数据并保存 while True: try: - service_table = await LanceDB().get_table("service") - node_table = await LanceDB().get_table("node") + service_table = await LanceDB.get_table("service") + node_table = await LanceDB.get_table("node") await service_table.delete(f"id = '{metadata.id}'") await node_table.delete(f"service_id = '{metadata.id}'") break @@ -159,10 +160,8 @@ class ServiceLoader: ] while True: try: - service_table = await LanceDB().get_table("service") - await service_table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute( - service_vector_data, - ) + service_table = await LanceDB.get_table("service") + await service_table.add(service_vector_data) break except Exception as e: if "Commit conflict" in str(e): @@ -187,10 +186,8 @@ class ServiceLoader: ) while True: try: - node_table = await LanceDB().get_table("node") - await node_table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute( - node_vector_data, - ) + node_table = await LanceDB.get_table("node") + await node_table.add(node_vector_data) break except Exception as e: if "Commit conflict" in str(e): -- Gitee