From c97962e3eb02979e68a726fc8e520c70b8ae97e3 Mon Sep 17 00:00:00 2001 From: wangchao Date: Fri, 4 Aug 2023 16:46:41 +0800 Subject: [PATCH] support uint8 compare --- .../api_accuracy_checker/compare/algorithm.py | 47 +++++++++++-------- .../run_ut/data_generate.py | 2 + 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index a79125d83..b495a5316 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -1,8 +1,8 @@ # 定义比对算法及比对标准 -import torch -import numpy as np -from api_accuracy_checker.compare.compare_utils import CompareConst +import torch +import numpy as np +from api_accuracy_checker.compare.compare_utils import CompareConst from api_accuracy_checker.common.utils import print_warn_log, Const def compare_torch_tensor(cpu_output, npu_output, compare_alg): @@ -13,16 +13,16 @@ def compare_torch_tensor(cpu_output, npu_output, compare_alg): def compare_bool_tensor(cpu_output, npu_output): error_rate = CompareConst.NAN - cpu_shape = cpu_output.shape - npu_shape = npu_output.shape + cpu_shape = cpu_output.shape + npu_shape = npu_output.shape if cpu_shape != npu_shape: return error_rate, False - npu_data = npu_output.cpu().detach().numpy() + npu_data = npu_output.cpu().detach().numpy() bench_data = cpu_output.detach().numpy() - data_size = bench_data.size + data_size = bench_data.size error_nums = (bench_data != npu_data).sum() error_rate = float(error_nums / data_size) - return error_rate, error_rate < 0.001 + return error_rate, error_rate < 0.001 def get_max_rel_err(n_value, b_value): @@ -35,8 +35,8 @@ def get_max_rel_err(n_value, b_value): rel_err = np.abs((n_value - b_value) / (b_value + np.finfo(b_value.dtype).eps)).max() return rel_err, rel_err < 0.001 if np.all(n_value == b_value): - return 0, True - return 1, False + return 0, True + return 1, False def cosine_standard(compare_result): @@ -47,42 +47,51 @@ def cosine_standard(compare_result): def cosine_sim(cpu_output, npu_output): n_value = npu_output.cpu().detach().numpy().reshape(-1) b_value = cpu_output.detach().numpy().reshape(-1) - cos = CompareConst.NA + cos = CompareConst.NA np.seterr(divide="ignore", invalid="ignore") if len(n_value) == 1: print_warn_log("All the data in npu dump data is scalar. Compare by relative error.") return get_max_rel_err(n_value, b_value) + if n_value.dtype == np.uint8: + return compare_uint8_data(n_value, b_value) num = n_value.dot(b_value) a_norm = np.linalg.norm(n_value) b_norm = np.linalg.norm(b_value) if a_norm <= np.finfo(float).eps and b_norm <= np.finfo(float).eps: - return cos, True - elif a_norm <= np.finfo(float).eps: + return cos, True + elif a_norm <= np.finfo(float).eps: print_warn_log("All the data is Zero in npu dump data. Compare by relative error.") return get_max_rel_err(n_value, b_value) elif b_norm <= np.finfo(float).eps: print_warn_log("All the data is Zero in bench dump data. Compare by relative error.") - else: + else: cos = num / (a_norm * b_norm) if np.isnan(cos): print_warn_log("Dump data has NaN when comparing with Cosine Similarity.") return cos, cos > 0.99 +def compare_uint8_data(n_value, b_value): + if (n_value == b_value).all(): + return 1, True + else: + return 0, False + + def compare_builtin_type(bench_out, npu_out): if bench_out != npu_out: - return CompareConst.NAN, False - return 1.0, True + return CompareConst.NAN, False + return 1.0, True def flatten_compare_result(result): - flatten_result = [] + flatten_result = [] for result_i in result: if isinstance(result_i, list): flatten_result += flatten_compare_result(result_i) else: flatten_result.append(result_i) - return flatten_result + return flatten_result def compare_core(bench_out, npu_out, alg): @@ -113,4 +122,4 @@ def compare_core(bench_out, npu_out, alg): compare_result = flatten_compare_result(compare_result) return compare_result, test_success - + diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 65091159a..92c85a16a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -187,6 +187,8 @@ def gen_kwargs(api_info, convert_type=None): kwargs_params[key] = gen_list_kwargs(value, convert_type) elif value.get('type') in TENSOR_DATA_LIST: kwargs_params[key] = gen_data(value, False, convert_type) + elif value.get('type') == "torch.device": + kwargs_params[key] = torch.device(value.get('value')) else: kwargs_params[key] = value.get('value') return kwargs_params -- Gitee