Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,26 +462,29 @@ def _replace_inf(self, graph_module: GraphModule) -> GraphModule:
node.args = tuple(arg_list)
elif node.op == "get_attr":
constant_tensor = attrgetter(node.target)(graph_module)
if (
if not (
torch.is_tensor(constant_tensor)
and constant_tensor.is_floating_point()
and torch.isinf(constant_tensor).any()
):
# Anything smaller than float16.min, which covers float32.min and float(-inf)
min_value = torch.finfo(torch.float16).min
# Anything larger than float16.max, which covers float32.max and float(inf)
max_value = torch.finfo(torch.float16).max

quant_min, quant_max = float("inf"), float("-inf")
for source_node in node.users:
if quant_range := self._get_quant_range(source_node):
quant_min = min(quant_min, -quant_range)
quant_max = max(quant_max, quant_range)

if quant_min != float("inf") and quant_max != float("-inf"):
# Inplace update
with torch.no_grad():
constant_tensor[constant_tensor <= min_value] = quant_min
constant_tensor[constant_tensor >= max_value] = quant_max
continue

# Anything smaller than float16.min, which covers float32.min and float(-inf)
min_value = torch.finfo(torch.float16).min
# Anything larger than float16.max, which covers float32.max and float(inf)
max_value = torch.finfo(torch.float16).max

quant_min, quant_max = float("inf"), float("-inf")
for source_node in node.users:
if quant_range := self._get_quant_range(source_node):
quant_min = min(quant_min, -quant_range)
quant_max = max(quant_max, quant_range)

if quant_min != float("inf") and quant_max != float("-inf"):
# Inplace update
with torch.no_grad():
constant_tensor[constant_tensor <= min_value] = quant_min
constant_tensor[constant_tensor >= max_value] = quant_max

graph_module.recompile()

Expand Down
Loading