From 1f8b0cbb50dbb3cfd3a7d318b41b6906322dc8bc Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Mon, 25 May 2026 00:32:37 -0400 Subject: [PATCH 1/2] finish1 --- python/tvm/tirx/script/__init__.py | 10 +- python/tvm/tirx/script/builder/ir.py | 75 ++++++++- python/tvm/tirx/script/builder/tirx.py | 127 +++++++++++---- python/tvm/tirx/stmt.py | 10 ++ src/tirx/ir/stmt.cc | 6 + src/tirx/script/builder/utils.h | 6 +- tests/python/tirx-base/test_tir_buffer.py | 181 ++++++++++++++++++++++ 7 files changed, 371 insertions(+), 44 deletions(-) diff --git a/python/tvm/tirx/script/__init__.py b/python/tvm/tirx/script/__init__.py index 57877f4e73b8..09191fe32e5a 100644 --- a/python/tvm/tirx/script/__init__.py +++ b/python/tvm/tirx/script/__init__.py @@ -62,14 +62,16 @@ def _fn(*args, workspace=None, config=None, dispatch=None, **kwargs): workspace = {} if config is None: config = kwargs or {} - # Convert Buffer args to BufferRegion (covers full extent) + # Convert buffer-like tile args to BufferRegion. from tvm.tirx import Buffer as _TBuffer + from tvm.tirx.expr import BufferLoad as _TBufferLoad + + from .builder.tirx import _to_region new_args = [] for a in args: - if isinstance(a, _TBuffer): - slices = [slice(None) for _ in range(len(a.shape))] - a = a[slices] + if isinstance(a, _TBuffer | _TBufferLoad): + a = _to_region(a) new_args.append(a) # Insert into the active frame using same FFI hook as registered ops. from .builder.tirx import f_insert as _f_insert diff --git a/python/tvm/tirx/script/builder/ir.py b/python/tvm/tirx/script/builder/ir.py index 4547a864613e..de39ebd59e76 100644 --- a/python/tvm/tirx/script/builder/ir.py +++ b/python/tvm/tirx/script/builder/ir.py @@ -34,7 +34,7 @@ from tvm import DataType, ir from tvm import tirx as tir -from tvm.ir import Type +from tvm.ir import PointerType, PrimType, Type from tvm.ir import register_op_attr as _register_op_attr from tvm.ir.base import deprecated from tvm.runtime import convert @@ -1811,6 +1811,78 @@ def decl_buffer( return buf +def _infer_pointer_var_dtype_scope(var): + type_annotation = getattr(var, "type_annotation", None) + if isinstance(type_annotation, PointerType): + dtype = None + if isinstance(type_annotation.element_type, PrimType): + dtype = type_annotation.element_type.dtype + scope = type_annotation.storage_scope or None + return dtype, scope + return None, None + + +def buffer_from_ptr( + ptr, + shape, + dtype=None, + strides=None, + elem_offset=None, + byte_offset=None, + scope=None, + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, + layout="default", +) -> Buffer: + """Create a buffer view from a pointer or buffer element. + + When ``ptr`` is a BufferLoad, this helper creates a pointer to the loaded + element with ``T.address_of``. The pointer expression is then bound to a + Var before it is used as ``T.decl_buffer(..., data=...)``, preserving the + invariant that ``Buffer.data`` is a Var. + """ + if isinstance(ptr, Buffer): + raise ValueError( + "buffer_from_ptr expects a pointer or BufferLoad; use T.address_of(buffer) " + "for an explicit base pointer" + ) + if isinstance(ptr, BufferLoad): + if dtype is None: + dtype = ptr.dtype + if scope is None: + scope = ptr.buffer.scope() + ptr = address_of(ptr) + elif isinstance(ptr, Var): + ptr_dtype, ptr_scope = _infer_pointer_var_dtype_scope(ptr) + if dtype is None: + dtype = ptr_dtype + if scope is None: + scope = ptr_scope + + if dtype is None: + raise ValueError("buffer_from_ptr requires dtype when ptr is not a BufferLoad or typed Var") + if scope is None: + scope = "global" + + data = ptr if isinstance(ptr, Var) else Bind(ptr, handle(dtype, scope)) + return decl_buffer( + shape, + dtype=dtype, + data=data, + strides=strides, + elem_offset=elem_offset, + byte_offset=byte_offset, + scope=scope, + align=align, + offset_factor=offset_factor, + buffer_type=buffer_type, + axis_separators=axis_separators, + layout=layout, + ) + + alloc_shared = functools.partial(alloc_buffer, scope="shared") alloc_local = functools.partial(alloc_buffer, scope="local") smem = alloc_shared @@ -3743,6 +3815,7 @@ def visit(ns_obj, dotted_prefix): "Then", "Else", "decl_buffer", + "buffer_from_ptr", "launch_thread", "env_thread", "buffer_store", diff --git a/python/tvm/tirx/script/builder/tirx.py b/python/tvm/tirx/script/builder/tirx.py index efe79e1aa5bc..970181435bcf 100644 --- a/python/tvm/tirx/script/builder/tirx.py +++ b/python/tvm/tirx/script/builder/tirx.py @@ -22,7 +22,7 @@ import tvm.tirx.operator as tirx_op from tvm.ir import Op from tvm.tirx import Buffer, BufferRegion, PrimExpr -from tvm.tirx.expr import FloatImm +from tvm.tirx.expr import BufferLoad, FloatImm from tvm.tirx.lang.alloc_pool import SMEMPool, TMEMPool from tvm.tirx.predicate import Predicate @@ -34,9 +34,23 @@ def _is_buffer_or_region(x): return isinstance(x, Buffer | BufferRegion) -def _to_region(buffer: BufferRegion | Buffer): +def _is_buffer_region_or_load(x): + return isinstance(x, Buffer | BufferRegion | BufferLoad) + + +def _has_tile_call_options(workspace, dispatch, kwargs): + return ( + workspace is not None + or dispatch is not None + or any(key != "dtype" for key in kwargs.keys()) + ) + + +def _to_region(buffer: BufferRegion | Buffer | BufferLoad): + if isinstance(buffer, BufferLoad): + return BufferRegion.from_point(buffer.buffer, buffer.indices) if isinstance(buffer, Buffer): - return buffer[[slice(None, None, None) for _ in range(len(buffer.shape))]] + return BufferRegion.full_region(buffer) assert isinstance(buffer, BufferRegion) return buffer @@ -115,7 +129,13 @@ def sqrt( # Expression-form overload: ``sqrt(value)`` returns the underlying expression. from tvm import tirx as _tirx - if not _is_buffer_or_region(dst): + if ( + src is None + and bias is None + and scale is None + and not _has_tile_call_options(workspace, dispatch, kwargs) + and not _is_buffer_or_region(dst) + ): return _tirx.sqrt(dst) if src is None: src = dst @@ -124,7 +144,7 @@ def sqrt( config = kwargs or {} dst = _to_region(dst) src = _to_region(src) - if bias is not None and isinstance(bias, Buffer): + if bias is not None and isinstance(bias, Buffer | BufferLoad): bias = _to_region(bias) return f_insert( tirx_op.Sqrt(dst, src, bias, scale, workspace=workspace, config=config, dispatch=dispatch) @@ -159,9 +179,9 @@ def add( workspace = {} config = kwargs or {} dst = _to_region(dst) - if isinstance(src1, Buffer): + if isinstance(src1, Buffer | BufferLoad): src1 = _to_region(src1) - if isinstance(src2, Buffer): + if isinstance(src2, Buffer | BufferLoad): src2 = _to_region(src2) return f_insert( tirx_op.Add(dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch) @@ -196,9 +216,9 @@ def sub( workspace = {} config = kwargs or {} dst = _to_region(dst) - if isinstance(src1, Buffer): + if isinstance(src1, Buffer | BufferLoad): src1 = _to_region(src1) - if isinstance(src2, Buffer): + if isinstance(src2, Buffer | BufferLoad): src2 = _to_region(src2) return f_insert( tirx_op.Sub(dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch) @@ -233,9 +253,9 @@ def mul( workspace = {} config = kwargs or {} dst = _to_region(dst) - if isinstance(src1, Buffer): + if isinstance(src1, Buffer | BufferLoad): src1 = _to_region(src1) - if isinstance(src2, Buffer): + if isinstance(src2, Buffer | BufferLoad): src2 = _to_region(src2) return f_insert( tirx_op.Mul(dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch) @@ -271,7 +291,7 @@ def fdiv( config = kwargs or {} dst = _to_region(dst) src1 = _to_region(src1) - if isinstance(src2, Buffer): + if isinstance(src2, Buffer | BufferLoad): src2 = _to_region(src2) return f_insert( tirx_op.FDiv(dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch) @@ -311,9 +331,9 @@ def fma( config = kwargs or {} dst = _to_region(dst) src = _to_region(src) - if isinstance(scale, Buffer): + if isinstance(scale, Buffer | BufferLoad): scale = _to_region(scale) - if isinstance(bias, Buffer): + if isinstance(bias, Buffer | BufferLoad): bias = _to_region(bias) return f_insert( tirx_op.FMA(dst, src, scale, bias, workspace=workspace, config=config, dispatch=dispatch) @@ -623,7 +643,16 @@ def max( """ from tvm import tirx as _tirx - if not isinstance(dst, BufferRegion | Buffer) or not isinstance(src, BufferRegion | Buffer): + if not _is_buffer_region_or_load(dst) or not _is_buffer_region_or_load(src): + # Expression-level max + return _tirx.max(dst, src) + if ( + isinstance(dst, BufferLoad) + and isinstance(src, BufferLoad) + and axes == -1 + and not accum + and not _has_tile_call_options(workspace, dispatch, kwargs) + ): # Expression-level max return _tirx.max(dst, src) if workspace is None: @@ -653,7 +682,15 @@ def min( """ from tvm import tirx as _tirx - if not isinstance(dst, BufferRegion | Buffer) or not isinstance(src, BufferRegion | Buffer): + if not _is_buffer_region_or_load(dst) or not _is_buffer_region_or_load(src): + return _tirx.min(dst, src) + if ( + isinstance(dst, BufferLoad) + and isinstance(src, BufferLoad) + and axes == -1 + and not accum + and not _has_tile_call_options(workspace, dispatch, kwargs) + ): return _tirx.min(dst, src) if workspace is None: workspace = {} @@ -690,7 +727,11 @@ def reciprocal( # Expression-form overload: ``reciprocal(value)`` returns the underlying expression. from tvm import tirx as _tirx - if not _is_buffer_or_region(dst): + if ( + src is None + and not _has_tile_call_options(workspace, dispatch, kwargs) + and not _is_buffer_or_region(dst) + ): return _tirx.reciprocal(dst) if src is None: src = dst @@ -706,7 +747,7 @@ def reciprocal( def silu( dst: BufferRegion | Buffer, - src: BufferRegion | Buffer, + src: BufferRegion | Buffer | None = None, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, **kwargs, @@ -727,8 +768,14 @@ def silu( # Expression-form overload: ``silu(value)`` returns the underlying expression. from tvm import tirx as _tirx - if not _is_buffer_or_region(dst): + if ( + src is None + and not _has_tile_call_options(workspace, dispatch, kwargs) + and not _is_buffer_or_region(dst) + ): return _tirx.silu(dst) + if src is None: + src = dst if workspace is None: workspace = {} config = kwargs or {} @@ -794,9 +841,9 @@ def maximum( workspace = {} config = kwargs or {} dst = _to_region(dst) - if isinstance(src1, Buffer): + if isinstance(src1, Buffer | BufferLoad): src1 = _to_region(src1) - if isinstance(src2, Buffer): + if isinstance(src2, Buffer | BufferLoad): src2 = _to_region(src2) return f_insert( tirx_op.Maximum(dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch) @@ -831,9 +878,9 @@ def minimum( workspace = {} config = kwargs or {} dst = _to_region(dst) - if isinstance(src1, Buffer): + if isinstance(src1, Buffer | BufferLoad): src1 = _to_region(src1) - if isinstance(src2, Buffer): + if isinstance(src2, Buffer | BufferLoad): src2 = _to_region(src2) return f_insert( tirx_op.Minimum(dst, src1, src2, workspace=workspace, config=config, dispatch=dispatch) @@ -872,7 +919,13 @@ def exp( # Expression-form overload: ``exp(value)`` returns the underlying expression. from tvm import tirx as _tirx - if not _is_buffer_or_region(dst): + if ( + src is None + and bias is None + and scale is None + and not _has_tile_call_options(workspace, dispatch, kwargs) + and not _is_buffer_or_region(dst) + ): return _tirx.exp(dst) if src is None: src = dst @@ -881,7 +934,7 @@ def exp( config = kwargs or {} dst = _to_region(dst) src = _to_region(src) - if bias is not None and isinstance(bias, Buffer): + if bias is not None and isinstance(bias, Buffer | BufferLoad): bias = _to_region(bias) return f_insert( tirx_op.Exp(dst, src, bias, scale, workspace=workspace, config=config, dispatch=dispatch) @@ -920,7 +973,13 @@ def exp2( # Expression-form overload: ``exp2(value)`` returns the underlying expression. from tvm import tirx as _tirx - if not _is_buffer_or_region(dst): + if ( + src is None + and bias is None + and scale is None + and not _has_tile_call_options(workspace, dispatch, kwargs) + and not _is_buffer_or_region(dst) + ): return _tirx.exp2(dst) if src is None: src = dst @@ -929,7 +988,7 @@ def exp2( config = kwargs or {} dst = _to_region(dst) src = _to_region(src) - if bias is not None and isinstance(bias, Buffer): + if bias is not None and isinstance(bias, Buffer | BufferLoad): bias = _to_region(bias) return f_insert( tirx_op.Exp2(dst, src, bias, scale, workspace=workspace, config=config, dispatch=dispatch) @@ -1009,9 +1068,9 @@ def binary_reduce( workspace = {} binary_output = _to_region(binary_output) reduce_output = _to_region(reduce_output) - if isinstance(binary_input1, Buffer): + if isinstance(binary_input1, Buffer | BufferLoad): binary_input1 = _to_region(binary_input1) - if isinstance(binary_input2, Buffer): + if isinstance(binary_input2, Buffer | BufferLoad): binary_input2 = _to_region(binary_input2) reduce_axes = _wrap_elem_in_tuple(reduce_axes) @@ -1090,7 +1149,7 @@ def unary_reduce( reduce_output = _to_region(reduce_output) unary_input = _to_region(unary_input) - if bias is not None and isinstance(bias, Buffer): + if bias is not None and isinstance(bias, Buffer | BufferLoad): bias = _to_region(bias) reduce_axes = _wrap_elem_in_tuple(reduce_axes) @@ -1171,9 +1230,9 @@ def binary_chain( output = _to_region(output) data = _to_region(data) - if isinstance(operand0, Buffer): + if isinstance(operand0, Buffer | BufferLoad): operand0 = _to_region(operand0) - if isinstance(operand1, Buffer): + if isinstance(operand1, Buffer | BufferLoad): operand1 = _to_region(operand1) if isinstance(op0, str): @@ -1280,9 +1339,9 @@ def select( The predicate to evaluate. The callable should take the same number of arguments as the dimensions of the destination buffer. """ # noqa: E501 dst = _to_region(dst) - if isinstance(true_value, Buffer): + if isinstance(true_value, Buffer | BufferLoad): true_value = _to_region(true_value) - if isinstance(false_value, Buffer): + if isinstance(false_value, Buffer | BufferLoad): false_value = _to_region(false_value) if not isinstance(pred, Predicate): pred = Predicate(pred) diff --git a/python/tvm/tirx/stmt.py b/python/tvm/tirx/stmt.py index f1072bf25a07..06a40f25a51e 100644 --- a/python/tvm/tirx/stmt.py +++ b/python/tvm/tirx/stmt.py @@ -630,6 +630,16 @@ class BufferRegion(Object, Scriptable): def __init__(self, buffer: Buffer, region: list[Range]) -> None: self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer, region) # type: ignore + @staticmethod + def full_region(buffer: Buffer) -> "BufferRegion": + """Create a BufferRegion covering the full buffer.""" + return _ffi_api.BufferRegionFullRegion(buffer) # type: ignore + + @staticmethod + def from_point(buffer: Buffer, indices: list[PrimExpr]) -> "BufferRegion": + """Create a single-point BufferRegion from buffer indices.""" + return _ffi_api.BufferRegionFromPoint(buffer, indices) # type: ignore + def __getitem__(self, indices): from ..arith import Analyzer diff --git a/src/tirx/ir/stmt.cc b/src/tirx/ir/stmt.cc index 194313592675..1158cf865618 100644 --- a/src/tirx/ir/stmt.cc +++ b/src/tirx/ir/stmt.cc @@ -522,6 +522,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("tirx.BufferRegion", [](Buffer buffer, ffi::Array region) { return BufferRegion(buffer, region); }); + refl::GlobalDef().def("tirx.BufferRegionFullRegion", + [](Buffer buffer) { return BufferRegion::FullRegion(buffer); }); + refl::GlobalDef().def("tirx.BufferRegionFromPoint", + [](Buffer buffer, ffi::Array indices) { + return BufferRegion::FromPoint(buffer, indices); + }); } // MatchBufferRegion diff --git a/src/tirx/script/builder/utils.h b/src/tirx/script/builder/utils.h index fc0293fbfca0..0c4667046953 100644 --- a/src/tirx/script/builder/utils.h +++ b/src/tirx/script/builder/utils.h @@ -142,11 +142,7 @@ inline IfFrame FindIfFrame(const ffi::String& method) { * \return The converted BufferRegion. */ inline tvm::tirx::BufferRegion BufferRegionFromLoad(tvm::tirx::BufferLoad buffer_load) { - ffi::Array ranges; - for (const PrimExpr& index : buffer_load->indices) { - ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1))); - } - return tvm::tirx::BufferRegion(buffer_load->buffer, ranges); + return tvm::tirx::BufferRegion::FromPoint(buffer_load->buffer, buffer_load->indices); } } // namespace tirx diff --git a/tests/python/tirx-base/test_tir_buffer.py b/tests/python/tirx-base/test_tir_buffer.py index bcdd0830a7f3..fc352096db39 100644 --- a/tests/python/tirx-base/test_tir_buffer.py +++ b/tests/python/tirx-base/test_tir_buffer.py @@ -86,6 +86,187 @@ def test_buffer_access_ptr_extent(): tvm.ir.assert_structural_equal(aptr.args[3], T.int32(100)) +def test_buffer_from_ptr_buffer_load(): + @T.prim_func(private=True, s_tir=True) + def actual(A: T.Buffer((32,), "float32")): + B = T.buffer_from_ptr(A[16], shape=(16,)) + B[0] = T.float32(1) + + @T.prim_func(private=True, s_tir=True) + def expected(A: T.Buffer((32,), "float32")): + B_data: T.let[T.handle("float32", "global")] = T.address_of(A[16]) + B = T.decl_buffer((16,), "float32", data=B_data) + B[0] = T.float32(1) + + tvm.ir.assert_structural_equal(actual, expected) + + +def test_buffer_from_ptr_raw_pointer(): + @T.prim_func(private=True, s_tir=True) + def actual(A_data: T.handle("float32")): + A = T.buffer_from_ptr(A_data, shape=(16,)) + A[0] = T.float32(1) + + @T.prim_func(private=True, s_tir=True) + def expected(A_data: T.handle("float32")): + A = T.decl_buffer((16,), "float32", data=A_data) + A[0] = T.float32(1) + + tvm.ir.assert_structural_equal(actual, expected) + + +def test_buffer_from_ptr_rejects_buffer_without_explicit_pointer(): + with pytest.raises(tvm.error.DiagnosticError): + + @T.prim_func(private=True, s_tir=True) + def _func(A: T.Buffer((16,), "float32")): + B = T.buffer_from_ptr(A, shape=(16,)) + B[0] = T.float32(1) + + +def test_tile_op_canonicalizes_buffer_load_to_region(): + @T.prim_func(s_tir=True) + def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): + with T.kernel(): + T.copy(B[1], A[2]) + + op = func.body.body + dst = op.args[0] + src = op.args[1] + assert isinstance(dst, tvm.tirx.BufferRegion) + assert isinstance(src, tvm.tirx.BufferRegion) + tvm.ir.assert_structural_equal(dst.region[0].min, T.int32(1)) + tvm.ir.assert_structural_equal(dst.region[0].extent, T.int32(1)) + tvm.ir.assert_structural_equal(src.region[0].min, T.int32(2)) + tvm.ir.assert_structural_equal(src.region[0].extent, T.int32(1)) + + +def test_tile_op_canonicalizes_optional_buffer_load_operands_to_region(): + @T.prim_func(s_tir=True) + def func( + A: T.Buffer((4,), "float32"), + B: T.Buffer((4,), "float32"), + C: T.Buffer((4,), "float32"), + ): + with T.kernel(): + T.add(C[0], A[1], B[2]) + + op = func.body.body + dst, src1, src2 = op.args + assert isinstance(dst, tvm.tirx.BufferRegion) + assert isinstance(src1, tvm.tirx.BufferRegion) + assert isinstance(src2, tvm.tirx.BufferRegion) + tvm.ir.assert_structural_equal(dst.region[0].min, T.int32(0)) + tvm.ir.assert_structural_equal(dst.region[0].extent, T.int32(1)) + tvm.ir.assert_structural_equal(src1.region[0].min, T.int32(1)) + tvm.ir.assert_structural_equal(src1.region[0].extent, T.int32(1)) + tvm.ir.assert_structural_equal(src2.region[0].min, T.int32(2)) + tvm.ir.assert_structural_equal(src2.region[0].extent, T.int32(1)) + + +def test_overloaded_tile_op_canonicalizes_buffer_load_operands_to_region(): + @T.prim_func(s_tir=True) + def unary(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "float32")): + with T.kernel(): + T.exp(C[0], A[1]) + + op = unary.body.body + dst, src = op.args[:2] + assert isinstance(dst, tvm.tirx.BufferRegion) + assert isinstance(src, tvm.tirx.BufferRegion) + tvm.ir.assert_structural_equal(dst.region[0].min, T.int32(0)) + tvm.ir.assert_structural_equal(dst.region[0].extent, T.int32(1)) + tvm.ir.assert_structural_equal(src.region[0].min, T.int32(1)) + tvm.ir.assert_structural_equal(src.region[0].extent, T.int32(1)) + + @T.prim_func(s_tir=True) + def reduction(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "float32")): + with T.kernel(): + T.max(C, A[2]) + + op = reduction.body.body + dst, src = op.args[:2] + assert isinstance(dst, tvm.tirx.BufferRegion) + assert isinstance(src, tvm.tirx.BufferRegion) + tvm.ir.assert_structural_equal(dst.region[0].min, T.int32(0)) + tvm.ir.assert_structural_equal(dst.region[0].extent, T.int32(4)) + tvm.ir.assert_structural_equal(src.region[0].min, T.int32(2)) + tvm.ir.assert_structural_equal(src.region[0].extent, T.int32(1)) + + +def test_overloaded_tile_op_with_options_does_not_fall_back_to_expression(): + @T.prim_func(s_tir=True) + def unary(A: T.Buffer((4,), "float32"), Bias: T.Buffer((4,), "float32")): + with T.kernel(): + T.exp(A[0], bias=Bias[1]) + + op = unary.body.body + dst, src, bias = op.args[:3] + assert isinstance(dst, tvm.tirx.BufferRegion) + assert isinstance(src, tvm.tirx.BufferRegion) + assert isinstance(bias, tvm.tirx.BufferRegion) + tvm.ir.assert_structural_equal(dst.region[0].min, T.int32(0)) + tvm.ir.assert_structural_equal(dst.region[0].extent, T.int32(1)) + tvm.ir.assert_structural_equal(src.region[0].min, T.int32(0)) + tvm.ir.assert_structural_equal(src.region[0].extent, T.int32(1)) + tvm.ir.assert_structural_equal(bias.region[0].min, T.int32(1)) + tvm.ir.assert_structural_equal(bias.region[0].extent, T.int32(1)) + + @T.prim_func(s_tir=True) + def reduction(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): + with T.kernel(): + T.max(A[0], B[1], axes=0) + + op = reduction.body.body + dst, src = op.args[:2] + assert isinstance(dst, tvm.tirx.BufferRegion) + assert isinstance(src, tvm.tirx.BufferRegion) + tvm.ir.assert_structural_equal(dst.region[0].min, T.int32(0)) + tvm.ir.assert_structural_equal(dst.region[0].extent, T.int32(1)) + tvm.ir.assert_structural_equal(src.region[0].min, T.int32(1)) + tvm.ir.assert_structural_equal(src.region[0].extent, T.int32(1)) + + +def test_ambiguous_buffer_load_max_remains_expression(): + @T.prim_func(private=True, s_tir=True) + def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): + B[0] = T.max(A[1], B[2]) + + assert isinstance(func.body.value, tvm.tirx.PrimExpr) + + +def test_expression_dtype_kwarg_does_not_force_tile_op(): + @T.prim_func(private=True, s_tir=True) + def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): + B[0] = T.exp(A[1], dtype="float32") + T.max(A[1], B[2], dtype="float32") + + assert isinstance(func.body.value, tvm.tirx.PrimExpr) + + +def test_dynamic_tile_op_canonicalizes_buffer_load_to_region(): + @T.prim_func(s_tir=True) + def func(A: T.Buffer((4,), "float32")): + with T.kernel(): + T.test_buffer_load_region(A[2]) + + arg = func.body.body.args[0] + assert isinstance(arg, tvm.tirx.BufferRegion) + tvm.ir.assert_structural_equal(arg.region[0].min, T.int32(2)) + tvm.ir.assert_structural_equal(arg.region[0].extent, T.int32(1)) + + +def test_buffer_load_to_region_canonicalization_handles_ramp(): + @T.prim_func(s_tir=True) + def func(A: T.Buffer((16,), "float32")): + with T.sblock("read"): + T.reads(A[T.Ramp(4, 1, 4)]) + T.evaluate(0) + + read_region = func.body.block.reads[0] + tvm.ir.assert_structural_equal(read_region.region[0].min, T.int32(4)) + tvm.ir.assert_structural_equal(read_region.region[0].extent, T.int32(4)) + + def test_buffer_vload(): m = tvm.tirx.SizeVar("m", "int32") n = tvm.tirx.SizeVar("n", "int32") From 36a8b88eb8caee43a5f7ff1a795f7bef7a6cfcdd Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Mon, 25 May 2026 00:44:13 -0400 Subject: [PATCH 2/2] finish1 --- python/tvm/tirx/script/builder/tirx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/tirx/script/builder/tirx.py b/python/tvm/tirx/script/builder/tirx.py index 970181435bcf..df094379fad3 100644 --- a/python/tvm/tirx/script/builder/tirx.py +++ b/python/tvm/tirx/script/builder/tirx.py @@ -46,7 +46,7 @@ def _has_tile_call_options(workspace, dispatch, kwargs): ) -def _to_region(buffer: BufferRegion | Buffer | BufferLoad): +def _to_region(buffer: BufferRegion | Buffer | BufferLoad) -> BufferRegion: if isinstance(buffer, BufferLoad): return BufferRegion.from_point(buffer.buffer, buffer.indices) if isinstance(buffer, Buffer): @@ -746,8 +746,8 @@ def reciprocal( def silu( - dst: BufferRegion | Buffer, - src: BufferRegion | Buffer | None = None, + dst: BufferRegion | Buffer | BufferLoad, + src: BufferRegion | Buffer | BufferLoad | None = None, workspace: dict[str, Buffer] | None = None, dispatch: str | None = None, **kwargs,