diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index c6fbc51484f..f69e3019114 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -462,26 +462,29 @@ def _replace_inf(self, graph_module: GraphModule) -> GraphModule: node.args = tuple(arg_list) elif node.op == "get_attr": constant_tensor = attrgetter(node.target)(graph_module) - if ( + if not ( torch.is_tensor(constant_tensor) and constant_tensor.is_floating_point() + and torch.isinf(constant_tensor).any() ): - # Anything smaller than float16.min, which covers float32.min and float(-inf) - min_value = torch.finfo(torch.float16).min - # Anything larger than float16.max, which covers float32.max and float(inf) - max_value = torch.finfo(torch.float16).max - - quant_min, quant_max = float("inf"), float("-inf") - for source_node in node.users: - if quant_range := self._get_quant_range(source_node): - quant_min = min(quant_min, -quant_range) - quant_max = max(quant_max, quant_range) - - if quant_min != float("inf") and quant_max != float("-inf"): - # Inplace update - with torch.no_grad(): - constant_tensor[constant_tensor <= min_value] = quant_min - constant_tensor[constant_tensor >= max_value] = quant_max + continue + + # Anything smaller than float16.min, which covers float32.min and float(-inf) + min_value = torch.finfo(torch.float16).min + # Anything larger than float16.max, which covers float32.max and float(inf) + max_value = torch.finfo(torch.float16).max + + quant_min, quant_max = float("inf"), float("-inf") + for source_node in node.users: + if quant_range := self._get_quant_range(source_node): + quant_min = min(quant_min, -quant_range) + quant_max = max(quant_max, quant_range) + + if quant_min != float("inf") and quant_max != float("-inf"): + # Inplace update + with torch.no_grad(): + constant_tensor[constant_tensor <= min_value] = quant_min + constant_tensor[constant_tensor >= max_value] = quant_max graph_module.recompile()