Description
LegalizeOps lowering of relax.nn.batch_norm computes mean and variance from the input data (training mode) instead of using the provided running_mean and running_var parameters (inference mode). This produces incorrect results when batch_norm is used for inference.
When DecomposeOpsForInference is applied before LegalizeOps, the batch_norm is correctly decomposed using the provided running statistics.
Reproducer
import numpy as np
import tvm
from tvm import relax
import tvm.relax.op as R
N, C, HW = 1, 4, 4
bb = relax.BlockBuilder()
x = relax.Var('x', relax.TensorStructInfo((N, C, HW, HW), 'float32'))
gamma = relax.Var('gamma', relax.TensorStructInfo((C,), 'float32'))
beta = relax.Var('beta', relax.TensorStructInfo((C,), 'float32'))
mean = relax.Var('mean', relax.TensorStructInfo((C,), 'float32'))
var = relax.Var('var', relax.TensorStructInfo((C,), 'float32'))
with bb.function('main', [x, gamma, beta, mean, var]):
with bb.dataflow():
bn = bb.emit(R.nn.batch_norm(x, gamma, beta, mean, var, axis=1, epsilon=1e-5))
out = bb.emit_output(relax.TupleGetItem(bn, 0))
bb.emit_func_output(out)
mod = bb.finalize()
x_np = np.ones((N, C, HW, HW), dtype=np.float32)
gamma_np = np.array([1.0, 2.0, 0.5, 3.0], dtype=np.float32)
beta_np = np.array([0.0, 1.0, -1.0, 0.5], dtype=np.float32)
mean_np = np.array([0.0, 0.5, 1.0, -1.0], dtype=np.float32)
var_np = np.array([1.0, 0.25, 4.0, 0.01], dtype=np.float32)
inputs = [x_np, gamma_np, beta_np, mean_np, var_np]
def run(mod, passes, inputs):
pipeline = tvm.ir.transform.Sequential(passes)
mod_l = pipeline(mod)
exe = tvm.relax.build(mod_l, target='llvm')
vm = tvm.relax.VirtualMachine(exe, device=tvm.cpu())
tvm_inputs = [tvm.runtime.tensor(x, device=tvm.cpu()) for x in inputs]
return vm['main'](*tvm_inputs).numpy()
out_legalize = run(mod, [relax.transform.LegalizeOps()], inputs)
out_correct = run(mod, [relax.transform.DecomposeOpsForInference(),
relax.transform.LegalizeOps()], inputs)
# Expected: (1.0 - (-1.0)) / sqrt(0.01 + 1e-5) * 3.0 + 0.5 = 60.47
# Channel 3: legalize=0.5 (wrong), correct=60.47
print(f"LegalizeOps: {out_legalize[0, 3, 0, 0]:.4f}") # 0.5000 (WRONG)
print(f"Correct: {out_correct[0, 3, 0, 0]:.4f}") # 60.4700
Expected behavior
LegalizeOps batch_norm should use the provided running_mean and running_var for inference, producing the same result as DecomposeOpsForInference.
Actual behavior
LegalizeOps computes mean/var from the input tensor (training-mode batch normalization), producing incorrect results. The max error vs ground truth is 59.97.
Environment
- TVM version: 0.24.dev0 (commit 0b0afd8, 2026-04-24)
- Target: llvm (CPU)
Description
LegalizeOpslowering ofrelax.nn.batch_normcomputes mean and variance from the input data (training mode) instead of using the providedrunning_meanandrunning_varparameters (inference mode). This produces incorrect results whenbatch_normis used for inference.When
DecomposeOpsForInferenceis applied beforeLegalizeOps, the batch_norm is correctly decomposed using the provided running statistics.Reproducer
Expected behavior
LegalizeOpsbatch_norm should use the providedrunning_meanandrunning_varfor inference, producing the same result asDecomposeOpsForInference.Actual behavior
LegalizeOpscomputes mean/var from the input tensor (training-mode batch normalization), producing incorrect results. The max error vs ground truth is 59.97.Environment