@@ -62,127 +62,129 @@ bool IsCanonicalPrimExpr(const PrimExpr& expr) {
6262 * \brief Mutator to canonicalize ShapeExpr in struct info
6363 *
6464 * This pass handles ShapeExpr canonicalization by:
65- * 1. Detecting compound PrimExpr in ShapeExpr dimensions
66- * 2. Lifting them into separate ShapeExpr bindings
65+ * 1. Detecting compound PrimExpr in variable struct_info
66+ * 2. Emitting ShapeExpr bindings to compute expressions
6767 * 3. Using MatchCast to extract values into fresh symbolic tir::Var
68- * 4. Replacing compound expressions with these canonical vars
68+ * 4. Replacing compound expressions with these canonical vars in struct_info
6969 */
7070class ShapeExprCanonicalizer : public ExprMutator {
7171 public:
7272 using ExprMutator::VisitExpr_;
7373
7474 Expr VisitExpr_ (const FunctionNode* func) override {
7575 // Reset state for each function
76- auto cached_compound_to_var = compound_expr_to_var_;
77- auto cached_counter = symbolic_var_counter_;
78-
79- auto result = ExprMutator::VisitExpr_ (func);
80-
81- compound_expr_to_var_ = cached_compound_to_var;
82- symbolic_var_counter_ = cached_counter;
83-
84- return result;
85- }
86-
87- /* !
88- * \brief Override VisitVarDef to canonicalize struct_info
89- *
90- * This is where we intercept variable definitions and canonicalize any
91- * compound PrimExpr in their TensorStructInfo shapes.
92- */
93- Var VisitVarDef (const Var& var) override {
94- auto sinfo = GetStructInfo (var);
95-
96- // Check if we need to canonicalize the struct_info
97- auto canonical_sinfo = CanonicalizeStructInfo (sinfo);
98-
99- if (canonical_sinfo.same_as (sinfo)) {
100- // No changes needed
101- return ExprMutator::VisitVarDef (var);
76+ symbolic_var_counter_ = 0 ;
77+ compound_expr_to_var_.clear ();
78+ emitted_bindings_.clear ();
79+
80+ // Visit params to populate var_remap_
81+ ffi::Array<Var> params;
82+ bool all_params_unchanged = true ;
83+ for (Var param : func->params ) {
84+ Var new_param = this ->VisitVarDef (param);
85+ params.push_back (new_param);
86+ if (!param.same_as (new_param)) {
87+ var_remap_[param->vid ] = new_param;
88+ all_params_unchanged = false ;
89+ }
10290 }
10391
104- // Create a new var with canonicalized strcut_info
105- if (var->IsInstance <DataflowVarNode>()) {
106- return DataflowVar (var->vid , canonical_sinfo, var->span );
107- }
108- return Var (var->vid , canonical_sinfo, var->span );
109- }
92+ // Process the function body with proper scope setup
93+ Expr new_body = this ->VisitWithNewScope (func->body , params);
11094
111- private:
112- /* !
113- * \brief Canonicalize struct info by lifting compound shape expressions
114- */
115- StructInfo CanonicalizeStructInfo (const StructInfo& sinfo) {
116- if (auto tensor_sinfo = sinfo.as <TensorStructInfoNode>()) {
117- return CanonicalizeTensorStructInfo (ffi::GetRef<TensorStructInfo>(tensor_sinfo));
118- } else if (auto tuple_sinfo = sinfo.as <TupleStructInfoNode>()) {
119- return CanonicalizeTupleStructInfo (ffi::GetRef<TupleStructInfo>(tuple_sinfo));
95+ if (all_params_unchanged && new_body.same_as (func->body )) {
96+ return ffi::GetRef<Function>(func);
12097 }
121- return sinfo;
98+
99+ return Function (params, new_body, func->ret_struct_info , func->is_pure , func->attrs ,
100+ func->span );
122101 }
123102
124- /* !
125- * \brief Canonicalize TensorStructInfo by handling compound shape expressions
126- */
127- TensorStructInfo CanonicalizeTensorStructInfo (const TensorStructInfo& sinfo) {
128- if (!sinfo->shape .defined ()) {
129- return sinfo;
130- }
103+ Expr VisitExpr_ (const ShapeExprNode* op) override {
104+ // Just cannonicalize ShapeExpr values by replacing compound expression with symbolic vars
105+ // The bindings should have been emitted earlier by EmitBindingsForExpr
131106
132- auto shape_expr = sinfo-> shape . as <ShapeExprNode>();
133- if (!shape_expr) {
134- // Shape is Var, not a ShapeExpr - no canonicalization needed
135- return sinfo ;
107+ // Mark a copy of values to avoid any reference issues
108+ std::vector<PrimExpr> original_dims;
109+ for ( const PrimExpr& dim : op-> values ) {
110+ original_dims. push_back (dim) ;
136111 }
137112
138- // Canonicalize each dimension
139- ffi::Array<PrimExpr> canonical_dims;
113+ ffi::Array<PrimExpr> canonical_values;
140114 bool changed = false ;
141115
142- for (const PrimExpr& dim : shape_expr-> values ) {
143- PrimExpr canonical_dim = CanonicalizeDimension (dim);
144- canonical_dims .push_back (canonical_dim);
116+ for (const PrimExpr& dim : original_dims ) {
117+ PrimExpr canonical_dim = GetCanonicalDimension (dim);
118+ canonical_values .push_back (canonical_dim);
145119 changed |= !canonical_dim.same_as (dim);
146120 }
147121
148122 if (!changed) {
149- return sinfo ;
123+ return ffi::GetRef<ShapeExpr>(op) ;
150124 }
151125
152- // Create new TensorStructInfo with canonicalized shape
153- return TensorStructInfo (ShapeExpr (canonical_dims), sinfo->dtype , sinfo->vdevice , sinfo->span );
126+ return ShapeExpr (canonical_values, op->span );
154127 }
155128
156129 /* !
157- * \brief Canonicalize TupleStructInfo recursively
130+ * \brief Scan an expression for ShapeExprs and emit bindings for compound expressions.
131+ * This must be called BEFORE visiting the expression to ensure bindings are emitted first.
158132 */
159- TupleStructInfo CanonicalizeTupleStructInfo (const TupleStructInfo& sinfo) {
160- ffi::Array<StructInfo> canonical_fields;
161- bool changed = false ;
133+ void EmitBindingsForExpr (const Expr& expr) {
134+ // Use a simple visitor to find ShapeExpr nodes
135+ class ShapeExprScanner : public ExprVisitor {
136+ public:
137+ explicit ShapeExprScanner (ShapeExprCanonicalizer* canonicalizer)
138+ : canonicalizer_(canonicalizer) {}
139+
140+ void VisitExpr_ (const ShapeExprNode* op) override {
141+ // Make a copy of values to avoid reference issues during emission
142+ std::vector<PrimExpr> dims;
143+ for (const PrimExpr& dim : op->values ) {
144+ dims.push_back (dim);
145+ }
146+ for (const PrimExpr& dim : dims) {
147+ if (!IsCanonicalPrimExpr (dim)) {
148+ canonicalizer_->CanonicalizeDimension (dim);
149+ }
150+ }
151+ }
152+
153+ private:
154+ ShapeExprCanonicalizer* canonicalizer_;
155+ };
156+
157+ ShapeExprScanner scanner (this );
158+ scanner.VisitExpr (expr);
159+ }
162160
163- for (const StructInfo& field : sinfo->fields ) {
164- StructInfo canonical_field = CanonicalizeStructInfo (field);
165- canonical_fields.push_back (canonical_field);
166- changed |= !canonical_field.same_as (field);
167- }
161+ void VisitBinding_ (const VarBindingNode* binding) override {
162+ // Emit canonicalization bindings before processing the binding.
163+ // Scan the binding's value for ShapeExprs with compound expressions.
164+ EmitBindingsForExpr (binding->value );
168165
169- if (!changed) {
170- return sinfo ;
171- }
166+ // Let the base class handle the rest
167+ ExprMutator::VisitBinding_ (binding) ;
168+ }
172169
173- return TupleStructInfo (canonical_fields, sinfo->span );
170+ void VisitBinding_ (const MatchCastNode* binding) override {
171+ // Scan the binding's value for ShapeExprs with compound expressions
172+ EmitBindingsForExpr (binding->value );
173+
174+ // Delegate to base handling
175+ ExprMutator::VisitBinding_ (binding);
176+ }
177+
178+ Var VisitVarDef (const Var& var) override {
179+ // Don't canonicalize struct_info - just delegate to base
180+ return ExprMutator::VisitVarDef (var);
174181 }
175182
183+ private:
176184 /* !
177- * \brief Canonicalize a single shape dimension
178- *
179- * If the dimension is a compound PrimExpr:
180- * 1. Emit a ShapeExpr binding containing the compound expression
181- * 2. Create a fresh symbolic tir::Var
182- * 3. Emit a MatchCast to bind the computed value to the symbolic var
183- * 4. Return the symbolic var
185+ * \brief Get the canonical form of a dimension (returns the symbolic var if already emitted)
184186 */
185- PrimExpr CanonicalizeDimension (const PrimExpr& dim) {
187+ PrimExpr GetCanonicalDimension (const PrimExpr& dim) {
186188 // If already canonical, return as is
187189 if (IsCanonicalPrimExpr (dim)) {
188190 return dim;
@@ -193,25 +195,62 @@ class ShapeExprCanonicalizer : public ExprMutator {
193195 return it->second ;
194196 }
195197
196- // Create a fresh symbolic variable
198+ // Create a fresh symbolic variable, but don't emit yet
197199 tir::Var symbolic_var = CreateFreshSymbolicVar (dim->dtype );
198200
199- // Emit shape binding: shape_var = R.shape([compound_expr])
200- ShapeExpr shape_value ({dim});
201- Var shape_var = builder_->Emit (shape_value);
202-
203- // Emit MatchCast to extract the computed value into the symbolic variable
204- // match_cast_var: R.Shape([symbolic_var]) = shape_var
205- ShapeStructInfo match_sinfo (ffi::Array<PrimExpr>{symbolic_var});
206- Var match_cast_var (" _" , match_sinfo);
207- builder_->EmitNormalized (MatchCast (match_cast_var, shape_var, match_sinfo));
208-
209- // Cache the mapping to avoid duplicate bindings
210201 compound_expr_to_var_[dim] = symbolic_var;
211202
212203 return symbolic_var;
213204 }
214205
206+ /* !
207+ * \brief Emit bindings for a single compound dimension
208+ *
209+ * If the dimension is a compound PrimExpr:
210+ * 1. Create a fresh symbolic tir::Var for the compound expression
211+ * 2. Emit a MatchCast from a PrimValue to define the symbolic var
212+ */
213+ void CanonicalizeDimension (const PrimExpr& dim) {
214+ // If already canonical, nothing to emit
215+ if (IsCanonicalPrimExpr (dim)) {
216+ return ;
217+ }
218+
219+ // Check if we've already emitted the bindings
220+ if (emitted_bindings_.count (dim)) {
221+ return ;
222+ }
223+
224+ // Mark as emitted BEFORE emitting to prevent infinite recursion
225+ emitted_bindings_.insert (dim);
226+
227+ // Get or create the symbolic var for this compound expression
228+ tir::Var symbolic_var;
229+ auto it = compound_expr_to_var_.find (dim);
230+ if (it != compound_expr_to_var_.end ()) {
231+ symbolic_var = it->second ;
232+ } else {
233+ DataType dtype = dim->dtype ;
234+ symbolic_var = CreateFreshSymbolicVar (dtype);
235+ compound_expr_to_var_[dim] = symbolic_var;
236+ }
237+
238+ // Emit a PrimValue binding with the compound expression
239+ // This will be processed by VMShapeLower to compute the value
240+ PrimValue prim_value (dim);
241+ PrimStructInfo prim_sinfo (dim->dtype );
242+ std::string prim_var_name = " _prim" + std::to_string (symbolic_var_counter_ - 1 );
243+ Var prim_var (prim_var_name, prim_sinfo);
244+ builder_->EmitNormalized (VarBinding (prim_var, prim_value));
245+
246+ // Emit MatchCast to extract the computed value into the symbolic variable
247+ // The pattern uses the symbolic var which will be defined by this MatchCast
248+ PrimStructInfo match_sinfo (symbolic_var);
249+ std::string match_var_name = " _match" + std::to_string (symbolic_var_counter_ - 1 );
250+ Var match_cast_var (match_var_name, match_sinfo);
251+ builder_->EmitNormalized (MatchCast (match_cast_var, prim_var, match_sinfo));
252+ }
253+
215254 /* !
216255 * \brief Create a fresh symbolic TIR variable
217256 */
@@ -223,6 +262,9 @@ class ShapeExprCanonicalizer : public ExprMutator {
223262 // Cache to avoid creating duplicate bindings for the same compound expression
224263 std::unordered_map<PrimExpr, tir::Var, StructuralHash, StructuralEqual> compound_expr_to_var_;
225264
265+ // Track which compound expressions have had their bindings emitted
266+ std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> emitted_bindings_;
267+
226268 // Counter for generating unique symbolic variable names
227269 int symbolic_var_counter_ = 0 ;
228270};
0 commit comments