Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 16 additions & 38 deletions python/tvm/topi/gpu/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,25 +142,15 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
),
),
T.attr(by, "thread_extent", nthread_by),
T.allocate([1], "int32", scope="local"),
T.allocate([1], "int32", scope="local"),
T.allocate([1], "int32", scope="local"),
T.decl_buffer([1], "int32", scope="local"),
T.decl_buffer([1], "int32", scope="local"),
T.decl_buffer([1], "int32", scope="local"),
]
) as (_, _, _, start_ptr, middle_ptr, end_ptr):
) as (_, _, _, start_buf, middle_buf, end_buf):
tid = bx * nthread_tx + tx
start = T.buffer_proxy(
tvm.tir.decl_buffer(
[1], "int32", "start", data=start_ptr, scope="local"
)
)
middle = T.buffer_proxy(
tvm.tir.decl_buffer(
[1], "int32", "middle", data=middle_ptr, scope="local"
)
)
end = T.buffer_proxy(
tvm.tir.decl_buffer([1], "int32", "end", data=end_ptr, scope="local")
)
start = T.buffer_proxy(start_buf)
middle = T.buffer_proxy(middle_buf)
end = T.buffer_proxy(end_buf)
start[0] = width * tid
with T.If(start[0] < scan_axis_size):
with T.Then():
Expand Down Expand Up @@ -199,29 +189,17 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
),
),
T.attr(by, "thread_extent", nthread_by),
T.allocate([1], "int32", scope="local"),
T.allocate([1], "int32", scope="local"),
T.allocate([1], "int32", scope="local"),
T.allocate([1], out_dtype, scope="local"),
T.decl_buffer([1], "int32", scope="local"),
T.decl_buffer([1], "int32", scope="local"),
T.decl_buffer([1], "int32", scope="local"),
T.decl_buffer([1], out_dtype, scope="local"),
]
) as (_, _, _, start_ptr, middle_ptr, end_ptr, tmp_ptr):
) as (_, _, _, start_buf, middle_buf, end_buf, tmp_buf):
tid = bx * nthread_tx + tx
start = T.buffer_proxy(
tvm.tir.decl_buffer(
[1], "int32", "start", data=start_ptr, scope="local"
)
)
middle = T.buffer_proxy(
tvm.tir.decl_buffer(
[1], "int32", "middle", data=middle_ptr, scope="local"
)
)
end = T.buffer_proxy(
tvm.tir.decl_buffer([1], "int32", "end", data=end_ptr, scope="local")
)
tmp = T.buffer_proxy(
tvm.tir.decl_buffer([1], out_dtype, "tmp", data=tmp_ptr, scope="local")
)
start = T.buffer_proxy(start_buf)
middle = T.buffer_proxy(middle_buf)
end = T.buffer_proxy(end_buf)
tmp = T.buffer_proxy(tmp_buf)
start[0] = width * tid
with T.If(tvm.tir.all(start[0] < scan_axis_size)):
with T.Then():
Expand Down
147 changes: 44 additions & 103 deletions python/tvm/topi/gpu/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,71 +110,36 @@ def _odd_even_sort(
tid = 2 * tx
start = bx * block_size

# Build list of allocations
alloc_frames = [
T.allocate([block_size], keys_swap.dtype, scope="shared"), # tmp_keys_swap
T.allocate([1], keys_swap.dtype, scope="local"), # temp_keys
T.allocate([1], keys_swap.dtype, scope="local"), # temp_cond1
T.allocate([1], keys_swap.dtype, scope="local"), # temp_cond2
# Build list of buffer declarations (DeclBuffer generates both Allocate + DeclBuffer nodes)
decl_frames = [
T.decl_buffer([block_size], keys_swap.dtype, scope="shared"), # tmp_keys_swap
T.decl_buffer([1], keys_swap.dtype, scope="local"), # temp_keys
T.decl_buffer([1], keys_swap.dtype, scope="local"), # temp_cond1
T.decl_buffer([1], keys_swap.dtype, scope="local"), # temp_cond2
]
if values_swap is not None:
alloc_frames.append(
T.allocate([block_size], values_swap.dtype, scope="shared")
decl_frames.append(
T.decl_buffer([block_size], values_swap.dtype, scope="shared")
) # tmp_values_swap
alloc_frames.append(T.allocate([1], values_swap.dtype, scope="local")) # temp_values
decl_frames.append(T.decl_buffer([1], values_swap.dtype, scope="local")) # temp_values

with T.frame_scope(alloc_frames) as allocs:
with T.frame_scope(decl_frames) as bufs:
if values_swap is not None:
(
tmp_keys_swap_ptr,
temp_keys_ptr,
temp_cond1_ptr,
temp_cond2_ptr,
tmp_values_swap_ptr,
temp_values_ptr,
) = allocs
tmp_keys_swap,
temp_keys,
temp_cond1,
temp_cond2,
tmp_values_swap,
temp_values,
) = bufs
else:
(
tmp_keys_swap_ptr,
temp_keys_ptr,
temp_cond1_ptr,
temp_cond2_ptr,
) = allocs
tmp_values_swap_ptr = None
temp_values_ptr = None

# Create buffer views
tmp_keys_swap = tvm.tir.decl_buffer(
[block_size],
keys_swap.dtype,
"tmp_keys_swap",
data=tmp_keys_swap_ptr,
scope="shared",
)
temp_keys = tvm.tir.decl_buffer(
[1], keys_swap.dtype, "temp_keys", data=temp_keys_ptr, scope="local"
)
temp_cond1 = tvm.tir.decl_buffer(
[1], keys_swap.dtype, "temp_cond1", data=temp_cond1_ptr, scope="local"
)
temp_cond2 = tvm.tir.decl_buffer(
[1], keys_swap.dtype, "temp_cond2", data=temp_cond2_ptr, scope="local"
)
if values_swap is not None:
tmp_values_swap = tvm.tir.decl_buffer(
[block_size],
values_swap.dtype,
"tmp_values_swap",
data=tmp_values_swap_ptr,
scope="shared",
)
temp_values = tvm.tir.decl_buffer(
[1],
values_swap.dtype,
"temp_values",
data=temp_values_ptr,
scope="local",
)
tmp_keys_swap,
temp_keys,
temp_cond1,
temp_cond2,
) = bufs

# Copy data to scratch space
base_idx = by_val * size * axis_mul_after + bz
Expand Down Expand Up @@ -386,24 +351,16 @@ def mergepath(
):
with T.frame_scope(
[
T.allocate([1], target_dtype, scope="local"), # first
T.allocate([1], target_dtype, scope="local"), # last
T.allocate([1], target_dtype, scope="local"), # i_buf
T.allocate([1], target_dtype, scope="local"), # j_buf
T.decl_buffer([1], target_dtype, scope="local"), # first
T.decl_buffer([1], target_dtype, scope="local"), # last
T.decl_buffer([1], target_dtype, scope="local"), # i_buf
T.decl_buffer([1], target_dtype, scope="local"), # j_buf
]
) as (first_ptr, last_ptr, i_ptr, j_ptr):
first = T.buffer_proxy(
tvm.tir.decl_buffer([1], target_dtype, "first", data=first_ptr, scope="local")
)
last = T.buffer_proxy(
tvm.tir.decl_buffer([1], target_dtype, "last", data=last_ptr, scope="local")
)
i_buf = T.buffer_proxy(
tvm.tir.decl_buffer([1], target_dtype, "i", data=i_ptr, scope="local")
)
j_buf = T.buffer_proxy(
tvm.tir.decl_buffer([1], target_dtype, "j", data=j_ptr, scope="local")
)
) as (first_buf, last_buf, i_buf_buf, j_buf_buf):
first = T.buffer_proxy(first_buf)
last = T.buffer_proxy(last_buf)
i_buf = T.buffer_proxy(i_buf_buf)
j_buf = T.buffer_proxy(j_buf_buf)

diag = tx * step_count
with T.If(even):
Expand Down Expand Up @@ -469,36 +426,20 @@ def dual_mergepath(
):
with T.frame_scope(
[
T.allocate([1], target_dtype, scope="local"), # outer_first
T.allocate([1], target_dtype, scope="local"), # outer_last
T.allocate([1], target_dtype, scope="local"), # first
T.allocate([1], target_dtype, scope="local"), # last
T.allocate([1], target_dtype, scope="local"), # i_buf
T.allocate([1], target_dtype, scope="local"), # j_buf
T.decl_buffer([1], target_dtype, scope="local"), # outer_first
T.decl_buffer([1], target_dtype, scope="local"), # outer_last
T.decl_buffer([1], target_dtype, scope="local"), # first
T.decl_buffer([1], target_dtype, scope="local"), # last
T.decl_buffer([1], target_dtype, scope="local"), # i_buf
T.decl_buffer([1], target_dtype, scope="local"), # j_buf
]
) as (outer_first_ptr, outer_last_ptr, first_ptr, last_ptr, i_ptr, j_ptr):
outer_first = T.buffer_proxy(
tvm.tir.decl_buffer(
[1], target_dtype, "outer_first", data=outer_first_ptr, scope="local"
)
)
outer_last = T.buffer_proxy(
tvm.tir.decl_buffer(
[1], target_dtype, "outer_last", data=outer_last_ptr, scope="local"
)
)
first = T.buffer_proxy(
tvm.tir.decl_buffer([1], target_dtype, "first", data=first_ptr, scope="local")
)
last = T.buffer_proxy(
tvm.tir.decl_buffer([1], target_dtype, "last", data=last_ptr, scope="local")
)
i_buf = T.buffer_proxy(
tvm.tir.decl_buffer([1], target_dtype, "i", data=i_ptr, scope="local")
)
j_buf = T.buffer_proxy(
tvm.tir.decl_buffer([1], target_dtype, "j", data=j_ptr, scope="local")
)
) as (outer_first_buf, outer_last_buf, first_buf, last_buf, i_buf_buf, j_buf_buf):
outer_first = T.buffer_proxy(outer_first_buf)
outer_last = T.buffer_proxy(outer_last_buf)
first = T.buffer_proxy(first_buf)
last = T.buffer_proxy(last_buf)
i_buf = T.buffer_proxy(i_buf_buf)
j_buf = T.buffer_proxy(j_buf_buf)

diag = bx * step_count
with T.If(even):
Expand Down
11 changes: 5 additions & 6 deletions python/tvm/topi/searchsorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# pylint: disable=invalid-name
"""searchsorted operator"""

import tvm
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import tir as T

Expand All @@ -41,12 +40,12 @@ def binary_search(sequence_offset, search_range, sorted_sequence, value, right,
"""
with T.frame_scope(
[
T.allocate([1], out_dtype, scope="local"),
T.allocate([1], out_dtype, scope="local"),
T.decl_buffer([1], out_dtype, scope="local"),
T.decl_buffer([1], out_dtype, scope="local"),
]
) as (lo_ptr, hi_ptr):
lo = T.buffer_proxy(tvm.tir.decl_buffer([1], out_dtype, "lo", data=lo_ptr, scope="local"))
hi = T.buffer_proxy(tvm.tir.decl_buffer([1], out_dtype, "hi", data=hi_ptr, scope="local"))
) as (lo_buf, hi_buf):
lo = T.buffer_proxy(lo_buf)
hi = T.buffer_proxy(hi_buf)

lo[0] = cast(0, out_dtype)
hi[0] = cast(search_range, out_dtype)
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/topi/vision/nms_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def binary_search(y, num_boxes, scores, score_threshold, out):
out = T.buffer_proxy(out)
with T.frame_scope(
[
T.allocate([1], "int32", scope="local"),
T.allocate([1], "int32", scope="local"),
T.decl_buffer([1], "int32", scope="local"),
T.decl_buffer([1], "int32", scope="local"),
]
) as (lo_ptr, hi_ptr):
lo = T.buffer_proxy(tvm.tir.decl_buffer([1], "int32", "lo", data=lo_ptr, scope="local"))
hi = T.buffer_proxy(tvm.tir.decl_buffer([1], "int32", "hi", data=hi_ptr, scope="local"))
) as (lo_buf, hi_buf):
lo = T.buffer_proxy(lo_buf)
hi = T.buffer_proxy(hi_buf)
lo[0] = T.int32(0)
hi[0] = tvm.tir.Cast("int32", num_boxes)
with T.While(lo[0] < hi[0]):
Expand Down
21 changes: 20 additions & 1 deletion src/s_tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,14 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) {
if (buffer_region->buffer.same_as(old_buffer_)) {
TVM_FFI_ICHECK(infered_access_regions.size() == 1);
return infered_access_regions[0];
BufferRegion result = infered_access_regions[0];
// The inferred region may reference old_buffer_ (e.g. when resolved
// through match_buffer source). Ensure we use new_buffer_ instead.
if (result->buffer.same_as(old_buffer_)) {
auto* n = result.CopyOnWrite();
n->buffer = new_buffer_;
}
return result;
}
return buffer_region;
};
Expand All @@ -887,6 +894,18 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
auto* n = block.CopyOnWrite();
RewriteAccessRegion(&n->reads, infered_access_regions[0]);
RewriteAccessRegion(&n->writes, infered_access_regions[1]);
// Update match_buffers whose source references old_buffer_
n->match_buffers.MutateByApply([this](const MatchBufferRegion& match_buf) {
if (match_buf->source->buffer.same_as(old_buffer_)) {
auto new_source = match_buf->source;
auto* source_n = new_source.CopyOnWrite();
source_n->buffer = new_buffer_;
auto new_match = match_buf;
new_match.CopyOnWrite()->source = new_source;
return new_match;
Comment thread
tqchen marked this conversation as resolved.
}
return match_buf;
});
n->alloc_buffers.MutateByApply([this](const Buffer& buffer) {
if (buffer.same_as(old_buffer_)) {
return new_buffer_;
Expand Down
13 changes: 13 additions & 0 deletions src/s_tir/transform/renew_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ class RenewDefMutator : public StmtExprMutator {
STMT_REGENERATE_VAR_DEF(AllocateNode, buffer_var);
STMT_REGENERATE_VAR_DEF(ForNode, loop_var);

Stmt VisitStmt_(const DeclBufferNode* op) final {
Buffer new_buffer = VisitBuffer(op->buffer, /*define=*/true);
Stmt body = this->VisitStmt(op->body);
if (new_buffer.same_as(op->buffer) && body.same_as(op->body)) {
return ffi::GetRef<Stmt>(op);
} else {
auto n = ffi::make_object<DeclBufferNode>(*op);
n->buffer = std::move(new_buffer);
n->body = std::move(body);
return Stmt(n);
}
}

Stmt VisitStmt_(const SBlockNode* op) final {
// Step 0. Re-define Itervars
ffi::Array<IterVar> iter_vars =
Expand Down
Loading
Loading