Skip to content

Commit 0002c69

Browse files
committed
Refactor VisitExpr_
1 parent 25bbab1 commit 0002c69

File tree

1 file changed

+139
-97
lines changed

1 file changed

+139
-97
lines changed

src/relax/transform/canonicalize_shape_expr.cc

Lines changed: 139 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -62,127 +62,129 @@ bool IsCanonicalPrimExpr(const PrimExpr& expr) {
6262
* \brief Mutator to canonicalize ShapeExpr in struct info
6363
*
6464
* This pass handles ShapeExpr canonicalization by:
65-
* 1. Detecting compound PrimExpr in ShapeExpr dimensions
66-
* 2. Lifting them into separate ShapeExpr bindings
65+
* 1. Detecting compound PrimExpr in variable struct_info
66+
* 2. Emitting ShapeExpr bindings to compute expressions
6767
* 3. Using MatchCast to extract values into fresh symbolic tir::Var
68-
* 4. Replacing compound expressions with these canonical vars
68+
* 4. Replacing compound expressions with these canonical vars in struct_info
6969
*/
7070
class ShapeExprCanonicalizer : public ExprMutator {
7171
public:
7272
using ExprMutator::VisitExpr_;
7373

7474
Expr VisitExpr_(const FunctionNode* func) override {
7575
// Reset state for each function
76-
auto cached_compound_to_var = compound_expr_to_var_;
77-
auto cached_counter = symbolic_var_counter_;
78-
79-
auto result = ExprMutator::VisitExpr_(func);
80-
81-
compound_expr_to_var_ = cached_compound_to_var;
82-
symbolic_var_counter_ = cached_counter;
83-
84-
return result;
85-
}
86-
87-
/*!
88-
* \brief Override VisitVarDef to canonicalize struct_info
89-
*
90-
* This is where we intercept variable definitions and canonicalize any
91-
* compound PrimExpr in their TensorStructInfo shapes.
92-
*/
93-
Var VisitVarDef(const Var& var) override {
94-
auto sinfo = GetStructInfo(var);
95-
96-
// Check if we need to canonicalize the struct_info
97-
auto canonical_sinfo = CanonicalizeStructInfo(sinfo);
98-
99-
if (canonical_sinfo.same_as(sinfo)) {
100-
// No changes needed
101-
return ExprMutator::VisitVarDef(var);
76+
symbolic_var_counter_ = 0;
77+
compound_expr_to_var_.clear();
78+
emitted_bindings_.clear();
79+
80+
// Visit params to populate var_remap_
81+
ffi::Array<Var> params;
82+
bool all_params_unchanged = true;
83+
for (Var param : func->params) {
84+
Var new_param = this->VisitVarDef(param);
85+
params.push_back(new_param);
86+
if (!param.same_as(new_param)) {
87+
var_remap_[param->vid] = new_param;
88+
all_params_unchanged = false;
89+
}
10290
}
10391

104-
// Create a new var with canonicalized strcut_info
105-
if (var->IsInstance<DataflowVarNode>()) {
106-
return DataflowVar(var->vid, canonical_sinfo, var->span);
107-
}
108-
return Var(var->vid, canonical_sinfo, var->span);
109-
}
92+
// Process the function body with proper scope setup
93+
Expr new_body = this->VisitWithNewScope(func->body, params);
11094

111-
private:
112-
/*!
113-
* \brief Canonicalize struct info by lifting compound shape expressions
114-
*/
115-
StructInfo CanonicalizeStructInfo(const StructInfo& sinfo) {
116-
if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
117-
return CanonicalizeTensorStructInfo(ffi::GetRef<TensorStructInfo>(tensor_sinfo));
118-
} else if (auto tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
119-
return CanonicalizeTupleStructInfo(ffi::GetRef<TupleStructInfo>(tuple_sinfo));
95+
if (all_params_unchanged && new_body.same_as(func->body)) {
96+
return ffi::GetRef<Function>(func);
12097
}
121-
return sinfo;
98+
99+
return Function(params, new_body, func->ret_struct_info, func->is_pure, func->attrs,
100+
func->span);
122101
}
123102

124-
/*!
125-
* \brief Canonicalize TensorStructInfo by handling compound shape expressions
126-
*/
127-
TensorStructInfo CanonicalizeTensorStructInfo(const TensorStructInfo& sinfo) {
128-
if (!sinfo->shape.defined()) {
129-
return sinfo;
130-
}
103+
Expr VisitExpr_(const ShapeExprNode* op) override {
104+
// Just cannonicalize ShapeExpr values by replacing compound expression with symbolic vars
105+
// The bindings should have been emitted earlier by EmitBindingsForExpr
131106

132-
auto shape_expr = sinfo->shape.as<ShapeExprNode>();
133-
if (!shape_expr) {
134-
// Shape is Var, not a ShapeExpr - no canonicalization needed
135-
return sinfo;
107+
// Mark a copy of values to avoid any reference issues
108+
std::vector<PrimExpr> original_dims;
109+
for (const PrimExpr& dim : op->values) {
110+
original_dims.push_back(dim);
136111
}
137112

138-
// Canonicalize each dimension
139-
ffi::Array<PrimExpr> canonical_dims;
113+
ffi::Array<PrimExpr> canonical_values;
140114
bool changed = false;
141115

142-
for (const PrimExpr& dim : shape_expr->values) {
143-
PrimExpr canonical_dim = CanonicalizeDimension(dim);
144-
canonical_dims.push_back(canonical_dim);
116+
for (const PrimExpr& dim : original_dims) {
117+
PrimExpr canonical_dim = GetCanonicalDimension(dim);
118+
canonical_values.push_back(canonical_dim);
145119
changed |= !canonical_dim.same_as(dim);
146120
}
147121

148122
if (!changed) {
149-
return sinfo;
123+
return ffi::GetRef<ShapeExpr>(op);
150124
}
151125

152-
// Create new TensorStructInfo with canonicalized shape
153-
return TensorStructInfo(ShapeExpr(canonical_dims), sinfo->dtype, sinfo->vdevice, sinfo->span);
126+
return ShapeExpr(canonical_values, op->span);
154127
}
155128

156129
/*!
157-
* \brief Canonicalize TupleStructInfo recursively
130+
* \brief Scan an expression for ShapeExprs and emit bindings for compound expressions.
131+
* This must be called BEFORE visiting the expression to ensure bindings are emitted first.
158132
*/
159-
TupleStructInfo CanonicalizeTupleStructInfo(const TupleStructInfo& sinfo) {
160-
ffi::Array<StructInfo> canonical_fields;
161-
bool changed = false;
133+
void EmitBindingsForExpr(const Expr& expr) {
134+
// Use a simple visitor to find ShapeExpr nodes
135+
class ShapeExprScanner : public ExprVisitor {
136+
public:
137+
explicit ShapeExprScanner(ShapeExprCanonicalizer* canonicalizer)
138+
: canonicalizer_(canonicalizer) {}
139+
140+
void VisitExpr_(const ShapeExprNode* op) override {
141+
// Make a copy of values to avoid reference issues during emission
142+
std::vector<PrimExpr> dims;
143+
for (const PrimExpr& dim : op->values) {
144+
dims.push_back(dim);
145+
}
146+
for (const PrimExpr& dim : dims) {
147+
if (!IsCanonicalPrimExpr(dim)) {
148+
canonicalizer_->CanonicalizeDimension(dim);
149+
}
150+
}
151+
}
152+
153+
private:
154+
ShapeExprCanonicalizer* canonicalizer_;
155+
};
156+
157+
ShapeExprScanner scanner(this);
158+
scanner.VisitExpr(expr);
159+
}
162160

163-
for (const StructInfo& field : sinfo->fields) {
164-
StructInfo canonical_field = CanonicalizeStructInfo(field);
165-
canonical_fields.push_back(canonical_field);
166-
changed |= !canonical_field.same_as(field);
167-
}
161+
void VisitBinding_(const VarBindingNode* binding) override {
162+
// Emit canonicalization bindings before processing the binding.
163+
// Scan the binding's value for ShapeExprs with compound expressions.
164+
EmitBindingsForExpr(binding->value);
168165

169-
if (!changed) {
170-
return sinfo;
171-
}
166+
// Let the base class handle the rest
167+
ExprMutator::VisitBinding_(binding);
168+
}
172169

173-
return TupleStructInfo(canonical_fields, sinfo->span);
170+
void VisitBinding_(const MatchCastNode* binding) override {
171+
// Scan the binding's value for ShapeExprs with compound expressions
172+
EmitBindingsForExpr(binding->value);
173+
174+
// Delegate to base handling
175+
ExprMutator::VisitBinding_(binding);
176+
}
177+
178+
Var VisitVarDef(const Var& var) override {
179+
// Don't canonicalize struct_info - just delegate to base
180+
return ExprMutator::VisitVarDef(var);
174181
}
175182

183+
private:
176184
/*!
177-
* \brief Canonicalize a single shape dimension
178-
*
179-
* If the dimension is a compound PrimExpr:
180-
* 1. Emit a ShapeExpr binding containing the compound expression
181-
* 2. Create a fresh symbolic tir::Var
182-
* 3. Emit a MatchCast to bind the computed value to the symbolic var
183-
* 4. Return the symbolic var
185+
* \brief Get the canonical form of a dimension (returns the symbolic var if already emitted)
184186
*/
185-
PrimExpr CanonicalizeDimension(const PrimExpr& dim) {
187+
PrimExpr GetCanonicalDimension(const PrimExpr& dim) {
186188
// If already canonical, return as is
187189
if (IsCanonicalPrimExpr(dim)) {
188190
return dim;
@@ -193,25 +195,62 @@ class ShapeExprCanonicalizer : public ExprMutator {
193195
return it->second;
194196
}
195197

196-
// Create a fresh symbolic variable
198+
// Create a fresh symbolic variable, but don't emit yet
197199
tir::Var symbolic_var = CreateFreshSymbolicVar(dim->dtype);
198200

199-
// Emit shape binding: shape_var = R.shape([compound_expr])
200-
ShapeExpr shape_value({dim});
201-
Var shape_var = builder_->Emit(shape_value);
202-
203-
// Emit MatchCast to extract the computed value into the symbolic variable
204-
// match_cast_var: R.Shape([symbolic_var]) = shape_var
205-
ShapeStructInfo match_sinfo(ffi::Array<PrimExpr>{symbolic_var});
206-
Var match_cast_var("_", match_sinfo);
207-
builder_->EmitNormalized(MatchCast(match_cast_var, shape_var, match_sinfo));
208-
209-
// Cache the mapping to avoid duplicate bindings
210201
compound_expr_to_var_[dim] = symbolic_var;
211202

212203
return symbolic_var;
213204
}
214205

206+
/*!
207+
* \brief Emit bindings for a single compound dimension
208+
*
209+
* If the dimension is a compound PrimExpr:
210+
* 1. Create a fresh symbolic tir::Var for the compound expression
211+
* 2. Emit a MatchCast from a PrimValue to define the symbolic var
212+
*/
213+
void CanonicalizeDimension(const PrimExpr& dim) {
214+
// If already canonical, nothing to emit
215+
if (IsCanonicalPrimExpr(dim)) {
216+
return;
217+
}
218+
219+
// Check if we've already emitted the bindings
220+
if (emitted_bindings_.count(dim)) {
221+
return;
222+
}
223+
224+
// Mark as emitted BEFORE emitting to prevent infinite recursion
225+
emitted_bindings_.insert(dim);
226+
227+
// Get or create the symbolic var for this compound expression
228+
tir::Var symbolic_var;
229+
auto it = compound_expr_to_var_.find(dim);
230+
if (it != compound_expr_to_var_.end()) {
231+
symbolic_var = it->second;
232+
} else {
233+
DataType dtype = dim->dtype;
234+
symbolic_var = CreateFreshSymbolicVar(dtype);
235+
compound_expr_to_var_[dim] = symbolic_var;
236+
}
237+
238+
// Emit a PrimValue binding with the compound expression
239+
// This will be processed by VMShapeLower to compute the value
240+
PrimValue prim_value(dim);
241+
PrimStructInfo prim_sinfo(dim->dtype);
242+
std::string prim_var_name = "_prim" + std::to_string(symbolic_var_counter_ - 1);
243+
Var prim_var(prim_var_name, prim_sinfo);
244+
builder_->EmitNormalized(VarBinding(prim_var, prim_value));
245+
246+
// Emit MatchCast to extract the computed value into the symbolic variable
247+
// The pattern uses the symbolic var which will be defined by this MatchCast
248+
PrimStructInfo match_sinfo(symbolic_var);
249+
std::string match_var_name = "_match" + std::to_string(symbolic_var_counter_ - 1);
250+
Var match_cast_var(match_var_name, match_sinfo);
251+
builder_->EmitNormalized(MatchCast(match_cast_var, prim_var, match_sinfo));
252+
}
253+
215254
/*!
216255
* \brief Create a fresh symbolic TIR variable
217256
*/
@@ -223,6 +262,9 @@ class ShapeExprCanonicalizer : public ExprMutator {
223262
// Cache to avoid creating duplicate bindings for the same compound expression
224263
std::unordered_map<PrimExpr, tir::Var, StructuralHash, StructuralEqual> compound_expr_to_var_;
225264

265+
// Track which compound expressions have had their bindings emitted
266+
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> emitted_bindings_;
267+
226268
// Counter for generating unique symbolic variable names
227269
int symbolic_var_counter_ = 0;
228270
};

0 commit comments

Comments
 (0)