From 979af92a5f3c95c7a275e5f6a16aa25b2cf731c1 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Mon, 1 Sep 2025 19:27:57 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=86=E7=BA=A7=E5=8F=AF=E8=A7=86=E5=8C=96?= =?UTF-8?q?=E8=BD=ACdb=E4=BC=98=E5=8C=96stack=E4=BF=A1=E6=81=AF=E5=AD=98?= =?UTF-8?q?=E5=82=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/visualization/db_utils.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/visualization/db_utils.py b/debug/accuracy_tools/msprobe/visualization/db_utils.py index 3988ebecc..b9397e20b 100644 --- a/debug/accuracy_tools/msprobe/visualization/db_utils.py +++ b/debug/accuracy_tools/msprobe/visualization/db_utils.py @@ -41,7 +41,7 @@ node_columns = { 'overflow_level': TEXT, 'micro_step_id': INTEGER_NOT_NULL, 'matched_node_link': TEXT, - 'stack_info': TEXT, + 'stack_id': TEXT, 'parallel_merge_info': TEXT, 'matched_distributed': TEXT, 'modified': INTEGER_NOT_NULL, @@ -65,6 +65,11 @@ config_columns = { 'step_list': TEXT_NOT_NULL } +stack_columns = { + 'id': TEXT_PRIMARY_KEY, + 'stack_info': TEXT +} + indexes = { "index1": ["step", "rank", "data_source", "up_node", "node_order"], "index2": ["step", "rank", "data_source", "node_name"], @@ -197,19 +202,24 @@ def node_to_db(graph, db_name): create_table_sql = create_table_sql_from_dict('tb_nodes', node_columns) insert_sql = create_insert_sql_from_dict('tb_nodes', node_columns) data = [] + stack_dict = {} for i, node in enumerate(graph.get_sorted_nodes()): + stack_info_text = json.dumps(node.stack_info) + if stack_info_text not in stack_dict: + stack_dict[stack_info_text] = get_stack_unique_id(graph, stack_dict) data.append((get_node_unique_id(graph, node), get_graph_unique_id(graph), i, node.id, node.op.value, node.upnode.id if node.upnode else '', json.dumps([node.id for node in node.subnodes]) if node.subnodes else '', node.data.get(GraphConst.JSON_INDEX_KEY), node.data.get(GraphConst.OVERFLOW_LEVEL), node.micro_step_id if node.micro_step_id is not None else 0, json.dumps(node.matched_node_link), - json.dumps(node.stack_info), + stack_dict.get(stack_info_text), json.dumps(node.parallel_merge_info) if node.parallel_merge_info else '', json.dumps(node.matched_distributed), 0, json.dumps(format_node_data(node.input_data, node.id, graph.compare_mode)), json.dumps(format_node_data(node.output_data, node.id, graph.compare_mode)), graph.data_source, graph.data_path, graph.step, graph.rank)) to_db(db_name, create_table_sql, insert_sql, data) + stack_to_db(stack_dict, db_name) def config_to_db(config, db_name): @@ -221,9 +231,21 @@ def config_to_db(config, db_name): to_db(db_name, create_table_sql, insert_sql, data) +def stack_to_db(stack_dict, db_name): + create_table_sql = create_table_sql_from_dict('tb_stack', stack_columns) + insert_sql = create_insert_sql_from_dict('tb_stack', stack_columns) + data = [] + for stack_info_text, unique_id in stack_dict.items(): + data.append((unique_id, stack_info_text)) + to_db(db_name, create_table_sql, insert_sql, data) + + def get_graph_unique_id(graph): return f'{graph.data_source}_{graph.step}_{graph.rank}' def get_node_unique_id(graph, node): return f'{get_graph_unique_id(graph)}_{node.id}' + +def get_stack_unique_id(graph, stack_dict): + return f'{get_graph_unique_id(graph)}_{len(stack_dict)}' -- Gitee