Skip to content

[Bug] LegalizeOps batch_norm uses training-mode computation, ignoring provided running stats #19574

@wuyii8941

Description

@wuyii8941

Description

LegalizeOps lowering of relax.nn.batch_norm computes mean and variance from the input data (training mode) instead of using the provided running_mean and running_var parameters (inference mode). This produces incorrect results when batch_norm is used for inference.

When DecomposeOpsForInference is applied before LegalizeOps, the batch_norm is correctly decomposed using the provided running statistics.

Reproducer

import numpy as np
import tvm
from tvm import relax
import tvm.relax.op as R

N, C, HW = 1, 4, 4

bb = relax.BlockBuilder()
x = relax.Var('x', relax.TensorStructInfo((N, C, HW, HW), 'float32'))
gamma = relax.Var('gamma', relax.TensorStructInfo((C,), 'float32'))
beta = relax.Var('beta', relax.TensorStructInfo((C,), 'float32'))
mean = relax.Var('mean', relax.TensorStructInfo((C,), 'float32'))
var = relax.Var('var', relax.TensorStructInfo((C,), 'float32'))
with bb.function('main', [x, gamma, beta, mean, var]):
    with bb.dataflow():
        bn = bb.emit(R.nn.batch_norm(x, gamma, beta, mean, var, axis=1, epsilon=1e-5))
        out = bb.emit_output(relax.TupleGetItem(bn, 0))
    bb.emit_func_output(out)
mod = bb.finalize()

x_np = np.ones((N, C, HW, HW), dtype=np.float32)
gamma_np = np.array([1.0, 2.0, 0.5, 3.0], dtype=np.float32)
beta_np = np.array([0.0, 1.0, -1.0, 0.5], dtype=np.float32)
mean_np = np.array([0.0, 0.5, 1.0, -1.0], dtype=np.float32)
var_np = np.array([1.0, 0.25, 4.0, 0.01], dtype=np.float32)
inputs = [x_np, gamma_np, beta_np, mean_np, var_np]

def run(mod, passes, inputs):
    pipeline = tvm.ir.transform.Sequential(passes)
    mod_l = pipeline(mod)
    exe = tvm.relax.build(mod_l, target='llvm')
    vm = tvm.relax.VirtualMachine(exe, device=tvm.cpu())
    tvm_inputs = [tvm.runtime.tensor(x, device=tvm.cpu()) for x in inputs]
    return vm['main'](*tvm_inputs).numpy()

out_legalize = run(mod, [relax.transform.LegalizeOps()], inputs)
out_correct = run(mod, [relax.transform.DecomposeOpsForInference(),
                         relax.transform.LegalizeOps()], inputs)

# Expected: (1.0 - (-1.0)) / sqrt(0.01 + 1e-5) * 3.0 + 0.5 = 60.47
# Channel 3: legalize=0.5 (wrong), correct=60.47
print(f"LegalizeOps:  {out_legalize[0, 3, 0, 0]:.4f}")   # 0.5000 (WRONG)
print(f"Correct:      {out_correct[0, 3, 0, 0]:.4f}")     # 60.4700

Expected behavior

LegalizeOps batch_norm should use the provided running_mean and running_var for inference, producing the same result as DecomposeOpsForInference.

Actual behavior

LegalizeOps computes mean/var from the input tensor (training-mode batch normalization), producing incorrect results. The max error vs ground truth is 59.97.

Environment

  • TVM version: 0.24.dev0 (commit 0b0afd8, 2026-04-24)
  • Target: llvm (CPU)

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions