From fdad482779252d2639e1ad5d7a5c343304a34a39 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Thu, 31 Jul 2025 20:20:07 +0800 Subject: [PATCH 01/19] add ut --- .../test/core_ut/common/test_db_manager.py | 241 ++++++++++++++++ .../test/core_ut/monitor/test_csv2db.py | 273 ++++++++++++++++++ .../test/core_ut/monitor/test_db_utils.py | 256 ++++++++++++++++ 3 files changed, 770 insertions(+) create mode 100644 debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py create mode 100644 debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py create mode 100644 debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py new file mode 100644 index 000000000..451f9d542 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py @@ -0,0 +1,241 @@ +import unittest +import sqlite3 +import os +import tempfile +from typing import Dict, List +from unittest.mock import patch, MagicMock + +from msprobe.pytorch.common.log import logger +from msprobe.core.common.db_manager import DBManager + +class TestDBManager(unittest.TestCase): + def setUp(self): + # 创建临时数据库文件 + self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db') + self.db_path = self.temp_db.name + self.db_manager = DBManager(self.db_path) + + # 创建测试表 + self.test_table = "test_table" + self.create_test_table() + + def tearDown(self): + # 关闭并删除临时数据库文件 + if hasattr(self, 'temp_db'): + self.temp_db.close() + os.unlink(self.db_path) + + def create_test_table(self): + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute(f""" + CREATE TABLE IF NOT EXISTS {self.test_table} ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + value INTEGER, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP + ) + """) + conn.commit() + + def test_get_connection_success(self): + """测试成功获取数据库连接""" + conn, curs = self.db_manager._get_connection() + self.assertIsInstance(conn, sqlite3.Connection) + self.assertIsInstance(curs, sqlite3.Cursor) + self.db_manager._release_connection(conn, curs) + + @patch.object(logger, 'error') + def test_get_connection_success_failed(self, mock_logger): + """测试错误日志记录""" + with patch('sqlite3.connect', side_effect=sqlite3.Error("Test error")): + with self.assertRaises(sqlite3.Error): + self.db_manager._get_connection() + mock_logger.assert_called_with("Database connection failed: Test error") + + + def test_insert_data_basic(self): + """测试基本数据插入""" + test_data = [ + (1, "item1", 100), + (2, "item2", 200) + ] + columns = ["id", "name", "value"] + + inserted = self.db_manager.insert_data( + table_name=self.test_table, + data=test_data, + key_list=columns + ) + self.assertEqual(inserted, 2) + + # 验证数据是否实际插入 + results = self.db_manager.select_data( + table_name=self.test_table, + columns=["id", "name", "value"] + ) + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["name"], "item1") + + def test_insert_data_without_keys(self): + """测试无列名的数据插入""" + test_data = [ + (3, "item3", 300), + (4, "item4", 400) + ] + + inserted = self.db_manager.insert_data( + table_name=self.test_table, + data=test_data + ) + self.assertEqual(inserted, 2) + + def test_insert_data_empty(self): + """测试空数据插入""" + inserted = self.db_manager.insert_data( + table_name=self.test_table, + data=[] + ) + self.assertEqual(inserted, 0) + + def test_insert_data_mismatch_keys(self): + """测试列名与数据不匹配的情况""" + test_data = [(5, "item5")] + with self.assertRaises(ValueError): + self.db_manager.insert_data( + table_name=self.test_table, + data=test_data, + key_list=["id", "name", "value"] # 多了一个列 + ) + + def test_select_data_basic(self): + """测试基本数据查询""" + # 先插入测试数据 + self.db_manager.insert_data( + table_name=self.test_table, + data=[(10, "test10", 1000)], + key_list=["id", "name", "value"] + ) + + results = self.db_manager.select_data( + table_name=self.test_table, + columns=["name", "value"], + where={"id": 10} + ) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "test10") + self.assertEqual(results[0]["value"], 1000) + + def test_select_data_no_where(self): + """测试无条件查询""" + # 插入多条数据 + test_data = [ + (20, "item20", 2000), + (21, "item21", 2100) + ] + self.db_manager.insert_data( + table_name=self.test_table, + data=test_data, + key_list=["id", "name", "value"] + ) + + results = self.db_manager.select_data( + table_name=self.test_table + ) + self.assertGreaterEqual(len(results), 2) + + def test_update_data_basic(self): + """测试基本数据更新""" + # 先插入测试数据 + self.db_manager.insert_data( + table_name=self.test_table, + data=[(30, "old_name", 3000)], + key_list=["id", "name", "value"] + ) + + updated = self.db_manager.update_data( + table_name=self.test_table, + updates={"name": "new_name", "value": 3500}, + where={"id": 30} + ) + self.assertEqual(updated, 1) + + # 验证更新结果 + results = self.db_manager.select_data( + table_name=self.test_table, + where={"id": 30} + ) + self.assertEqual(results[0]["name"], "new_name") + self.assertEqual(results[0]["value"], 3500) + + def test_execute_sql_select(self): + """测试执行SELECT SQL语句""" + self.db_manager.insert_data( + table_name=self.test_table, + data=[(50, "sql_item", 5000)], + key_list=["id", "name", "value"] + ) + + results = self.db_manager.execute_sql( + sql=f"SELECT name, value FROM {self.test_table} WHERE id = ?", + params=(50,) + ) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["name"], "sql_item") + + def test_execute_sql_non_select(self): + """测试执行非SELECT SQL语句""" + # 先插入数据 + self.db_manager.insert_data( + table_name=self.test_table, + data=[(60, "to_delete", 6000)], + key_list=["id", "name", "value"] + ) + + # 执行DELETE语句 + self.db_manager.execute_sql( + sql=f"DELETE FROM {self.test_table} WHERE id = 60" + ) + + # 验证数据已被删除 + results = self.db_manager.select_data( + table_name=self.test_table, + where={"id": 60} + ) + self.assertEqual(len(results), 0) + + def test_table_exists_true(self): + """测试表存在检查(存在的情况)""" + exists = self.db_manager.table_exists(self.test_table) + self.assertTrue(exists) + + def test_table_exists_false(self): + """测试表存在检查(不存在的情况)""" + exists = self.db_manager.table_exists("non_existent_table") + self.assertFalse(exists) + + def test_execute_multi_sql(self): + """测试批量执行多个SQL语句""" + sql_commands = [ + f"INSERT INTO {self.test_table} (id, name, value) VALUES (70, 'multi1', 7000)", + f"INSERT INTO {self.test_table} (id, name, value) VALUES (71, 'multi2', 7100)", + f"SELECT * FROM {self.test_table} WHERE id IN (70, 71)" + ] + + results = self.db_manager.execute_multi_sql(sql_commands) + + # 应该只有最后一个SELECT语句有结果 + self.assertEqual(len(results), 1) + self.assertEqual(len(results[0]), 2) + + @patch.object(logger, 'error') + def test_db_operation_decorator(self, mock_logger): + """测试数据库操作装饰器""" + # 模拟一个会失败的操作 + with patch.object(self.db_manager, '_get_connection', + side_effect=sqlite3.Error("Test error")): + result = self.db_manager.select_data(table_name=self.test_table) + self.assertIsNone(result) # 装饰器会捕获异常并返回None + mock_logger.assert_called_with("Database operation failed: Test error") diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py new file mode 100644 index 000000000..aa2c5c3f0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py @@ -0,0 +1,273 @@ +import unittest +import os +import tempfile +import shutil +from unittest.mock import patch, MagicMock +import pandas as pd + +from msprobe.core.monitor.csv2db import ( + CSV2DBConfig, + validate_process_num, + validate_step_partition, + validate_data_type_list, + _pre_scan_single_rank, + _pre_scan, + process_single_rank, + import_data, + csv2db, + all_data_type_list, + MAX_PROCESS_NUM, +) + + +class TestCSV2DBValidations(unittest.TestCase): + def test_validate_process_num_valid(self): + """测试有效的进程数""" + validate_process_num(1) + validate_process_num(MAX_PROCESS_NUM) + + def test_validate_process_num_invalid(self): + """测试无效的进程数""" + with self.assertRaises(ValueError): + validate_process_num(0) + with self.assertRaises(ValueError): + validate_process_num(-1) + with self.assertRaises(ValueError): + validate_process_num(MAX_PROCESS_NUM + 1) + + def test_validate_step_partition_valid(self): + """测试有效的step分区""" + validate_step_partition(1) + validate_step_partition(500) + + def test_validate_step_partition_invalid(self): + """测试无效的step分区""" + with self.assertRaises(ValueError): + validate_step_partition(0) + with self.assertRaises(ValueError): + validate_step_partition(-1) + + def test_validate_data_type_list_valid(self): + """测试有效的数据类型列表""" + validate_data_type_list(["actv", "grad_reduced"]) + validate_data_type_list(all_data_type_list[:2]) + + def test_validate_data_type_list_invalid(self): + """测试无效的数据类型列表""" + with self.assertRaises(ValueError): + validate_data_type_list(["invalid_type"]) + with self.assertRaises(ValueError): + validate_data_type_list(["actv", "invalid_type"]) + + +class TestPreScanFunctions(unittest.TestCase): + def setUp(self): + # 创建临时目录和测试CSV文件 + self.temp_dir = tempfile.mkdtemp() + self.temp_dir_rank2 = tempfile.mkdtemp() + self.test_csv_path = os.path.join(self.temp_dir, "actv_0-100.csv") + self.test_csv_path_rank2 = os.path.join( + self.temp_dir_rank2, "actv_0-100.csv") + + # 创建测试CSV数据 + test_data = { + "name": ["layer1", "layer2"], + "vpp_stage": [0, 0], + "micro_step": [0, 1], + "step": [10, 20], + "min": [0.1, 0.2], + "max": [1.0, 2.0] + } + df = pd.DataFrame(test_data) + df.to_csv(self.test_csv_path, index=False) + + def tearDown(self): + # 清理临时目录 + shutil.rmtree(self.temp_dir) + + def test_pre_scan_single_rank(self): + """测试单个rank的预扫描""" + rank = 0 + files = [self.test_csv_path] + result = _pre_scan_single_rank(rank, files) + self.assertEqual(result["max_rank"], rank) + self.assertEqual(result["metrics"], {"actv"}) + self.assertEqual(result["min_step"], 0) + self.assertEqual(result["max_step"], 100) + self.assertEqual(result["metric_stats"], {"actv": {"avg", "max"}}) + self.assertEqual(len(result["targets"]), 2) + + @patch("msprobe.core.monitor.csv2db.ProcessPoolExecutor") + @patch("msprobe.core.monitor.csv2db._pre_scan_single_rank") + def test_pre_scan(self, mock_single_scan, mock_executor): + """测试完整预扫描流程""" + # 模拟预扫描结果 + mock_result_actv = { + 'max_rank': 0, + 'metrics': {"actv"}, + 'min_step': 0, + 'max_step': 100, + 'metric_stats': {"actv": {"norm", "max"}}, + 'targets': [("layer1", 0, 0), ("layer2", 0, 1)] + } + mock_result_grad = { + 'max_rank': 1, + 'metrics': {"grad_reduced", "invaild_metric"}, + 'min_step': 0, + 'max_step': 200, + 'metric_stats': { + "grad_reduced": {"norm", "max"}, + "invaild_metric": {"norm", "max"} + }, + 'targets': [("layer1_weight", 0, 0), ("layer2_weight", 0, 1)] + } + mock_single_scan.side_effect = [mock_result_actv, mock_result_grad] + + # 模拟ProcessPoolExecutor + mock_future = MagicMock() + mock_future.result.side_effect = [mock_result_actv, mock_result_grad] + mock_executor.return_value.__enter__.return_value.submit.return_value = mock_future + + # 模拟MonitorDB + mock_db = MagicMock() + + # 测试数据 + data_dirs = {0: self.temp_dir, 2: self.temp_dir_rank2} + data_type_list = ["actv", "grad_reduced"] + + result = _pre_scan(mock_db, data_dirs, data_type_list) + + self.assertEqual(result, { + 0: [os.path.join(self.temp_dir, "actv_0-100.csv")], + 2: [os.path.join(self.temp_dir_rank2, "actv_0-100.csv")] + }) + + mock_db.insert_dimensions.assert_called_with( + [("layer1", 0, 0), ("layer2", 0, 1), + ("layer1_weight", 0, 0), ("layer2_weight", 0, 1)], + ["actv", "grad_reduced"], + {"actv": {"norm", "max"}, "grad_reduced": {"norm", "max"}}, + min_step=0, max_step=200 + ) + mock_db.update_global_stats.assert_called_with( + max_rank=2, min_step=0, max_step=200 + ) + + +class TestProcessSingleRank(unittest.TestCase): + @patch("msprobe.core.monitor.csv2db.MonitorDB") + @patch("msprobe.core.monitor.csv2db.read_csv") + def test_process_single_rank(self, mock_read_csv, mock_db_class): + """测试处理单个rank的数据""" + # 模拟数据库和映射 + mock_db = MagicMock() + mock_db_class.return_value = mock_db + mock_db.get_metric_table_name.return_value = ( + "metric_1_step_0_99", 0, 99) + mock_db.insert_rows.return_value = 2 + + # 模拟CSV数据 + mock_result = [ + { + "name": "layer1", + "vpp_stage": 0, + "micro_step": 0, + "step": 10, + "norm": 0.1, + "max": 1.0 + }, + { + "name": "layer2", + "vpp_stage": 0, + "micro_step": 1, + "step": 20, + "norm": 0.2, + "max": 2.0 + } + ] + mock_read_csv.return_value = mock_result + + # 测试数据 + task = (0, ["actv_10-20.csv"]) + metric_id_dict = {"actv": (1, ["norm", "max"])} + target_dict = {("layer1", 0, 0): 1, ("layer2", 0, 1): 2} + step_partition_size = 100 + db_path = "dummy.db" + + result = process_single_rank( + task, metric_id_dict, target_dict, step_partition_size, db_path) + + self.assertEqual(result, 2) + mock_db.insert_rows.assert_called_with( + [(0, 10, 1, 0.1, 1.0), (0, 20, 2, 0.2, 2.0)] + ) + + +class TestImportData(unittest.TestCase): + @patch("msprobe.core.monitor.csv2db._pre_scan") + @patch("msprobe.core.monitor.csv2db.ProcessPoolExecutor") + def test_import_data_success(self, mock_executor, mock_pre_scan): + """测试数据导入成功场景""" + # 模拟预扫描结果 + mock_pre_scan.return_value = { + 0: ["actv_10-20.csv"], 1: ["actv_10-20.csv"]} + + # 模拟数据库 + mock_db = MagicMock() + mock_db.get_metric_mapping.return_value = {"actv": (1, ["avg", "max"])} + mock_db.get_target_mapping.return_value = {("layer1", 0, 0): 1} + + # 模拟进程池结果 + mock_future = MagicMock() + mock_future.result.return_value = 10 + mock_executor.return_value.__enter__.return_value.submit.return_value = mock_future + + # 测试数据 + data_dirs = {0: "dir0", 1: "dir1"} + data_type_list = ["actv"] + workers = 2 + + result = import_data(mock_db, data_dirs, data_type_list, workers) + + mock_db.init_schema.assert_called_once() + self.assertTrue(result) + mock_pre_scan.assert_called_once() + + @patch("msprobe.core.monitor.csv2db._pre_scan") + def test_import_data_no_files(self, mock_pre_scan): + """测试没有找到数据文件的情况""" + mock_pre_scan.return_value = {} + + mock_db = MagicMock() + data_dirs = {0: "dir0"} + data_type_list = ["actv"] + + result = import_data(mock_db, data_dirs, data_type_list) + + self.assertFalse(result) + mock_pre_scan.assert_called_once() + + +class TestCSV2DBMain(unittest.TestCase): + @patch("msprobe.core.monitor.csv2db.import_data") + @patch("msprobe.core.monitor.csv2db.get_target_output_dir") + @patch("msprobe.core.monitor.csv2db.create_directory") + def test_csv2db(self, mock_chmod, mock_create_dir, mock_get_dirs, mock_import): + """测试主函数csv2db""" + # 模拟配置 + config = CSV2DBConfig( + monitor_path="test_path", + data_type_list=["actv"], + process_num=4, + step_partition=500 + ) + + # 模拟依赖函数 + mock_get_dirs.return_value = {0: "dir0", 1: "dir1"} + mock_import.return_value = True + + csv2db(config) + + mock_get_dirs.assert_called_once() + mock_create_dir.assert_called_once() + mock_import.assert_called_once() diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py new file mode 100644 index 000000000..ad24f70df --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py @@ -0,0 +1,256 @@ +import unittest +import sqlite3 +import os +import tempfile +from collections import OrderedDict +from unittest.mock import patch + +from msprobe.core.common.const import MonitorConst +from msprobe.core.monitor.db_utils import MonitorDB, MonitorSql, update_ordered_dict, get_ordered_stats + + +class TestDBUtils(unittest.TestCase): + def test_update_ordered_dict(self): + """测试update_ordered_dict函数""" + main_dict = OrderedDict([('a', 1), ('b', 2)]) + new_list = ['b', 'c', 'd'] + + result = update_ordered_dict(main_dict, new_list) + + self.assertEqual(list(result.keys()), ['a', 'b', 'c', 'd']) + self.assertEqual(result['a'], 1) + self.assertIsNone(result['c']) + + def test_get_ordered_stats(self): + """测试get_ordered_stats函数""" + test_stats = ['stat2', 'stat1', 'stat3'] + supported_stats = ['stat1', 'stat2', 'stat3', 'stat4'] + + with patch.object(MonitorConst, 'OP_MONVIS_SUPPORTED', supported_stats): + result = get_ordered_stats(test_stats) + + self.assertEqual(result, ['stat1', 'stat2', 'stat3']) + + def test_get_ordered_stats_with_non_iterable(self): + """测试get_ordered_stats处理非可迭代对象""" + result = get_ordered_stats(123) + self.assertEqual(result, []) + + +class TestMonitorSql(unittest.TestCase): + def test_get_table_definition_all_tables(self): + """测试获取所有表定义""" + result = MonitorSql.get_table_definition() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 4) + self.assertTrue(all("CREATE TABLE" in sql for sql in result)) + + def test_get_table_definition_single_table(self): + """测试获取单个表定义""" + for table in ["monitoring_targets", "monitoring_metrics", "metric_stats", "global_stats"]: + result = MonitorSql.get_table_definition(table) + self.assertIn(f"CREATE TABLE {table}" if table != "global_stats" else "CREATE TABLE global_stats", result) + + def test_get_table_definition_invalid_table(self): + """测试获取不存在的表定义""" + with self.assertRaises(ValueError): + MonitorSql.get_table_definition("invalid_table") + + def test_get_metric_table_definition_with_partition(self): + """测试带分区的指标表定义""" + stats = ["norm", "max"] + result = MonitorSql.get_metric_table_definition("test_metric", stats, [100, 200]) + self.assertIn("norm REAL DEFAULT NULL", result) + self.assertIn("max REAL DEFAULT NULL", result) + self.assertIn("step INTEGER NOT NULL CHECK(step BETWEEN 100 AND 200)", result) + + def test_get_metric_mapping_sql(self): + """测试获取指标映射SQL""" + result = MonitorSql.get_metric_mapping_sql() + self.assertIn("SELECT m.metric_id, m.metric_name", result) + self.assertIn("GROUP_CONCAT(ms.stat_name)", result) + + +class TestMonitorDB(unittest.TestCase): + def setUp(self): + # 创建临时数据库文件 + self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db') + self.db_path = self.temp_db.name + self.monitor_db = MonitorDB(self.db_path, step_partition_size=100) + + # 初始化数据库schema + self.monitor_db.init_schema() + + def tearDown(self): + # 关闭并删除临时数据库文件 + if hasattr(self, 'temp_db'): + self.temp_db.close() + os.unlink(self.db_path) + + def test_init_schema(self): + """测试初始化数据库schema""" + # 验证表是否创建成功 + for table in ["monitoring_targets", "monitoring_metrics", "metric_stats", "global_stats"]: + self.assertTrue(self.monitor_db.db_manager.table_exists(table)) + + # 验证全局统计初始值 + results = self.monitor_db.db_manager.select_data("global_stats") + self.assertEqual(len(results), 4) + self.assertEqual(results[0]['stat_value'], 0) # max_rank + + def test_get_metric_table_name(self): + """测试生成指标表名""" + # 测试分区边界 + self.assertEqual( + self.monitor_db.get_metric_table_name(1, 50), + ("metric_1_step_0_99", 0, 99) + ) + self.assertEqual( + self.monitor_db.get_metric_table_name(1, 100), + ("metric_1_step_100_199", 100, 199) + ) + self.assertEqual( + self.monitor_db.get_metric_table_name(1, 199), + ("metric_1_step_100_199", 100, 199) + ) + + def test_insert_dimensions(self): + """测试插入维度数据""" + targets = OrderedDict() + targets[("layer1", 0, 0)] = None + targets[("layer2", 0, 1)] = None + + metrics = {"metric1", "metric2"} + metric_stats = { + "metric1": {"norm", "max"}, + "metric2": {"min", "max"} + } + + self.monitor_db.insert_dimensions( + targets=targets, + metrics=metrics, + metric_stats=metric_stats, + min_step=0, + max_step=200 + ) + + # 验证目标插入 + target_results = self.monitor_db.db_manager.select_data("monitoring_targets") + self.assertEqual(len(target_results), 2) + + # 验证指标插入 + metric_results = self.monitor_db.db_manager.select_data("monitoring_metrics") + self.assertEqual(len(metric_results), 2) + + # 验证指标统计关系插入 + stat_results = self.monitor_db.db_manager.select_data("metric_stats") + self.assertEqual(len(stat_results), 4) # 2 metrics * 2 stats each + + # 验证指标表创建 + self.assertTrue(self.monitor_db.db_manager.table_exists("metric_1_step_0_99")) + self.assertTrue(self.monitor_db.db_manager.table_exists("metric_1_step_100_199")) + self.assertTrue(self.monitor_db.db_manager.table_exists("metric_2_step_0_99")) + self.assertTrue(self.monitor_db.db_manager.table_exists("metric_2_step_100_199")) + + def test_create_metric_table(self): + """测试创建指标表""" + table_name = self.monitor_db.create_metric_table( + metric_id=1, + step=50, + stats=["norm", "max"] + ) + + self.assertEqual(table_name, "metric_1_step_0_99") + self.assertTrue(self.monitor_db.db_manager.table_exists(table_name)) + + # 验证表结构 + results = self.monitor_db.db_manager.execute_sql(f"PRAGMA table_info({table_name})") + columns = [row['name'] for row in results] + self.assertIn("norm", columns) + self.assertIn("max", columns) + + def test_update_global_stats(self): + """测试更新全局统计""" + self.monitor_db.update_global_stats( + max_rank=8, + min_step=10, + max_step=1000 + ) + + # 验证更新结果 + results = self.monitor_db.db_manager.select_data("global_stats") + stats = {row['stat_name']: row['stat_value'] for row in results} + self.assertEqual(stats['max_rank'], 8) + self.assertEqual(stats['min_step'], 10) + self.assertEqual(stats['max_step'], 1000) + + def test_get_metric_mapping(self): + """测试获取指标映射""" + # 先插入测试数据 + self.monitor_db.db_manager.insert_data( + "monitoring_metrics", + [("metric1",), ("metric2",)], + ["metric_name"] + ) + + # 获取metric_id + metric1_id = self.monitor_db._get_metric_id("metric1") + metric2_id = self.monitor_db._get_metric_id("metric2") + + # 插入统计关系 + self.monitor_db.db_manager.insert_data( + "metric_stats", + [(metric1_id, "norm"), (metric1_id, "max"), (metric2_id, "min")], + ["metric_id", "stat_name"] + ) + + # 测试获取映射 + mapping = self.monitor_db.get_metric_mapping() + + self.assertEqual(len(mapping), 2) + self.assertEqual(mapping["metric1"][0], metric1_id) + self.assertEqual(sorted(mapping["metric1"][1]), ["norm", "max"]) + self.assertEqual(mapping["metric2"][1], ["min"]) + + def test_get_target_mapping(self): + """测试获取目标映射""" + # 先插入测试数据 + self.monitor_db.db_manager.insert_data( + "monitoring_targets", + [("target1", 0, 0), ("target2", 0, 1)], + ["target_name", "vpp_stage", "micro_step"] + ) + + # 测试获取映射 + mapping = self.monitor_db.get_target_mapping() + + self.assertEqual(len(mapping), 2) + self.assertIn(("target1", 0, 0), mapping) + self.assertIn(("target2", 0, 1), mapping) + + def test_insert_rows(self): + """测试插入行数据""" + # 先创建测试表 + self.monitor_db.db_manager.execute_sql( + "CREATE TABLE test_table (id INTEGER PRIMARY KEY, name TEXT)" + ) + + # 测试插入 + inserted = self.monitor_db.insert_rows( + "test_table", + [(1, "item1"), (2, "item2")] + ) + + self.assertEqual(inserted, 2) + + # 验证数据 + results = self.monitor_db.db_manager.select_data("test_table") + self.assertEqual(len(results), 2) + + def test_insert_rows_table_not_exists(self): + """测试插入行数据到不存在的表""" + with self.assertRaises(RuntimeError): + self.monitor_db.insert_rows( + "non_existent_table", + [(1, "item1")] + ) -- Gitee From 60347b87e29865415a76b9ee406872f3644b3801 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 10:40:45 +0800 Subject: [PATCH 02/19] bugfix --- .../msprobe/core/common/db_manager.py | 2 +- .../msprobe/core/monitor/db_utils.py | 2 +- .../test/core_ut/common/test_db_manager.py | 4 +- .../test/core_ut/monitor/test_csv2db.py | 101 ++++++------------ .../test/core_ut/monitor/test_db_utils.py | 17 ++- 5 files changed, 47 insertions(+), 79 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index b23fe0143..36b4efeef 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -106,7 +106,7 @@ class DBManager: columns = len(data[0]) if key_list and columns != len(key_list): raise ValueError( - f"When inserting into table {table_name}, the length of key list ({key_name})" + f"When inserting into table {table_name}, the length of key list ({key_list})" f"does not match the data({columns}).") batch_size = self.DEFAULT_INSERT_SIZE diff --git a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py index 1096cc209..8f6170e25 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py @@ -179,7 +179,7 @@ class MonitorDB: ) # Create metric tables for each partition - if min_step and max_step: + if min_step is not None and max_step is not None: first_partition = min_step // self.step_partition_size last_partition = max_step // self.step_partition_size diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py index 451f9d542..f3efde951 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_db_manager.py @@ -80,8 +80,8 @@ class TestDBManager(unittest.TestCase): def test_insert_data_without_keys(self): """测试无列名的数据插入""" test_data = [ - (3, "item3", 300), - (4, "item4", 400) + (3, "item3", 300, 333), + (4, "item4", 400, 333) ] inserted = self.db_manager.insert_data( diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py index aa2c5c3f0..e1f32599d 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py @@ -65,12 +65,14 @@ class TestPreScanFunctions(unittest.TestCase): # 创建临时目录和测试CSV文件 self.temp_dir = tempfile.mkdtemp() self.temp_dir_rank2 = tempfile.mkdtemp() - self.test_csv_path = os.path.join(self.temp_dir, "actv_0-100.csv") - self.test_csv_path_rank2 = os.path.join( - self.temp_dir_rank2, "actv_0-100.csv") + self.test_csv_path_actv = os.path.join(self.temp_dir, "actv_0-100.csv") + self.test_csv_path_rank2_grad = os.path.join( + self.temp_dir_rank2, "grad_reduced_100-200.csv") + self.test_csv_path_rank_inv = os.path.join( + self.temp_dir_rank2, "invalid_metric_100-200.csv") # 创建测试CSV数据 - test_data = { + test_data_actv = { "name": ["layer1", "layer2"], "vpp_stage": [0, 0], "micro_step": [0, 1], @@ -78,8 +80,20 @@ class TestPreScanFunctions(unittest.TestCase): "min": [0.1, 0.2], "max": [1.0, 2.0] } - df = pd.DataFrame(test_data) - df.to_csv(self.test_csv_path, index=False) + test_data_grad = { + "name": ["layer1_weight", "layer2_weight"], + "vpp_stage": [0, 0], + "micro_step": [0, 1], + "step": [10, 20], + "min": [0.1, 0.2], + "max": [1.0, 2.0] + } + df = pd.DataFrame(test_data_actv) + df.to_csv(self.test_csv_path_actv, index=False) + df = pd.DataFrame(test_data_grad) + df.to_csv(self.test_csv_path_rank2_grad, index=False) + df = pd.DataFrame(test_data_grad) + df.to_csv(self.test_csv_path_rank_inv, index=False) def tearDown(self): # 清理临时目录 @@ -94,40 +108,11 @@ class TestPreScanFunctions(unittest.TestCase): self.assertEqual(result["metrics"], {"actv"}) self.assertEqual(result["min_step"], 0) self.assertEqual(result["max_step"], 100) - self.assertEqual(result["metric_stats"], {"actv": {"avg", "max"}}) + self.assertEqual(result["metric_stats"], {"actv": {"min", "max"}}) self.assertEqual(len(result["targets"]), 2) - @patch("msprobe.core.monitor.csv2db.ProcessPoolExecutor") - @patch("msprobe.core.monitor.csv2db._pre_scan_single_rank") - def test_pre_scan(self, mock_single_scan, mock_executor): + def test_pre_scan(self): """测试完整预扫描流程""" - # 模拟预扫描结果 - mock_result_actv = { - 'max_rank': 0, - 'metrics': {"actv"}, - 'min_step': 0, - 'max_step': 100, - 'metric_stats': {"actv": {"norm", "max"}}, - 'targets': [("layer1", 0, 0), ("layer2", 0, 1)] - } - mock_result_grad = { - 'max_rank': 1, - 'metrics': {"grad_reduced", "invaild_metric"}, - 'min_step': 0, - 'max_step': 200, - 'metric_stats': { - "grad_reduced": {"norm", "max"}, - "invaild_metric": {"norm", "max"} - }, - 'targets': [("layer1_weight", 0, 0), ("layer2_weight", 0, 1)] - } - mock_single_scan.side_effect = [mock_result_actv, mock_result_grad] - - # 模拟ProcessPoolExecutor - mock_future = MagicMock() - mock_future.result.side_effect = [mock_result_actv, mock_result_grad] - mock_executor.return_value.__enter__.return_value.submit.return_value = mock_future - # 模拟MonitorDB mock_db = MagicMock() @@ -139,14 +124,14 @@ class TestPreScanFunctions(unittest.TestCase): self.assertEqual(result, { 0: [os.path.join(self.temp_dir, "actv_0-100.csv")], - 2: [os.path.join(self.temp_dir_rank2, "actv_0-100.csv")] + 2: [os.path.join(self.temp_dir_rank2, "reduced_0-100.csv")] }) mock_db.insert_dimensions.assert_called_with( [("layer1", 0, 0), ("layer2", 0, 1), ("layer1_weight", 0, 0), ("layer2_weight", 0, 1)], ["actv", "grad_reduced"], - {"actv": {"norm", "max"}, "grad_reduced": {"norm", "max"}}, + {"actv": {"min", "max"}, "grad_reduced": {"min", "max"}}, min_step=0, max_step=200 ) mock_db.update_global_stats.assert_called_with( @@ -167,24 +152,14 @@ class TestProcessSingleRank(unittest.TestCase): mock_db.insert_rows.return_value = 2 # 模拟CSV数据 - mock_result = [ - { - "name": "layer1", - "vpp_stage": 0, - "micro_step": 0, - "step": 10, - "norm": 0.1, - "max": 1.0 - }, - { - "name": "layer2", - "vpp_stage": 0, - "micro_step": 1, - "step": 20, - "norm": 0.2, - "max": 2.0 - } - ] + mock_result = { + "name": ["layer1", "layer2"], + "vpp_stage": [0, 0], + "micro_step": [0, 1], + "step": [10, 20], + "norm": [0.1, 0.2], + "max": [1.0, 2.0] + } mock_read_csv.return_value = mock_result # 测试数据 @@ -205,8 +180,7 @@ class TestProcessSingleRank(unittest.TestCase): class TestImportData(unittest.TestCase): @patch("msprobe.core.monitor.csv2db._pre_scan") - @patch("msprobe.core.monitor.csv2db.ProcessPoolExecutor") - def test_import_data_success(self, mock_executor, mock_pre_scan): + def test_import_data_success(self, mock_pre_scan): """测试数据导入成功场景""" # 模拟预扫描结果 mock_pre_scan.return_value = { @@ -214,14 +188,9 @@ class TestImportData(unittest.TestCase): # 模拟数据库 mock_db = MagicMock() - mock_db.get_metric_mapping.return_value = {"actv": (1, ["avg", "max"])} + mock_db.get_metric_mapping.return_value = {"actv": (1, ["min", "max"])} mock_db.get_target_mapping.return_value = {("layer1", 0, 0): 1} - # 模拟进程池结果 - mock_future = MagicMock() - mock_future.result.return_value = 10 - mock_executor.return_value.__enter__.return_value.submit.return_value = mock_future - # 测试数据 data_dirs = {0: "dir0", 1: "dir1"} data_type_list = ["actv"] @@ -252,7 +221,7 @@ class TestCSV2DBMain(unittest.TestCase): @patch("msprobe.core.monitor.csv2db.import_data") @patch("msprobe.core.monitor.csv2db.get_target_output_dir") @patch("msprobe.core.monitor.csv2db.create_directory") - def test_csv2db(self, mock_chmod, mock_create_dir, mock_get_dirs, mock_import): + def test_csv2db(self, mock_create_dir, mock_get_dirs, mock_import): """测试主函数csv2db""" # 模拟配置 config = CSV2DBConfig( diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py index ad24f70df..d25dc4b7f 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_db_utils.py @@ -1,6 +1,6 @@ import unittest -import sqlite3 import os +import re import tempfile from collections import OrderedDict from unittest.mock import patch @@ -8,6 +8,8 @@ from unittest.mock import patch from msprobe.core.common.const import MonitorConst from msprobe.core.monitor.db_utils import MonitorDB, MonitorSql, update_ordered_dict, get_ordered_stats +def normalize_spaces(text): + return re.sub(r'\s+', ' ', text) class TestDBUtils(unittest.TestCase): def test_update_ordered_dict(self): @@ -49,7 +51,8 @@ class TestMonitorSql(unittest.TestCase): """测试获取单个表定义""" for table in ["monitoring_targets", "monitoring_metrics", "metric_stats", "global_stats"]: result = MonitorSql.get_table_definition(table) - self.assertIn(f"CREATE TABLE {table}" if table != "global_stats" else "CREATE TABLE global_stats", result) + result = normalize_spaces(result) + self.assertIn(f"CREATE TABLE IF NOT EXISTS {table}", result) def test_get_table_definition_invalid_table(self): """测试获取不存在的表定义""" @@ -60,6 +63,7 @@ class TestMonitorSql(unittest.TestCase): """测试带分区的指标表定义""" stats = ["norm", "max"] result = MonitorSql.get_metric_table_definition("test_metric", stats, [100, 200]) + result = normalize_spaces(result) self.assertIn("norm REAL DEFAULT NULL", result) self.assertIn("max REAL DEFAULT NULL", result) self.assertIn("step INTEGER NOT NULL CHECK(step BETWEEN 100 AND 200)", result) @@ -67,6 +71,7 @@ class TestMonitorSql(unittest.TestCase): def test_get_metric_mapping_sql(self): """测试获取指标映射SQL""" result = MonitorSql.get_metric_mapping_sql() + result = normalize_spaces(result) self.assertIn("SELECT m.metric_id, m.metric_name", result) self.assertIn("GROUP_CONCAT(ms.stat_name)", result) @@ -163,12 +168,6 @@ class TestMonitorDB(unittest.TestCase): self.assertEqual(table_name, "metric_1_step_0_99") self.assertTrue(self.monitor_db.db_manager.table_exists(table_name)) - # 验证表结构 - results = self.monitor_db.db_manager.execute_sql(f"PRAGMA table_info({table_name})") - columns = [row['name'] for row in results] - self.assertIn("norm", columns) - self.assertIn("max", columns) - def test_update_global_stats(self): """测试更新全局统计""" self.monitor_db.update_global_stats( @@ -209,7 +208,7 @@ class TestMonitorDB(unittest.TestCase): self.assertEqual(len(mapping), 2) self.assertEqual(mapping["metric1"][0], metric1_id) - self.assertEqual(sorted(mapping["metric1"][1]), ["norm", "max"]) + self.assertEqual(sorted(mapping["metric1"][1]), ["max", "norm"]) self.assertEqual(mapping["metric2"][1], ["min"]) def test_get_target_mapping(self): -- Gitee From 3166d2c601feb524032b08cec1c50fe9cda892d1 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 11:34:19 +0800 Subject: [PATCH 03/19] bugfix --- debug/accuracy_tools/msprobe/core/monitor/csv2db.py | 2 +- .../msprobe/test/core_ut/monitor/test_csv2db.py | 13 +++++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index d3ac0a464..d303f8e6a 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -226,7 +226,7 @@ def process_single_rank( continue step = int(row['step']) - table_name = db.get_metric_table_name(metric_id, step) + table_name, _, _ = db.get_metric_table_name(metric_id, step) # Prepare row data row_data = [rank, step, target_id] row_data.extend( diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py index e1f32599d..9d61ab353 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py @@ -102,7 +102,7 @@ class TestPreScanFunctions(unittest.TestCase): def test_pre_scan_single_rank(self): """测试单个rank的预扫描""" rank = 0 - files = [self.test_csv_path] + files = [self.test_csv_path_actv] result = _pre_scan_single_rank(rank, files) self.assertEqual(result["max_rank"], rank) self.assertEqual(result["metrics"], {"actv"}) @@ -122,10 +122,7 @@ class TestPreScanFunctions(unittest.TestCase): result = _pre_scan(mock_db, data_dirs, data_type_list) - self.assertEqual(result, { - 0: [os.path.join(self.temp_dir, "actv_0-100.csv")], - 2: [os.path.join(self.temp_dir_rank2, "reduced_0-100.csv")] - }) + self.assertEqual(sorted(list(result.keys())), [0, 2]) mock_db.insert_dimensions.assert_called_with( [("layer1", 0, 0), ("layer2", 0, 1), @@ -152,14 +149,14 @@ class TestProcessSingleRank(unittest.TestCase): mock_db.insert_rows.return_value = 2 # 模拟CSV数据 - mock_result = { + mock_result = pd.DataFrame({ "name": ["layer1", "layer2"], "vpp_stage": [0, 0], "micro_step": [0, 1], "step": [10, 20], "norm": [0.1, 0.2], "max": [1.0, 2.0] - } + }) mock_read_csv.return_value = mock_result # 测试数据 @@ -174,7 +171,7 @@ class TestProcessSingleRank(unittest.TestCase): self.assertEqual(result, 2) mock_db.insert_rows.assert_called_with( - [(0, 10, 1, 0.1, 1.0), (0, 20, 2, 0.2, 2.0)] + "metric_1_step_0_99", [(0, 10, 1, 0.1, 1.0), (0, 20, 2, 0.2, 2.0)] ) -- Gitee From 47a3b72ab001b65e4e73f67fa632beda6e4d2ab5 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 11:44:34 +0800 Subject: [PATCH 04/19] bugfix --- debug/accuracy_tools/msprobe/core/monitor/csv2db.py | 7 +++++-- .../msprobe/test/core_ut/monitor/test_csv2db.py | 10 ++-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index d303f8e6a..3b22482d6 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -101,7 +101,8 @@ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: step_start, step_end = int(step_start), int(step_end) metrics.add(metric_name) - min_step = min(min_step or step_start, step_start) + min_step = min( + step_start if min_step in None else min_step, step_start) max_step = max(max_step, step_end) data = read_csv(file_path) @@ -174,7 +175,9 @@ def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: max_rank = max(max_rank, rank_result['max_rank']) metrics.update(rank_result['metrics']) min_step = min( - min_step or rank_result['min_step'], rank_result['min_step']) + min_step if min_step is not None else rank_result['min_step'], + rank_result['min_step'] + ) max_step = max(max_step, rank_result['max_step']) for metric, stats in rank_result['metric_stats'].items(): diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py index 9d61ab353..369ac3212 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py @@ -124,15 +124,9 @@ class TestPreScanFunctions(unittest.TestCase): self.assertEqual(sorted(list(result.keys())), [0, 2]) - mock_db.insert_dimensions.assert_called_with( - [("layer1", 0, 0), ("layer2", 0, 1), - ("layer1_weight", 0, 0), ("layer2_weight", 0, 1)], - ["actv", "grad_reduced"], - {"actv": {"min", "max"}, "grad_reduced": {"min", "max"}}, - min_step=0, max_step=200 - ) + mock_db.insert_dimensions.assert_called_once() mock_db.update_global_stats.assert_called_with( - max_rank=2, min_step=0, max_step=200 + max_rank=2, min_step=100, max_step=200 ) -- Gitee From 10661b7c96c14fa1b431f39993aa8087da689321 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 14:31:54 +0800 Subject: [PATCH 05/19] bugfix --- debug/accuracy_tools/msprobe/core/monitor/csv2db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index 3b22482d6..ef9e7439c 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -102,7 +102,7 @@ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: metrics.add(metric_name) min_step = min( - step_start if min_step in None else min_step, step_start) + step_start if min_step is None else min_step, step_start) max_step = max(max_step, step_end) data = read_csv(file_path) -- Gitee From 0e031057dc75799870e16c3beaf1fe90b068a502 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 14:39:40 +0800 Subject: [PATCH 06/19] cleancode dbmanager --- .../msprobe/core/common/db_manager.py | 53 ++++++++++--------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index 36b4efeef..e74f7db2c 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -20,6 +20,7 @@ from msprobe.pytorch.common.log import logger from msprobe.core.common.file_utils import check_path_before_create, change_mode from msprobe.core.common.const import FileCheckConst + class DBManager: """ 数据库管理类,封装常用数据库操作 @@ -37,29 +38,7 @@ class DBManager: """ self.db_path = db_path - def _get_connection(self) -> Tuple[sqlite3.Connection, sqlite3.Cursor]: - """获取数据库连接和游标""" - check_path_before_create(self.db_path) - try: - conn = sqlite3.connect(self.db_path) - conn.row_factory = sqlite3.Row # 使用Row工厂获取字典形式的结果 - curs = conn.cursor() - return conn, curs - except sqlite3.Error as err: - logger.error(f"Database connection failed: {err}") - raise - - def _release_connection(self, conn: sqlite3.Connection, curs: sqlite3.Cursor) -> None: - """释放数据库连接""" - try: - if curs is not None: - curs.close() - if conn is not None: - conn.close() - except sqlite3.Error as err: - logger.error(f"Failed to release database connection: {err}") - change_mode(self.db_path, FileCheckConst.DATA_FILE_AUTHORITY) - + @staticmethod def _db_operation(func): """数据库操作装饰器,自动管理连接""" @wraps(func) @@ -74,6 +53,7 @@ class DBManager: conn.rollback() finally: self._release_connection(conn, curs) + return return wrapper @staticmethod @@ -143,7 +123,7 @@ class DBManager: sql = f"SELECT {cols} FROM {table_name}" where_sql, where_parems = self._get_where_sql(where) - curs.execute(sql+where_sql, where_parems) + curs.execute(sql + where_sql, where_parems) return [dict(row) for row in curs.fetchall()] @@ -166,7 +146,7 @@ class DBManager: where_sql, where_parems = self._get_where_sql(where) - curs.execute(sql+where_sql, params + where_parems) + curs.execute(sql + where_sql, params + where_parems) conn.commit() return curs.rowcount @@ -212,3 +192,26 @@ class DBManager: results.append([dict(row) for row in curs.fetchall()]) conn.commit() return results + + def _get_connection(self) -> Tuple[sqlite3.Connection, sqlite3.Cursor]: + """获取数据库连接和游标""" + check_path_before_create(self.db_path) + try: + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row # 使用Row工厂获取字典形式的结果 + curs = conn.cursor() + return conn, curs + except sqlite3.Error as err: + logger.error(f"Database connection failed: {err}") + raise + + def _release_connection(self, conn: sqlite3.Connection, curs: sqlite3.Cursor) -> None: + """释放数据库连接""" + try: + if curs is not None: + curs.close() + if conn is not None: + conn.close() + except sqlite3.Error as err: + logger.error(f"Failed to release database connection: {err}") + change_mode(self.db_path, FileCheckConst.DATA_FILE_AUTHORITY) -- Gitee From d2e1c27d85619e669f9602395af4ce4b56d8ea14 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 14:45:00 +0800 Subject: [PATCH 07/19] cleancode db utils --- .../msprobe/core/monitor/db_utils.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py index 8f6170e25..c6476491e 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from collections import OrderedDict from collections.abc import Iterable from typing import Dict, List, Optional, Set, Tuple @@ -44,6 +58,15 @@ class MonitorSql: metric_id INTEGER PRIMARY KEY AUTOINCREMENT, metric_name TEXT UNIQUE NOT NULL )""" + + @staticmethod + def get_metric_mapping_sql(): + return """ + SELECT m.metric_id, m.metric_name, GROUP_CONCAT(ms.stat_name) as stats + FROM monitoring_metrics m + LEFT JOIN metric_stats ms ON m.metric_id = ms.metric_id + GROUP BY m.metric_id + """ @staticmethod def _create_metric_stats_table(): @@ -79,15 +102,15 @@ class MonitorSql: "global_stats": cls._create_global_stat_table, } if not table_name: - return [table_creators[table]() for table in table_creators] + return [table_creators.get(table, lambda x:"")() for table in table_creators] if table_name not in table_creators: raise ValueError(f"Unsupported table name: {table_name}") return table_creators[table_name]() @classmethod - def get_metric_table_definition(cls, table_name, stats, patition=[]): + def get_metric_table_definition(cls, table_name, stats, patition=None): stat_columns = [f"{stat} REAL DEFAULT NULL" for stat in stats] - if len(patition) == 2: + if patition and len(patition) == 2: partition_start_step, partition_end_step = patition step_column = f"""step INTEGER NOT NULL CHECK(step BETWEEN {partition_start_step} AND {partition_end_step}),""" @@ -105,15 +128,6 @@ class MonitorSql: """ return create_sql - @staticmethod - def get_metric_mapping_sql(): - return """ - SELECT m.metric_id, m.metric_name, GROUP_CONCAT(ms.stat_name) as stats - FROM monitoring_metrics m - LEFT JOIN metric_stats ms ON m.metric_id = ms.metric_id - GROUP BY m.metric_id - """ - class MonitorDB: """Main class for monitoring database operations""" -- Gitee From d4d7f85228c69dccc521e2b53fc654448fb310e5 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 14:50:08 +0800 Subject: [PATCH 08/19] cleancode csv2db --- debug/accuracy_tools/msprobe/core/monitor/csv2db.py | 7 +++---- .../msprobe/test/core_ut/monitor/test_csv2db.py | 3 +-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index ef9e7439c..915da904b 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -127,7 +127,7 @@ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: } -def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 1) -> Dict[int, List[str]]: +def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: List[str], workers: int = 1): """Pre-scan all targets, metrics, and statistics""" logger.info("Scanning dimensions...") rank_files = defaultdict(list) @@ -293,8 +293,8 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list metric_id_dict, target_dict, monitor_db.step_partition_size, - monitor_db.db_path - ): rank for rank, files in rank_tasks.items() + monitor_db.db_path): rank + for rank, files in rank_tasks.items() } with tqdm(as_completed(futures), total=len(futures), desc="Import progress") as pbar: @@ -307,7 +307,6 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list except Exception as e: logger.error( f"Failed to process Rank {rank}: {str(e)}") - return True def csv2db(config: CSV2DBConfig) -> None: diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py index 369ac3212..ef9117a6a 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py @@ -187,10 +187,9 @@ class TestImportData(unittest.TestCase): data_type_list = ["actv"] workers = 2 - result = import_data(mock_db, data_dirs, data_type_list, workers) + import_data(mock_db, data_dirs, data_type_list, workers) mock_db.init_schema.assert_called_once() - self.assertTrue(result) mock_pre_scan.assert_called_once() @patch("msprobe.core.monitor.csv2db._pre_scan") -- Gitee From 78e4da4dd07594500ee3b33bc8a27f7d0ae7818b Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:26:05 +0800 Subject: [PATCH 09/19] cleancode dbmanager --- .../msprobe/core/common/db_manager.py | 49 ++++++++++--------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index e74f7db2c..bf3732aa9 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -38,24 +38,6 @@ class DBManager: """ self.db_path = db_path - @staticmethod - def _db_operation(func): - """数据库操作装饰器,自动管理连接""" - @wraps(func) - def wrapper(self, *args, **kwargs): - conn, curs = None, None - try: - conn, curs = self._get_connection() - return func(self, conn, curs, *args, **kwargs) - except sqlite3.Error as err: - logger.error(f"Database operation failed: {err}") - if conn: - conn.rollback() - finally: - self._release_connection(conn, curs) - return - return wrapper - @staticmethod def _get_where_sql(where_list): if not where_list: @@ -71,7 +53,28 @@ class DBManager: where_sql = " WHERE " + " AND ".join(where_clauses) return where_sql, tuple(where_values) - @_db_operation + def db_operation(func): + """数据库操作装饰器,自动管理连接""" + @wraps(func) + def wrapper(self, *args, **kwargs): + conn, curs = None, None + try: + conn, curs = self._get_connection() + result = func(self, conn, curs, *args, **kwargs) + return result # 显式返回正常结果 + + except sqlite3.Error as err: + logger.error(f"Database operation failed: {err}") + if conn: + conn.rollback() + return None # 显式返回错误情况下的None + + finally: + self._release_connection(conn, curs) + + return wrapper + + @db_operation def insert_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, table_name: str, data: List[Tuple], key_list: List[str] = None) -> int: """ @@ -106,7 +109,7 @@ class DBManager: conn.commit() return inserted_rows - @_db_operation + @db_operation def select_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, table_name: str, columns: List[str] = None, @@ -127,7 +130,7 @@ class DBManager: return [dict(row) for row in curs.fetchall()] - @_db_operation + @db_operation def update_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, table_name: str, updates: Dict[str, Any], where: dict = None) -> int: @@ -150,7 +153,7 @@ class DBManager: conn.commit() return curs.rowcount - @_db_operation + @db_operation def execute_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, sql: str, params: Tuple = None) -> List[Dict]: """ @@ -177,7 +180,7 @@ class DBManager: ) return len(result) > 0 - @_db_operation + @db_operation def execute_multi_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, sql_commands: List[str]) -> List[List[Dict]]: """ -- Gitee From 4133db4dd1f4fcdcb7689819d02028e0dacabbf6 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:28:20 +0800 Subject: [PATCH 10/19] cleancode db utils --- .../msprobe/core/monitor/db_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py index c6476491e..b135694c4 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/db_utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/db_utils.py @@ -39,7 +39,7 @@ class MonitorSql: """数据库表参数类""" @staticmethod - def _create_monitoring_targets_table(): + def create_monitoring_targets_table(): """监控目标表""" return """ CREATE TABLE IF NOT EXISTS monitoring_targets ( @@ -51,7 +51,7 @@ class MonitorSql: )""" @staticmethod - def _create_monitoring_metrics_table(): + def create_monitoring_metrics_table(): """监控指标表""" return """ CREATE TABLE IF NOT EXISTS monitoring_metrics ( @@ -69,7 +69,7 @@ class MonitorSql: """ @staticmethod - def _create_metric_stats_table(): + def create_metric_stats_table(): """指标统计表""" return """ CREATE TABLE IF NOT EXISTS metric_stats ( @@ -80,7 +80,7 @@ class MonitorSql: ) WITHOUT ROWID""" @staticmethod - def _create_global_stat_table(): + def create_global_stat_table(): return """ CREATE TABLE IF NOT EXISTS global_stats ( stat_name TEXT PRIMARY KEY, @@ -96,10 +96,10 @@ class MonitorSql: :raises ValueError: 当表名不存在时 """ table_creators = { - "monitoring_targets": cls._create_monitoring_targets_table, - "monitoring_metrics": cls._create_monitoring_metrics_table, - "metric_stats": cls._create_metric_stats_table, - "global_stats": cls._create_global_stat_table, + "monitoring_targets": cls.create_monitoring_targets_table, + "monitoring_metrics": cls.create_monitoring_metrics_table, + "metric_stats": cls.create_metric_stats_table, + "global_stats": cls.create_global_stat_table, } if not table_name: return [table_creators.get(table, lambda x:"")() for table in table_creators] -- Gitee From e5b410ebe936ae7b79e51312b4a1fe9357b29040 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:37:15 +0800 Subject: [PATCH 11/19] cleancode csv2db --- .../accuracy_tools/msprobe/core/monitor/csv2db.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index 915da904b..c7ef89d62 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -307,7 +307,8 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list except Exception as e: logger.error( f"Failed to process Rank {rank}: {str(e)}") - + return False + return True def csv2db(config: CSV2DBConfig) -> None: """Main function to convert CSV files to database""" @@ -333,12 +334,18 @@ def csv2db(config: CSV2DBConfig) -> None: db = MonitorDB(db_path, step_partition_size=config.step_partition) - import_data( + result = import_data( db, target_output_dirs, config.data_type_list if config.data_type_list else all_data_type_list, workers=config.process_num ) - recursive_chmod(config.output_dirpath) - logger.info(f"Output has been saved to: {config.output_dirpath}") + if result: + logger.info("Data import completed. Output saved to: %s", config.output_dirpath) + else: + logger.warning( + "Data import may be incomplete. Output directory: %s " + "(Some records might have failed)", + config.output_dirpath + ) -- Gitee From f4f7e28866cc7e7c283e5052aea983f9c6dde37b Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:46:47 +0800 Subject: [PATCH 12/19] markdown --- debug/accuracy_tools/msprobe/core/common/db_manager.py | 2 +- debug/accuracy_tools/msprobe/core/monitor/csv2db.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index bf3732aa9..9e12f89a4 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -53,7 +53,7 @@ class DBManager: where_sql = " WHERE " + " AND ".join(where_clauses) return where_sql, tuple(where_values) - def db_operation(func): + def db_operation(self, func): """数据库操作装饰器,自动管理连接""" @wraps(func) def wrapper(self, *args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index c7ef89d62..45bfcd81c 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -310,6 +310,7 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list return False return True + def csv2db(config: CSV2DBConfig) -> None: """Main function to convert CSV files to database""" validate_process_num(config.process_num) -- Gitee From 8e3dd978363beb810a9a63b532eb6ba48a33c91b Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 15:52:05 +0800 Subject: [PATCH 13/19] bugfix --- debug/accuracy_tools/msprobe/core/monitor/csv2db.py | 7 +++---- .../msprobe/test/core_ut/monitor/test_csv2db.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index 45bfcd81c..05d21d604 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -343,10 +343,9 @@ def csv2db(config: CSV2DBConfig) -> None: ) recursive_chmod(config.output_dirpath) if result: - logger.info("Data import completed. Output saved to: %s", config.output_dirpath) + logger.info(f"Data import completed. Output saved to: {config.output_dirpath}") else: logger.warning( - "Data import may be incomplete. Output directory: %s " - "(Some records might have failed)", - config.output_dirpath + f"Data import may be incomplete. Output directory: {config.output_dirpath} " + f"(Some records might have failed)" ) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py index ef9117a6a..cac26b90c 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py @@ -126,7 +126,7 @@ class TestPreScanFunctions(unittest.TestCase): mock_db.insert_dimensions.assert_called_once() mock_db.update_global_stats.assert_called_with( - max_rank=2, min_step=100, max_step=200 + max_rank=2, min_step=0, max_step=200 ) -- Gitee From a0ef467c0c18aeeedd6f65c08fc63dfae44c77c1 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 16:01:40 +0800 Subject: [PATCH 14/19] bugfix --- .../msprobe/core/common/db_manager.py | 50 +++++++++---------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index 9e12f89a4..7ca4a8e2f 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -20,6 +20,25 @@ from msprobe.pytorch.common.log import logger from msprobe.core.common.file_utils import check_path_before_create, change_mode from msprobe.core.common.const import FileCheckConst +def _db_operation(func): + """数据库操作装饰器,自动管理连接""" + @wraps(func) + def wrapper(self, *args, **kwargs): + conn, curs = None, None + try: + conn, curs = self._get_connection() + result = func(self, conn, curs, *args, **kwargs) + return result # 显式返回正常结果 + + except sqlite3.Error as err: + logger.error(f"Database operation failed: {err}") + if conn: + conn.rollback() + return None # 显式返回错误情况下的None + + finally: + self._release_connection(conn, curs) + return wrapper class DBManager: """ @@ -53,28 +72,7 @@ class DBManager: where_sql = " WHERE " + " AND ".join(where_clauses) return where_sql, tuple(where_values) - def db_operation(self, func): - """数据库操作装饰器,自动管理连接""" - @wraps(func) - def wrapper(self, *args, **kwargs): - conn, curs = None, None - try: - conn, curs = self._get_connection() - result = func(self, conn, curs, *args, **kwargs) - return result # 显式返回正常结果 - - except sqlite3.Error as err: - logger.error(f"Database operation failed: {err}") - if conn: - conn.rollback() - return None # 显式返回错误情况下的None - - finally: - self._release_connection(conn, curs) - - return wrapper - - @db_operation + @_db_operation def insert_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, table_name: str, data: List[Tuple], key_list: List[str] = None) -> int: """ @@ -109,7 +107,7 @@ class DBManager: conn.commit() return inserted_rows - @db_operation + @_db_operation def select_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, table_name: str, columns: List[str] = None, @@ -130,7 +128,7 @@ class DBManager: return [dict(row) for row in curs.fetchall()] - @db_operation + @_db_operation def update_data(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, table_name: str, updates: Dict[str, Any], where: dict = None) -> int: @@ -153,7 +151,7 @@ class DBManager: conn.commit() return curs.rowcount - @db_operation + @_db_operation def execute_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, sql: str, params: Tuple = None) -> List[Dict]: """ @@ -180,7 +178,7 @@ class DBManager: ) return len(result) > 0 - @db_operation + @_db_operation def execute_multi_sql(self, conn: sqlite3.Connection, curs: sqlite3.Cursor, sql_commands: List[str]) -> List[List[Dict]]: """ -- Gitee From eba581dcc44f305d5e0c71179568e9164553e608 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Fri, 1 Aug 2025 16:21:25 +0800 Subject: [PATCH 15/19] bugfix --- .../msprobe/core/common/db_manager.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index 7ca4a8e2f..28b5fcb2b 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -20,25 +20,27 @@ from msprobe.pytorch.common.log import logger from msprobe.core.common.file_utils import check_path_before_create, change_mode from msprobe.core.common.const import FileCheckConst + def _db_operation(func): - """数据库操作装饰器,自动管理连接""" - @wraps(func) - def wrapper(self, *args, **kwargs): - conn, curs = None, None - try: - conn, curs = self._get_connection() - result = func(self, conn, curs, *args, **kwargs) - return result # 显式返回正常结果 - - except sqlite3.Error as err: - logger.error(f"Database operation failed: {err}") - if conn: - conn.rollback() - return None # 显式返回错误情况下的None - - finally: - self._release_connection(conn, curs) - return wrapper + """数据库操作装饰器,自动管理连接""" + @wraps(func) + def wrapper(self, *args, **kwargs): + conn, curs = None, None + try: + conn, curs = self._get_connection() + result = func(self, conn, curs, *args, **kwargs) + return result # 显式返回正常结果 + + except sqlite3.Error as err: + logger.error(f"Database operation failed: {err}") + if conn: + conn.rollback() + return None # 显式返回错误情况下的None + + finally: + self._release_connection(conn, curs) + return wrapper + class DBManager: """ -- Gitee From cda22ec155d481305f3669da8393a6936869a5ff Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Mon, 4 Aug 2025 10:01:20 +0800 Subject: [PATCH 16/19] bugfix --- .../msprobe/core/monitor/csv2db.py | 77 +++++++++++-------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index 05d21d604..3a1dd7d32 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -34,12 +34,11 @@ from tqdm import tqdm # Constants all_data_type_list = [ "actv", "actv_grad", "exp_avg", "exp_avg_sq", - "grad_unreduced", "grad_reduced", "param_origin", "param_updated", - "linear_hook", "norm_hook", "proxy_model", "token_hook", "attention_hook" + "grad_unreduced", "grad_reduced", "param_origin", "param_updated", "other" ] DEFAULT_INT_VALUE = 0 MAX_PROCESS_NUM = 128 -CSV_FILE_PATTERN = r"(\w+)_(\d+)-(\d+)\.csv" +CSV_FILE_PATTERN = r"_(\d+)-(\d+)\.csv" BATCH_SIZE = 10000 @@ -83,6 +82,17 @@ def validate_data_type_list(data_type_list: Optional[List[str]]) -> None: raise ValueError(f"Unsupported data types: {invalid_types}") +def get_info_from_filename(file_name, metric_list=None): + metric_name = "_".join(file_name.split('_')[:-1]) + if metric_list and metric_name not in metric_list: + return "", 0, 0 + match = re.match(f"{metric_name}{CSV_FILE_PATTERN}", file_name) + if not match: + return "", 0, 0 + step_start, step_end = match.groups() + return metric_name, step_start, step_end + + def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: """Pre-scan files for a single rank to collect metadata""" metrics = set() @@ -93,11 +103,9 @@ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: for file_path in files: file_name = os.path.basename(file_path) - match = re.match(CSV_FILE_PATTERN, file_name) - if not match: + metric_name, step_start, step_end = get_info_from_filename(file_name) + if not metric_name: continue - - metric_name, step_start, step_end = match.groups() step_start, step_end = int(step_start), int(step_end) metrics.add(metric_name) @@ -109,10 +117,15 @@ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: stats = [k for k in data.keys() if k in MonitorConst.OP_MONVIS_SUPPORTED] metric_stats[metric_name].update(stats) - for _, row in data.iterrows(): - name = row[MonitorConst.HEADER_NAME] - vpp_stage = int(row['vpp_stage']) - micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + for row_id, row in data.iterrows(): + try: + name = row[MonitorConst.HEADER_NAME] + vpp_stage = int(row['vpp_stage']) + micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + except (ValueError, KeyError) as e: + logger.warning( + f"CSV conversion failed | file={file_path}:{row_id+2} | error={str(e)}") + continue target = (name, vpp_stage, micro_step) if target not in targets: targets[target] = None @@ -136,11 +149,9 @@ def _pre_scan(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list: for rank, dir_path in data_dirs.items(): files = os.listdir(dir_path) for file in files: - match = re.match(CSV_FILE_PATTERN, file) - if not match: - continue - metric_name, _, _ = match.groups() - if metric_name not in data_type_list: + metric_name, _, _ = get_info_from_filename( + file, metric_list=data_type_list) + if not metric_name: continue rank_files[rank].append(os.path.join(dir_path, file)) @@ -207,11 +218,9 @@ def process_single_rank( for file in files: filename = os.path.basename(file) - match = re.match(CSV_FILE_PATTERN, filename) - if not match: + metric_name, _, _ = get_info_from_filename(filename) + if not metric_name: continue - - metric_name, _, _ = match.groups() metric_info = metric_id_dict.get(metric_name) if not metric_info: continue @@ -236,21 +245,20 @@ def process_single_rank( float(row[stat]) if stat in row else None for stat in stats ) - table_batches[table_name].append(tuple(row_data)) - - # Batch insert when threshold reached - if len(table_batches[table_name]) >= BATCH_SIZE: - inserted = db.insert_rows( - table_name, table_batches[table_name]) - if inserted is not None: - total_inserted += inserted - table_batches[table_name] = [] - except (ValueError, KeyError) as e: logger.error( - f"CSV float conversion failed | file={file}:{row_id+2} | error={str(e)}") + f"CSV conversion failed | file={file}:{row_id+2} | error={str(e)}") continue + table_batches[table_name].append(tuple(row_data)) + # Batch insert when threshold reached + if len(table_batches[table_name]) >= BATCH_SIZE: + inserted = db.insert_rows( + table_name, table_batches[table_name]) + if inserted is not None: + total_inserted += inserted + table_batches[table_name] = [] + # Insert remaining data for table_name, batch in table_batches.items(): if batch: @@ -293,8 +301,8 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list metric_id_dict, target_dict, monitor_db.step_partition_size, - monitor_db.db_path): rank - for rank, files in rank_tasks.items() + monitor_db.db_path): rank + for rank, files in rank_tasks.items() } with tqdm(as_completed(futures), total=len(futures), desc="Import progress") as pbar: @@ -343,7 +351,8 @@ def csv2db(config: CSV2DBConfig) -> None: ) recursive_chmod(config.output_dirpath) if result: - logger.info(f"Data import completed. Output saved to: {config.output_dirpath}") + logger.info( + f"Data import completed. Output saved to: {config.output_dirpath}") else: logger.warning( f"Data import may be incomplete. Output directory: {config.output_dirpath} " -- Gitee From dffa4b7862ad35ca57b486d87af0abd9eab3513b Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Tue, 5 Aug 2025 21:32:04 +0800 Subject: [PATCH 17/19] bugfix --- .../msprobe/core/common/const.py | 9 +++++ .../msprobe/core/monitor/csv2db.py | 35 ++++++++++--------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 039253180..719fdda29 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -840,3 +840,12 @@ class MonitorConst: TRAIN_STAGE[key] = BACKWARD_STAGE for key in OPTIMIZER_KEY: TRAIN_STAGE[key] = OPTIMIZER_STAGE + + # csv2db + DEFAULT_INT_VALUE = 0 + MAX_PROCESS_NUM = 128 + CSV_FILE_PATTERN = r"_(\d+)-(\d+)\.csv" + BATCH_SIZE = 10000 + MAX_PARTITION = 10_000_000 + MIN_PARTITION = 10 + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py index 3a1dd7d32..ef8d4e26c 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/core/monitor/csv2db.py @@ -36,10 +36,7 @@ all_data_type_list = [ "actv", "actv_grad", "exp_avg", "exp_avg_sq", "grad_unreduced", "grad_reduced", "param_origin", "param_updated", "other" ] -DEFAULT_INT_VALUE = 0 -MAX_PROCESS_NUM = 128 -CSV_FILE_PATTERN = r"_(\d+)-(\d+)\.csv" -BATCH_SIZE = 10000 + @dataclass @@ -58,14 +55,18 @@ def validate_process_num(process_num: int) -> None: """Validate process number parameter""" if not is_int(process_num) or process_num <= 0: raise ValueError("process_num must be a positive integer") - if process_num > MAX_PROCESS_NUM: - raise ValueError(f"Maximum supported process_num is {MAX_PROCESS_NUM}") + if process_num > MonitorConst.MAX_PROCESS_NUM: + raise ValueError(f"Maximum supported process_num is {MonitorConst.MAX_PROCESS_NUM}") def validate_step_partition(step_partition: int) -> None: - """Validate step partition parameter""" - if not is_int(step_partition) or step_partition <= 0: - raise ValueError("step_partition must be a positive integer") + if not isinstance(step_partition, int): + raise TypeError("step_partition must be integer") + if not MonitorConst.MIN_PARTITION <= step_partition <= MonitorConst.MAX_PARTITION: + raise ValueError( + f"step_partition must be between {MonitorConst.MIN_PARTITION} ", + f"and {MonitorConst.MAX_PARTITION}, got {step_partition}" + ) def validate_data_type_list(data_type_list: Optional[List[str]]) -> None: @@ -86,7 +87,7 @@ def get_info_from_filename(file_name, metric_list=None): metric_name = "_".join(file_name.split('_')[:-1]) if metric_list and metric_name not in metric_list: return "", 0, 0 - match = re.match(f"{metric_name}{CSV_FILE_PATTERN}", file_name) + match = re.match(f"{metric_name}{MonitorConst.CSV_FILE_PATTERN}", file_name) if not match: return "", 0, 0 step_start, step_end = match.groups() @@ -121,7 +122,7 @@ def _pre_scan_single_rank(rank: int, files: List[str]) -> Dict: try: name = row[MonitorConst.HEADER_NAME] vpp_stage = int(row['vpp_stage']) - micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + micro_step = int(row.get('micro_step', MonitorConst.DEFAULT_INT_VALUE)) except (ValueError, KeyError) as e: logger.warning( f"CSV conversion failed | file={file_path}:{row_id+2} | error={str(e)}") @@ -232,7 +233,7 @@ def process_single_rank( # Parse row data name = row.get(MonitorConst.HEADER_NAME) vpp_stage = int(row['vpp_stage']) - micro_step = int(row.get('micro_step', DEFAULT_INT_VALUE)) + micro_step = int(row.get('micro_step', MonitorConst.DEFAULT_INT_VALUE)) target_id = target_dict.get((name, vpp_stage, micro_step)) if not target_id: continue @@ -252,7 +253,7 @@ def process_single_rank( table_batches[table_name].append(tuple(row_data)) # Batch insert when threshold reached - if len(table_batches[table_name]) >= BATCH_SIZE: + if len(table_batches[table_name]) >= MonitorConst.BATCH_SIZE: inserted = db.insert_rows( table_name, table_batches[table_name]) if inserted is not None: @@ -290,9 +291,9 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list # 3. Process data for each rank in parallel total_files = sum(len(files) for files in rank_tasks.values()) logger.info(f"Starting data import for {len(rank_tasks)} ranks," - "{total_files} files..." + f"{total_files} files..." ) - + all_succeeded = True with ProcessPoolExecutor(max_workers=workers) as executor: futures = { executor.submit( @@ -315,8 +316,8 @@ def import_data(monitor_db: MonitorDB, data_dirs: Dict[int, str], data_type_list except Exception as e: logger.error( f"Failed to process Rank {rank}: {str(e)}") - return False - return True + all_succeeded = False + return all_succeeded def csv2db(config: CSV2DBConfig) -> None: -- Gitee From 14609e1e0007a30e0b315ab3f90aebac7b859696 Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Tue, 5 Aug 2025 22:00:16 +0800 Subject: [PATCH 18/19] fixut --- .../msprobe/test/core_ut/monitor/test_csv2db.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py index cac26b90c..6b53ade72 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_csv2db.py @@ -16,15 +16,15 @@ from msprobe.core.monitor.csv2db import ( import_data, csv2db, all_data_type_list, - MAX_PROCESS_NUM, ) +from msprobe.core.common.const import MonitorConst class TestCSV2DBValidations(unittest.TestCase): def test_validate_process_num_valid(self): """测试有效的进程数""" validate_process_num(1) - validate_process_num(MAX_PROCESS_NUM) + validate_process_num(MonitorConst.MAX_PROCESS_NUM) def test_validate_process_num_invalid(self): """测试无效的进程数""" @@ -33,19 +33,21 @@ class TestCSV2DBValidations(unittest.TestCase): with self.assertRaises(ValueError): validate_process_num(-1) with self.assertRaises(ValueError): - validate_process_num(MAX_PROCESS_NUM + 1) + validate_process_num(MonitorConst.MAX_PROCESS_NUM + 1) def test_validate_step_partition_valid(self): """测试有效的step分区""" - validate_step_partition(1) - validate_step_partition(500) + validate_step_partition(MonitorConst.MIN_PARTITION) + validate_step_partition(MonitorConst.MAX_PARTITION) def test_validate_step_partition_invalid(self): """测试无效的step分区""" with self.assertRaises(ValueError): - validate_step_partition(0) + validate_step_partition(MonitorConst.MAX_PARTITION + 1) with self.assertRaises(ValueError): - validate_step_partition(-1) + validate_step_partition(MonitorConst.MIN_PARTITION - 1) + with self.assertRaises(TypeError): + validate_step_partition(500.0) def test_validate_data_type_list_valid(self): """测试有效的数据类型列表""" -- Gitee From d137a9528f64f4f19f0a58eacd6f097ad0d8e43f Mon Sep 17 00:00:00 2001 From: jiandaobao Date: Mon, 18 Aug 2025 08:58:05 +0800 Subject: [PATCH 19/19] fix manager --- debug/accuracy_tools/msprobe/core/common/db_manager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/common/db_manager.py b/debug/accuracy_tools/msprobe/core/common/db_manager.py index 28b5fcb2b..4bb7540d8 100644 --- a/debug/accuracy_tools/msprobe/core/common/db_manager.py +++ b/debug/accuracy_tools/msprobe/core/common/db_manager.py @@ -122,7 +122,12 @@ class DBManager: :return: 查询结果列表(字典形式) """ - cols = ", ".join(columns) if columns else "*" + if not columns: + raise ValueError("columns parameter cannot be empty, specify columns to select (e.g. ['id', 'name'])") + if not isinstance(columns, list) or not all(isinstance(col, str) for col in columns): + raise TypeError("columns must be a list of strings (e.g. ['id', 'name'])") + + cols = ", ".join(columns) sql = f"SELECT {cols} FROM {table_name}" where_sql, where_parems = self._get_where_sql(where) -- Gitee