diff --git a/apps/common/postgres.py b/apps/common/postgres.py index ff2d03ed41c6971579728435f0ee9b354c2dde26..60bbc604adb8aada449452fb9b50f282d17572a9 100644 --- a/apps/common/postgres.py +++ b/apps/common/postgres.py @@ -60,7 +60,7 @@ class Postgres: try: yield session except Exception: - logger.exception("[Postgres] 会话错误") + logger.warning("[Postgres] 会话错误,可能为SQL执行失败") # 发生异常时回滚 await session.rollback() raise diff --git a/apps/common/process_handler.py b/apps/common/process_handler.py index 01ea94027c9d7a9cc8afc33589292665eb65f4f0..b4bddf2c6c585cd8981c53b2e45f3aad40c26ebf 100644 --- a/apps/common/process_handler.py +++ b/apps/common/process_handler.py @@ -30,7 +30,17 @@ class ProcessHandler: """子进程目标函数""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - loop.run_until_complete(target(*args, **kwargs)) + try: + loop.run_until_complete(target(*args, **kwargs)) + finally: + # 等待所有pending tasks完成 + pending = asyncio.all_tasks(loop) + for task in pending: + task.cancel() + if pending: + loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() @staticmethod def get_all_task_ids() -> list[str]: diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index 5401531abed4f0e4c0c39bf7eadbe24e78a6dfbc..e32642821c217a8a84da066fc68a222d4fec8d21 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -4,6 +4,8 @@ from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any, Self +from jinja2.loaders import BaseLoader +from jinja2.sandbox import SandboxedEnvironment from pydantic import Field from apps.llm import json_generator @@ -11,6 +13,7 @@ from apps.models import LanguageType, NodeInfo from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars +from apps.services.tag import TagManager from apps.services.user_tag import UserTagManager from .prompt import DOMAIN_FUNCTION, DOMAIN_PROMPT, FACTS_FUNCTION, FACTS_PROMPT @@ -93,8 +96,19 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): ) facts_obj = FactsGen.model_validate(facts_result) + # 获取所有标签 + all_tags = await TagManager.get_all_tag() + tag_names = [tag.name for tag in all_tags] + + # 使用jinja2渲染DOMAIN_PROMPT + jinja_env = SandboxedEnvironment( + loader=BaseLoader(), + autoescape=False, + ) + domain_prompt_template = jinja_env.from_string(DOMAIN_PROMPT[self._sys_vars.language]) + domain_prompt = domain_prompt_template.render(available_keywords=tag_names) + # 组装conversation消息 - domain_prompt = DOMAIN_PROMPT[self._sys_vars.language] domain_conversation = [ {"role": "system", "content": "You are a helpful assistant."}, *data.message, @@ -110,7 +124,12 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): domain_list = DomainGen.model_validate(domain_result) for domain in domain_list.keywords: - await UserTagManager.update_user_domain_by_user_and_domain_name(data.user_id, domain) + # 先检查标签是否存在 + tag = await TagManager.get_tag_by_name(domain) + if tag: + # 标签存在,更新用户标签 + await UserTagManager.update_user_domain_by_user_and_domain_name(data.user_id, domain) + # 标签不存在,跳过处理 yield CallOutputChunk( type=CallOutputType.DATA, diff --git a/apps/scheduler/call/facts/prompt.py b/apps/scheduler/call/facts/prompt.py index 22fab49f8e68479e699c2463134fe937645ade11..a2420f87e951a3c14d4d4cb2ac277cead093f953 100644 --- a/apps/scheduler/call/facts/prompt.py +++ b/apps/scheduler/call/facts/prompt.py @@ -10,61 +10,73 @@ DOMAIN_PROMPT: dict[LanguageType, str] = { LanguageType.CHINESE: dedent( r""" # 任务说明 - 根据对话历史,提取推荐系统所需的关键词标签。这些标签将用于内容推荐、用户画像构建和个性化服务。 + 根据对话历史,从下面的"备选提示词列表"中选择最合适的标签。这些标签将用于内容推荐、用户画像构建和个性化服务。 - ## 提取要求 + ## 备选提示词列表 + {{ available_keywords }} + + ## 选择要求 - 1. **关键词类型**:可以是实体名词(人名、地名、组织名)、技术术语、产品名称、时间范围、领域概念等 - 2. **话题相关性**:至少提取一个与对话主题直接相关的关键词 - 3. **质量标准**: - - 标签应精准且简洁,每个标签不超过10个字 - - 避免重复或高度相似的标签 - - 优先提取具有区分度的关键词 - - 提取3-8个关键词为宜 - 4. **输出格式**:返回JSON对象,包含keywords字段,值为字符串数组 + 1. **精准匹配**:只能从备选提示词列表中选择,不要自创新标签 + 2. **话题相关性**:选择与对话主题直接相关的标签 + 3. **数量控制**:选择3-8个最相关的标签 + 4. **质量标准**: + - 避免选择重复或高度相似的标签 + - 优先选择具有区分度的标签 + - 按相关性从高到低排序 + 5. **输出格式**:返回JSON对象,包含keywords字段,值为字符串数组 ## 示例 + 假设备选提示词列表包含: + ["北京", "上海", "天气", "气温", "Python", "Java", "装饰器", "设计模式", "餐厅", "美食"] + **示例1:天气查询** - 用户:"北京天气如何?" - 助手:"北京今天晴。" - - 提取结果:["北京", "天气"] + - 选择结果:["北京", "天气", "气温"] - **示例2:技术讨论** - - 用户:"介绍一下Python的装饰器" - - 助手:"Python装饰器是一种设计模式。" - - 提取结果:["Python", "装饰器", "设计模式"] + **示例2:如果对话内容与备选列表无关** + - 用户:"今天心情不错" + - 助手:"很高兴听到这个消息。" + - 选择结果:[](如果备选列表中没有相关标签,返回空数组) """, ), LanguageType.ENGLISH: dedent( r""" # Task Description - Extract keyword tags for the recommendation system based on conversation history. These tags will be used \ -for content recommendation, user profiling, and personalized services. + Based on conversation history, select the most appropriate tags from the "Available Keywords List" below. \ +These tags will be used for content recommendation, user profiling, and personalized services. - ## Extraction Requirements + ## Available Keywords List + {available_keywords} + + ## Selection Requirements - 1. **Keyword Types**: Can be entity nouns (names, locations, organizations), technical terms, \ -product names, time ranges, domain concepts, etc. - 2. **Topic Relevance**: Extract at least one keyword directly related to the conversation topic - 3. **Quality Standards**: - - Tags should be precise and concise, each tag not exceeding 10 characters - - Avoid duplicate or highly similar tags - - Prioritize extracting distinctive keywords - - Extract 3-8 keywords as appropriate - 4. **Output Format**: Return JSON object containing keywords field with string array value + 1. **Exact Match**: Only select from the available keywords list, do not create new tags + 2. **Topic Relevance**: Select tags directly related to the conversation topic + 3. **Quantity Control**: Select 3-8 most relevant tags + 4. **Quality Standards**: + - Avoid selecting duplicate or highly similar tags + - Prioritize selecting distinctive tags + - Sort by relevance from high to low + 5. **Output Format**: Return JSON object containing keywords field with string array value ## Examples + Assume the available keywords list contains: + ["Beijing", "Shanghai", "weather", "temperature", "Python", "Java", "decorator", "design pattern", \ +"restaurant", "food"] + **Example 1: Weather Query** - User: "What's the weather like in Beijing?" - Assistant: "Beijing is sunny today." - - Extraction result: ["Beijing", "weather"] + - Selection result: ["Beijing", "weather", "temperature"] - **Example 2: Technical Discussion** - - User: "Tell me about Python decorators" - - Assistant: "Python decorators are a design pattern." - - Extraction result: ["Python", "decorator", "design pattern"] + **Example 2: If conversation content is unrelated to the available list** + - User: "I'm feeling good today" + - Assistant: "Glad to hear that." + - Selection result: [] (return empty array if no relevant tags in the available list) """, ), } diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 5f51166d7a2c3074d38dacddb910567c217e5732..0948259fd4e2081a7f78444835d98f7afd08cdee 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -10,7 +10,7 @@ import asyncer from anyio import Path from sqlalchemy import and_, delete, select, update -from apps.common.postgres import postgres +from apps.common.postgres import Postgres, postgres from apps.common.process_handler import ProcessHandler from apps.constants import MCP_PATH from apps.llm.embedding import embedding @@ -80,6 +80,10 @@ class MCPLoader: :param MCPServerConfig config: MCP配置 :return: 无 """ + # 直接初始化PostgreSQL实例,不使用单例 + pg = Postgres() + await pg.init() + mcp_id = next(iter(item.mcpServers.keys())) mcp_config = item.mcpServers[mcp_id] @@ -96,7 +100,7 @@ class MCPLoader: if mcp_config is None: logger.error("[MCPLoader] MCP模板安装失败: %s", mcp_id) - await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED) + await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED, pg) return item.mcpServers[mcp_id] = mcp_config @@ -111,13 +115,15 @@ class MCPLoader: logger.info("[Installer] SSE/StreamableHTTP方式的MCP模板,无需安装: %s", mcp_id) item.mcpServers[mcp_id].autoInstall = False - await MCPLoader._insert_template_tool(mcp_id, item) - await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.READY) + await MCPLoader._insert_template_tool(mcp_id, item, pg) + await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.READY, pg) logger.info("[Installer] MCP模板安装成功: %s", mcp_id) except Exception: logger.exception("[MCPLoader] MCP模板安装失败: %s", mcp_id) - await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED) + await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED, pg) raise + finally: + await pg.close() @staticmethod @@ -166,7 +172,7 @@ class MCPLoader: await Path.mkdir(template_path, parents=True, exist_ok=True) # 安装MCP模板 - if not ProcessHandler.add_task(mcp_id, MCPLoader._install_template_task, mcp_id, config): + if not ProcessHandler.add_task(mcp_id, MCPLoader._install_template_task, config): err = f"安装任务无法执行,请稍后重试: {mcp_id}" logger.error(err) raise RuntimeError(err) @@ -309,19 +315,20 @@ class MCPLoader: @staticmethod - async def _insert_template_tool(mcp_id: str, config: MCPServerConfig) -> None: + async def _insert_template_tool(mcp_id: str, config: MCPServerConfig, pg: Postgres = postgres) -> None: """ 插入单个MCP Server工具信息到数据库 :param str mcp_id: MCP模板ID :param MCPServerSSEConfig | MCPServerStdioConfig config: MCP配置 + :param Postgres pg: PostgreSQL实例,默认为全局单例postgres :return: 无 """ # 获取工具列表 tool_list = await MCPLoader._get_template_tool(mcp_id, config) # 基本信息插入数据库 - async with postgres.session() as session: + async with pg.session() as session: # 删除旧的工具 await session.execute(delete(MCPTools).where(MCPTools.mcpId == mcp_id)) # 插入新的工具 @@ -571,20 +578,22 @@ class MCPLoader: @staticmethod - async def update_template_status(mcp_id: str, status: MCPInstallStatus) -> None: + async def update_template_status(mcp_id: str, status: MCPInstallStatus, pg: Postgres = postgres) -> None: """ 更新数据库中MCP模板状态 :param str mcp_id: MCP模板ID :param MCPInstallStatus status: MCP模板状态 + :param Postgres pg: PostgreSQL实例,默认为全局单例postgres :return: 无 """ - async with postgres.session() as session: + async with pg.session() as session: mcp_data = (await session.scalars(select(MCPInfo).where(MCPInfo.id == mcp_id))).one_or_none() if mcp_data: logger.info("[MCPLoader] 更新MCP模板状态: %s -> %s", mcp_id, status) mcp_data.status = status await session.merge(mcp_data) + await session.commit() @staticmethod diff --git a/apps/scheduler/scheduler/data.py b/apps/scheduler/scheduler/data.py index 1acfedf87a3711bf2279d3e5c8bc49c0a8990f39..9ad19b2232e9b7aa9e462035e215ab77e5c60049 100644 --- a/apps/scheduler/scheduler/data.py +++ b/apps/scheduler/scheduler/data.py @@ -2,10 +2,12 @@ """数据管理相关的Mixin类""" import logging +import uuid from datetime import UTC, datetime +from typing import Any from apps.common.security import Security -from apps.models import Record, RecordMetadata, StepStatus +from apps.models import Conversation, ExecutorStatus, Record, RecordMetadata, StepStatus from apps.schemas.record import FlowHistory, RecordContent from apps.schemas.request_data import RequestData from apps.schemas.task import TaskData @@ -16,6 +18,9 @@ from apps.services.task import TaskManager _logger = logging.getLogger(__name__) +# 对话标题最大长度 +_CONVERSATION_TITLE_MAX_LENGTH = 30 + class DataMixin: """处理数据保存和管理相关的逻辑""" @@ -28,10 +33,9 @@ class DataMixin: used_docs = [] task = self.task - if hasattr(task.runtime, "documents") and task.runtime.documents: - for docs in task.runtime.documents: - doc_dict = docs if isinstance(docs, dict) else (docs.model_dump() if hasattr(docs, "model_dump") else docs) - used_docs.append(doc_dict) + for docs in task.runtime.document: + doc_dict = docs if isinstance(docs, dict) else (docs.model_dump() if hasattr(docs, "model_dump") else docs) + used_docs.append(doc_dict) return used_docs @@ -40,22 +44,13 @@ class DataMixin: task = self.task return RecordContent( - question=task.runtime.question if hasattr(task.runtime, "question") else "", - answer=task.runtime.answer if hasattr(task.runtime, "answer") else "", - facts=task.runtime.facts if hasattr(task.runtime, "facts") else [], + question=task.runtime.userInput, + answer=task.runtime.fullAnswer, + facts=task.runtime.fact, data={}, ) - def _encrypt_record_content(self, record_content: RecordContent) -> tuple[str, str] | None: - """加密记录内容""" - try: - encrypt_data, encrypt_config = Security.encrypt(record_content.model_dump_json(by_alias=True)) - return encrypt_data, encrypt_config - except Exception: - _logger.exception("[Scheduler] 问答对加密错误") - return None - - def _build_record(self, encrypt_data: str, encrypt_config: str, current_time: float) -> Record: + def _build_record(self, encrypt_data: str, encrypt_config: dict[str, Any], current_time: float) -> Record: """构建记录对象""" task = self.task user_id = task.metadata.userId @@ -93,33 +88,89 @@ class DataMixin: if record_group and used_docs: await DocumentManager.save_answer_doc(user_id, record_group, used_docs) - async def _save_record_data(self, record: Record) -> None: - """保存记录数据""" - user_id = self.task.metadata.userId - post_body = self.post_body - - if post_body.conversation_id: - await RecordManager.insert_record_data(user_id, post_body.conversation_id, record) + async def _save_record_data(self, record_content: RecordContent, current_time: float) -> None: + """加密并保存记录数据""" + # 加密记录内容 + try: + encrypt_data, encrypt_config = Security.encrypt(record_content.model_dump_json(by_alias=True)) + except Exception: + _logger.exception("[Scheduler] 问答对加密错误") + return - async def _update_app_center(self) -> None: - """更新应用中心最近使用的应用""" - user_id = self.task.metadata.userId - post_body = self.post_body + # 构建记录对象 + record = self._build_record(encrypt_data, encrypt_config, current_time) - if post_body.app and post_body.app.app_id: - await AppCenterManager.update_recent_app(user_id, post_body.app.app_id) + # 保存记录 + if self.post_body.conversation_id: + await RecordManager.insert_record_data( + self.task.metadata.userId, + self.post_body.conversation_id, + record, + ) async def _handle_task_state(self) -> None: - """处理任务状态管理""" + """根据任务状态判断删除或保存Task""" task = self.task - if not task.state or task.state.flow_status in [StepStatus.SUCCESS, StepStatus.ERROR, StepStatus.CANCELLED]: + if not task.state or task.state.executorStatus in [ + ExecutorStatus.SUCCESS, + ExecutorStatus.ERROR, + ExecutorStatus.CANCELLED, + ]: await TaskManager.delete_task_by_task_id(task.metadata.id) else: - await TaskManager.save_task(task.metadata.id, task.metadata) - await TaskManager.save_task_runtime(task.runtime) - if task.state: - await TaskManager.save_executor_checkpoint(task.state) + await TaskManager.save_task(task) + + async def _ensure_conversation_exists(self) -> None: + """确保存在conversation,如果不存在则创建""" + # 如果已经有 conversation_id,无需创建 + if self.task.metadata.conversationId: + return + + _logger.info("[Scheduler] 当前无 conversation_id,创建新对话") + + # 确定标题:直接使用问题的前 _CONVERSATION_TITLE_MAX_LENGTH 个字符 + title = "" + if hasattr(self.task.runtime, "question"): + question_attr = getattr(self.task.runtime, "question", "") + if question_attr and isinstance(question_attr, str): + question = question_attr.strip() + if question: + # 截取前 N 个字符作为标题 + if len(question) > _CONVERSATION_TITLE_MAX_LENGTH: + title = question[:_CONVERSATION_TITLE_MAX_LENGTH] + "..." + else: + title = question + + # 确定 app_id + app_id: uuid.UUID | None = None + if self.post_body.app and self.post_body.app.app_id: + app_id = self.post_body.app.app_id + + # 确定是否为调试模式 + debug = getattr(self.post_body, "debug", False) + + try: + # 调用 InitMixin 中的 _create_new_conversation 方法创建对话 + new_conversation: Conversation = await self._create_new_conversation( # type: ignore[attr-defined] + title=title, + user_id=self.task.metadata.userId, + app_id=app_id, + debug=debug, + ) + + # 更新 task 和 post_body 中的 conversation_id + self.task.metadata.conversationId = new_conversation.id + self.post_body.conversation_id = new_conversation.id + + _logger.info( + "[Scheduler] 成功创建新对话,conversation_id: %s, title: %s", + new_conversation.id, + title, + ) + except Exception: + _logger.exception("[Scheduler] 创建新对话失败") + raise async def _save_data(self) -> None: """保存当前Executor、Task、Record等的数据""" @@ -129,21 +180,20 @@ class DataMixin: used_docs = self._extract_used_documents() record_content = self._build_record_content() - encrypted_result = self._encrypt_record_content(record_content) - if encrypted_result is None: - return - encrypt_data, encrypt_config = encrypted_result if task.state: await TaskManager.save_flow_context(task.context) + # 在保存 Record 之前,确保存在 conversation + await self._ensure_conversation_exists() + current_time = round(datetime.now(UTC).timestamp(), 2) - record = self._build_record(encrypt_data, encrypt_config, current_time) await self._handle_document_management(record_group, used_docs) + await self._save_record_data(record_content, current_time) - await self._save_record_data(record) - - await self._update_app_center() + # 更新应用中心最近使用的应用 + if self.post_body.app and self.post_body.app.app_id: + await AppCenterManager.update_recent_app(self.task.metadata.userId, self.post_body.app.app_id) await self._handle_task_state() diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 41d4558ba394748a89b92bce5a0d27ade9ade1cd..5e5d42815ebd23024cb7fcde7636a462ea38caa3 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -3,7 +3,6 @@ import asyncio import logging -import uuid from apps.common.queue import MessageQueue from apps.schemas.request_data import RequestData diff --git a/apps/services/task.py b/apps/services/task.py index 8034ea66ed257ced6e58fde4949fbae81c2f2a67..4cd35bf903f43bdca04101b670575b43ee18e147 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -155,64 +155,15 @@ class TaskManager: @staticmethod - async def save_task(task_id: uuid.UUID, task: Task) -> None: - """保存Task数据到PostgreSQL""" + async def save_task(task_data: TaskData) -> None: + """保存Task、TaskRuntime和ExecutorCheckpoint数据到PostgreSQL""" async with postgres.session() as session: - # 查询是否存在该Task - existing_task = (await session.scalars( - select(Task).where(Task.id == task_id), - )).one_or_none() - - if existing_task: - # 更新现有Task - for key, value in task.__dict__.items(): - if not key.startswith("_"): - setattr(existing_task, key, value) - else: - # 插入新Task - session.add(task) - - await session.commit() - - - @staticmethod - async def save_task_runtime(task_runtime: TaskRuntime) -> None: - """保存TaskRuntime数据到PostgreSQL""" - async with postgres.session() as session: - # 查询是否存在该TaskRuntime - existing_runtime = (await session.scalars( - select(TaskRuntime).where(TaskRuntime.taskId == task_runtime.taskId), - )).one_or_none() + await session.merge(task_data.metadata) + await session.merge(task_data.runtime) - if existing_runtime: - # 更新现有TaskRuntime - for key, value in task_runtime.__dict__.items(): - if not key.startswith("_"): - setattr(existing_runtime, key, value) - else: - # 插入新TaskRuntime - session.add(task_runtime) - - await session.commit() - - - @staticmethod - async def save_executor_checkpoint(checkpoint: ExecutorCheckpoint) -> None: - """保存ExecutorCheckpoint数据到PostgreSQL""" - async with postgres.session() as session: - # 查询是否存在该Checkpoint - existing_checkpoint = (await session.scalars( - select(ExecutorCheckpoint).where(ExecutorCheckpoint.taskId == checkpoint.taskId), - )).one_or_none() - - if existing_checkpoint: - # 更新现有Checkpoint - for key, value in checkpoint.__dict__.items(): - if not key.startswith("_"): - setattr(existing_checkpoint, key, value) - else: - # 插入新Checkpoint - session.add(checkpoint) + # 保存ExecutorCheckpoint(如果存在) + if task_data.state: + await session.merge(task_data.state) await session.commit()