diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py index a862b4dbc42eff7e733d41ad61c45d98c125050c..4545cf81e08f7af72c2e218fa552f2794dd2aada 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/utils/graph_utils.py @@ -23,6 +23,7 @@ import sys from functools import cmp_to_key from pathlib import Path from tensorboard.util import tb_logging +from werkzeug import Request from .global_state import GraphState, FILE_NAME_REGEX, MAX_FILE_SIZE, PERM_GROUP_WRITE, PERM_OTHER_WRITE logger = tb_logging.get_logger() @@ -140,6 +141,89 @@ class GraphUtils: except Exception as e: logger.error(f"Unexpected error: {e}") return default_value + + @staticmethod + def safe_get_node_info(request: Request, default_value=None): + + node_info_str = request.args.get("nodeInfo") + + # 如果参数不存在,返回None + if node_info_str is None: + logger.error("nodeInfo参数不存在") + return default_value + + # 检查是否为字符串类型(防止其他类型注入) + if not isinstance(node_info_str, str): + logger.error("nodeInfo参数必须是JSON字符串") + return default_value + + # 长度限制 + if len(node_info_str) > MAX_FILE_SIZE: + logger.error(f"Input length exceeds {MAX_FILE_SIZE} characters.") + return default_value + + try: + # 解析JSON + node_info = json.loads(node_info_str) + + # 验证解析结果是否为字典 + if not isinstance(node_info, dict): + logger.error("nodeInfo必须是JSON对象") + return default_value + + # 验证必要字段是否存在 + if "nodeName" not in node_info or "nodeType" not in node_info: + logger.error("nodeInfo必须包含nodeName和nodeType字段") + return default_value + + return node_info + + except json.JSONDecodeError: + logger.error("nodeInfo参数不是有效的JSON格式") + return default_value + except Exception as e: + logger.error(f"解析nodeInfo参数时发生错误: {str(e)}") + return default_value + + @staticmethod + def safe_get_meta_data(request: Request, default_value=None): + + meta_data_str = request.args.get("metaData") + + # 如果参数不存在,返回None + if meta_data_str is None: + logger.error("metaData参数不存在") + return default_value + + # 检查是否为字符串类型(防止其他类型注入) + if not isinstance(meta_data_str, str): + logger.error("metaData参数必须是JSON字符串") + return default_value + + try: + # 解析JSON + meta_data = json.loads(meta_data_str) + + # 验证解析结果是否为字典 + if not isinstance(meta_data, dict): + logger.error("metaData必须是JSON对象") + return default_value + + # 验证必要字段是否存在 + required_fields = ["tag", "microStep", "run"] + for field in required_fields: + if field not in meta_data: + logger.error("metaData必须包含nodeName和nodeType字段") + return default_value + + return meta_data + + except json.JSONDecodeError: + logger.error("metaData参数不是有效的JSON格式") + return default_value + except Exception as e: + logger.error(f"解析metaData参数时发生错误: {str(e)}") + return default_value @staticmethod def remove_prototype_pollution(obj, current_depth=1, max_depth=200): diff --git a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py index 17d85b704ecd4b1b4d3b3800b9bd533bd31bf3cd..aa02ae9ab76ab8653b0bc66f04048c33c8819864 100644 --- a/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py +++ b/plugins/tensorboard-plugins/tb_graph_ascend/server/app/views/graph_views.py @@ -98,8 +98,8 @@ class GraphView: @staticmethod @wrappers.Request.application def change_node_expand_state(request): - node_info = GraphUtils.safe_json_loads(request.args.get("nodeInfo")) - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + node_info = GraphUtils.safe_get_node_info(request) + meta_data = GraphUtils.safe_get_meta_data(request) hierarchy = GraphService.change_node_expand_state(node_info, meta_data) return http_util.Respond(request, json.dumps(hierarchy), "application/json") @@ -115,8 +115,8 @@ class GraphView: @staticmethod @wrappers.Request.application def get_node_info(request): - node_info = GraphUtils.safe_json_loads(request.args.get("nodeInfo")) - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + node_info = GraphUtils.safe_get_node_info(request) + meta_data = GraphUtils.safe_get_meta_data(request) node_detail = GraphService.get_node_info(node_info, meta_data) return http_util.Respond(request, json.dumps(node_detail), "application/json") @@ -125,7 +125,7 @@ class GraphView: @wrappers.Request.application def add_match_nodes_by_config(request): config_file = request.args.get("configFile") - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + meta_data = GraphUtils.safe_get_meta_data(request) match_result = GraphService.add_match_nodes_by_config(config_file, meta_data) return http_util.Respond(request, json.dumps(match_result), "application/json") @@ -135,7 +135,7 @@ class GraphView: def add_match_nodes(request): npu_node_name = request.args.get("npuNodeName") bench_node_name = request.args.get("benchNodeName") - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + meta_data = GraphUtils.safe_get_meta_data(request) is_match_children = GraphUtils.safe_json_loads(request.args.get("isMatchChildren")) match_result = GraphService.add_match_nodes(npu_node_name, bench_node_name, meta_data, is_match_children) return http_util.Respond(request, json.dumps(match_result), "application/json") @@ -146,7 +146,7 @@ class GraphView: def delete_match_nodes(request): npu_node_name = request.args.get("npuNodeName") bench_node_name = request.args.get("benchNodeName") - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + meta_data = GraphUtils.safe_get_meta_data(request) is_unmatch_children = GraphUtils.safe_json_loads(request.args.get("isUnMatchChildren")) match_result = GraphService.delete_match_nodes(npu_node_name, bench_node_name, meta_data, is_unmatch_children) return http_util.Respond(request, json.dumps(match_result), "application/json") @@ -155,7 +155,7 @@ class GraphView: @staticmethod @wrappers.Request.application def save_data(request): - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + meta_data = GraphUtils.safe_get_meta_data(request) save_result = GraphService.save_data(meta_data) return http_util.Respond(request, json.dumps(save_result), "application/json") @@ -172,6 +172,6 @@ class GraphView: @staticmethod @wrappers.Request.application def save_matched_relations(request): - meta_data = GraphUtils.safe_json_loads(request.args.get("metaData")) + meta_data = GraphUtils.safe_get_meta_data(request) save_result = GraphService.save_matched_relations(meta_data) return http_util.Respond(request, json.dumps(save_result), "application/json")