From 0fe0d457de39ce225fe38249af544599f793c9f5 Mon Sep 17 00:00:00 2001 From: sunchao <1299792067@qq.com> Date: Fri, 29 Aug 2025 15:08:49 +0800 Subject: [PATCH 1/4] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=9B=BE?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E4=BB=93=E5=BA=93=E6=A8=A1=E5=9D=97=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../server/app/repositories/__init__.py | 15 + .../app/repositories/graph_repo_base.py | 32 + .../server/app/repositories/graph_repo_db.py | 836 ++++++++++++++++++ .../server/app/repositories/graph_repo_vis.py | 83 ++ 4 files changed, 966 insertions(+) create mode 100644 plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/__init__.py create mode 100644 plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_base.py create mode 100644 plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py create mode 100644 plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_vis.py diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/__init__.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/__init__.py new file mode 100644 index 000000000..ee2432f47 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025, Huawei Technologies. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_base.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_base.py new file mode 100644 index 000000000..43d73304c --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_base.py @@ -0,0 +1,32 @@ + +# Copyright (c) 2025, Huawei Technologies. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from abc import ABC, abstractmethod + + +class GraphRepo(ABC): + + @abstractmethod + def query_root_nodes(self, graph_type, rank, step): + pass + + @abstractmethod + def query_sub_nodes(self, node_name, graph_type, rank, step): + pass + + @abstractmethod + def query_up_nodes(self, node_name, graph_type, rank, step): + pass diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py new file mode 100644 index 000000000..b36dd655e --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py @@ -0,0 +1,836 @@ +# Copyright (c) 2025, Huawei Technologies. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIE S OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import json +import time +import sqlite3 +from .graph_repo_base import GraphRepo +from ..utils.graph_utils import GraphUtils +from ..utils.global_state import GraphState, SINGLE, NPU, BENCH, DataType +from tensorboard.util import tb_logging + +logger = tb_logging.get_logger() +DB_TYPE = DataType.DB.value + + +class GraphRepoDB(GraphRepo): + + def __init__(self, db_path): + self.db_path = db_path + self.repo_type = DB_TYPE + self._initialize_db_connection() + + def _initialize_db_connection(self): + try: + # 目录安全校验 + dir = str(os.path.dirname(self.db_path)) + success, error = GraphUtils.safe_check_load_file_path(dir, True) + if not success: + raise PermissionError(error) + # 文件安全校验 + success, error = GraphUtils.safe_check_load_file_path(self.db_path) + if not success: + raise PermissionError(error) + self.conn = sqlite3.connect(self.db_path, check_same_thread=False) + self.conn.row_factory = sqlite3.Row + self.is_db_connected = self.conn is not None + # 提升性能的 PRAGMA 设置 + self.conn.execute("PRAGMA journal_mode = WAL;") + self.conn.execute("PRAGMA synchronous = NORMAL;") # 或 OFF(不安全) + self.conn.execute("PRAGMA cache_size = 40000;") + self.conn.execute("PRAGMA wal_autocheckpoint = 0;") + except: + logger.error("Failed to connect to database") + return None + + def get_db_connection(self): + return self.conn + + # DB: 查询配置表信息 + def query_config_info(self): + try: + query = f"SELECT * FROM tb_config" + start = time.perf_counter() + with self.conn as c: + cursor = c.execute(query) + rows = cursor.fetchall() + + record = dict(rows[0]) + # 构建最终的 data 对象 + config_info = { + "microSteps": record.get('micro_steps', 1) or 1, + "tooltips": GraphUtils.safe_json_loads(record.get('tool_tip')), + "overflowCheck": bool(record.get('overflow_check', 1) or 1), + "isSingleGraph": not record.get('graph_type') == 'compare', + "colors": GraphUtils.safe_json_loads(record.get('node_colors')), + "matchedConfigFiles": [], + "task": record.get('task', ''), + "ranks": GraphUtils.safe_json_loads(record.get('rank_list')), + "steps": GraphUtils.safe_json_loads(record.get('step_list')), + } + end = time.perf_counter() + print("query_config_info time:", end - start) + return config_info + except Exception as e: + logger.error(f"Failed to query config info: {e}") + return [] + + # DB:查询根节点信息 + def query_root_nodes(self, graph_type, rank, step): + try: + type = graph_type if graph_type != SINGLE else NPU + start = time.perf_counter() + query = """ + SELECT + node_name, + up_node, + sub_nodes, + node_type, + matched_node_link, + precision_index, + overflow_level, + matched_distributed + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND data_source = ? + AND up_node = '' + """ + with self.conn as c: + cursor = c.execute(query, (step, rank, type)) + rows = cursor.fetchall() + + end = time.perf_counter() + print("query_root_nodes time:", end - start) + if len(rows) > 0: + return self._convert_db_to_object(dict(rows[0])) + else: + return None + except Exception as e: + logger.error(f"Failed to query root nodes: {e}") + return [] + + # DB:查询当前节点的所有父节点信息 + def query_up_nodes(self, node_name, graph_type, rank, step): + try: + start = time.perf_counter() + type = graph_type if graph_type != SINGLE else NPU + # 现根据节点名称查询节点信息,根据up_node字段得到父节点名称 + # 再根据父节点名称查询父节点信息 + # 递归查询父节点,直到根节点 + query = """ + WITH RECURSIVE parent_chain AS ( + SELECT child.id, child.node_name, child.up_node, child.data_source, child.rank, child.step, 0 AS level + FROM + tb_nodes child + WHERE + child.step = ? + AND child.rank = ? + AND child.data_source = ? + AND child.node_name = ? + + UNION ALL + + SELECT + parent.id, + parent.node_name, + parent.up_node, + parent.data_source, + parent.rank, + parent.step, + pc.level + 1 + FROM + tb_nodes parent + INNER JOIN parent_chain pc + ON parent.data_source = pc.data_source + AND parent.node_name = pc.up_node + AND parent.rank = pc.rank + AND parent.step = pc.step + WHERE + pc.up_node IS NOT NULL + AND pc.up_node != '' + ) + SELECT + tb_nodes.id, + tb_nodes.data_source, + tb_nodes.node_name, + tb_nodes.up_node, + tb_nodes.sub_nodes, + tb_nodes.node_type, + tb_nodes.matched_node_link, + tb_nodes.precision_index, + tb_nodes.overflow_level, + tb_nodes.matched_distributed + FROM + tb_nodes + WHERE + id IN (SELECT id FROM parent_chain) + ORDER BY ( + SELECT + level + FROM + parent_chain pc + WHERE + pc.node_name = tb_nodes.node_name) + ASC + """ + with self.conn as c: + cursor = c.execute(query, (step, rank, type, node_name)) + rows = cursor.fetchall() + + up_nodes = {} + for row in rows: + dict_row = self._convert_db_to_object(dict(row)) + up_nodes[row['node_name']] = dict_row + end = time.perf_counter() + print("query_up_nodes time:", end - start) + return up_nodes + except Exception as e: + logger.error(f"Failed to query up nodes: {e}") + return {} + + # DB: 查询待匹配节点的信息,构造graph data + def query_matched_nodes_info(self, npu_node_name, bench_node_name, rank, step): + try: + start = time.perf_counter() + query = """ + SELECT + id, + node_name, + node_type, + up_node, + sub_nodes, + data_source, + input_data, + output_data, + matched_node_link + FROM tb_nodes + WHERE step = ? AND rank = ? AND data_source = ? AND node_name = ? + """ + npu_nodes = {} + bench_nodes = {} + opposite_npu_node_name = GraphUtils.get_opposite_node_name(npu_node_name) + opposite_bench_node_name = GraphUtils.get_opposite_node_name(bench_node_name) + # 定义查询参数列表:(graph_type, node_name, target_dict_key) + queries = [ + (NPU, npu_node_name, 'npu'), + (NPU, opposite_npu_node_name, 'npu_opposite'), + (BENCH, bench_node_name, 'bench'), + (BENCH, opposite_bench_node_name, 'bench_opposite'), + ] + # 存储结果的字典 + nodes_dict = {} + with self.conn as c: + for graph_type, node_name, key in queries: + if not node_name: # 可选:跳过空 node_name + continue + cursor = c.execute(query, (step, rank, graph_type, node_name)) + rows = cursor.fetchall() + if rows: + node_obj = self._convert_db_to_object(dict(rows[0])) + nodes_dict[key] = {node_obj.get('node_name'): node_obj} + else: + nodes_dict[key] = {} + + npu_nodes = nodes_dict.get('npu', {}) | nodes_dict.get('npu_opposite', {}) + bench_nodes = nodes_dict.get('bench', {}) | nodes_dict.get('bench_opposite', {}) + result = self._convert_to_graph_json(npu_nodes, bench_nodes) + end = time.perf_counter() + print("query_matched_nodes_info time:", end - start) + return result + except Exception as e: + logger.error(f"Failed to query matched nodes info: {e}") + return self._convert_to_graph_json({}, {}) + + # DB: 查询待匹配节点及其子节点的信息,递归查询当前节点信息和其所有的子节点信息,一直叶子节点 + def query_node_and_sub_nodes(self, npu_node_name, bench_node_name, rank, step): + try: + start = time.perf_counter() + query = """ + WITH RECURSIVE descendants AS ( + -- 初始节点选择 + SELECT + id, + node_name, + node_type, + up_node, + sub_nodes, + data_source, + input_data, + output_data, + matched_node_link, + node_order, + step, + rank + FROM tb_nodes + WHERE step = ? AND rank = ? AND data_source = ? AND node_name = ? + + UNION ALL + + -- 递归部分 + SELECT + child.id, + child.node_name, + child.node_type, + child.up_node, + child.sub_nodes, + child.data_source, + child.input_data, + child.output_data, + child.matched_node_link, + child.node_order, + child.step, + child.rank + FROM descendants d + JOIN json_each(d.sub_nodes) AS je -- 将 sub_nodes JSON 数组展开为多行 + JOIN tb_nodes child + ON child.node_name = je.value -- 子节点名称匹配 + AND child.step = d.step + AND child.rank = d.rank + AND child.data_source = d.data_source + WHERE + d.sub_nodes IS NOT NULL -- 父节点的 sub_nodes 不为 NULL + AND d.sub_nodes != '' -- 不是空 + AND d.sub_nodes != '[]' + AND json_type(d.sub_nodes) = 'array' -- 确保是合法 JSON 数组 + ) + SELECT * FROM descendants + """ + + npu_nodes = {} + bench_nodes = {} + opposite_npu_node_name = GraphUtils.get_opposite_node_name(npu_node_name) + opposite_bench_node_name = GraphUtils.get_opposite_node_name(bench_node_name) + # 定义查询参数列表:(graph_type, node_name, target_dict_key) + queries = [ + (NPU, npu_node_name, 'npu'), + (NPU, opposite_npu_node_name, 'npu_opposite'), + (BENCH, bench_node_name, 'bench'), + (BENCH, opposite_bench_node_name, 'bench_opposite'), + ] + # 存储结果的字典 + nodes_dict = {} + with self.conn as c: + for graph_type, node_name, key in queries: + if not node_name: # 可选:跳过空 node_name + continue + cursor = c.execute(query, (step, rank, graph_type, node_name)) + nodes_dict[key] = self._fetch_and_convert_rows(cursor) + npu_nodes = nodes_dict.get('npu', {}) | nodes_dict.get('npu_opposite', {}) + bench_nodes = nodes_dict.get('bench', {}) | nodes_dict.get('bench_opposite', {}) + result = self._convert_to_graph_json(npu_nodes, bench_nodes) + end = time.perf_counter() + print("query_node_and_sub_nodes time:", end - start) + return result + except Exception as e: + logger.error(f"Failed to query node and sub nodes: {e}") + return {'NPU': {}, 'Bench': {}} + + # DB:查询配置文件中的待匹配节点信息 + def query_matched_nodes_info_by_config(self, match_node_links, rank, step): + try: + start = time.perf_counter() + query = """ + SELECT + id, + node_name, + node_type, + up_node, + sub_nodes, + data_source, + input_data, + output_data, + matched_node_link + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND data_source = ? + AND node_name IN ({}) + """.format(','.join(['?'] * len(match_node_links))) + + with self.conn as c: + npu_node_names = list(match_node_links.keys()) + bench_node_names = list(match_node_links.values()) + npu_cursor = c.execute(query, (step, rank, NPU , *npu_node_names)) + bench_cursor = c.execute(query, (step, rank, BENCH, *bench_node_names)) + npu_nodes = self._fetch_and_convert_rows(npu_cursor) + bench_nodes = self._fetch_and_convert_rows(bench_cursor) + result = self._convert_to_graph_json(npu_nodes, bench_nodes) + end = time.perf_counter() + print("query_matched_nodes_info_by_config time:", end - start) + return result + except Exception as e: + logger.error(f"Failed to query nodes info: {e}") + return {} + + # DB: 查询所有以当前为父节点的子节点 + def query_sub_nodes(self, node_name, graph_type, rank, step): + try: + start = time.perf_counter() + type = graph_type if graph_type != SINGLE else NPU + query = """ + SELECT + node_name, + up_node, + sub_nodes, + node_type, + micro_step_id, + matched_node_link, + precision_index, + overflow_level, + matched_distributed + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND data_source = ? + AND up_node = ? + ORDER BY + node_order ASC + """ + with self.conn as c: + cursor = c.execute(query, (step, rank, type, node_name)) + rows = cursor.fetchall() + sub_nodes = {} + for row in rows: + dict_row = self._convert_db_to_object(dict(row)) + sub_nodes[row['node_name']] = dict_row + end = time.perf_counter() + print("query_sub_nodes time:", end - start) + return sub_nodes + except Exception as e: + logger.error(f"Failed to query sub nodes: {e}") + return {} + + # DB: 查询当前节点信息 + def query_node_info(self, node_name, graph_type, rank, step): + try: + start = time.perf_counter() + type = graph_type if graph_type != SINGLE else NPU + query = """ + SELECT + * + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND data_source = ? + AND node_name = ? + """ + with self.conn as c: + cursor = c.execute(query, (step, rank, type, node_name)) + rows = cursor.fetchall() + + end = time.perf_counter() + print("query_node_info time:", end - start) + if len(rows) > 0: + return self._convert_db_to_object(dict(rows[0])) + else: + return {} + except Exception as e: + logger.error(f"Failed to query node info: {e}") + return {} + + # DB: 查询单图节点名称列表 + def query_node_name_list(self, rank, step, micro_step): + try: + query = """ + SELECT + node_name + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND (? = -1 OR micro_step_id = ?) + AND data_source = 'NPU' + ORDER BY + node_order ASC + """ + with self.conn as c: + cursor = c.execute(query, (step, rank, micro_step, micro_step)) + rows = cursor.fetchall() + return [row['node_name'] for row in rows] + except Exception as e: + logger.error(f"Failed to query node name list: {e}") + return [] + + # DB: 查询已匹配节点列表,未匹配节点列表,所有的节点列表 + def query_all_node_info_in_one(self, rank, step, micro_step): + try: + # 查找缓存 + all_node_info_cache = GraphState.get_global_value('all_node_info_cache', {}) + cache = f'{rank}_{step}_{micro_step}' + if all_node_info_cache.get(cache) != None: + print("all_node_info_cache hit") + return all_node_info_cache.get(cache) + # 查询数据库 + start = time.perf_counter() + # 单次查询:获取 node_name 和 matched_node_link + query = """ + SELECT + node_name, + data_source, + matched_node_link + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND (? = -1 OR micro_step_id = ?) + ORDER BY + node_order ASC + """ + + with self.conn as conn: + cursor = conn.execute(query, (step, rank, micro_step, micro_step)) + rows = cursor.fetchall() + end = time.perf_counter() + print(f"query_all_node_info_in_one time: {end - start:.4f}s") + + # 初始化结果 + npu_node_list = [] + bench_node_list = [] + npu_match_node = {} # {node_name: last_matched_link} + bench_match_node = {} + npu_unmatch_node = [] + bench_unmatch_node = [] + + # 一次性遍历结果,分类处理 + for row in rows: + node_name = row['node_name'] + matched_link_str = row['matched_node_link'] + if row['data_source'] == NPU: + npu_node_list.append(node_name) + # 解析 matched_node_link + matched_link = GraphUtils.safe_json_loads(matched_link_str) + # 判断是否为有效匹配(非空列表) + if isinstance(matched_link, list) and len(matched_link) > 0: + npu_match_node[node_name] = matched_link[-1] # 取最后一个匹配项 + else: + npu_unmatch_node.append(node_name) + elif row['data_source'] == BENCH: + bench_node_list.append(node_name) + # 解析 matched_node_link + matched_link = GraphUtils.safe_json_loads(matched_link_str) + # 判断是否为有效匹配(非空列表) + if isinstance(matched_link, list) and len(matched_link) > 0: + bench_match_node[node_name] = matched_link[-1] # 取最后一个匹配项 + else: + bench_unmatch_node.append(node_name) + else: + logger.error(f"Invalid data source: {row['data_source']}") + all_node_info = { + 'npu_node_list': npu_node_list, + 'bench_node_list': bench_node_list, + 'npu_match_node': npu_match_node, + 'bench_match_node': bench_match_node, + 'npu_unmatch_node': npu_unmatch_node, + 'bench_unmatch_node': bench_unmatch_node + } + all_node_info_cache = GraphState.get_global_value('all_node_info_cache', {}) + all_node_info_cache[cache] = all_node_info + return all_node_info + + except Exception as e: + logger.error(f"Failed to query all node info: {e}") + return { + 'npu_node_list': [], + 'bench_node_list': [], + 'npu_match_node': {}, + 'bench_match_node': {}, + 'npu_unmatch_node': [], + 'bench_unmatch_node': [] + } + + # # DB:根据step rank modify match_node_link查询已经修改的匹配成功的节点关系 + def query_modify_matched_nodes_list(self, rank, step): + try: + start = time.perf_counter() + query = """ + SELECT + node_name, + matched_node_link + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND modified = 1 + AND matched_node_link IS NOT NULL + AND matched_node_link != '[]' + AND matched_node_link != '' + """ + with self.conn as c: + cursor = c.execute(query, (step, rank)) + rows = cursor.fetchall() + result = {} + for row in rows: + matched_node_link = GraphUtils.safe_json_loads(row['matched_node_link']) + node_name = row['node_name'] + if isinstance(matched_node_link, list) and len(matched_node_link) > 0: + result[node_name] = matched_node_link[-1] # 取最后一个匹配项 + end = time.perf_counter() + print("query_modify_matched_nodes_list time:", end - start) + return result + except Exception as e: + logger.error(f"Failed to query modify matched nodes list: {e}") + return {} + + # DB: 根据精度误差查询节点信息 + def query_node_list_by_precision(self, step, rank, micro_step, values, is_filter_unmatch_nodes): + try: + # 准备占位符 + conditions = [] + placeholders = [] + params = [] + conditions.append("step = ?") + conditions.append("rank = ?") + conditions.append("data_source = 'NPU'") + conditions.append("(? = -1 OR micro_step_id = ?)") + conditions.append("(sub_nodes = '' OR sub_nodes IS NULL OR sub_nodes = '[]')") + for value in values: + placeholder = "(precision_index BETWEEN ? AND ?)" + placeholders.append(placeholder) + params.extend(value) + + if is_filter_unmatch_nodes: + placeholders.append("(matched_node_link = '' OR matched_node_link IS NULL OR matched_node_link = '[]')") + + if len(placeholders) > 0: + conditions.append(f"({'OR'.join(placeholders)})") + start = time.perf_counter() + query = f""" + SELECT + node_name + FROM + tb_nodes + WHERE + {" AND ".join(conditions)} + ORDER BY + node_order ASC + """ + with self.conn as c: + cursor = c.execute(query, (step, rank, micro_step, micro_step, *params)) + rows = cursor.fetchall() + node_list = [row['node_name'] for row in rows] + end = time.perf_counter() + print("query_node_list_by_precision time:", end - start) + return node_list + except Exception as e: + logger.error(f"Failed to query node list by precision: {e}") + return [] + + # DB: 根据溢出查询节点信息 + def query_node_list_by_overflow(self, step, rank, micro_step, values): + try: + # 准备占位符 + conditions = [] + + conditions.append("step = ?") + conditions.append("rank = ?") + conditions.append("data_source = 'NPU'") + conditions.append("(? = -1 OR micro_step_id = ?)") + placeholders = ", ".join(["?"] * len(values)) + start = time.perf_counter() + query = f""" + SELECT + node_name + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND data_source = 'NPU' + AND (? = -1 OR micro_step_id = ?) + AND (sub_nodes = '' OR sub_nodes IS NULL OR sub_nodes = '[]') + AND overflow_level IN ({placeholders}) + ORDER BY + node_order ASC + """ + with self.conn as c: + cursor = c.execute(query, (step, rank, micro_step, micro_step, *values)) + rows = cursor.fetchall() + node_list = [row['node_name'] for row in rows] + end = time.perf_counter() + print("query_node_list_by_overflow time:", end - start) + return node_list + except Exception as e: + logger.error(f"Failed to query node list by overflow: {e}") + return [] + + # DB:查询节点信息 + def query_node_info_by_data_source(self, step, rank, data_source): + try: + start = time.perf_counter() + query = """ + SELECT + node_name, + matched_node_link, + output_data, + precision_index, + sub_nodes + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND data_source = ? + """ + with self.conn as c: + cursor = c.execute(query, (step, rank, data_source)) + nodes = self._fetch_and_convert_rows(cursor) + end = time.perf_counter() + print("query_node_info time:", end - start) + return nodes + except Exception as e: + logger.error(f"Failed to query node info: {e}") + return {} + + # DB:更新config的colors + def update_config_colors(self, colors): + try: + query = """ + UPDATE + tb_config + SET + node_colors = ? + WHERE + id=1 + """ + with self.conn as c: + c.execute(query, (json.dumps(colors),)) + return True + except Exception as e: + logger.error(f"Failed to update config colors: {e}") + return False + + # DB:批量更新节点信息 + def update_nodes_info(self, nodes_info, rank, step): + # 取消匹配和匹配都要走这个逻辑 + try: + start = time.perf_counter() + data = [ + ( + json.dumps(node['matched_node_link']), + json.dumps(node['input_data']), + json.dumps(node['output_data']), + node['precision_index'], + step, + rank, + node['graph_type'], + node['node_name'] # WHERE 条件 + ) + for node in nodes_info + ] + query = """ + UPDATE tb_nodes + SET + matched_node_link = ?, + input_data = ?, + output_data = ?, + precision_index = ?, + modified= 1 + WHERE + step = ? + AND rank = ? + AND data_source = ? + AND node_name = ? + """ + with self.conn as c: + c.executemany(query, data) + end = time.perf_counter() + print("update_nodes_info time:", end - start) + return True + except Exception as e: + logger.error(f"Failed to update nodes info: {e}") + return False + + def update_nodes_precision_error(self, update_data): + + try: + start = time.perf_counter() + query = """ + UPDATE + tb_nodes + SET + precision_index = ? + WHERE + step = ? + AND rank = ? + AND data_source = 'NPU' + AND node_name = ? + """ + self.conn.executemany(query, update_data) + self.conn.commit() + end = time.perf_counter() + print("update_precision_error time:", end - start) + return True + except Exception as e: + logger.error(f"Failed to update precision error: {e}") + return False + + def _fetch_and_convert_rows(self, cursor): + """ + Helper function to fetch rows from cursor and convert them. + :param cursor: SQLite cursor object + :return: Dictionary of nodes keyed by node_name + """ + nodes = {} + for row in cursor.fetchall(): + dict_row = self._convert_db_to_object(dict(row)) + nodes[row['node_name']] = dict_row + return nodes + + def _convert_to_graph_json(self, npu_nodes, bench_nodes): + graph_data = { + "NPU":{ + "node": npu_nodes, + }, + "Bench":{ + "node": bench_nodes, + } + } + return graph_data + + def _convert_db_to_object(self, data): + object = { + "id": data.get('node_name'), + "node_name": data.get('node_name'), + "node_type": int(data.get('node_type')) if data.get('node_type') is not None else 0, + "output_data": GraphUtils.safe_json_loads(data.get('output_data') or "{}"), + "input_data": GraphUtils.safe_json_loads(data.get('input_data') or "{}"), + "upnode":data.get('up_node'), + "subnodes":GraphUtils.safe_json_loads(data.get('sub_nodes') or "[]"), + "matched_node_link":GraphUtils.safe_json_loads(data.get('matched_node_link') or "[]"), + "stack_info":GraphUtils.safe_json_loads(data.get('stack_info') or "[]"), + "micro_step_id": int(data.get('micro_step_id')) if data.get('micro_step_id') is not None else -1, + "data":{ + "precision_index": data.get('precision_index'), + 'overflow_level': data.get('overflow_level'), + }, + "parallel_merge_info": GraphUtils.safe_json_loads(data.get('parallel_merge_info') or "[]"), + "matched_distributed": GraphUtils.safe_json_loads(data.get('matched_distributed') or "[]"), + "modified":int(data.get('modified')) if data.get('modified') is not None else 0, + } + return object + diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_vis.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_vis.py new file mode 100644 index 000000000..998901e56 --- /dev/null +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_vis.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025, Huawei Technologies. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from ..utils.global_state import SINGLE, DataType +from .graph_repo_base import GraphRepo + +JSON_TYPE = DataType.JSON.value + + +class GraphRepoVis(GraphRepo): + + def __init__(self, graph): + self.graph = graph + self.repo_type = JSON_TYPE + + # 查询根节点信息 + def query_root_nodes(self, graph_type, rank, step): + root_node = {} + if graph_type == SINGLE: + root_node_name = self.graph.get('root') + root_node = self.graph.get('node', {}).get(root_node_name, {}) + else: + root_node_name = self.graph.get(graph_type, {}).get('root') + root_node = self.graph.get(graph_type, {}).get('node', {}).get(root_node_name, {}) + root_node['node_name'] = root_node_name + return root_node + + # 查询所有以当前为父节点的子节点 + def query_sub_nodes(self, node_name, graph_type, rank, step): + sub_nodes = {} + graph_nodes = {} + if graph_type == SINGLE: + graph_nodes = self.graph.get('node', {}) + else: + graph_nodes = self.graph.get(graph_type, {}).get('node', {}) + target_node = graph_nodes.get(node_name, {}) + + target_node_children = target_node.get("subnodes", []) + for subnode_name in target_node_children: + node_info = graph_nodes.get(subnode_name, {}) + sub_nodes[subnode_name] = node_info + return sub_nodes + + # 查询当前节点的父节点信息 + def query_up_nodes(self, node_name, graph_type, rank, step): + graph_nodes = {} + if graph_type == SINGLE: + graph_nodes = self.graph.get('node', {}) + else: + graph_nodes = self.graph.get(graph_type, {}).get('node', {}) + + # 查询当前节点及其的所有父节点,一直到没有父节点位置{} + up_nodes = {} + up_nodes[node_name] = graph_nodes.get(node_name, {}) + parent_node_name = graph_nodes.get(node_name, {}).get("upnode") + while graph_nodes.get(parent_node_name, None) != None: + parent_node = graph_nodes.get(parent_node_name, {}) + up_nodes[parent_node_name] = parent_node + parent_node_name = parent_node.get("upnode") + return up_nodes + + # 查询当前节点信息 + def query_node_info(self, node_name, graph_type): + graph_nodes = {} + if graph_type == SINGLE: + graph_nodes = self.graph.get('node', {}) + else: + graph_nodes = self.graph.get(graph_type, {}).get('node', {}) + return graph_nodes.get(node_name, {}) + -- Gitee From 2a962b2619129953c88a32e3e0db8b16327b6d7b Mon Sep 17 00:00:00 2001 From: sunchao <1299792067@qq.com> Date: Sat, 30 Aug 2025 14:46:40 +0800 Subject: [PATCH 2/4] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=9F=A5=E8=AF=A2=E6=80=A7=E8=83=BD=E5=92=8C?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../server/app/repositories/graph_repo_db.py | 298 ++++++++---------- .../server/app/repositories/graph_repo_vis.py | 2 +- 2 files changed, 125 insertions(+), 175 deletions(-) diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py index b36dd655e..24b16c8d2 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py @@ -17,10 +17,11 @@ import os import json import time import sqlite3 +from tensorboard.util import tb_logging + from .graph_repo_base import GraphRepo from ..utils.graph_utils import GraphUtils from ..utils.global_state import GraphState, SINGLE, NPU, BENCH, DataType -from tensorboard.util import tb_logging logger = tb_logging.get_logger() DB_TYPE = DataType.DB.value @@ -33,37 +34,13 @@ class GraphRepoDB(GraphRepo): self.repo_type = DB_TYPE self._initialize_db_connection() - def _initialize_db_connection(self): - try: - # 目录安全校验 - dir = str(os.path.dirname(self.db_path)) - success, error = GraphUtils.safe_check_load_file_path(dir, True) - if not success: - raise PermissionError(error) - # 文件安全校验 - success, error = GraphUtils.safe_check_load_file_path(self.db_path) - if not success: - raise PermissionError(error) - self.conn = sqlite3.connect(self.db_path, check_same_thread=False) - self.conn.row_factory = sqlite3.Row - self.is_db_connected = self.conn is not None - # 提升性能的 PRAGMA 设置 - self.conn.execute("PRAGMA journal_mode = WAL;") - self.conn.execute("PRAGMA synchronous = NORMAL;") # 或 OFF(不安全) - self.conn.execute("PRAGMA cache_size = 40000;") - self.conn.execute("PRAGMA wal_autocheckpoint = 0;") - except: - logger.error("Failed to connect to database") - return None - def get_db_connection(self): return self.conn # DB: 查询配置表信息 def query_config_info(self): + query = f"SELECT * FROM tb_config" try: - query = f"SELECT * FROM tb_config" - start = time.perf_counter() with self.conn as c: cursor = c.execute(query) rows = cursor.fetchall() @@ -81,8 +58,6 @@ class GraphRepoDB(GraphRepo): "ranks": GraphUtils.safe_json_loads(record.get('rank_list')), "steps": GraphUtils.safe_json_loads(record.get('step_list')), } - end = time.perf_counter() - print("query_config_info time:", end - start) return config_info except Exception as e: logger.error(f"Failed to query config info: {e}") @@ -91,8 +66,7 @@ class GraphRepoDB(GraphRepo): # DB:查询根节点信息 def query_root_nodes(self, graph_type, rank, step): try: - type = graph_type if graph_type != SINGLE else NPU - start = time.perf_counter() + graph_type = graph_type if graph_type != SINGLE else NPU query = """ SELECT node_name, @@ -112,15 +86,12 @@ class GraphRepoDB(GraphRepo): AND up_node = '' """ with self.conn as c: - cursor = c.execute(query, (step, rank, type)) + cursor = c.execute(query, (step, rank, graph_type)) rows = cursor.fetchall() - - end = time.perf_counter() - print("query_root_nodes time:", end - start) if len(rows) > 0: return self._convert_db_to_object(dict(rows[0])) else: - return None + return [] except Exception as e: logger.error(f"Failed to query root nodes: {e}") return [] @@ -128,8 +99,7 @@ class GraphRepoDB(GraphRepo): # DB:查询当前节点的所有父节点信息 def query_up_nodes(self, node_name, graph_type, rank, step): try: - start = time.perf_counter() - type = graph_type if graph_type != SINGLE else NPU + graph_type = graph_type if graph_type != SINGLE else NPU # 现根据节点名称查询节点信息,根据up_node字段得到父节点名称 # 再根据父节点名称查询父节点信息 # 递归查询父节点,直到根节点 @@ -190,15 +160,12 @@ class GraphRepoDB(GraphRepo): ASC """ with self.conn as c: - cursor = c.execute(query, (step, rank, type, node_name)) + cursor = c.execute(query, (step, rank, graph_type, node_name)) rows = cursor.fetchall() - up_nodes = {} for row in rows: dict_row = self._convert_db_to_object(dict(row)) up_nodes[row['node_name']] = dict_row - end = time.perf_counter() - print("query_up_nodes time:", end - start) return up_nodes except Exception as e: logger.error(f"Failed to query up nodes: {e}") @@ -207,7 +174,6 @@ class GraphRepoDB(GraphRepo): # DB: 查询待匹配节点的信息,构造graph data def query_matched_nodes_info(self, npu_node_name, bench_node_name, rank, step): try: - start = time.perf_counter() query = """ SELECT id, @@ -250,8 +216,6 @@ class GraphRepoDB(GraphRepo): npu_nodes = nodes_dict.get('npu', {}) | nodes_dict.get('npu_opposite', {}) bench_nodes = nodes_dict.get('bench', {}) | nodes_dict.get('bench_opposite', {}) result = self._convert_to_graph_json(npu_nodes, bench_nodes) - end = time.perf_counter() - print("query_matched_nodes_info time:", end - start) return result except Exception as e: logger.error(f"Failed to query matched nodes info: {e}") @@ -260,7 +224,6 @@ class GraphRepoDB(GraphRepo): # DB: 查询待匹配节点及其子节点的信息,递归查询当前节点信息和其所有的子节点信息,一直叶子节点 def query_node_and_sub_nodes(self, npu_node_name, bench_node_name, rank, step): try: - start = time.perf_counter() query = """ WITH RECURSIVE descendants AS ( -- 初始节点选择 @@ -334,8 +297,6 @@ class GraphRepoDB(GraphRepo): npu_nodes = nodes_dict.get('npu', {}) | nodes_dict.get('npu_opposite', {}) bench_nodes = nodes_dict.get('bench', {}) | nodes_dict.get('bench_opposite', {}) result = self._convert_to_graph_json(npu_nodes, bench_nodes) - end = time.perf_counter() - print("query_node_and_sub_nodes time:", end - start) return result except Exception as e: logger.error(f"Failed to query node and sub nodes: {e}") @@ -344,7 +305,6 @@ class GraphRepoDB(GraphRepo): # DB:查询配置文件中的待匹配节点信息 def query_matched_nodes_info_by_config(self, match_node_links, rank, step): try: - start = time.perf_counter() query = """ SELECT id, @@ -373,8 +333,6 @@ class GraphRepoDB(GraphRepo): npu_nodes = self._fetch_and_convert_rows(npu_cursor) bench_nodes = self._fetch_and_convert_rows(bench_cursor) result = self._convert_to_graph_json(npu_nodes, bench_nodes) - end = time.perf_counter() - print("query_matched_nodes_info_by_config time:", end - start) return result except Exception as e: logger.error(f"Failed to query nodes info: {e}") @@ -383,8 +341,7 @@ class GraphRepoDB(GraphRepo): # DB: 查询所有以当前为父节点的子节点 def query_sub_nodes(self, node_name, graph_type, rank, step): try: - start = time.perf_counter() - type = graph_type if graph_type != SINGLE else NPU + graph_type = graph_type if graph_type != SINGLE else NPU query = """ SELECT node_name, @@ -407,14 +364,12 @@ class GraphRepoDB(GraphRepo): node_order ASC """ with self.conn as c: - cursor = c.execute(query, (step, rank, type, node_name)) + cursor = c.execute(query, (step, rank, graph_type, node_name)) rows = cursor.fetchall() sub_nodes = {} for row in rows: dict_row = self._convert_db_to_object(dict(row)) sub_nodes[row['node_name']] = dict_row - end = time.perf_counter() - print("query_sub_nodes time:", end - start) return sub_nodes except Exception as e: logger.error(f"Failed to query sub nodes: {e}") @@ -423,8 +378,7 @@ class GraphRepoDB(GraphRepo): # DB: 查询当前节点信息 def query_node_info(self, node_name, graph_type, rank, step): try: - start = time.perf_counter() - type = graph_type if graph_type != SINGLE else NPU + graph_type = graph_type if graph_type != SINGLE else NPU query = """ SELECT * @@ -437,11 +391,8 @@ class GraphRepoDB(GraphRepo): AND node_name = ? """ with self.conn as c: - cursor = c.execute(query, (step, rank, type, node_name)) + cursor = c.execute(query, (step, rank, graph_type, node_name)) rows = cursor.fetchall() - - end = time.perf_counter() - print("query_node_info time:", end - start) if len(rows) > 0: return self._convert_db_to_object(dict(rows[0])) else: @@ -452,20 +403,20 @@ class GraphRepoDB(GraphRepo): # DB: 查询单图节点名称列表 def query_node_name_list(self, rank, step, micro_step): + query = """ + SELECT + node_name + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND (? = -1 OR micro_step_id = ?) + AND data_source = 'NPU' + ORDER BY + node_order ASC + """ try: - query = """ - SELECT - node_name - FROM - tb_nodes - WHERE - step = ? - AND rank = ? - AND (? = -1 OR micro_step_id = ?) - AND data_source = 'NPU' - ORDER BY - node_order ASC - """ with self.conn as c: cursor = c.execute(query, (step, rank, micro_step, micro_step)) rows = cursor.fetchall() @@ -480,11 +431,9 @@ class GraphRepoDB(GraphRepo): # 查找缓存 all_node_info_cache = GraphState.get_global_value('all_node_info_cache', {}) cache = f'{rank}_{step}_{micro_step}' - if all_node_info_cache.get(cache) != None: - print("all_node_info_cache hit") + if all_node_info_cache.get(cache) is not None: return all_node_info_cache.get(cache) # 查询数据库 - start = time.perf_counter() # 单次查询:获取 node_name 和 matched_node_link query = """ SELECT @@ -504,9 +453,6 @@ class GraphRepoDB(GraphRepo): with self.conn as conn: cursor = conn.execute(query, (step, rank, micro_step, micro_step)) rows = cursor.fetchall() - end = time.perf_counter() - print(f"query_all_node_info_in_one time: {end - start:.4f}s") - # 初始化结果 npu_node_list = [] bench_node_list = [] @@ -565,7 +511,6 @@ class GraphRepoDB(GraphRepo): # # DB:根据step rank modify match_node_link查询已经修改的匹配成功的节点关系 def query_modify_matched_nodes_list(self, rank, step): try: - start = time.perf_counter() query = """ SELECT node_name, @@ -589,8 +534,6 @@ class GraphRepoDB(GraphRepo): node_name = row['node_name'] if isinstance(matched_node_link, list) and len(matched_node_link) > 0: result[node_name] = matched_node_link[-1] # 取最后一个匹配项 - end = time.perf_counter() - print("query_modify_matched_nodes_list time:", end - start) return result except Exception as e: logger.error(f"Failed to query modify matched nodes list: {e}") @@ -598,43 +541,41 @@ class GraphRepoDB(GraphRepo): # DB: 根据精度误差查询节点信息 def query_node_list_by_precision(self, step, rank, micro_step, values, is_filter_unmatch_nodes): - try: - # 准备占位符 - conditions = [] - placeholders = [] - params = [] - conditions.append("step = ?") - conditions.append("rank = ?") - conditions.append("data_source = 'NPU'") - conditions.append("(? = -1 OR micro_step_id = ?)") - conditions.append("(sub_nodes = '' OR sub_nodes IS NULL OR sub_nodes = '[]')") - for value in values: - placeholder = "(precision_index BETWEEN ? AND ?)" - placeholders.append(placeholder) - params.extend(value) - if is_filter_unmatch_nodes: - placeholders.append("(matched_node_link = '' OR matched_node_link IS NULL OR matched_node_link = '[]')") + # 准备占位符 + conditions = [] + placeholders = [] + params = [] + conditions.append("step = ?") + conditions.append("rank = ?") + conditions.append("data_source = 'NPU'") + conditions.append("(? = -1 OR micro_step_id = ?)") + conditions.append("(sub_nodes = '' OR sub_nodes IS NULL OR sub_nodes = '[]')") + for value in values: + placeholder = "(precision_index BETWEEN ? AND ?)" + placeholders.append(placeholder) + params.extend(value) + + if is_filter_unmatch_nodes: + placeholders.append("(matched_node_link = '' OR matched_node_link IS NULL OR matched_node_link = '[]')") - if len(placeholders) > 0: - conditions.append(f"({'OR'.join(placeholders)})") - start = time.perf_counter() - query = f""" - SELECT - node_name - FROM - tb_nodes - WHERE - {" AND ".join(conditions)} - ORDER BY - node_order ASC - """ + if len(placeholders) > 0: + conditions.append(f"({'OR'.join(placeholders)})") + query = f""" + SELECT + node_name + FROM + tb_nodes + WHERE + {" AND ".join(conditions)} + ORDER BY + node_order ASC + """ + try: with self.conn as c: cursor = c.execute(query, (step, rank, micro_step, micro_step, *params)) rows = cursor.fetchall() node_list = [row['node_name'] for row in rows] - end = time.perf_counter() - print("query_node_list_by_precision time:", end - start) return node_list except Exception as e: logger.error(f"Failed to query node list by precision: {e}") @@ -642,37 +583,33 @@ class GraphRepoDB(GraphRepo): # DB: 根据溢出查询节点信息 def query_node_list_by_overflow(self, step, rank, micro_step, values): + # 准备占位符 + conditions = [] + conditions.append("step = ?") + conditions.append("rank = ?") + conditions.append("data_source = 'NPU'") + conditions.append("(? = -1 OR micro_step_id = ?)") + placeholders = ", ".join(["?"] * len(values)) + query = f""" + SELECT + node_name + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND data_source = 'NPU' + AND (? = -1 OR micro_step_id = ?) + AND (sub_nodes = '' OR sub_nodes IS NULL OR sub_nodes = '[]') + AND overflow_level IN ({placeholders}) + ORDER BY + node_order ASC + """ try: - # 准备占位符 - conditions = [] - - conditions.append("step = ?") - conditions.append("rank = ?") - conditions.append("data_source = 'NPU'") - conditions.append("(? = -1 OR micro_step_id = ?)") - placeholders = ", ".join(["?"] * len(values)) - start = time.perf_counter() - query = f""" - SELECT - node_name - FROM - tb_nodes - WHERE - step = ? - AND rank = ? - AND data_source = 'NPU' - AND (? = -1 OR micro_step_id = ?) - AND (sub_nodes = '' OR sub_nodes IS NULL OR sub_nodes = '[]') - AND overflow_level IN ({placeholders}) - ORDER BY - node_order ASC - """ with self.conn as c: cursor = c.execute(query, (step, rank, micro_step, micro_step, *values)) rows = cursor.fetchall() node_list = [row['node_name'] for row in rows] - end = time.perf_counter() - print("query_node_list_by_overflow time:", end - start) return node_list except Exception as e: logger.error(f"Failed to query node list by overflow: {e}") @@ -680,27 +617,24 @@ class GraphRepoDB(GraphRepo): # DB:查询节点信息 def query_node_info_by_data_source(self, step, rank, data_source): + query = """ + SELECT + node_name, + matched_node_link, + output_data, + precision_index, + sub_nodes + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND data_source = ? + """ try: - start = time.perf_counter() - query = """ - SELECT - node_name, - matched_node_link, - output_data, - precision_index, - sub_nodes - FROM - tb_nodes - WHERE - step = ? - AND rank = ? - AND data_source = ? - """ with self.conn as c: cursor = c.execute(query, (step, rank, data_source)) nodes = self._fetch_and_convert_rows(cursor) - end = time.perf_counter() - print("query_node_info time:", end - start) return nodes except Exception as e: logger.error(f"Failed to query node info: {e}") @@ -728,7 +662,6 @@ class GraphRepoDB(GraphRepo): def update_nodes_info(self, nodes_info, rank, step): # 取消匹配和匹配都要走这个逻辑 try: - start = time.perf_counter() data = [ ( json.dumps(node['matched_node_link']), @@ -758,17 +691,13 @@ class GraphRepoDB(GraphRepo): """ with self.conn as c: c.executemany(query, data) - end = time.perf_counter() - print("update_nodes_info time:", end - start) return True except Exception as e: logger.error(f"Failed to update nodes info: {e}") return False def update_nodes_precision_error(self, update_data): - try: - start = time.perf_counter() query = """ UPDATE tb_nodes @@ -782,12 +711,33 @@ class GraphRepoDB(GraphRepo): """ self.conn.executemany(query, update_data) self.conn.commit() - end = time.perf_counter() - print("update_precision_error time:", end - start) return True except Exception as e: logger.error(f"Failed to update precision error: {e}") return False + + def _initialize_db_connection(self): + try: + # 目录安全校验 + dir_path = str(os.path.dirname(self.db_path)) + success, error = GraphUtils.safe_check_load_file_path(dir_path, True) + if not success: + raise PermissionError(error) + # 文件安全校验 + success, error = GraphUtils.safe_check_load_file_path(self.db_path) + if not success: + raise PermissionError(error) + self.conn = sqlite3.connect(self.db_path, check_same_thread=False) + self.conn.row_factory = sqlite3.Row + self.is_db_connected = self.conn is not None + # 提升性能的 PRAGMA 设置 + self.conn.execute("PRAGMA journal_mode = WAL;") + self.conn.execute("PRAGMA synchronous = NORMAL;") # 或 OFF(不安全) + self.conn.execute("PRAGMA cache_size = 40000;") + self.conn.execute("PRAGMA wal_autocheckpoint = 0;") + except Exception as e: + logger.error(f"Failed to connect to database: {e}") + self.conn = None def _fetch_and_convert_rows(self, cursor): """ @@ -803,34 +753,34 @@ class GraphRepoDB(GraphRepo): def _convert_to_graph_json(self, npu_nodes, bench_nodes): graph_data = { - "NPU":{ + "NPU": { "node": npu_nodes, }, - "Bench":{ + "Bench": { "node": bench_nodes, } } return graph_data def _convert_db_to_object(self, data): - object = { + object_res = { "id": data.get('node_name'), "node_name": data.get('node_name'), "node_type": int(data.get('node_type')) if data.get('node_type') is not None else 0, "output_data": GraphUtils.safe_json_loads(data.get('output_data') or "{}"), "input_data": GraphUtils.safe_json_loads(data.get('input_data') or "{}"), - "upnode":data.get('up_node'), - "subnodes":GraphUtils.safe_json_loads(data.get('sub_nodes') or "[]"), - "matched_node_link":GraphUtils.safe_json_loads(data.get('matched_node_link') or "[]"), - "stack_info":GraphUtils.safe_json_loads(data.get('stack_info') or "[]"), + "upnode": data.get('up_node'), + "subnodes": GraphUtils.safe_json_loads(data.get('sub_nodes') or "[]"), + "matched_node_link": GraphUtils.safe_json_loads(data.get('matched_node_link') or "[]"), + "stack_info": GraphUtils.safe_json_loads(data.get('stack_info') or "[]"), "micro_step_id": int(data.get('micro_step_id')) if data.get('micro_step_id') is not None else -1, - "data":{ + "data": { "precision_index": data.get('precision_index'), 'overflow_level': data.get('overflow_level'), }, "parallel_merge_info": GraphUtils.safe_json_loads(data.get('parallel_merge_info') or "[]"), "matched_distributed": GraphUtils.safe_json_loads(data.get('matched_distributed') or "[]"), - "modified":int(data.get('modified')) if data.get('modified') is not None else 0, + "modified": int(data.get('modified')) if data.get('modified') is not None else 0, } - return object + return object_res diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_vis.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_vis.py index 998901e56..650266582 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_vis.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_vis.py @@ -66,7 +66,7 @@ class GraphRepoVis(GraphRepo): up_nodes = {} up_nodes[node_name] = graph_nodes.get(node_name, {}) parent_node_name = graph_nodes.get(node_name, {}).get("upnode") - while graph_nodes.get(parent_node_name, None) != None: + while graph_nodes.get(parent_node_name, None) is not None: parent_node = graph_nodes.get(parent_node_name, {}) up_nodes[parent_node_name] = parent_node parent_node_name = parent_node.get("upnode") -- Gitee From 469226a5cefcc2dd7303c0c5b8cbe8c690e57d86 Mon Sep 17 00:00:00 2001 From: sunchao <1299792067@qq.com> Date: Sat, 30 Aug 2025 15:02:38 +0800 Subject: [PATCH 3/4] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E6=9F=A5=E8=AF=A2=E4=BB=A3=E7=A0=81=E7=BB=93?= =?UTF-8?q?=E6=9E=84=EF=BC=8C=E5=B0=86=E5=8F=98=E9=87=8F=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E7=A7=BB=E8=87=B3try=E5=9D=97=E5=A4=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../server/app/repositories/graph_repo_db.py | 506 +++++++++--------- 1 file changed, 252 insertions(+), 254 deletions(-) diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py index 24b16c8d2..f716db170 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py @@ -65,26 +65,27 @@ class GraphRepoDB(GraphRepo): # DB:查询根节点信息 def query_root_nodes(self, graph_type, rank, step): + + graph_type = graph_type if graph_type != SINGLE else NPU + query = """ + SELECT + node_name, + up_node, + sub_nodes, + node_type, + matched_node_link, + precision_index, + overflow_level, + matched_distributed + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND data_source = ? + AND up_node = '' + """ try: - graph_type = graph_type if graph_type != SINGLE else NPU - query = """ - SELECT - node_name, - up_node, - sub_nodes, - node_type, - matched_node_link, - precision_index, - overflow_level, - matched_distributed - FROM - tb_nodes - WHERE - step = ? - AND rank = ? - AND data_source = ? - AND up_node = '' - """ with self.conn as c: cursor = c.execute(query, (step, rank, graph_type)) rows = cursor.fetchall() @@ -98,67 +99,67 @@ class GraphRepoDB(GraphRepo): # DB:查询当前节点的所有父节点信息 def query_up_nodes(self, node_name, graph_type, rank, step): - try: - graph_type = graph_type if graph_type != SINGLE else NPU - # 现根据节点名称查询节点信息,根据up_node字段得到父节点名称 - # 再根据父节点名称查询父节点信息 - # 递归查询父节点,直到根节点 - query = """ - WITH RECURSIVE parent_chain AS ( - SELECT child.id, child.node_name, child.up_node, child.data_source, child.rank, child.step, 0 AS level - FROM - tb_nodes child - WHERE - child.step = ? - AND child.rank = ? - AND child.data_source = ? - AND child.node_name = ? + graph_type = graph_type if graph_type != SINGLE else NPU + # 现根据节点名称查询节点信息,根据up_node字段得到父节点名称 + # 再根据父节点名称查询父节点信息 + # 递归查询父节点,直到根节点 + query = """ + WITH RECURSIVE parent_chain AS ( + SELECT child.id, child.node_name, child.up_node, child.data_source, child.rank, child.step, 0 AS level + FROM + tb_nodes child + WHERE + child.step = ? + AND child.rank = ? + AND child.data_source = ? + AND child.node_name = ? - UNION ALL + UNION ALL - SELECT - parent.id, - parent.node_name, - parent.up_node, - parent.data_source, - parent.rank, - parent.step, - pc.level + 1 - FROM - tb_nodes parent - INNER JOIN parent_chain pc - ON parent.data_source = pc.data_source - AND parent.node_name = pc.up_node - AND parent.rank = pc.rank - AND parent.step = pc.step - WHERE - pc.up_node IS NOT NULL - AND pc.up_node != '' - ) SELECT - tb_nodes.id, - tb_nodes.data_source, - tb_nodes.node_name, - tb_nodes.up_node, - tb_nodes.sub_nodes, - tb_nodes.node_type, - tb_nodes.matched_node_link, - tb_nodes.precision_index, - tb_nodes.overflow_level, - tb_nodes.matched_distributed + parent.id, + parent.node_name, + parent.up_node, + parent.data_source, + parent.rank, + parent.step, + pc.level + 1 FROM - tb_nodes + tb_nodes parent + INNER JOIN parent_chain pc + ON parent.data_source = pc.data_source + AND parent.node_name = pc.up_node + AND parent.rank = pc.rank + AND parent.step = pc.step WHERE - id IN (SELECT id FROM parent_chain) - ORDER BY ( - SELECT - level - FROM - parent_chain pc - WHERE - pc.node_name = tb_nodes.node_name) - ASC - """ + pc.up_node IS NOT NULL + AND pc.up_node != '' + ) + SELECT + tb_nodes.id, + tb_nodes.data_source, + tb_nodes.node_name, + tb_nodes.up_node, + tb_nodes.sub_nodes, + tb_nodes.node_type, + tb_nodes.matched_node_link, + tb_nodes.precision_index, + tb_nodes.overflow_level, + tb_nodes.matched_distributed + FROM + tb_nodes + WHERE + id IN (SELECT id FROM parent_chain) + ORDER BY ( + SELECT + level + FROM + parent_chain pc + WHERE + pc.node_name = tb_nodes.node_name) + ASC + """ + try: with self.conn as c: cursor = c.execute(query, (step, rank, graph_type, node_name)) rows = cursor.fetchall() @@ -173,34 +174,34 @@ class GraphRepoDB(GraphRepo): # DB: 查询待匹配节点的信息,构造graph data def query_matched_nodes_info(self, npu_node_name, bench_node_name, rank, step): + query = """ + SELECT + id, + node_name, + node_type, + up_node, + sub_nodes, + data_source, + input_data, + output_data, + matched_node_link + FROM tb_nodes + WHERE step = ? AND rank = ? AND data_source = ? AND node_name = ? + """ + npu_nodes = {} + bench_nodes = {} + opposite_npu_node_name = GraphUtils.get_opposite_node_name(npu_node_name) + opposite_bench_node_name = GraphUtils.get_opposite_node_name(bench_node_name) + # 定义查询参数列表:(graph_type, node_name, target_dict_key) + queries = [ + (NPU, npu_node_name, 'npu'), + (NPU, opposite_npu_node_name, 'npu_opposite'), + (BENCH, bench_node_name, 'bench'), + (BENCH, opposite_bench_node_name, 'bench_opposite'), + ] + # 存储结果的字典 + nodes_dict = {} try: - query = """ - SELECT - id, - node_name, - node_type, - up_node, - sub_nodes, - data_source, - input_data, - output_data, - matched_node_link - FROM tb_nodes - WHERE step = ? AND rank = ? AND data_source = ? AND node_name = ? - """ - npu_nodes = {} - bench_nodes = {} - opposite_npu_node_name = GraphUtils.get_opposite_node_name(npu_node_name) - opposite_bench_node_name = GraphUtils.get_opposite_node_name(bench_node_name) - # 定义查询参数列表:(graph_type, node_name, target_dict_key) - queries = [ - (NPU, npu_node_name, 'npu'), - (NPU, opposite_npu_node_name, 'npu_opposite'), - (BENCH, bench_node_name, 'bench'), - (BENCH, opposite_bench_node_name, 'bench_opposite'), - ] - # 存储结果的字典 - nodes_dict = {} with self.conn as c: for graph_type, node_name, key in queries: if not node_name: # 可选:跳过空 node_name @@ -223,71 +224,70 @@ class GraphRepoDB(GraphRepo): # DB: 查询待匹配节点及其子节点的信息,递归查询当前节点信息和其所有的子节点信息,一直叶子节点 def query_node_and_sub_nodes(self, npu_node_name, bench_node_name, rank, step): - try: - query = """ - WITH RECURSIVE descendants AS ( - -- 初始节点选择 - SELECT - id, - node_name, - node_type, - up_node, - sub_nodes, - data_source, - input_data, - output_data, - matched_node_link, - node_order, - step, - rank - FROM tb_nodes - WHERE step = ? AND rank = ? AND data_source = ? AND node_name = ? - - UNION ALL + query = """ + WITH RECURSIVE descendants AS ( + -- 初始节点选择 + SELECT + id, + node_name, + node_type, + up_node, + sub_nodes, + data_source, + input_data, + output_data, + matched_node_link, + node_order, + step, + rank + FROM tb_nodes + WHERE step = ? AND rank = ? AND data_source = ? AND node_name = ? - -- 递归部分 - SELECT - child.id, - child.node_name, - child.node_type, - child.up_node, - child.sub_nodes, - child.data_source, - child.input_data, - child.output_data, - child.matched_node_link, - child.node_order, - child.step, - child.rank - FROM descendants d - JOIN json_each(d.sub_nodes) AS je -- 将 sub_nodes JSON 数组展开为多行 - JOIN tb_nodes child - ON child.node_name = je.value -- 子节点名称匹配 - AND child.step = d.step - AND child.rank = d.rank - AND child.data_source = d.data_source - WHERE - d.sub_nodes IS NOT NULL -- 父节点的 sub_nodes 不为 NULL - AND d.sub_nodes != '' -- 不是空 - AND d.sub_nodes != '[]' - AND json_type(d.sub_nodes) = 'array' -- 确保是合法 JSON 数组 - ) - SELECT * FROM descendants - """ + UNION ALL - npu_nodes = {} - bench_nodes = {} - opposite_npu_node_name = GraphUtils.get_opposite_node_name(npu_node_name) - opposite_bench_node_name = GraphUtils.get_opposite_node_name(bench_node_name) - # 定义查询参数列表:(graph_type, node_name, target_dict_key) - queries = [ - (NPU, npu_node_name, 'npu'), - (NPU, opposite_npu_node_name, 'npu_opposite'), - (BENCH, bench_node_name, 'bench'), - (BENCH, opposite_bench_node_name, 'bench_opposite'), - ] - # 存储结果的字典 - nodes_dict = {} + -- 递归部分 + SELECT + child.id, + child.node_name, + child.node_type, + child.up_node, + child.sub_nodes, + child.data_source, + child.input_data, + child.output_data, + child.matched_node_link, + child.node_order, + child.step, + child.rank + FROM descendants d + JOIN json_each(d.sub_nodes) AS je -- 将 sub_nodes JSON 数组展开为多行 + JOIN tb_nodes child + ON child.node_name = je.value -- 子节点名称匹配 + AND child.step = d.step + AND child.rank = d.rank + AND child.data_source = d.data_source + WHERE + d.sub_nodes IS NOT NULL -- 父节点的 sub_nodes 不为 NULL + AND d.sub_nodes != '' -- 不是空 + AND d.sub_nodes != '[]' + AND json_type(d.sub_nodes) = 'array' -- 确保是合法 JSON 数组 + ) + SELECT * FROM descendants + """ + npu_nodes = {} + bench_nodes = {} + opposite_npu_node_name = GraphUtils.get_opposite_node_name(npu_node_name) + opposite_bench_node_name = GraphUtils.get_opposite_node_name(bench_node_name) + # 定义查询参数列表:(graph_type, node_name, target_dict_key) + queries = [ + (NPU, npu_node_name, 'npu'), + (NPU, opposite_npu_node_name, 'npu_opposite'), + (BENCH, bench_node_name, 'bench'), + (BENCH, opposite_bench_node_name, 'bench_opposite'), + ] + # 存储结果的字典 + nodes_dict = {} + try: with self.conn as c: for graph_type, node_name, key in queries: if not node_name: # 可选:跳过空 node_name @@ -304,31 +304,30 @@ class GraphRepoDB(GraphRepo): # DB:查询配置文件中的待匹配节点信息 def query_matched_nodes_info_by_config(self, match_node_links, rank, step): + query = """ + SELECT + id, + node_name, + node_type, + up_node, + sub_nodes, + data_source, + input_data, + output_data, + matched_node_link + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND data_source = ? + AND node_name IN ({}) + """.format(','.join(['?'] * len(match_node_links))) try: - query = """ - SELECT - id, - node_name, - node_type, - up_node, - sub_nodes, - data_source, - input_data, - output_data, - matched_node_link - FROM - tb_nodes - WHERE - step = ? - AND rank = ? - AND data_source = ? - AND node_name IN ({}) - """.format(','.join(['?'] * len(match_node_links))) - with self.conn as c: npu_node_names = list(match_node_links.keys()) bench_node_names = list(match_node_links.values()) - npu_cursor = c.execute(query, (step, rank, NPU , *npu_node_names)) + npu_cursor = c.execute(query, (step, rank, NPU, *npu_node_names)) bench_cursor = c.execute(query, (step, rank, BENCH, *bench_node_names)) npu_nodes = self._fetch_and_convert_rows(npu_cursor) bench_nodes = self._fetch_and_convert_rows(bench_cursor) @@ -340,29 +339,29 @@ class GraphRepoDB(GraphRepo): # DB: 查询所有以当前为父节点的子节点 def query_sub_nodes(self, node_name, graph_type, rank, step): + graph_type = graph_type if graph_type != SINGLE else NPU + query = """ + SELECT + node_name, + up_node, + sub_nodes, + node_type, + micro_step_id, + matched_node_link, + precision_index, + overflow_level, + matched_distributed + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND data_source = ? + AND up_node = ? + ORDER BY + node_order ASC + """ try: - graph_type = graph_type if graph_type != SINGLE else NPU - query = """ - SELECT - node_name, - up_node, - sub_nodes, - node_type, - micro_step_id, - matched_node_link, - precision_index, - overflow_level, - matched_distributed - FROM - tb_nodes - WHERE - step = ? - AND rank = ? - AND data_source = ? - AND up_node = ? - ORDER BY - node_order ASC - """ with self.conn as c: cursor = c.execute(query, (step, rank, graph_type, node_name)) rows = cursor.fetchall() @@ -377,19 +376,19 @@ class GraphRepoDB(GraphRepo): # DB: 查询当前节点信息 def query_node_info(self, node_name, graph_type, rank, step): + graph_type = graph_type if graph_type != SINGLE else NPU + query = """ + SELECT + * + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND data_source = ? + AND node_name = ? + """ try: - graph_type = graph_type if graph_type != SINGLE else NPU - query = """ - SELECT - * - FROM - tb_nodes - WHERE - step = ? - AND rank = ? - AND data_source = ? - AND node_name = ? - """ with self.conn as c: cursor = c.execute(query, (step, rank, graph_type, node_name)) rows = cursor.fetchall() @@ -510,21 +509,21 @@ class GraphRepoDB(GraphRepo): # # DB:根据step rank modify match_node_link查询已经修改的匹配成功的节点关系 def query_modify_matched_nodes_list(self, rank, step): + query = """ + SELECT + node_name, + matched_node_link + FROM + tb_nodes + WHERE + step = ? + AND rank = ? + AND modified = 1 + AND matched_node_link IS NOT NULL + AND matched_node_link != '[]' + AND matched_node_link != '' + """ try: - query = """ - SELECT - node_name, - matched_node_link - FROM - tb_nodes - WHERE - step = ? - AND rank = ? - AND modified = 1 - AND matched_node_link IS NOT NULL - AND matched_node_link != '[]' - AND matched_node_link != '' - """ with self.conn as c: cursor = c.execute(query, (step, rank)) rows = cursor.fetchall() @@ -541,7 +540,6 @@ class GraphRepoDB(GraphRepo): # DB: 根据精度误差查询节点信息 def query_node_list_by_precision(self, step, rank, micro_step, values, is_filter_unmatch_nodes): - # 准备占位符 conditions = [] placeholders = [] @@ -642,15 +640,15 @@ class GraphRepoDB(GraphRepo): # DB:更新config的colors def update_config_colors(self, colors): + query = """ + UPDATE + tb_config + SET + node_colors = ? + WHERE + id=1 + """ try: - query = """ - UPDATE - tb_config - SET - node_colors = ? - WHERE - id=1 - """ with self.conn as c: c.execute(query, (json.dumps(colors),)) return True @@ -697,18 +695,18 @@ class GraphRepoDB(GraphRepo): return False def update_nodes_precision_error(self, update_data): + query = """ + UPDATE + tb_nodes + SET + precision_index = ? + WHERE + step = ? + AND rank = ? + AND data_source = 'NPU' + AND node_name = ? + """ try: - query = """ - UPDATE - tb_nodes - SET - precision_index = ? - WHERE - step = ? - AND rank = ? - AND data_source = 'NPU' - AND node_name = ? - """ self.conn.executemany(query, update_data) self.conn.commit() return True -- Gitee From 14ef9f5f1483d4443f27604c146827450a93fbc8 Mon Sep 17 00:00:00 2001 From: sunchao <1299792067@qq.com> Date: Mon, 1 Sep 2025 19:36:44 +0800 Subject: [PATCH 4/4] =?UTF-8?q?refactor(graph):=20=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E8=8A=82=E7=82=B9=E6=9F=A5=E8=AF=A2SQL=E5=B9=B6=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E5=A0=86=E6=A0=88=E4=BF=A1=E6=81=AF=E5=85=B3=E8=81=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../server/app/repositories/graph_repo_db.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py index f716db170..285d53b16 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/repositories/graph_repo_db.py @@ -379,14 +379,16 @@ class GraphRepoDB(GraphRepo): graph_type = graph_type if graph_type != SINGLE else NPU query = """ SELECT - * + n.*, + d.stack_info FROM - tb_nodes + tb_nodes n + JOIN tb_stack d ON n.stack_id = d.id WHERE - step = ? - AND rank = ? - AND data_source = ? - AND node_name = ? + n.step = ? + AND n.rank = ? + AND n.data_source = ? + AND n.node_name = ? """ try: with self.conn as c: -- Gitee