1//===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "AffineExprDetail.h"
12#include "mlir/IR/AffineExpr.h"
13#include "mlir/IR/AffineExprVisitor.h"
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/IntegerSet.h"
16#include "mlir/Support/MathExtras.h"
17#include "mlir/Support/TypeID.h"
18#include "llvm/ADT/STLExtras.h"
19#include <numeric>
20#include <optional>
21
22using namespace mlir;
23using namespace mlir::detail;
24
25MLIRContext *AffineExpr::getContext() const { return expr->context; }
26
27AffineExprKind AffineExpr::getKind() const { return expr->kind; }
28
29/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
30/// method to help handle lambda walk functions. Users should use the regular
31/// (non-static) `walk` method.
32template <typename WalkRetTy>
33WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
34 function_ref<WalkRetTy(AffineExpr)> callback) {
35 struct AffineExprWalker
36 : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
37 function_ref<WalkRetTy(AffineExpr)> callback;
38
39 AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
40 : callback(callback) {}
41
42 WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
43 return callback(expr);
44 }
45 WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
46 return callback(expr);
47 }
48 WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
49 WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
50 };
51
52 return AffineExprWalker(callback).walkPostOrder(e);
53}
54// Explicitly instantiate for the two supported return types.
55template void mlir::AffineExpr::walk(AffineExpr e,
56 function_ref<void(AffineExpr)> callback);
57template WalkResult
58mlir::AffineExpr::walk(AffineExpr e,
59 function_ref<WalkResult(AffineExpr)> callback);
60
61// Dispatch affine expression construction based on kind.
62AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
63 AffineExpr rhs) {
64 if (kind == AffineExprKind::Add)
65 return lhs + rhs;
66 if (kind == AffineExprKind::Mul)
67 return lhs * rhs;
68 if (kind == AffineExprKind::FloorDiv)
69 return lhs.floorDiv(other: rhs);
70 if (kind == AffineExprKind::CeilDiv)
71 return lhs.ceilDiv(other: rhs);
72 if (kind == AffineExprKind::Mod)
73 return lhs % rhs;
74
75 llvm_unreachable("unknown binary operation on affine expressions");
76}
77
78/// This method substitutes any uses of dimensions and symbols (e.g.
79/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
80AffineExpr
81AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
82 ArrayRef<AffineExpr> symReplacements) const {
83 switch (getKind()) {
84 case AffineExprKind::Constant:
85 return *this;
86 case AffineExprKind::DimId: {
87 unsigned dimId = llvm::cast<AffineDimExpr>(Val: *this).getPosition();
88 if (dimId >= dimReplacements.size())
89 return *this;
90 return dimReplacements[dimId];
91 }
92 case AffineExprKind::SymbolId: {
93 unsigned symId = llvm::cast<AffineSymbolExpr>(Val: *this).getPosition();
94 if (symId >= symReplacements.size())
95 return *this;
96 return symReplacements[symId];
97 }
98 case AffineExprKind::Add:
99 case AffineExprKind::Mul:
100 case AffineExprKind::FloorDiv:
101 case AffineExprKind::CeilDiv:
102 case AffineExprKind::Mod:
103 auto binOp = llvm::cast<AffineBinaryOpExpr>(Val: *this);
104 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
105 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
106 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
107 if (newLHS == lhs && newRHS == rhs)
108 return *this;
109 return getAffineBinaryOpExpr(kind: getKind(), lhs: newLHS, rhs: newRHS);
110 }
111 llvm_unreachable("Unknown AffineExpr");
112}
113
114AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
115 return replaceDimsAndSymbols(dimReplacements, symReplacements: {});
116}
117
118AffineExpr
119AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
120 return replaceDimsAndSymbols(dimReplacements: {}, symReplacements);
121}
122
123/// Replace dims[offset ... numDims)
124/// by dims[offset + shift ... shift + numDims).
125AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
126 unsigned offset) const {
127 SmallVector<AffineExpr, 4> dims;
128 for (unsigned idx = 0; idx < offset; ++idx)
129 dims.push_back(Elt: getAffineDimExpr(position: idx, context: getContext()));
130 for (unsigned idx = offset; idx < numDims; ++idx)
131 dims.push_back(Elt: getAffineDimExpr(position: idx + shift, context: getContext()));
132 return replaceDimsAndSymbols(dimReplacements: dims, symReplacements: {});
133}
134
135/// Replace symbols[offset ... numSymbols)
136/// by symbols[offset + shift ... shift + numSymbols).
137AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
138 unsigned offset) const {
139 SmallVector<AffineExpr, 4> symbols;
140 for (unsigned idx = 0; idx < offset; ++idx)
141 symbols.push_back(Elt: getAffineSymbolExpr(position: idx, context: getContext()));
142 for (unsigned idx = offset; idx < numSymbols; ++idx)
143 symbols.push_back(Elt: getAffineSymbolExpr(position: idx + shift, context: getContext()));
144 return replaceDimsAndSymbols(dimReplacements: {}, symReplacements: symbols);
145}
146
147/// Sparse replace method. Return the modified expression tree.
148AffineExpr
149AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
150 auto it = map.find(Val: *this);
151 if (it != map.end())
152 return it->second;
153 switch (getKind()) {
154 default:
155 return *this;
156 case AffineExprKind::Add:
157 case AffineExprKind::Mul:
158 case AffineExprKind::FloorDiv:
159 case AffineExprKind::CeilDiv:
160 case AffineExprKind::Mod:
161 auto binOp = llvm::cast<AffineBinaryOpExpr>(Val: *this);
162 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
163 auto newLHS = lhs.replace(map);
164 auto newRHS = rhs.replace(map);
165 if (newLHS == lhs && newRHS == rhs)
166 return *this;
167 return getAffineBinaryOpExpr(kind: getKind(), lhs: newLHS, rhs: newRHS);
168 }
169 llvm_unreachable("Unknown AffineExpr");
170}
171
172/// Sparse replace method. Return the modified expression tree.
173AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
174 DenseMap<AffineExpr, AffineExpr> map;
175 map.insert(KV: std::make_pair(x&: expr, y&: replacement));
176 return replace(map);
177}
178/// Returns true if this expression is made out of only symbols and
179/// constants (no dimensional identifiers).
180bool AffineExpr::isSymbolicOrConstant() const {
181 switch (getKind()) {
182 case AffineExprKind::Constant:
183 return true;
184 case AffineExprKind::DimId:
185 return false;
186 case AffineExprKind::SymbolId:
187 return true;
188
189 case AffineExprKind::Add:
190 case AffineExprKind::Mul:
191 case AffineExprKind::FloorDiv:
192 case AffineExprKind::CeilDiv:
193 case AffineExprKind::Mod: {
194 auto expr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
195 return expr.getLHS().isSymbolicOrConstant() &&
196 expr.getRHS().isSymbolicOrConstant();
197 }
198 }
199 llvm_unreachable("Unknown AffineExpr");
200}
201
202/// Returns true if this is a pure affine expression, i.e., multiplication,
203/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
204bool AffineExpr::isPureAffine() const {
205 switch (getKind()) {
206 case AffineExprKind::SymbolId:
207 case AffineExprKind::DimId:
208 case AffineExprKind::Constant:
209 return true;
210 case AffineExprKind::Add: {
211 auto op = llvm::cast<AffineBinaryOpExpr>(Val: *this);
212 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
213 }
214
215 case AffineExprKind::Mul: {
216 // TODO: Canonicalize the constants in binary operators to the RHS when
217 // possible, allowing this to merge into the next case.
218 auto op = llvm::cast<AffineBinaryOpExpr>(Val: *this);
219 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
220 (llvm::isa<AffineConstantExpr>(Val: op.getLHS()) ||
221 llvm::isa<AffineConstantExpr>(Val: op.getRHS()));
222 }
223 case AffineExprKind::FloorDiv:
224 case AffineExprKind::CeilDiv:
225 case AffineExprKind::Mod: {
226 auto op = llvm::cast<AffineBinaryOpExpr>(Val: *this);
227 return op.getLHS().isPureAffine() &&
228 llvm::isa<AffineConstantExpr>(Val: op.getRHS());
229 }
230 }
231 llvm_unreachable("Unknown AffineExpr");
232}
233
234// Returns the greatest known integral divisor of this affine expression.
235int64_t AffineExpr::getLargestKnownDivisor() const {
236 AffineBinaryOpExpr binExpr(nullptr);
237 switch (getKind()) {
238 case AffineExprKind::DimId:
239 [[fallthrough]];
240 case AffineExprKind::SymbolId:
241 return 1;
242 case AffineExprKind::CeilDiv:
243 [[fallthrough]];
244 case AffineExprKind::FloorDiv: {
245 // If the RHS is a constant and divides the known divisor on the LHS, the
246 // quotient is a known divisor of the expression.
247 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
248 auto rhs = llvm::dyn_cast<AffineConstantExpr>(Val: binExpr.getRHS());
249 // Leave alone undefined expressions.
250 if (rhs && rhs.getValue() != 0) {
251 int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor();
252 if (lhsDiv % rhs.getValue() == 0)
253 return lhsDiv / rhs.getValue();
254 }
255 return 1;
256 }
257 case AffineExprKind::Constant:
258 return std::abs(i: llvm::cast<AffineConstantExpr>(Val: *this).getValue());
259 case AffineExprKind::Mul: {
260 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
261 return binExpr.getLHS().getLargestKnownDivisor() *
262 binExpr.getRHS().getLargestKnownDivisor();
263 }
264 case AffineExprKind::Add:
265 [[fallthrough]];
266 case AffineExprKind::Mod: {
267 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
268 return std::gcd(m: (uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
269 n: (uint64_t)binExpr.getRHS().getLargestKnownDivisor());
270 }
271 }
272 llvm_unreachable("Unknown AffineExpr");
273}
274
275bool AffineExpr::isMultipleOf(int64_t factor) const {
276 AffineBinaryOpExpr binExpr(nullptr);
277 uint64_t l, u;
278 switch (getKind()) {
279 case AffineExprKind::SymbolId:
280 [[fallthrough]];
281 case AffineExprKind::DimId:
282 return factor * factor == 1;
283 case AffineExprKind::Constant:
284 return llvm::cast<AffineConstantExpr>(Val: *this).getValue() % factor == 0;
285 case AffineExprKind::Mul: {
286 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
287 // It's probably not worth optimizing this further (to not traverse the
288 // whole sub-tree under - it that would require a version of isMultipleOf
289 // that on a 'false' return also returns the largest known divisor).
290 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
291 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
292 (l * u) % factor == 0;
293 }
294 case AffineExprKind::Add:
295 case AffineExprKind::FloorDiv:
296 case AffineExprKind::CeilDiv:
297 case AffineExprKind::Mod: {
298 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
299 return std::gcd(m: (uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
300 n: (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
301 factor ==
302 0;
303 }
304 }
305 llvm_unreachable("Unknown AffineExpr");
306}
307
308bool AffineExpr::isFunctionOfDim(unsigned position) const {
309 if (getKind() == AffineExprKind::DimId) {
310 return *this == mlir::getAffineDimExpr(position, context: getContext());
311 }
312 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(Val: *this)) {
313 return expr.getLHS().isFunctionOfDim(position) ||
314 expr.getRHS().isFunctionOfDim(position);
315 }
316 return false;
317}
318
319bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
320 if (getKind() == AffineExprKind::SymbolId) {
321 return *this == mlir::getAffineSymbolExpr(position, context: getContext());
322 }
323 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(Val: *this)) {
324 return expr.getLHS().isFunctionOfSymbol(position) ||
325 expr.getRHS().isFunctionOfSymbol(position);
326 }
327 return false;
328}
329
330AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
331 : AffineExpr(ptr) {}
332AffineExpr AffineBinaryOpExpr::getLHS() const {
333 return static_cast<ImplType *>(expr)->lhs;
334}
335AffineExpr AffineBinaryOpExpr::getRHS() const {
336 return static_cast<ImplType *>(expr)->rhs;
337}
338
339AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
340unsigned AffineDimExpr::getPosition() const {
341 return static_cast<ImplType *>(expr)->position;
342}
343
344/// Returns true if the expression is divisible by the given symbol with
345/// position `symbolPos`. The argument `opKind` specifies here what kind of
346/// division or mod operation called this division. It helps in implementing the
347/// commutative property of the floordiv and ceildiv operations. If the argument
348///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
349/// operation, then the commutative property can be used otherwise, the floordiv
350/// operation is not divisible. The same argument holds for ceildiv operation.
351static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
352 AffineExprKind opKind) {
353 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
354 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
355 opKind == AffineExprKind::CeilDiv) &&
356 "unexpected opKind");
357 switch (expr.getKind()) {
358 case AffineExprKind::Constant:
359 return cast<AffineConstantExpr>(Val&: expr).getValue() == 0;
360 case AffineExprKind::DimId:
361 return false;
362 case AffineExprKind::SymbolId:
363 return (cast<AffineSymbolExpr>(Val&: expr).getPosition() == symbolPos);
364 // Checks divisibility by the given symbol for both operands.
365 case AffineExprKind::Add: {
366 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
367 return isDivisibleBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind) &&
368 isDivisibleBySymbol(expr: binaryExpr.getRHS(), symbolPos, opKind);
369 }
370 // Checks divisibility by the given symbol for both operands. Consider the
371 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
372 // this is a division by s1 and both the operands of modulo are divisible by
373 // s1 but it is not divisible by s1 always. The third argument is
374 // `AffineExprKind::Mod` for this reason.
375 case AffineExprKind::Mod: {
376 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
377 return isDivisibleBySymbol(expr: binaryExpr.getLHS(), symbolPos,
378 opKind: AffineExprKind::Mod) &&
379 isDivisibleBySymbol(expr: binaryExpr.getRHS(), symbolPos,
380 opKind: AffineExprKind::Mod);
381 }
382 // Checks if any of the operand divisible by the given symbol.
383 case AffineExprKind::Mul: {
384 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
385 return isDivisibleBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind) ||
386 isDivisibleBySymbol(expr: binaryExpr.getRHS(), symbolPos, opKind);
387 }
388 // Floordiv and ceildiv are divisible by the given symbol when the first
389 // operand is divisible, and the affine expression kind of the argument expr
390 // is same as the argument `opKind`. This can be inferred from commutative
391 // property of floordiv and ceildiv operations and are as follow:
392 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
393 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
394 // It will fail if operations are not same. For example:
395 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
396 case AffineExprKind::FloorDiv:
397 case AffineExprKind::CeilDiv: {
398 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
399 if (opKind != expr.getKind())
400 return false;
401 return isDivisibleBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind: expr.getKind());
402 }
403 }
404 llvm_unreachable("Unknown AffineExpr");
405}
406
407/// Divides the given expression by the given symbol at position `symbolPos`. It
408/// considers the divisibility condition is checked before calling itself. A
409/// null expression is returned whenever the divisibility condition fails.
410static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
411 AffineExprKind opKind) {
412 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
413 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
414 opKind == AffineExprKind::CeilDiv) &&
415 "unexpected opKind");
416 switch (expr.getKind()) {
417 case AffineExprKind::Constant:
418 if (cast<AffineConstantExpr>(Val&: expr).getValue() != 0)
419 return nullptr;
420 return getAffineConstantExpr(constant: 0, context: expr.getContext());
421 case AffineExprKind::DimId:
422 return nullptr;
423 case AffineExprKind::SymbolId:
424 return getAffineConstantExpr(constant: 1, context: expr.getContext());
425 // Dividing both operands by the given symbol.
426 case AffineExprKind::Add: {
427 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
428 return getAffineBinaryOpExpr(
429 kind: expr.getKind(), lhs: symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind),
430 rhs: symbolicDivide(expr: binaryExpr.getRHS(), symbolPos, opKind));
431 }
432 // Dividing both operands by the given symbol.
433 case AffineExprKind::Mod: {
434 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
435 return getAffineBinaryOpExpr(
436 kind: expr.getKind(),
437 lhs: symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind: expr.getKind()),
438 rhs: symbolicDivide(expr: binaryExpr.getRHS(), symbolPos, opKind: expr.getKind()));
439 }
440 // Dividing any of the operand by the given symbol.
441 case AffineExprKind::Mul: {
442 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
443 if (!isDivisibleBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind))
444 return binaryExpr.getLHS() *
445 symbolicDivide(expr: binaryExpr.getRHS(), symbolPos, opKind);
446 return symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind) *
447 binaryExpr.getRHS();
448 }
449 // Dividing first operand only by the given symbol.
450 case AffineExprKind::FloorDiv:
451 case AffineExprKind::CeilDiv: {
452 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
453 return getAffineBinaryOpExpr(
454 kind: expr.getKind(),
455 lhs: symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind: expr.getKind()),
456 rhs: binaryExpr.getRHS());
457 }
458 }
459 llvm_unreachable("Unknown AffineExpr");
460}
461
462/// Populate `result` with all summand operands of given (potentially nested)
463/// addition. If the given expression is not an addition, just populate the
464/// expression itself.
465/// Example: Add(Add(7, 8), Mul(9, 10)) will return [7, 8, Mul(9, 10)].
466static void getSummandExprs(AffineExpr expr, SmallVector<AffineExpr> &result) {
467 auto addExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr);
468 if (!addExpr || addExpr.getKind() != AffineExprKind::Add) {
469 result.push_back(Elt: expr);
470 return;
471 }
472 getSummandExprs(expr: addExpr.getLHS(), result);
473 getSummandExprs(expr: addExpr.getRHS(), result);
474}
475
476/// Return "true" if `candidate` is a negated expression, i.e., Mul(-1, expr).
477/// If so, also return the non-negated expression via `expr`.
478static bool isNegatedAffineExpr(AffineExpr candidate, AffineExpr &expr) {
479 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(Val&: candidate);
480 if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
481 return false;
482 if (auto lhs = dyn_cast<AffineConstantExpr>(Val: mulExpr.getLHS())) {
483 if (lhs.getValue() == -1) {
484 expr = mulExpr.getRHS();
485 return true;
486 }
487 }
488 if (auto rhs = dyn_cast<AffineConstantExpr>(Val: mulExpr.getRHS())) {
489 if (rhs.getValue() == -1) {
490 expr = mulExpr.getLHS();
491 return true;
492 }
493 }
494 return false;
495}
496
497/// Return "true" if `lhs` % `rhs` is guaranteed to evaluate to zero based on
498/// the fact that `lhs` contains another modulo expression that ensures that
499/// `lhs` is divisible by `rhs`. This is a common pattern in the resulting IR
500/// after loop peeling.
501///
502/// Example: lhs = ub - ub % step
503/// rhs = step
504/// => (ub - ub % step) % step is guaranteed to evaluate to 0.
505static bool isModOfModSubtraction(AffineExpr lhs, AffineExpr rhs,
506 unsigned numDims, unsigned numSymbols) {
507 // TODO: Try to unify this function with `getBoundForAffineExpr`.
508 // Collect all summands in lhs.
509 SmallVector<AffineExpr> summands;
510 getSummandExprs(expr: lhs, result&: summands);
511 // Look for Mul(-1, Mod(x, rhs)) among the summands. If x matches the
512 // remaining summands, then lhs % rhs is guaranteed to evaluate to 0.
513 for (int64_t i = 0, e = summands.size(); i < e; ++i) {
514 AffineExpr current = summands[i];
515 AffineExpr beforeNegation;
516 if (!isNegatedAffineExpr(candidate: current, expr&: beforeNegation))
517 continue;
518 AffineBinaryOpExpr innerMod = dyn_cast<AffineBinaryOpExpr>(Val&: beforeNegation);
519 if (!innerMod || innerMod.getKind() != AffineExprKind::Mod)
520 continue;
521 if (innerMod.getRHS() != rhs)
522 continue;
523 // Sum all remaining summands and subtract x. If that expression can be
524 // simplified to zero, then the remaining summands and x are equal.
525 AffineExpr diff = getAffineConstantExpr(constant: 0, context: lhs.getContext());
526 for (int64_t j = 0; j < e; ++j)
527 if (i != j)
528 diff = diff + summands[j];
529 diff = diff - innerMod.getLHS();
530 diff = simplifyAffineExpr(expr: diff, numDims, numSymbols);
531 auto constExpr = dyn_cast<AffineConstantExpr>(Val&: diff);
532 if (constExpr && constExpr.getValue() == 0)
533 return true;
534 }
535 return false;
536}
537
538/// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
539/// operations when the second operand simplifies to a symbol and the first
540/// operand is divisible by that symbol. It can be applied to any semi-affine
541/// expression. Returned expression can either be a semi-affine or pure affine
542/// expression.
543static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
544 unsigned numSymbols) {
545 switch (expr.getKind()) {
546 case AffineExprKind::Constant:
547 case AffineExprKind::DimId:
548 case AffineExprKind::SymbolId:
549 return expr;
550 case AffineExprKind::Add:
551 case AffineExprKind::Mul: {
552 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
553 return getAffineBinaryOpExpr(
554 kind: expr.getKind(),
555 lhs: simplifySemiAffine(expr: binaryExpr.getLHS(), numDims, numSymbols),
556 rhs: simplifySemiAffine(expr: binaryExpr.getRHS(), numDims, numSymbols));
557 }
558 // Check if the simplification of the second operand is a symbol, and the
559 // first operand is divisible by it. If the operation is a modulo, a constant
560 // zero expression is returned. In the case of floordiv and ceildiv, the
561 // symbol from the simplification of the second operand divides the first
562 // operand. Otherwise, simplification is not possible.
563 case AffineExprKind::FloorDiv:
564 case AffineExprKind::CeilDiv:
565 case AffineExprKind::Mod: {
566 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
567 AffineExpr sLHS =
568 simplifySemiAffine(expr: binaryExpr.getLHS(), numDims, numSymbols);
569 AffineExpr sRHS =
570 simplifySemiAffine(expr: binaryExpr.getRHS(), numDims, numSymbols);
571 if (isModOfModSubtraction(lhs: sLHS, rhs: sRHS, numDims, numSymbols))
572 return getAffineConstantExpr(constant: 0, context: expr.getContext());
573 AffineSymbolExpr symbolExpr = dyn_cast<AffineSymbolExpr>(
574 Val: simplifySemiAffine(expr: binaryExpr.getRHS(), numDims, numSymbols));
575 if (!symbolExpr)
576 return getAffineBinaryOpExpr(kind: expr.getKind(), lhs: sLHS, rhs: sRHS);
577 unsigned symbolPos = symbolExpr.getPosition();
578 if (!isDivisibleBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind: expr.getKind()))
579 return getAffineBinaryOpExpr(kind: expr.getKind(), lhs: sLHS, rhs: sRHS);
580 if (expr.getKind() == AffineExprKind::Mod)
581 return getAffineConstantExpr(constant: 0, context: expr.getContext());
582 return symbolicDivide(expr: sLHS, symbolPos, opKind: expr.getKind());
583 }
584 }
585 llvm_unreachable("Unknown AffineExpr");
586}
587
588static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
589 MLIRContext *context) {
590 auto assignCtx = [context](AffineDimExprStorage *storage) {
591 storage->context = context;
592 };
593
594 StorageUniquer &uniquer = context->getAffineUniquer();
595 return uniquer.get<AffineDimExprStorage>(
596 initFn: assignCtx, args: static_cast<unsigned>(kind), args&: position);
597}
598
599AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
600 return getAffineDimOrSymbol(kind: AffineExprKind::DimId, position, context);
601}
602
603AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
604 : AffineExpr(ptr) {}
605unsigned AffineSymbolExpr::getPosition() const {
606 return static_cast<ImplType *>(expr)->position;
607}
608
609AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
610 return getAffineDimOrSymbol(kind: AffineExprKind::SymbolId, position, context);
611}
612
613AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
614 : AffineExpr(ptr) {}
615int64_t AffineConstantExpr::getValue() const {
616 return static_cast<ImplType *>(expr)->constant;
617}
618
619bool AffineExpr::operator==(int64_t v) const {
620 return *this == getAffineConstantExpr(constant: v, context: getContext());
621}
622
623AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
624 auto assignCtx = [context](AffineConstantExprStorage *storage) {
625 storage->context = context;
626 };
627
628 StorageUniquer &uniquer = context->getAffineUniquer();
629 return uniquer.get<AffineConstantExprStorage>(initFn: assignCtx, args&: constant);
630}
631
632SmallVector<AffineExpr>
633mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
634 MLIRContext *context) {
635 return llvm::to_vector(Range: llvm::map_range(C&: constants, F: [&](int64_t constant) {
636 return getAffineConstantExpr(constant, context);
637 }));
638}
639
640/// Simplify add expression. Return nullptr if it can't be simplified.
641static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
642 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
643 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
644 // Fold if both LHS, RHS are a constant.
645 if (lhsConst && rhsConst)
646 return getAffineConstantExpr(constant: lhsConst.getValue() + rhsConst.getValue(),
647 context: lhs.getContext());
648
649 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
650 // If only one of them is a symbolic expressions, make it the RHS.
651 if (isa<AffineConstantExpr>(Val: lhs) ||
652 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
653 return rhs + lhs;
654 }
655
656 // At this point, if there was a constant, it would be on the right.
657
658 // Addition with a zero is a noop, return the other input.
659 if (rhsConst) {
660 if (rhsConst.getValue() == 0)
661 return lhs;
662 }
663 // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
664 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
665 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
666 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS()))
667 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
668 }
669
670 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
671 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
672 // respective multiplicands.
673 std::optional<int64_t> rLhsConst, rRhsConst;
674 AffineExpr firstExpr, secondExpr;
675 AffineConstantExpr rLhsConstExpr;
676 auto lBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
677 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
678 (rLhsConstExpr = dyn_cast<AffineConstantExpr>(Val: lBinOpExpr.getRHS()))) {
679 rLhsConst = rLhsConstExpr.getValue();
680 firstExpr = lBinOpExpr.getLHS();
681 } else {
682 rLhsConst = 1;
683 firstExpr = lhs;
684 }
685
686 auto rBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: rhs);
687 AffineConstantExpr rRhsConstExpr;
688 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
689 (rRhsConstExpr = dyn_cast<AffineConstantExpr>(Val: rBinOpExpr.getRHS()))) {
690 rRhsConst = rRhsConstExpr.getValue();
691 secondExpr = rBinOpExpr.getLHS();
692 } else {
693 rRhsConst = 1;
694 secondExpr = rhs;
695 }
696
697 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
698 return getAffineBinaryOpExpr(
699 kind: AffineExprKind::Mul, lhs: firstExpr,
700 rhs: getAffineConstantExpr(constant: *rLhsConst + *rRhsConst, context: lhs.getContext()));
701
702 // When doing successive additions, bring constant to the right: turn (d0 + 2)
703 // + d1 into (d0 + d1) + 2.
704 if (lBin && lBin.getKind() == AffineExprKind::Add) {
705 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
706 return lBin.getLHS() + rhs + lrhs;
707 }
708 }
709
710 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
711 // q may be a constant or symbolic expression. This leads to a much more
712 // efficient form when 'c' is a power of two, and in general a more compact
713 // and readable form.
714
715 // Process '(expr floordiv c) * (-c)'.
716 if (!rBinOpExpr)
717 return nullptr;
718
719 auto lrhs = rBinOpExpr.getLHS();
720 auto rrhs = rBinOpExpr.getRHS();
721
722 AffineExpr llrhs, rlrhs;
723
724 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
725 // symbolic expression.
726 auto lrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: lrhs);
727 // Check rrhsConstOpExpr = -1.
728 auto rrhsConstOpExpr = dyn_cast<AffineConstantExpr>(Val&: rrhs);
729 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
730 lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
731 // Check llrhs = expr floordiv q.
732 llrhs = lrhsBinOpExpr.getLHS();
733 // Check rlrhs = q.
734 rlrhs = lrhsBinOpExpr.getRHS();
735 auto llrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: llrhs);
736 if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
737 return nullptr;
738 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
739 return lhs % rlrhs;
740 }
741
742 // Process lrhs, which is 'expr floordiv c'.
743 AffineBinaryOpExpr lrBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: lrhs);
744 if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
745 return nullptr;
746
747 llrhs = lrBinOpExpr.getLHS();
748 rlrhs = lrBinOpExpr.getRHS();
749
750 if (lhs == llrhs && rlrhs == -rrhs) {
751 return lhs % rlrhs;
752 }
753 return nullptr;
754}
755
756AffineExpr AffineExpr::operator+(int64_t v) const {
757 return *this + getAffineConstantExpr(constant: v, context: getContext());
758}
759AffineExpr AffineExpr::operator+(AffineExpr other) const {
760 if (auto simplified = simplifyAdd(lhs: *this, rhs: other))
761 return simplified;
762
763 StorageUniquer &uniquer = getContext()->getAffineUniquer();
764 return uniquer.get<AffineBinaryOpExprStorage>(
765 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::Add), args: *this, args&: other);
766}
767
768/// Simplify a multiply expression. Return nullptr if it can't be simplified.
769static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
770 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
771 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
772
773 if (lhsConst && rhsConst)
774 return getAffineConstantExpr(constant: lhsConst.getValue() * rhsConst.getValue(),
775 context: lhs.getContext());
776
777 assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
778
779 // Canonicalize the mul expression so that the constant/symbolic term is the
780 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
781 // constant. (Note that a constant is trivially symbolic).
782 if (!rhs.isSymbolicOrConstant() || isa<AffineConstantExpr>(Val: lhs)) {
783 // At least one of them has to be symbolic.
784 return rhs * lhs;
785 }
786
787 // At this point, if there was a constant, it would be on the right.
788
789 // Multiplication with a one is a noop, return the other input.
790 if (rhsConst) {
791 if (rhsConst.getValue() == 1)
792 return lhs;
793 // Multiplication with zero.
794 if (rhsConst.getValue() == 0)
795 return rhsConst;
796 }
797
798 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
799 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
800 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
801 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS()))
802 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
803 }
804
805 // When doing successive multiplication, bring constant to the right: turn (d0
806 // * 2) * d1 into (d0 * d1) * 2.
807 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
808 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
809 return (lBin.getLHS() * rhs) * lrhs;
810 }
811 }
812
813 return nullptr;
814}
815
816AffineExpr AffineExpr::operator*(int64_t v) const {
817 return *this * getAffineConstantExpr(constant: v, context: getContext());
818}
819AffineExpr AffineExpr::operator*(AffineExpr other) const {
820 if (auto simplified = simplifyMul(lhs: *this, rhs: other))
821 return simplified;
822
823 StorageUniquer &uniquer = getContext()->getAffineUniquer();
824 return uniquer.get<AffineBinaryOpExprStorage>(
825 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::Mul), args: *this, args&: other);
826}
827
828// Unary minus, delegate to operator*.
829AffineExpr AffineExpr::operator-() const {
830 return *this * getAffineConstantExpr(constant: -1, context: getContext());
831}
832
833// Delegate to operator+.
834AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
835AffineExpr AffineExpr::operator-(AffineExpr other) const {
836 return *this + (-other);
837}
838
839static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
840 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
841 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
842
843 // mlir floordiv by zero or negative numbers is undefined and preserved as is.
844 if (!rhsConst || rhsConst.getValue() < 1)
845 return nullptr;
846
847 if (lhsConst)
848 return getAffineConstantExpr(
849 constant: floorDiv(lhs: lhsConst.getValue(), rhs: rhsConst.getValue()), context: lhs.getContext());
850
851 // Fold floordiv of a multiply with a constant that is a multiple of the
852 // divisor. Eg: (i * 128) floordiv 64 = i * 2.
853 if (rhsConst == 1)
854 return lhs;
855
856 // Simplify (expr * const) floordiv divConst when expr is known to be a
857 // multiple of divConst.
858 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
859 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
860 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
861 // rhsConst is known to be a positive constant.
862 if (lrhs.getValue() % rhsConst.getValue() == 0)
863 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
864 }
865 }
866
867 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
868 // known to be a multiple of divConst.
869 if (lBin && lBin.getKind() == AffineExprKind::Add) {
870 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
871 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
872 // rhsConst is known to be a positive constant.
873 if (llhsDiv % rhsConst.getValue() == 0 ||
874 lrhsDiv % rhsConst.getValue() == 0)
875 return lBin.getLHS().floorDiv(v: rhsConst.getValue()) +
876 lBin.getRHS().floorDiv(v: rhsConst.getValue());
877 }
878
879 return nullptr;
880}
881
882AffineExpr AffineExpr::floorDiv(uint64_t v) const {
883 return floorDiv(other: getAffineConstantExpr(constant: v, context: getContext()));
884}
885AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
886 if (auto simplified = simplifyFloorDiv(lhs: *this, rhs: other))
887 return simplified;
888
889 StorageUniquer &uniquer = getContext()->getAffineUniquer();
890 return uniquer.get<AffineBinaryOpExprStorage>(
891 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::FloorDiv), args: *this,
892 args&: other);
893}
894
895static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
896 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
897 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
898
899 if (!rhsConst || rhsConst.getValue() < 1)
900 return nullptr;
901
902 if (lhsConst)
903 return getAffineConstantExpr(
904 constant: ceilDiv(lhs: lhsConst.getValue(), rhs: rhsConst.getValue()), context: lhs.getContext());
905
906 // Fold ceildiv of a multiply with a constant that is a multiple of the
907 // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
908 if (rhsConst.getValue() == 1)
909 return lhs;
910
911 // Simplify (expr * const) ceildiv divConst when const is known to be a
912 // multiple of divConst.
913 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
914 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
915 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
916 // rhsConst is known to be a positive constant.
917 if (lrhs.getValue() % rhsConst.getValue() == 0)
918 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
919 }
920 }
921
922 return nullptr;
923}
924
925AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
926 return ceilDiv(other: getAffineConstantExpr(constant: v, context: getContext()));
927}
928AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
929 if (auto simplified = simplifyCeilDiv(lhs: *this, rhs: other))
930 return simplified;
931
932 StorageUniquer &uniquer = getContext()->getAffineUniquer();
933 return uniquer.get<AffineBinaryOpExprStorage>(
934 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::CeilDiv), args: *this,
935 args&: other);
936}
937
938static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
939 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
940 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
941
942 // mod w.r.t zero or negative numbers is undefined and preserved as is.
943 if (!rhsConst || rhsConst.getValue() < 1)
944 return nullptr;
945
946 if (lhsConst)
947 return getAffineConstantExpr(constant: mod(lhs: lhsConst.getValue(), rhs: rhsConst.getValue()),
948 context: lhs.getContext());
949
950 // Fold modulo of an expression that is known to be a multiple of a constant
951 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
952 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
953 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
954 return getAffineConstantExpr(constant: 0, context: lhs.getContext());
955
956 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
957 // known to be a multiple of divConst.
958 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
959 if (lBin && lBin.getKind() == AffineExprKind::Add) {
960 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
961 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
962 // rhsConst is known to be a positive constant.
963 if (llhsDiv % rhsConst.getValue() == 0)
964 return lBin.getRHS() % rhsConst.getValue();
965 if (lrhsDiv % rhsConst.getValue() == 0)
966 return lBin.getLHS() % rhsConst.getValue();
967 }
968
969 // Simplify (e % a) % b to e % b when b evenly divides a
970 if (lBin && lBin.getKind() == AffineExprKind::Mod) {
971 auto intermediate = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS());
972 if (intermediate && intermediate.getValue() >= 1 &&
973 mod(lhs: intermediate.getValue(), rhs: rhsConst.getValue()) == 0) {
974 return lBin.getLHS() % rhsConst.getValue();
975 }
976 }
977
978 return nullptr;
979}
980
981AffineExpr AffineExpr::operator%(uint64_t v) const {
982 return *this % getAffineConstantExpr(constant: v, context: getContext());
983}
984AffineExpr AffineExpr::operator%(AffineExpr other) const {
985 if (auto simplified = simplifyMod(lhs: *this, rhs: other))
986 return simplified;
987
988 StorageUniquer &uniquer = getContext()->getAffineUniquer();
989 return uniquer.get<AffineBinaryOpExprStorage>(
990 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::Mod), args: *this, args&: other);
991}
992
993AffineExpr AffineExpr::compose(AffineMap map) const {
994 SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
995 map.getResults().end());
996 return replaceDimsAndSymbols(dimReplacements, symReplacements: {});
997}
998raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
999 expr.print(os);
1000 return os;
1001}
1002
1003/// Constructs an affine expression from a flat ArrayRef. If there are local
1004/// identifiers (neither dimensional nor symbolic) that appear in the sum of
1005/// products expression, `localExprs` is expected to have the AffineExpr
1006/// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
1007/// in the format [dims, symbols, locals, constant term].
1008AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1009 unsigned numDims,
1010 unsigned numSymbols,
1011 ArrayRef<AffineExpr> localExprs,
1012 MLIRContext *context) {
1013 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1014 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1015 "unexpected number of local expressions");
1016
1017 auto expr = getAffineConstantExpr(constant: 0, context);
1018 // Dimensions and symbols.
1019 for (unsigned j = 0; j < numDims + numSymbols; j++) {
1020 if (flatExprs[j] == 0)
1021 continue;
1022 auto id = j < numDims ? getAffineDimExpr(position: j, context)
1023 : getAffineSymbolExpr(position: j - numDims, context);
1024 expr = expr + id * flatExprs[j];
1025 }
1026
1027 // Local identifiers.
1028 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1029 j++) {
1030 if (flatExprs[j] == 0)
1031 continue;
1032 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1033 expr = expr + term;
1034 }
1035
1036 // Constant term.
1037 int64_t constTerm = flatExprs[flatExprs.size() - 1];
1038 if (constTerm != 0)
1039 expr = expr + constTerm;
1040 return expr;
1041}
1042
1043/// Constructs a semi-affine expression from a flat ArrayRef. If there are
1044/// local identifiers (neither dimensional nor symbolic) that appear in the sum
1045/// of products expression, `localExprs` is expected to have the AffineExprs for
1046/// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
1047/// the format [dims, symbols, locals, constant term]. The semi-affine
1048/// expression is constructed in the sorted order of dimension and symbol
1049/// position numbers. Note: local expressions/ids are used for mod, div as well
1050/// as symbolic RHS terms for terms that are not pure affine.
1051static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1052 unsigned numDims,
1053 unsigned numSymbols,
1054 ArrayRef<AffineExpr> localExprs,
1055 MLIRContext *context) {
1056 assert(!flatExprs.empty() && "flatExprs cannot be empty");
1057
1058 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1059 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1060 "unexpected number of local expressions");
1061
1062 AffineExpr expr = getAffineConstantExpr(constant: 0, context);
1063
1064 // We design indices as a pair which help us present the semi-affine map as
1065 // sum of product where terms are sorted based on dimension or symbol
1066 // position: <keyA, keyB> for expressions of the form dimension * symbol,
1067 // where keyA is the position number of the dimension and keyB is the
1068 // position number of the symbol. For dimensional expressions we set the index
1069 // as (position number of the dimension, -1), as we want dimensional
1070 // expressions to appear before symbolic and product of dimensional and
1071 // symbolic expressions having the dimension with the same position number.
1072 // For symbolic expression set the index as (position number of the symbol,
1073 // maximum of last dimension and symbol position) number. For example, we want
1074 // the expression we are constructing to look something like: d0 + d0 * s0 +
1075 // s0 + d1*s1 + s1.
1076
1077 // Stores the affine expression corresponding to a given index.
1078 DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap;
1079 // Stores the constant coefficient value corresponding to a given
1080 // dimension, symbol or a non-pure affine expression stored in `localExprs`.
1081 DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
1082 // Stores the indices as defined above, and later sorted to produce
1083 // the semi-affine expression in the desired form.
1084 SmallVector<std::pair<unsigned, signed>, 8> indices;
1085
1086 // Example: expression = d0 + d0 * s0 + 2 * s0.
1087 // indices = [{0,-1}, {0, 0}, {0, 1}]
1088 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
1089 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
1090
1091 // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
1092 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
1093 AffineExpr expr) {
1094 assert(!llvm::is_contained(indices, index) &&
1095 "Key is already present in indices vector and overwriting will "
1096 "happen in `indexToExprMap` and `coefficients`!");
1097
1098 indices.push_back(Elt: index);
1099 coefficients.insert(KV: {index, coefficient});
1100 indexToExprMap.insert(KV: {index, expr});
1101 };
1102
1103 // Design indices for dimensional or symbolic terms, and store the indices,
1104 // constant coefficient corresponding to the indices in `coefficients` map,
1105 // and affine expression corresponding to indices in `indexToExprMap` map.
1106
1107 // Ensure we do not have duplicate keys in `indexToExpr` map.
1108 unsigned offsetSym = 0;
1109 signed offsetDim = -1;
1110 for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1111 if (flatExprs[j] == 0)
1112 continue;
1113 // For symbolic expression set the index as <position number
1114 // of the symbol, max(dimCount, symCount)> number,
1115 // as we want symbolic expressions with the same positional number to
1116 // appear after dimensional expressions having the same positional number.
1117 std::pair<unsigned, signed> indexEntry(
1118 j - numDims, std::max(a: numDims, b: numSymbols) + offsetSym++);
1119 addEntry(indexEntry, flatExprs[j],
1120 getAffineSymbolExpr(position: j - numDims, context));
1121 }
1122
1123 // Denotes semi-affine product, modulo or division terms, which has been added
1124 // to the `indexToExpr` map.
1125 SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1126 false);
1127 unsigned lhsPos, rhsPos;
1128 // Construct indices for product terms involving dimension, symbol or constant
1129 // as lhs/rhs, and store the indices, constant coefficient corresponding to
1130 // the indices in `coefficients` map, and affine expression corresponding to
1131 // in indices in `indexToExprMap` map.
1132 for (const auto &it : llvm::enumerate(First&: localExprs)) {
1133 AffineExpr expr = it.value();
1134 if (flatExprs[numDims + numSymbols + it.index()] == 0)
1135 continue;
1136 AffineExpr lhs = cast<AffineBinaryOpExpr>(Val&: expr).getLHS();
1137 AffineExpr rhs = cast<AffineBinaryOpExpr>(Val&: expr).getRHS();
1138 if (!((isa<AffineDimExpr>(Val: lhs) || isa<AffineSymbolExpr>(Val: lhs)) &&
1139 (isa<AffineDimExpr>(Val: rhs) || isa<AffineSymbolExpr>(Val: rhs) ||
1140 isa<AffineConstantExpr>(Val: rhs)))) {
1141 continue;
1142 }
1143 if (isa<AffineConstantExpr>(Val: rhs)) {
1144 // For product/modulo/division expressions, when rhs of modulo/division
1145 // expression is constant, we put 0 in place of keyB, because we want
1146 // them to appear earlier in the semi-affine expression we are
1147 // constructing. When rhs is constant, we place 0 in place of keyB.
1148 if (isa<AffineDimExpr>(Val: lhs)) {
1149 lhsPos = cast<AffineDimExpr>(Val&: lhs).getPosition();
1150 std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1151 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1152 expr);
1153 } else {
1154 lhsPos = cast<AffineSymbolExpr>(Val&: lhs).getPosition();
1155 std::pair<unsigned, signed> indexEntry(
1156 lhsPos, std::max(a: numDims, b: numSymbols) + offsetSym++);
1157 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1158 expr);
1159 }
1160 } else if (isa<AffineDimExpr>(Val: lhs)) {
1161 // For product/modulo/division expressions having lhs as dimension and rhs
1162 // as symbol, we order the terms in the semi-affine expression based on
1163 // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1164 // where keyA is the position number of the dimension and keyB is the
1165 // position number of the symbol.
1166 lhsPos = cast<AffineDimExpr>(Val&: lhs).getPosition();
1167 rhsPos = cast<AffineSymbolExpr>(Val&: rhs).getPosition();
1168 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1169 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1170 } else {
1171 // For product/modulo/division expressions having both lhs and rhs as
1172 // symbol, we design indices as a pair: <keyA, keyB> for expressions
1173 // of the form dimension * symbol, where keyA is the position number of
1174 // the dimension and keyB is the position number of the symbol.
1175 lhsPos = cast<AffineSymbolExpr>(Val&: lhs).getPosition();
1176 rhsPos = cast<AffineSymbolExpr>(Val&: rhs).getPosition();
1177 std::pair<unsigned, signed> indexEntry(
1178 lhsPos, std::max(a: numDims, b: numSymbols) + offsetSym++);
1179 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1180 }
1181 addedToMap[it.index()] = true;
1182 }
1183
1184 for (unsigned j = 0; j < numDims; ++j) {
1185 if (flatExprs[j] == 0)
1186 continue;
1187 // For dimensional expressions we set the index as <position number of the
1188 // dimension, 0>, as we want dimensional expressions to appear before
1189 // symbolic ones and products of dimensional and symbolic expressions
1190 // having the dimension with the same position number.
1191 std::pair<unsigned, signed> indexEntry(j, offsetDim--);
1192 addEntry(indexEntry, flatExprs[j], getAffineDimExpr(position: j, context));
1193 }
1194
1195 // Constructing the simplified semi-affine sum of product/division/mod
1196 // expression from the flattened form in the desired sorted order of indices
1197 // of the various individual product/division/mod expressions.
1198 llvm::sort(C&: indices);
1199 for (const std::pair<unsigned, unsigned> index : indices) {
1200 assert(indexToExprMap.lookup(index) &&
1201 "cannot find key in `indexToExprMap` map");
1202 expr = expr + indexToExprMap.lookup(Val: index) * coefficients.lookup(Val: index);
1203 }
1204
1205 // Local identifiers.
1206 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1207 j++) {
1208 // If the coefficient of the local expression is 0, continue as we need not
1209 // add it in out final expression.
1210 if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1211 continue;
1212 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1213 expr = expr + term;
1214 }
1215
1216 // Constant term.
1217 int64_t constTerm = flatExprs.back();
1218 if (constTerm != 0)
1219 expr = expr + constTerm;
1220 return expr;
1221}
1222
1223SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
1224 unsigned numSymbols)
1225 : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1226 operandExprStack.reserve(n: 8);
1227}
1228
1229// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1230//
1231// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1232// introduce a local variable p (= expr * symbolic_expr), and the affine
1233// expression expr * symbolic_expr is added to `localExprs`.
1234LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
1235 assert(operandExprStack.size() >= 2);
1236 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1237 operandExprStack.pop_back();
1238 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1239
1240 // Flatten semi-affine multiplication expressions by introducing a local
1241 // variable in place of the product; the affine expression
1242 // corresponding to the quantifier is added to `localExprs`.
1243 if (!isa<AffineConstantExpr>(Val: expr.getRHS())) {
1244 MLIRContext *context = expr.getContext();
1245 AffineExpr a = getAffineExprFromFlatForm(flatExprs: lhs, numDims, numSymbols,
1246 localExprs, context);
1247 AffineExpr b = getAffineExprFromFlatForm(flatExprs: rhs, numDims, numSymbols,
1248 localExprs, context);
1249 addLocalVariableSemiAffine(expr: a * b, result&: lhs, resultSize: lhs.size());
1250 return success();
1251 }
1252
1253 // Get the RHS constant.
1254 int64_t rhsConst = rhs[getConstantIndex()];
1255 for (int64_t &lhsElt : lhs)
1256 lhsElt *= rhsConst;
1257
1258 return success();
1259}
1260
1261LogicalResult SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
1262 assert(operandExprStack.size() >= 2);
1263 const auto &rhs = operandExprStack.back();
1264 auto &lhs = operandExprStack[operandExprStack.size() - 2];
1265 assert(lhs.size() == rhs.size());
1266 // Update the LHS in place.
1267 for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1268 lhs[i] += rhs[i];
1269 }
1270 // Pop off the RHS.
1271 operandExprStack.pop_back();
1272 return success();
1273}
1274
1275//
1276// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1277//
1278// A mod expression "expr mod c" is thus flattened by introducing a new local
1279// variable q (= expr floordiv c), such that expr mod c is replaced with
1280// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1281//
1282// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1283// introduce a local variable m (= expr mod symbolic_expr), and the affine
1284// expression expr mod symbolic_expr is added to `localExprs`.
1285LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1286 assert(operandExprStack.size() >= 2);
1287
1288 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1289 operandExprStack.pop_back();
1290 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1291 MLIRContext *context = expr.getContext();
1292
1293 // Flatten semi affine modulo expressions by introducing a local
1294 // variable in place of the modulo value, and the affine expression
1295 // corresponding to the quantifier is added to `localExprs`.
1296 if (!isa<AffineConstantExpr>(Val: expr.getRHS())) {
1297 AffineExpr dividendExpr = getAffineExprFromFlatForm(
1298 flatExprs: lhs, numDims, numSymbols, localExprs, context);
1299 AffineExpr divisorExpr = getAffineExprFromFlatForm(flatExprs: rhs, numDims, numSymbols,
1300 localExprs, context);
1301 AffineExpr modExpr = dividendExpr % divisorExpr;
1302 addLocalVariableSemiAffine(expr: modExpr, result&: lhs, resultSize: lhs.size());
1303 return success();
1304 }
1305
1306 int64_t rhsConst = rhs[getConstantIndex()];
1307 if (rhsConst <= 0)
1308 return failure();
1309
1310 // Check if the LHS expression is a multiple of modulo factor.
1311 unsigned i, e;
1312 for (i = 0, e = lhs.size(); i < e; i++)
1313 if (lhs[i] % rhsConst != 0)
1314 break;
1315 // If yes, modulo expression here simplifies to zero.
1316 if (i == lhs.size()) {
1317 std::fill(first: lhs.begin(), last: lhs.end(), value: 0);
1318 return success();
1319 }
1320
1321 // Add a local variable for the quotient, i.e., expr % c is replaced by
1322 // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1323 // the GCD of expr and c.
1324 SmallVector<int64_t, 8> floorDividend(lhs);
1325 uint64_t gcd = rhsConst;
1326 for (int64_t lhsElt : lhs)
1327 gcd = std::gcd(m: gcd, n: (uint64_t)std::abs(i: lhsElt));
1328 // Simplify the numerator and the denominator.
1329 if (gcd != 1) {
1330 for (int64_t &floorDividendElt : floorDividend)
1331 floorDividendElt = floorDividendElt / static_cast<int64_t>(gcd);
1332 }
1333 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1334
1335 // Construct the AffineExpr form of the floordiv to store in localExprs.
1336
1337 AffineExpr dividendExpr = getAffineExprFromFlatForm(
1338 flatExprs: floorDividend, numDims, numSymbols, localExprs, context);
1339 AffineExpr divisorExpr = getAffineConstantExpr(constant: floorDivisor, context);
1340 AffineExpr floorDivExpr = dividendExpr.floorDiv(other: divisorExpr);
1341 int loc;
1342 if ((loc = findLocalId(localExpr: floorDivExpr)) == -1) {
1343 addLocalFloorDivId(dividend: floorDividend, divisor: floorDivisor, localExpr: floorDivExpr);
1344 // Set result at top of stack to "lhs - rhsConst * q".
1345 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1346 } else {
1347 // Reuse the existing local id.
1348 lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1349 }
1350 return success();
1351}
1352
1353LogicalResult
1354SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
1355 return visitDivExpr(expr, /*isCeil=*/true);
1356}
1357LogicalResult
1358SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
1359 return visitDivExpr(expr, /*isCeil=*/false);
1360}
1361
1362LogicalResult SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
1363 operandExprStack.emplace_back(args: SmallVector<int64_t, 32>(getNumCols(), 0));
1364 auto &eq = operandExprStack.back();
1365 assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1366 eq[getDimStartIndex() + expr.getPosition()] = 1;
1367 return success();
1368}
1369
1370LogicalResult
1371SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
1372 operandExprStack.emplace_back(args: SmallVector<int64_t, 32>(getNumCols(), 0));
1373 auto &eq = operandExprStack.back();
1374 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1375 eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1376 return success();
1377}
1378
1379LogicalResult
1380SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
1381 operandExprStack.emplace_back(args: SmallVector<int64_t, 32>(getNumCols(), 0));
1382 auto &eq = operandExprStack.back();
1383 eq[getConstantIndex()] = expr.getValue();
1384 return success();
1385}
1386
1387void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1388 AffineExpr expr, SmallVectorImpl<int64_t> &result,
1389 unsigned long resultSize) {
1390 assert(result.size() == resultSize &&
1391 "`result` vector passed is not of correct size");
1392 int loc;
1393 if ((loc = findLocalId(localExpr: expr)) == -1)
1394 addLocalIdSemiAffine(localExpr: expr);
1395 std::fill(first: result.begin(), last: result.end(), value: 0);
1396 if (loc == -1)
1397 result[getLocalVarStartIndex() + numLocals - 1] = 1;
1398 else
1399 result[getLocalVarStartIndex() + loc] = 1;
1400}
1401
1402// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1403// A floordiv is thus flattened by introducing a new local variable q, and
1404// replacing that expression with 'q' while adding the constraints
1405// c * q <= expr <= c * q + c - 1 to localVarCst (done by
1406// IntegerRelation::addLocalFloorDiv).
1407//
1408// A ceildiv is similarly flattened:
1409// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1410//
1411// In case of semi affine division expressions, t = expr floordiv symbolic_expr
1412// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1413// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1414// `localExprs`.
1415LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1416 bool isCeil) {
1417 assert(operandExprStack.size() >= 2);
1418
1419 MLIRContext *context = expr.getContext();
1420 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1421 operandExprStack.pop_back();
1422 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1423
1424 // Flatten semi affine division expressions by introducing a local
1425 // variable in place of the quotient, and the affine expression corresponding
1426 // to the quantifier is added to `localExprs`.
1427 if (!isa<AffineConstantExpr>(Val: expr.getRHS())) {
1428 AffineExpr a = getAffineExprFromFlatForm(flatExprs: lhs, numDims, numSymbols,
1429 localExprs, context);
1430 AffineExpr b = getAffineExprFromFlatForm(flatExprs: rhs, numDims, numSymbols,
1431 localExprs, context);
1432 AffineExpr divExpr = isCeil ? a.ceilDiv(other: b) : a.floorDiv(other: b);
1433 addLocalVariableSemiAffine(expr: divExpr, result&: lhs, resultSize: lhs.size());
1434 return success();
1435 }
1436
1437 // This is a pure affine expr; the RHS is a positive constant.
1438 int64_t rhsConst = rhs[getConstantIndex()];
1439 if (rhsConst <= 0)
1440 return failure();
1441
1442 // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1443 // common divisors of the numerator and denominator.
1444 uint64_t gcd = std::abs(i: rhsConst);
1445 for (int64_t lhsElt : lhs)
1446 gcd = std::gcd(m: gcd, n: (uint64_t)std::abs(i: lhsElt));
1447 // Simplify the numerator and the denominator.
1448 if (gcd != 1) {
1449 for (int64_t &lhsElt : lhs)
1450 lhsElt = lhsElt / static_cast<int64_t>(gcd);
1451 }
1452 int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1453 // If the divisor becomes 1, the updated LHS is the result. (The
1454 // divisor can't be negative since rhsConst is positive).
1455 if (divisor == 1)
1456 return success();
1457
1458 // If the divisor cannot be simplified to one, we will have to retain
1459 // the ceil/floor expr (simplified up until here). Add an existential
1460 // quantifier to express its result, i.e., expr1 div expr2 is replaced
1461 // by a new identifier, q.
1462 AffineExpr a =
1463 getAffineExprFromFlatForm(flatExprs: lhs, numDims, numSymbols, localExprs, context);
1464 AffineExpr b = getAffineConstantExpr(constant: divisor, context);
1465
1466 int loc;
1467 AffineExpr divExpr = isCeil ? a.ceilDiv(other: b) : a.floorDiv(other: b);
1468 if ((loc = findLocalId(localExpr: divExpr)) == -1) {
1469 if (!isCeil) {
1470 SmallVector<int64_t, 8> dividend(lhs);
1471 addLocalFloorDivId(dividend, divisor, localExpr: divExpr);
1472 } else {
1473 // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1474 SmallVector<int64_t, 8> dividend(lhs);
1475 dividend.back() += divisor - 1;
1476 addLocalFloorDivId(dividend, divisor, localExpr: divExpr);
1477 }
1478 }
1479 // Set the expression on stack to the local var introduced to capture the
1480 // result of the division (floor or ceil).
1481 std::fill(first: lhs.begin(), last: lhs.end(), value: 0);
1482 if (loc == -1)
1483 lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1484 else
1485 lhs[getLocalVarStartIndex() + loc] = 1;
1486 return success();
1487}
1488
1489// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1490// The local identifier added is always a floordiv of a pure add/mul affine
1491// function of other identifiers, coefficients of which are specified in
1492// dividend and with respect to a positive constant divisor. localExpr is the
1493// simplified tree expression (AffineExpr) corresponding to the quantifier.
1494void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1495 int64_t divisor,
1496 AffineExpr localExpr) {
1497 assert(divisor > 0 && "positive constant divisor expected");
1498 for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1499 subExpr.insert(I: subExpr.begin() + getLocalVarStartIndex() + numLocals, Elt: 0);
1500 localExprs.push_back(Elt: localExpr);
1501 numLocals++;
1502 // dividend and divisor are not used here; an override of this method uses it.
1503}
1504
1505void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr) {
1506 for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1507 subExpr.insert(I: subExpr.begin() + getLocalVarStartIndex() + numLocals, Elt: 0);
1508 localExprs.push_back(Elt: localExpr);
1509 ++numLocals;
1510}
1511
1512int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1513 SmallVectorImpl<AffineExpr>::iterator it;
1514 if ((it = llvm::find(Range&: localExprs, Val: localExpr)) == localExprs.end())
1515 return -1;
1516 return it - localExprs.begin();
1517}
1518
1519/// Simplify the affine expression by flattening it and reconstructing it.
1520AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1521 unsigned numSymbols) {
1522 // Simplify semi-affine expressions separately.
1523 if (!expr.isPureAffine())
1524 expr = simplifySemiAffine(expr, numDims, numSymbols);
1525
1526 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1527 // has poison expression
1528 if (failed(result: flattener.walkPostOrder(expr)))
1529 return expr;
1530 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1531 if (!expr.isPureAffine() &&
1532 expr == getAffineExprFromFlatForm(flatExprs: flattenedExpr, numDims, numSymbols,
1533 localExprs: flattener.localExprs,
1534 context: expr.getContext()))
1535 return expr;
1536 AffineExpr simplifiedExpr =
1537 expr.isPureAffine()
1538 ? getAffineExprFromFlatForm(flatExprs: flattenedExpr, numDims, numSymbols,
1539 localExprs: flattener.localExprs, context: expr.getContext())
1540 : getSemiAffineExprFromFlatForm(flatExprs: flattenedExpr, numDims, numSymbols,
1541 localExprs: flattener.localExprs,
1542 context: expr.getContext());
1543
1544 flattener.operandExprStack.pop_back();
1545 assert(flattener.operandExprStack.empty());
1546 return simplifiedExpr;
1547}
1548
1549std::optional<int64_t> mlir::getBoundForAffineExpr(
1550 AffineExpr expr, unsigned numDims, unsigned numSymbols,
1551 ArrayRef<std::optional<int64_t>> constLowerBounds,
1552 ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
1553 // Handle divs and mods.
1554 if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr)) {
1555 // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
1556 // can compute an upper bound.
1557 if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
1558 auto rhsConst = dyn_cast<AffineConstantExpr>(Val: binOpExpr.getRHS());
1559 if (!rhsConst || rhsConst.getValue() < 1)
1560 return std::nullopt;
1561 auto bound =
1562 getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1563 constLowerBounds, constUpperBounds, isUpper);
1564 if (!bound)
1565 return std::nullopt;
1566 return mlir::floorDiv(lhs: *bound, rhs: rhsConst.getValue());
1567 }
1568 if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
1569 auto rhsConst = dyn_cast<AffineConstantExpr>(Val: binOpExpr.getRHS());
1570 if (rhsConst && rhsConst.getValue() >= 1) {
1571 auto bound =
1572 getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1573 constLowerBounds, constUpperBounds, isUpper);
1574 if (!bound)
1575 return std::nullopt;
1576 return mlir::ceilDiv(lhs: *bound, rhs: rhsConst.getValue());
1577 }
1578 return std::nullopt;
1579 }
1580 if (binOpExpr.getKind() == AffineExprKind::Mod) {
1581 // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
1582 // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
1583 // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
1584 auto rhsConst = dyn_cast<AffineConstantExpr>(Val: binOpExpr.getRHS());
1585 if (rhsConst && rhsConst.getValue() >= 1) {
1586 int64_t rhsConstVal = rhsConst.getValue();
1587 auto lb = getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1588 constLowerBounds, constUpperBounds,
1589 /*isUpper=*/false);
1590 auto ub =
1591 getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1592 constLowerBounds, constUpperBounds, isUpper);
1593 if (ub && lb &&
1594 floorDiv(lhs: *lb, rhs: rhsConstVal) == floorDiv(lhs: *ub, rhs: rhsConstVal))
1595 return isUpper ? mod(lhs: *ub, rhs: rhsConstVal) : mod(lhs: *lb, rhs: rhsConstVal);
1596 return isUpper ? rhsConstVal - 1 : 0;
1597 }
1598 }
1599 }
1600 // Flatten the expression.
1601 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1602 auto simpleResult = flattener.walkPostOrder(expr);
1603 // has poison expression
1604 if (failed(result: simpleResult))
1605 return std::nullopt;
1606 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1607 // TODO: Handle local variables. We can get hold of flattener.localExprs and
1608 // get bound on the local expr recursively.
1609 if (flattener.numLocals > 0)
1610 return std::nullopt;
1611 int64_t bound = 0;
1612 // Substitute the constant lower or upper bound for the dimensional or
1613 // symbolic input depending on `isUpper` to determine the bound.
1614 for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
1615 if (flattenedExpr[i] > 0) {
1616 auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
1617 if (!constBound)
1618 return std::nullopt;
1619 bound += *constBound * flattenedExpr[i];
1620 } else if (flattenedExpr[i] < 0) {
1621 auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
1622 if (!constBound)
1623 return std::nullopt;
1624 bound += *constBound * flattenedExpr[i];
1625 }
1626 }
1627 // Constant term.
1628 bound += flattenedExpr.back();
1629 return bound;
1630}
1631

source code of mlir/lib/IR/AffineExpr.cpp