diff --git a/debug/accuracy_tools/msprobe/visualization/db_utils.py b/debug/accuracy_tools/msprobe/visualization/db_utils.py index 3988ebecc88863b7e8d65724719cf132971f116e..b9397e20beca5ae34ac81f8280e91fd1405f8790 100644 --- a/debug/accuracy_tools/msprobe/visualization/db_utils.py +++ b/debug/accuracy_tools/msprobe/visualization/db_utils.py @@ -41,7 +41,7 @@ node_columns = { 'overflow_level': TEXT, 'micro_step_id': INTEGER_NOT_NULL, 'matched_node_link': TEXT, - 'stack_info': TEXT, + 'stack_id': TEXT, 'parallel_merge_info': TEXT, 'matched_distributed': TEXT, 'modified': INTEGER_NOT_NULL, @@ -65,6 +65,11 @@ config_columns = { 'step_list': TEXT_NOT_NULL } +stack_columns = { + 'id': TEXT_PRIMARY_KEY, + 'stack_info': TEXT +} + indexes = { "index1": ["step", "rank", "data_source", "up_node", "node_order"], "index2": ["step", "rank", "data_source", "node_name"], @@ -197,19 +202,24 @@ def node_to_db(graph, db_name): create_table_sql = create_table_sql_from_dict('tb_nodes', node_columns) insert_sql = create_insert_sql_from_dict('tb_nodes', node_columns) data = [] + stack_dict = {} for i, node in enumerate(graph.get_sorted_nodes()): + stack_info_text = json.dumps(node.stack_info) + if stack_info_text not in stack_dict: + stack_dict[stack_info_text] = get_stack_unique_id(graph, stack_dict) data.append((get_node_unique_id(graph, node), get_graph_unique_id(graph), i, node.id, node.op.value, node.upnode.id if node.upnode else '', json.dumps([node.id for node in node.subnodes]) if node.subnodes else '', node.data.get(GraphConst.JSON_INDEX_KEY), node.data.get(GraphConst.OVERFLOW_LEVEL), node.micro_step_id if node.micro_step_id is not None else 0, json.dumps(node.matched_node_link), - json.dumps(node.stack_info), + stack_dict.get(stack_info_text), json.dumps(node.parallel_merge_info) if node.parallel_merge_info else '', json.dumps(node.matched_distributed), 0, json.dumps(format_node_data(node.input_data, node.id, graph.compare_mode)), json.dumps(format_node_data(node.output_data, node.id, graph.compare_mode)), graph.data_source, graph.data_path, graph.step, graph.rank)) to_db(db_name, create_table_sql, insert_sql, data) + stack_to_db(stack_dict, db_name) def config_to_db(config, db_name): @@ -221,9 +231,21 @@ def config_to_db(config, db_name): to_db(db_name, create_table_sql, insert_sql, data) +def stack_to_db(stack_dict, db_name): + create_table_sql = create_table_sql_from_dict('tb_stack', stack_columns) + insert_sql = create_insert_sql_from_dict('tb_stack', stack_columns) + data = [] + for stack_info_text, unique_id in stack_dict.items(): + data.append((unique_id, stack_info_text)) + to_db(db_name, create_table_sql, insert_sql, data) + + def get_graph_unique_id(graph): return f'{graph.data_source}_{graph.step}_{graph.rank}' def get_node_unique_id(graph, node): return f'{get_graph_unique_id(graph)}_{node.id}' + +def get_stack_unique_id(graph, stack_dict): + return f'{get_graph_unique_id(graph)}_{len(stack_dict)}'