1//===- AffineExprVisitor.h - MLIR AffineExpr Visitor Class ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the AffineExpr visitor class.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_AFFINEEXPRVISITOR_H
14#define MLIR_IR_AFFINEEXPRVISITOR_H
15
16#include "mlir/IR/AffineExpr.h"
17#include "mlir/Support/LogicalResult.h"
18#include "llvm/ADT/ArrayRef.h"
19
20namespace mlir {
21
22/// Base class for AffineExpr visitors/walkers.
23///
24/// AffineExpr visitors are used when you want to perform different actions
25/// for different kinds of AffineExprs without having to use lots of casts
26/// and a big switch instruction.
27///
28/// To define your own visitor, inherit from this class, specifying your
29/// new type for the 'SubClass' template parameter, and "override" visitXXX
30/// functions in your class. This class is defined in terms of statically
31/// resolved overloading, not virtual functions.
32///
33/// The visitor is templated on its return type (`RetTy`). With a WalkResult
34/// return type, the visitor supports interrupting walks.
35///
36/// For example, here is a visitor that counts the number of for AffineDimExprs
37/// in an AffineExpr.
38///
39/// /// Declare the class. Note that we derive from AffineExprVisitor
40/// /// instantiated with our new subclasses_ type.
41///
42/// struct DimExprCounter : public AffineExprVisitor<DimExprCounter> {
43/// unsigned numDimExprs;
44/// DimExprCounter() : numDimExprs(0) {}
45/// void visitDimExpr(AffineDimExpr expr) { ++numDimExprs; }
46/// };
47///
48/// And this class would be used like this:
49/// DimExprCounter dec;
50/// dec.visit(affineExpr);
51/// numDimExprs = dec.numDimExprs;
52///
53/// AffineExprVisitor provides visit methods for the following binary affine
54/// op expressions:
55/// AffineBinaryAddOpExpr, AffineBinaryMulOpExpr,
56/// AffineBinaryModOpExpr, AffineBinaryFloorDivOpExpr,
57/// AffineBinaryCeilDivOpExpr. Note that default implementations of these
58/// methods will call the general AffineBinaryOpExpr method.
59///
60/// In addition, visit methods are provided for the following affine
61// expressions: AffineConstantExpr, AffineDimExpr, and
62// AffineSymbolExpr.
63///
64/// Note that if you don't implement visitXXX for some affine expression type,
65/// the visitXXX method for Instruction superclass will be invoked.
66///
67/// Note that this class is specifically designed as a template to avoid
68/// virtual function call overhead. Defining and using a AffineExprVisitor is
69/// just as efficient as having your own switch instruction over the instruction
70/// opcode.
71template <typename SubClass, typename RetTy>
72class AffineExprVisitorBase {
73public:
74 // Function to visit an AffineExpr.
75 RetTy visit(AffineExpr expr) {
76 static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value,
77 "Must instantiate with a derived type of AffineExprVisitor");
78 auto self = static_cast<SubClass *>(this);
79 switch (expr.getKind()) {
80 case AffineExprKind::Add: {
81 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
82 return self->visitAddExpr(binOpExpr);
83 }
84 case AffineExprKind::Mul: {
85 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
86 return self->visitMulExpr(binOpExpr);
87 }
88 case AffineExprKind::Mod: {
89 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
90 return self->visitModExpr(binOpExpr);
91 }
92 case AffineExprKind::FloorDiv: {
93 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
94 return self->visitFloorDivExpr(binOpExpr);
95 }
96 case AffineExprKind::CeilDiv: {
97 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
98 return self->visitCeilDivExpr(binOpExpr);
99 }
100 case AffineExprKind::Constant:
101 return self->visitConstantExpr(cast<AffineConstantExpr>(Val&: expr));
102 case AffineExprKind::DimId:
103 return self->visitDimExpr(cast<AffineDimExpr>(Val&: expr));
104 case AffineExprKind::SymbolId:
105 return self->visitSymbolExpr(cast<AffineSymbolExpr>(Val&: expr));
106 }
107 llvm_unreachable("Unknown AffineExpr");
108 }
109
110 //===--------------------------------------------------------------------===//
111 // Visitation functions... these functions provide default fallbacks in case
112 // the user does not specify what to do for a particular instruction type.
113 // The default behavior is to generalize the instruction type to its subtype
114 // and try visiting the subtype. All of this should be inlined perfectly,
115 // because there are no virtual functions to get in the way.
116 //
117
118 // Default visit methods. Note that the default op-specific binary op visit
119 // methods call the general visitAffineBinaryOpExpr visit method.
120 RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
121 RetTy visitAddExpr(AffineBinaryOpExpr expr) {
122 return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
123 }
124 RetTy visitMulExpr(AffineBinaryOpExpr expr) {
125 return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
126 }
127 RetTy visitModExpr(AffineBinaryOpExpr expr) {
128 return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
129 }
130 RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
131 return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
132 }
133 RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
134 return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
135 }
136 RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
137 RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
138 RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
139};
140
141/// See documentation for AffineExprVisitorBase. This visitor supports
142/// interrupting walks when a `WalkResult` is used for `RetTy`.
143template <typename SubClass, typename RetTy = void>
144class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> {
145 //===--------------------------------------------------------------------===//
146 // Interface code - This is the public interface of the AffineExprVisitor
147 // that you use to visit affine expressions...
148public:
149 // Function to walk an AffineExpr (in post order).
150 RetTy walkPostOrder(AffineExpr expr) {
151 static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
152 "Must instantiate with a derived type of AffineExprVisitor");
153 auto self = static_cast<SubClass *>(this);
154 switch (expr.getKind()) {
155 case AffineExprKind::Add: {
156 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
157 if constexpr (std::is_same<RetTy, WalkResult>::value) {
158 if (walkOperandsPostOrder(expr: binOpExpr).wasInterrupted())
159 return WalkResult::interrupt();
160 } else {
161 walkOperandsPostOrder(expr: binOpExpr);
162 }
163 return self->visitAddExpr(binOpExpr);
164 }
165 case AffineExprKind::Mul: {
166 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
167 if constexpr (std::is_same<RetTy, WalkResult>::value) {
168 if (walkOperandsPostOrder(expr: binOpExpr).wasInterrupted())
169 return WalkResult::interrupt();
170 } else {
171 walkOperandsPostOrder(expr: binOpExpr);
172 }
173 return self->visitMulExpr(binOpExpr);
174 }
175 case AffineExprKind::Mod: {
176 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
177 if constexpr (std::is_same<RetTy, WalkResult>::value) {
178 if (walkOperandsPostOrder(expr: binOpExpr).wasInterrupted())
179 return WalkResult::interrupt();
180 } else {
181 walkOperandsPostOrder(expr: binOpExpr);
182 }
183 return self->visitModExpr(binOpExpr);
184 }
185 case AffineExprKind::FloorDiv: {
186 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
187 if constexpr (std::is_same<RetTy, WalkResult>::value) {
188 if (walkOperandsPostOrder(expr: binOpExpr).wasInterrupted())
189 return WalkResult::interrupt();
190 } else {
191 walkOperandsPostOrder(expr: binOpExpr);
192 }
193 return self->visitFloorDivExpr(binOpExpr);
194 }
195 case AffineExprKind::CeilDiv: {
196 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
197 if constexpr (std::is_same<RetTy, WalkResult>::value) {
198 if (walkOperandsPostOrder(expr: binOpExpr).wasInterrupted())
199 return WalkResult::interrupt();
200 } else {
201 walkOperandsPostOrder(expr: binOpExpr);
202 }
203 return self->visitCeilDivExpr(binOpExpr);
204 }
205 case AffineExprKind::Constant:
206 return self->visitConstantExpr(cast<AffineConstantExpr>(Val&: expr));
207 case AffineExprKind::DimId:
208 return self->visitDimExpr(cast<AffineDimExpr>(Val&: expr));
209 case AffineExprKind::SymbolId:
210 return self->visitSymbolExpr(cast<AffineSymbolExpr>(Val&: expr));
211 }
212 llvm_unreachable("Unknown AffineExpr");
213 }
214
215private:
216 // Walk the operands - each operand is itself walked in post order.
217 RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
218 if constexpr (std::is_same<RetTy, WalkResult>::value) {
219 if (walkPostOrder(expr: expr.getLHS()).wasInterrupted())
220 return WalkResult::interrupt();
221 } else {
222 walkPostOrder(expr: expr.getLHS());
223 }
224 if constexpr (std::is_same<RetTy, WalkResult>::value) {
225 if (walkPostOrder(expr: expr.getRHS()).wasInterrupted())
226 return WalkResult::interrupt();
227 return WalkResult::advance();
228 } else {
229 return walkPostOrder(expr: expr.getRHS());
230 }
231 }
232};
233
234template <typename SubClass>
235class AffineExprVisitor<SubClass, LogicalResult>
236 : public AffineExprVisitorBase<SubClass, LogicalResult> {
237 //===--------------------------------------------------------------------===//
238 // Interface code - This is the public interface of the AffineExprVisitor
239 // that you use to visit affine expressions...
240public:
241 // Function to walk an AffineExpr (in post order).
242 LogicalResult walkPostOrder(AffineExpr expr) {
243 static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
244 "Must instantiate with a derived type of AffineExprVisitor");
245 auto self = static_cast<SubClass *>(this);
246 switch (expr.getKind()) {
247 case AffineExprKind::Add: {
248 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
249 if (failed(walkOperandsPostOrder(expr: binOpExpr)))
250 return failure();
251 return self->visitAddExpr(binOpExpr);
252 }
253 case AffineExprKind::Mul: {
254 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
255 if (failed(walkOperandsPostOrder(expr: binOpExpr)))
256 return failure();
257 return self->visitMulExpr(binOpExpr);
258 }
259 case AffineExprKind::Mod: {
260 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
261 if (failed(walkOperandsPostOrder(expr: binOpExpr)))
262 return failure();
263 return self->visitModExpr(binOpExpr);
264 }
265 case AffineExprKind::FloorDiv: {
266 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
267 if (failed(walkOperandsPostOrder(expr: binOpExpr)))
268 return failure();
269 return self->visitFloorDivExpr(binOpExpr);
270 }
271 case AffineExprKind::CeilDiv: {
272 auto binOpExpr = cast<AffineBinaryOpExpr>(Val&: expr);
273 if (failed(walkOperandsPostOrder(expr: binOpExpr)))
274 return failure();
275 return self->visitCeilDivExpr(binOpExpr);
276 }
277 case AffineExprKind::Constant:
278 return self->visitConstantExpr(cast<AffineConstantExpr>(Val&: expr));
279 case AffineExprKind::DimId:
280 return self->visitDimExpr(cast<AffineDimExpr>(Val&: expr));
281 case AffineExprKind::SymbolId:
282 return self->visitSymbolExpr(cast<AffineSymbolExpr>(Val&: expr));
283 }
284 llvm_unreachable("Unknown AffineExpr");
285 }
286
287private:
288 // Walk the operands - each operand is itself walked in post order.
289 LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) {
290 if (failed(walkPostOrder(expr: expr.getLHS())))
291 return failure();
292 if (failed(walkPostOrder(expr: expr.getRHS())))
293 return failure();
294 return success();
295 }
296};
297
298// This class is used to flatten a pure affine expression (AffineExpr,
299// which is in a tree form) into a sum of products (w.r.t constants) when
300// possible, and in that process simplifying the expression. For a modulo,
301// floordiv, or a ceildiv expression, an additional identifier, called a local
302// identifier, is introduced to rewrite the expression as a sum of product
303// affine expression. Each local identifier is always and by construction a
304// floordiv of a pure add/mul affine function of dimensional, symbolic, and
305// other local identifiers, in a non-mutually recursive way. Hence, every local
306// identifier can ultimately always be recovered as an affine function of
307// dimensional and symbolic identifiers (involving floordiv's); note however
308// that by AffineExpr construction, some floordiv combinations are converted to
309// mod's. The result of the flattening is a flattened expression and a set of
310// constraints involving just the local variables.
311//
312// d2 + (d0 + d1) floordiv 4 is flattened to d2 + q where 'q' is the local
313// variable introduced, with localVarCst containing 4*q <= d0 + d1 <= 4*q + 3.
314//
315// The simplification performed includes the accumulation of contributions for
316// each dimensional and symbolic identifier together, the simplification of
317// floordiv/ceildiv/mod expressions and other simplifications that in turn
318// happen as a result. A simplification that this flattening naturally performs
319// is of simplifying the numerator and denominator of floordiv/ceildiv, and
320// folding a modulo expression to a zero, if possible. Three examples are below:
321//
322// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
323// (d0 - d0 mod 4 + 4) mod 4 simplified to 0
324// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1
325//
326// The way the flattening works for the second example is as follows: d0 % 4 is
327// replaced by d0 - 4*q with q being introduced: the expression then simplifies
328// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
329// zero. Note that an affine expression may not always be expressible purely as
330// a sum of products involving just the original dimensional and symbolic
331// identifiers due to the presence of modulo/floordiv/ceildiv expressions that
332// may not be eliminated after simplification; in such cases, the final
333// expression can be reconstructed by replacing the local identifiers with their
334// corresponding explicit form stored in 'localExprs' (note that each of the
335// explicit forms itself would have been simplified).
336//
337// The expression walk method here performs a linear time post order walk that
338// performs the above simplifications through visit methods, with partial
339// results being stored in 'operandExprStack'. When a parent expr is visited,
340// the flattened expressions corresponding to its two operands would already be
341// on the stack - the parent expression looks at the two flattened expressions
342// and combines the two. It pops off the operand expressions and pushes the
343// combined result (although this is done in-place on its LHS operand expr).
344// When the walk is completed, the flattened form of the top-level expression
345// would be left on the stack.
346//
347// A flattener can be repeatedly used for multiple affine expressions that bind
348// to the same operands, for example, for all result expressions of an
349// AffineMap or AffineValueMap. In such cases, using it for multiple expressions
350// is more efficient than creating a new flattener for each expression since
351// common identical div and mod expressions appearing across different
352// expressions are mapped to the same local identifier (same column position in
353// 'localVarCst').
354class SimpleAffineExprFlattener
355 : public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> {
356public:
357 // Flattend expression layout: [dims, symbols, locals, constant]
358 // Stack that holds the LHS and RHS operands while visiting a binary op expr.
359 // In future, consider adding a prepass to determine how big the SmallVector's
360 // will be, and linearize this to std::vector<int64_t> to prevent
361 // SmallVector moves on re-allocation.
362 std::vector<SmallVector<int64_t, 8>> operandExprStack;
363
364 unsigned numDims;
365 unsigned numSymbols;
366
367 // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv's.
368 unsigned numLocals;
369
370 // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
371 // which new identifiers were introduced; if the latter do not get canceled
372 // out, these expressions can be readily used to reconstruct the AffineExpr
373 // (tree) form. Note that these expressions themselves would have been
374 // simplified (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4
375 // will be simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1)
376 // ceildiv 2 would be the local expression stored for q.
377 SmallVector<AffineExpr, 4> localExprs;
378
379 SimpleAffineExprFlattener(unsigned numDims, unsigned numSymbols);
380
381 virtual ~SimpleAffineExprFlattener() = default;
382
383 // Visitor method overrides.
384 LogicalResult visitMulExpr(AffineBinaryOpExpr expr);
385 LogicalResult visitAddExpr(AffineBinaryOpExpr expr);
386 LogicalResult visitDimExpr(AffineDimExpr expr);
387 LogicalResult visitSymbolExpr(AffineSymbolExpr expr);
388 LogicalResult visitConstantExpr(AffineConstantExpr expr);
389 LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr);
390 LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr);
391
392 //
393 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
394 //
395 // A mod expression "expr mod c" is thus flattened by introducing a new local
396 // variable q (= expr floordiv c), such that expr mod c is replaced with
397 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
398 LogicalResult visitModExpr(AffineBinaryOpExpr expr);
399
400protected:
401 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
402 // The local identifier added is always a floordiv of a pure add/mul affine
403 // function of other identifiers, coefficients of which are specified in
404 // dividend and with respect to a positive constant divisor. localExpr is the
405 // simplified tree expression (AffineExpr) corresponding to the quantifier.
406 virtual void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
407 AffineExpr localExpr);
408
409 /// Add a local identifier (needed to flatten a mod, floordiv, ceildiv, mul
410 /// expr) when the rhs is a symbolic expression. The local identifier added
411 /// may be a floordiv, ceildiv, mul or mod of a pure affine/semi-affine
412 /// function of other identifiers, coefficients of which are specified in the
413 /// lhs of the mod, floordiv, ceildiv or mul expression and with respect to a
414 /// symbolic rhs expression. `localExpr` is the simplified tree expression
415 /// (AffineExpr) corresponding to the quantifier.
416 virtual void addLocalIdSemiAffine(AffineExpr localExpr);
417
418private:
419 /// Adds `expr`, which may be mod, ceildiv, floordiv or mod expression
420 /// representing the affine expression corresponding to the quantifier
421 /// introduced as the local variable corresponding to `expr`. If the
422 /// quantifier is already present, we put the coefficient in the proper index
423 /// of `result`, otherwise we add a new local variable and put the coefficient
424 /// there.
425 void addLocalVariableSemiAffine(AffineExpr expr,
426 SmallVectorImpl<int64_t> &result,
427 unsigned long resultSize);
428
429 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
430 // A floordiv is thus flattened by introducing a new local variable q, and
431 // replacing that expression with 'q' while adding the constraints
432 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
433 // IntegerRelation::addLocalFloorDiv).
434 //
435 // A ceildiv is similarly flattened:
436 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
437 LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil);
438
439 int findLocalId(AffineExpr localExpr);
440
441 inline unsigned getNumCols() const {
442 return numDims + numSymbols + numLocals + 1;
443 }
444 inline unsigned getConstantIndex() const { return getNumCols() - 1; }
445 inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
446 inline unsigned getSymbolStartIndex() const { return numDims; }
447 inline unsigned getDimStartIndex() const { return 0; }
448};
449
450} // namespace mlir
451
452#endif // MLIR_IR_AFFINEEXPRVISITOR_H
453

source code of mlir/include/mlir/IR/AffineExprVisitor.h