diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 0e9edec8a6d8a2a3e661354e428324f7b3cd7330..d33ed21d8134afc0a8e90453754dd53da38899eb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -13,7 +13,8 @@ class APIInfo: self.rank = os.getpid() self.api_name = api_name self.save_real_data = msCheckerConfig.real_data - + self.torch_object_key = {'device' : self.analyze_device_in_kwargs, 'dtype' : self.analyze_dtype_in_kwargs} + def analyze_element(self, element): if isinstance(element, (list, tuple)): out = [] @@ -22,7 +23,11 @@ class APIInfo: elif isinstance(element, dict): out = {} for key, value in element.items(): - out[key] = self.analyze_element(value) + if key in self.torch_object_key.keys(): + fun = self.torch_object_key[key] + out[key] = fun(value) + else: + out[key] = self.analyze_element(value) elif isinstance(element, torch.Tensor): out = self.analyze_tensor(element, self.save_real_data) @@ -78,7 +83,26 @@ class APIInfo: if element is None or isinstance(element, (bool,int,float,str,slice)): return True return False - + + def analyze_device_in_kwargs(self, element): + single_arg = {} + single_arg.update({'type' : 'torch.device'}) + if not isinstance(element, str): + + if hasattr(element, "index"): + device_value = element.type + ":" + str(element.index) + single_arg.update({'value' : device_value}) + else: + device_value = element.type + else: + single_arg.update({'value' : element}) + return single_arg + + def analyze_dtype_in_kwargs(self, element): + single_arg = {} + single_arg.update({'type' : 'torch.dtype'}) + single_arg.update({'value' : str(element)}) + return single_arg def get_tensor_extremum(self, data, operator): if data.dtype is torch.bool: