1//===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cmath>
10#include <cstdint>
11#include <limits>
12#include <utility>
13
14#include "AffineExprDetail.h"
15#include "mlir/IR/AffineExpr.h"
16#include "mlir/IR/AffineExprVisitor.h"
17#include "mlir/IR/AffineMap.h"
18#include "mlir/IR/IntegerSet.h"
19#include "mlir/Support/TypeID.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/MathExtras.h"
22#include <numeric>
23#include <optional>
24
25using namespace mlir;
26using namespace mlir::detail;
27
28using llvm::divideCeilSigned;
29using llvm::divideFloorSigned;
30using llvm::divideSignedWouldOverflow;
31using llvm::mod;
32
33MLIRContext *AffineExpr::getContext() const { return expr->context; }
34
35AffineExprKind AffineExpr::getKind() const { return expr->kind; }
36
37/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
38/// method to help handle lambda walk functions. Users should use the regular
39/// (non-static) `walk` method.
40template <typename WalkRetTy>
41WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
42 function_ref<WalkRetTy(AffineExpr)> callback) {
43 struct AffineExprWalker
44 : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
45 function_ref<WalkRetTy(AffineExpr)> callback;
46
47 AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
48 : callback(callback) {}
49
50 WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
51 return callback(expr);
52 }
53 WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
54 return callback(expr);
55 }
56 WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
57 WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
58 };
59
60 return AffineExprWalker(callback).walkPostOrder(e);
61}
62// Explicitly instantiate for the two supported return types.
63template void mlir::AffineExpr::walk(AffineExpr e,
64 function_ref<void(AffineExpr)> callback);
65template WalkResult
66mlir::AffineExpr::walk(AffineExpr e,
67 function_ref<WalkResult(AffineExpr)> callback);
68
69// Dispatch affine expression construction based on kind.
70AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
71 AffineExpr rhs) {
72 if (kind == AffineExprKind::Add)
73 return lhs + rhs;
74 if (kind == AffineExprKind::Mul)
75 return lhs * rhs;
76 if (kind == AffineExprKind::FloorDiv)
77 return lhs.floorDiv(other: rhs);
78 if (kind == AffineExprKind::CeilDiv)
79 return lhs.ceilDiv(other: rhs);
80 if (kind == AffineExprKind::Mod)
81 return lhs % rhs;
82
83 llvm_unreachable("unknown binary operation on affine expressions");
84}
85
86/// This method substitutes any uses of dimensions and symbols (e.g.
87/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
88AffineExpr
89AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
90 ArrayRef<AffineExpr> symReplacements) const {
91 switch (getKind()) {
92 case AffineExprKind::Constant:
93 return *this;
94 case AffineExprKind::DimId: {
95 unsigned dimId = llvm::cast<AffineDimExpr>(Val: *this).getPosition();
96 if (dimId >= dimReplacements.size())
97 return *this;
98 return dimReplacements[dimId];
99 }
100 case AffineExprKind::SymbolId: {
101 unsigned symId = llvm::cast<AffineSymbolExpr>(Val: *this).getPosition();
102 if (symId >= symReplacements.size())
103 return *this;
104 return symReplacements[symId];
105 }
106 case AffineExprKind::Add:
107 case AffineExprKind::Mul:
108 case AffineExprKind::FloorDiv:
109 case AffineExprKind::CeilDiv:
110 case AffineExprKind::Mod:
111 auto binOp = llvm::cast<AffineBinaryOpExpr>(Val: *this);
112 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
113 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
114 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
115 if (newLHS == lhs && newRHS == rhs)
116 return *this;
117 return getAffineBinaryOpExpr(kind: getKind(), lhs: newLHS, rhs: newRHS);
118 }
119 llvm_unreachable("Unknown AffineExpr");
120}
121
122AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
123 return replaceDimsAndSymbols(dimReplacements, symReplacements: {});
124}
125
126AffineExpr
127AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
128 return replaceDimsAndSymbols(dimReplacements: {}, symReplacements);
129}
130
131/// Replace dims[offset ... numDims)
132/// by dims[offset + shift ... shift + numDims).
133AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
134 unsigned offset) const {
135 SmallVector<AffineExpr, 4> dims;
136 for (unsigned idx = 0; idx < offset; ++idx)
137 dims.push_back(Elt: getAffineDimExpr(position: idx, context: getContext()));
138 for (unsigned idx = offset; idx < numDims; ++idx)
139 dims.push_back(Elt: getAffineDimExpr(position: idx + shift, context: getContext()));
140 return replaceDimsAndSymbols(dimReplacements: dims, symReplacements: {});
141}
142
143/// Replace symbols[offset ... numSymbols)
144/// by symbols[offset + shift ... shift + numSymbols).
145AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
146 unsigned offset) const {
147 SmallVector<AffineExpr, 4> symbols;
148 for (unsigned idx = 0; idx < offset; ++idx)
149 symbols.push_back(Elt: getAffineSymbolExpr(position: idx, context: getContext()));
150 for (unsigned idx = offset; idx < numSymbols; ++idx)
151 symbols.push_back(Elt: getAffineSymbolExpr(position: idx + shift, context: getContext()));
152 return replaceDimsAndSymbols(dimReplacements: {}, symReplacements: symbols);
153}
154
155/// Sparse replace method. Return the modified expression tree.
156AffineExpr
157AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
158 auto it = map.find(Val: *this);
159 if (it != map.end())
160 return it->second;
161 switch (getKind()) {
162 default:
163 return *this;
164 case AffineExprKind::Add:
165 case AffineExprKind::Mul:
166 case AffineExprKind::FloorDiv:
167 case AffineExprKind::CeilDiv:
168 case AffineExprKind::Mod:
169 auto binOp = llvm::cast<AffineBinaryOpExpr>(Val: *this);
170 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
171 auto newLHS = lhs.replace(map);
172 auto newRHS = rhs.replace(map);
173 if (newLHS == lhs && newRHS == rhs)
174 return *this;
175 return getAffineBinaryOpExpr(kind: getKind(), lhs: newLHS, rhs: newRHS);
176 }
177 llvm_unreachable("Unknown AffineExpr");
178}
179
180/// Sparse replace method. Return the modified expression tree.
181AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
182 DenseMap<AffineExpr, AffineExpr> map;
183 map.insert(KV: std::make_pair(x&: expr, y&: replacement));
184 return replace(map);
185}
186/// Returns true if this expression is made out of only symbols and
187/// constants (no dimensional identifiers).
188bool AffineExpr::isSymbolicOrConstant() const {
189 switch (getKind()) {
190 case AffineExprKind::Constant:
191 return true;
192 case AffineExprKind::DimId:
193 return false;
194 case AffineExprKind::SymbolId:
195 return true;
196
197 case AffineExprKind::Add:
198 case AffineExprKind::Mul:
199 case AffineExprKind::FloorDiv:
200 case AffineExprKind::CeilDiv:
201 case AffineExprKind::Mod: {
202 auto expr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
203 return expr.getLHS().isSymbolicOrConstant() &&
204 expr.getRHS().isSymbolicOrConstant();
205 }
206 }
207 llvm_unreachable("Unknown AffineExpr");
208}
209
210/// Returns true if this is a pure affine expression, i.e., multiplication,
211/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
212bool AffineExpr::isPureAffine() const {
213 switch (getKind()) {
214 case AffineExprKind::SymbolId:
215 case AffineExprKind::DimId:
216 case AffineExprKind::Constant:
217 return true;
218 case AffineExprKind::Add: {
219 auto op = llvm::cast<AffineBinaryOpExpr>(Val: *this);
220 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
221 }
222
223 case AffineExprKind::Mul: {
224 // TODO: Canonicalize the constants in binary operators to the RHS when
225 // possible, allowing this to merge into the next case.
226 auto op = llvm::cast<AffineBinaryOpExpr>(Val: *this);
227 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
228 (llvm::isa<AffineConstantExpr>(Val: op.getLHS()) ||
229 llvm::isa<AffineConstantExpr>(Val: op.getRHS()));
230 }
231 case AffineExprKind::FloorDiv:
232 case AffineExprKind::CeilDiv:
233 case AffineExprKind::Mod: {
234 auto op = llvm::cast<AffineBinaryOpExpr>(Val: *this);
235 return op.getLHS().isPureAffine() &&
236 llvm::isa<AffineConstantExpr>(Val: op.getRHS());
237 }
238 }
239 llvm_unreachable("Unknown AffineExpr");
240}
241
242// Returns the greatest known integral divisor of this affine expression.
243int64_t AffineExpr::getLargestKnownDivisor() const {
244 AffineBinaryOpExpr binExpr(nullptr);
245 switch (getKind()) {
246 case AffineExprKind::DimId:
247 [[fallthrough]];
248 case AffineExprKind::SymbolId:
249 return 1;
250 case AffineExprKind::CeilDiv:
251 [[fallthrough]];
252 case AffineExprKind::FloorDiv: {
253 // If the RHS is a constant and divides the known divisor on the LHS, the
254 // quotient is a known divisor of the expression.
255 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
256 auto rhs = llvm::dyn_cast<AffineConstantExpr>(Val: binExpr.getRHS());
257 // Leave alone undefined expressions.
258 if (rhs && rhs.getValue() != 0) {
259 int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor();
260 if (lhsDiv % rhs.getValue() == 0)
261 return std::abs(i: lhsDiv / rhs.getValue());
262 }
263 return 1;
264 }
265 case AffineExprKind::Constant:
266 return std::abs(i: llvm::cast<AffineConstantExpr>(Val: *this).getValue());
267 case AffineExprKind::Mul: {
268 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
269 return binExpr.getLHS().getLargestKnownDivisor() *
270 binExpr.getRHS().getLargestKnownDivisor();
271 }
272 case AffineExprKind::Add:
273 [[fallthrough]];
274 case AffineExprKind::Mod: {
275 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
276 return std::gcd(m: (uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
277 n: (uint64_t)binExpr.getRHS().getLargestKnownDivisor());
278 }
279 }
280 llvm_unreachable("Unknown AffineExpr");
281}
282
283bool AffineExpr::isMultipleOf(int64_t factor) const {
284 AffineBinaryOpExpr binExpr(nullptr);
285 uint64_t l, u;
286 switch (getKind()) {
287 case AffineExprKind::SymbolId:
288 [[fallthrough]];
289 case AffineExprKind::DimId:
290 return factor * factor == 1;
291 case AffineExprKind::Constant:
292 return llvm::cast<AffineConstantExpr>(Val: *this).getValue() % factor == 0;
293 case AffineExprKind::Mul: {
294 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
295 // It's probably not worth optimizing this further (to not traverse the
296 // whole sub-tree under - it that would require a version of isMultipleOf
297 // that on a 'false' return also returns the largest known divisor).
298 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
299 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
300 (l * u) % factor == 0;
301 }
302 case AffineExprKind::Add:
303 case AffineExprKind::FloorDiv:
304 case AffineExprKind::CeilDiv:
305 case AffineExprKind::Mod: {
306 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
307 return std::gcd(m: (uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
308 n: (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
309 factor ==
310 0;
311 }
312 }
313 llvm_unreachable("Unknown AffineExpr");
314}
315
316bool AffineExpr::isFunctionOfDim(unsigned position) const {
317 if (getKind() == AffineExprKind::DimId) {
318 return *this == mlir::getAffineDimExpr(position, context: getContext());
319 }
320 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(Val: *this)) {
321 return expr.getLHS().isFunctionOfDim(position) ||
322 expr.getRHS().isFunctionOfDim(position);
323 }
324 return false;
325}
326
327bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
328 if (getKind() == AffineExprKind::SymbolId) {
329 return *this == mlir::getAffineSymbolExpr(position, context: getContext());
330 }
331 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(Val: *this)) {
332 return expr.getLHS().isFunctionOfSymbol(position) ||
333 expr.getRHS().isFunctionOfSymbol(position);
334 }
335 return false;
336}
337
338AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
339 : AffineExpr(ptr) {}
340AffineExpr AffineBinaryOpExpr::getLHS() const {
341 return static_cast<ImplType *>(expr)->lhs;
342}
343AffineExpr AffineBinaryOpExpr::getRHS() const {
344 return static_cast<ImplType *>(expr)->rhs;
345}
346
347AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
348unsigned AffineDimExpr::getPosition() const {
349 return static_cast<ImplType *>(expr)->position;
350}
351
352/// Returns true if the expression is divisible by the given symbol with
353/// position `symbolPos`. The argument `opKind` specifies here what kind of
354/// division or mod operation called this division. It helps in implementing the
355/// commutative property of the floordiv and ceildiv operations. If the argument
356///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
357/// operation, then the commutative property can be used otherwise, the floordiv
358/// operation is not divisible. The same argument holds for ceildiv operation.
359static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos,
360 AffineExprKind opKind,
361 bool fromMul = false) {
362 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
363 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
364 opKind == AffineExprKind::CeilDiv) &&
365 "unexpected opKind");
366 switch (expr.getKind()) {
367 case AffineExprKind::Constant:
368 return cast<AffineConstantExpr>(Val&: expr).getValue() == 0;
369 case AffineExprKind::DimId:
370 return false;
371 case AffineExprKind::SymbolId:
372 return (cast<AffineSymbolExpr>(Val&: expr).getPosition() == symbolPos);
373 // Checks divisibility by the given symbol for both operands.
374 case AffineExprKind::Add: {
375 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
376 return canSimplifyDivisionBySymbol(expr: binaryExpr.getLHS(), symbolPos,
377 opKind) &&
378 canSimplifyDivisionBySymbol(expr: binaryExpr.getRHS(), symbolPos, opKind);
379 }
380 // Checks divisibility by the given symbol for both operands. Consider the
381 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
382 // this is a division by s1 and both the operands of modulo are divisible by
383 // s1 but it is not divisible by s1 always. The third argument is
384 // `AffineExprKind::Mod` for this reason.
385 case AffineExprKind::Mod: {
386 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
387 return canSimplifyDivisionBySymbol(expr: binaryExpr.getLHS(), symbolPos,
388 opKind: AffineExprKind::Mod) &&
389 canSimplifyDivisionBySymbol(expr: binaryExpr.getRHS(), symbolPos,
390 opKind: AffineExprKind::Mod);
391 }
392 // Checks if any of the operand divisible by the given symbol.
393 case AffineExprKind::Mul: {
394 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
395 return canSimplifyDivisionBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind,
396 fromMul: true) ||
397 canSimplifyDivisionBySymbol(expr: binaryExpr.getRHS(), symbolPos, opKind,
398 fromMul: true);
399 }
400 // Floordiv and ceildiv are divisible by the given symbol when the first
401 // operand is divisible, and the affine expression kind of the argument expr
402 // is same as the argument `opKind`. This can be inferred from commutative
403 // property of floordiv and ceildiv operations and are as follow:
404 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
405 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
406 // It will fail 1.if operations are not same. For example:
407 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
408 // multiplication operation in the expression. For example:
409 // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
410 case AffineExprKind::FloorDiv:
411 case AffineExprKind::CeilDiv: {
412 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
413 if (opKind != expr.getKind())
414 return false;
415 if (fromMul)
416 return false;
417 return canSimplifyDivisionBySymbol(expr: binaryExpr.getLHS(), symbolPos,
418 opKind: expr.getKind());
419 }
420 }
421 llvm_unreachable("Unknown AffineExpr");
422}
423
424/// Divides the given expression by the given symbol at position `symbolPos`. It
425/// considers the divisibility condition is checked before calling itself. A
426/// null expression is returned whenever the divisibility condition fails.
427static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
428 AffineExprKind opKind) {
429 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
430 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
431 opKind == AffineExprKind::CeilDiv) &&
432 "unexpected opKind");
433 switch (expr.getKind()) {
434 case AffineExprKind::Constant:
435 if (cast<AffineConstantExpr>(Val&: expr).getValue() != 0)
436 return nullptr;
437 return getAffineConstantExpr(constant: 0, context: expr.getContext());
438 case AffineExprKind::DimId:
439 return nullptr;
440 case AffineExprKind::SymbolId:
441 return getAffineConstantExpr(constant: 1, context: expr.getContext());
442 // Dividing both operands by the given symbol.
443 case AffineExprKind::Add: {
444 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
445 return getAffineBinaryOpExpr(
446 kind: expr.getKind(), lhs: symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind),
447 rhs: symbolicDivide(expr: binaryExpr.getRHS(), symbolPos, opKind));
448 }
449 // Dividing both operands by the given symbol.
450 case AffineExprKind::Mod: {
451 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
452 return getAffineBinaryOpExpr(
453 kind: expr.getKind(),
454 lhs: symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind: expr.getKind()),
455 rhs: symbolicDivide(expr: binaryExpr.getRHS(), symbolPos, opKind: expr.getKind()));
456 }
457 // Dividing any of the operand by the given symbol.
458 case AffineExprKind::Mul: {
459 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
460 if (!canSimplifyDivisionBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind))
461 return binaryExpr.getLHS() *
462 symbolicDivide(expr: binaryExpr.getRHS(), symbolPos, opKind);
463 return symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind) *
464 binaryExpr.getRHS();
465 }
466 // Dividing first operand only by the given symbol.
467 case AffineExprKind::FloorDiv:
468 case AffineExprKind::CeilDiv: {
469 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
470 return getAffineBinaryOpExpr(
471 kind: expr.getKind(),
472 lhs: symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind: expr.getKind()),
473 rhs: binaryExpr.getRHS());
474 }
475 }
476 llvm_unreachable("Unknown AffineExpr");
477}
478
479/// Populate `result` with all summand operands of given (potentially nested)
480/// addition. If the given expression is not an addition, just populate the
481/// expression itself.
482/// Example: Add(Add(7, 8), Mul(9, 10)) will return [7, 8, Mul(9, 10)].
483static void getSummandExprs(AffineExpr expr, SmallVector<AffineExpr> &result) {
484 auto addExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr);
485 if (!addExpr || addExpr.getKind() != AffineExprKind::Add) {
486 result.push_back(Elt: expr);
487 return;
488 }
489 getSummandExprs(expr: addExpr.getLHS(), result);
490 getSummandExprs(expr: addExpr.getRHS(), result);
491}
492
493/// Return "true" if `candidate` is a negated expression, i.e., Mul(-1, expr).
494/// If so, also return the non-negated expression via `expr`.
495static bool isNegatedAffineExpr(AffineExpr candidate, AffineExpr &expr) {
496 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(Val&: candidate);
497 if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
498 return false;
499 if (auto lhs = dyn_cast<AffineConstantExpr>(Val: mulExpr.getLHS())) {
500 if (lhs.getValue() == -1) {
501 expr = mulExpr.getRHS();
502 return true;
503 }
504 }
505 if (auto rhs = dyn_cast<AffineConstantExpr>(Val: mulExpr.getRHS())) {
506 if (rhs.getValue() == -1) {
507 expr = mulExpr.getLHS();
508 return true;
509 }
510 }
511 return false;
512}
513
514/// Return "true" if `lhs` % `rhs` is guaranteed to evaluate to zero based on
515/// the fact that `lhs` contains another modulo expression that ensures that
516/// `lhs` is divisible by `rhs`. This is a common pattern in the resulting IR
517/// after loop peeling.
518///
519/// Example: lhs = ub - ub % step
520/// rhs = step
521/// => (ub - ub % step) % step is guaranteed to evaluate to 0.
522static bool isModOfModSubtraction(AffineExpr lhs, AffineExpr rhs,
523 unsigned numDims, unsigned numSymbols) {
524 // TODO: Try to unify this function with `getBoundForAffineExpr`.
525 // Collect all summands in lhs.
526 SmallVector<AffineExpr> summands;
527 getSummandExprs(expr: lhs, result&: summands);
528 // Look for Mul(-1, Mod(x, rhs)) among the summands. If x matches the
529 // remaining summands, then lhs % rhs is guaranteed to evaluate to 0.
530 for (int64_t i = 0, e = summands.size(); i < e; ++i) {
531 AffineExpr current = summands[i];
532 AffineExpr beforeNegation;
533 if (!isNegatedAffineExpr(candidate: current, expr&: beforeNegation))
534 continue;
535 AffineBinaryOpExpr innerMod = dyn_cast<AffineBinaryOpExpr>(Val&: beforeNegation);
536 if (!innerMod || innerMod.getKind() != AffineExprKind::Mod)
537 continue;
538 if (innerMod.getRHS() != rhs)
539 continue;
540 // Sum all remaining summands and subtract x. If that expression can be
541 // simplified to zero, then the remaining summands and x are equal.
542 AffineExpr diff = getAffineConstantExpr(constant: 0, context: lhs.getContext());
543 for (int64_t j = 0; j < e; ++j)
544 if (i != j)
545 diff = diff + summands[j];
546 diff = diff - innerMod.getLHS();
547 diff = simplifyAffineExpr(expr: diff, numDims, numSymbols);
548 auto constExpr = dyn_cast<AffineConstantExpr>(Val&: diff);
549 if (constExpr && constExpr.getValue() == 0)
550 return true;
551 }
552 return false;
553}
554
555/// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
556/// operations when the second operand simplifies to a symbol and the first
557/// operand is divisible by that symbol. It can be applied to any semi-affine
558/// expression. Returned expression can either be a semi-affine or pure affine
559/// expression.
560static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
561 unsigned numSymbols) {
562 switch (expr.getKind()) {
563 case AffineExprKind::Constant:
564 case AffineExprKind::DimId:
565 case AffineExprKind::SymbolId:
566 return expr;
567 case AffineExprKind::Add:
568 case AffineExprKind::Mul: {
569 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
570 return getAffineBinaryOpExpr(
571 kind: expr.getKind(),
572 lhs: simplifySemiAffine(expr: binaryExpr.getLHS(), numDims, numSymbols),
573 rhs: simplifySemiAffine(expr: binaryExpr.getRHS(), numDims, numSymbols));
574 }
575 // Check if the simplification of the second operand is a symbol, and the
576 // first operand is divisible by it. If the operation is a modulo, a constant
577 // zero expression is returned. In the case of floordiv and ceildiv, the
578 // symbol from the simplification of the second operand divides the first
579 // operand. Otherwise, simplification is not possible.
580 case AffineExprKind::FloorDiv:
581 case AffineExprKind::CeilDiv:
582 case AffineExprKind::Mod: {
583 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
584 AffineExpr sLHS =
585 simplifySemiAffine(expr: binaryExpr.getLHS(), numDims, numSymbols);
586 AffineExpr sRHS =
587 simplifySemiAffine(expr: binaryExpr.getRHS(), numDims, numSymbols);
588 if (isModOfModSubtraction(lhs: sLHS, rhs: sRHS, numDims, numSymbols))
589 return getAffineConstantExpr(constant: 0, context: expr.getContext());
590 AffineSymbolExpr symbolExpr = dyn_cast<AffineSymbolExpr>(
591 Val: simplifySemiAffine(expr: binaryExpr.getRHS(), numDims, numSymbols));
592 if (!symbolExpr)
593 return getAffineBinaryOpExpr(kind: expr.getKind(), lhs: sLHS, rhs: sRHS);
594 unsigned symbolPos = symbolExpr.getPosition();
595 if (!canSimplifyDivisionBySymbol(expr: binaryExpr.getLHS(), symbolPos,
596 opKind: expr.getKind()))
597 return getAffineBinaryOpExpr(kind: expr.getKind(), lhs: sLHS, rhs: sRHS);
598 if (expr.getKind() == AffineExprKind::Mod)
599 return getAffineConstantExpr(constant: 0, context: expr.getContext());
600 AffineExpr simplifiedQuotient =
601 symbolicDivide(expr: sLHS, symbolPos, opKind: expr.getKind());
602 return simplifiedQuotient
603 ? simplifiedQuotient
604 : getAffineBinaryOpExpr(kind: expr.getKind(), lhs: sLHS, rhs: sRHS);
605 }
606 }
607 llvm_unreachable("Unknown AffineExpr");
608}
609
610static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
611 MLIRContext *context) {
612 auto assignCtx = [context](AffineDimExprStorage *storage) {
613 storage->context = context;
614 };
615
616 StorageUniquer &uniquer = context->getAffineUniquer();
617 return uniquer.get<AffineDimExprStorage>(
618 initFn: assignCtx, args: static_cast<unsigned>(kind), args&: position);
619}
620
621AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
622 return getAffineDimOrSymbol(kind: AffineExprKind::DimId, position, context);
623}
624
625AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
626 : AffineExpr(ptr) {}
627unsigned AffineSymbolExpr::getPosition() const {
628 return static_cast<ImplType *>(expr)->position;
629}
630
631AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
632 return getAffineDimOrSymbol(kind: AffineExprKind::SymbolId, position, context);
633}
634
635AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
636 : AffineExpr(ptr) {}
637int64_t AffineConstantExpr::getValue() const {
638 return static_cast<ImplType *>(expr)->constant;
639}
640
641bool AffineExpr::operator==(int64_t v) const {
642 return *this == getAffineConstantExpr(constant: v, context: getContext());
643}
644
645AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
646 auto assignCtx = [context](AffineConstantExprStorage *storage) {
647 storage->context = context;
648 };
649
650 StorageUniquer &uniquer = context->getAffineUniquer();
651 return uniquer.get<AffineConstantExprStorage>(initFn: assignCtx, args&: constant);
652}
653
654SmallVector<AffineExpr>
655mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
656 MLIRContext *context) {
657 return llvm::to_vector(Range: llvm::map_range(C&: constants, F: [&](int64_t constant) {
658 return getAffineConstantExpr(constant, context);
659 }));
660}
661
662/// Simplify add expression. Return nullptr if it can't be simplified.
663static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
664 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
665 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
666 // Fold if both LHS, RHS are a constant and the sum does not overflow.
667 if (lhsConst && rhsConst) {
668 int64_t sum;
669 if (llvm::AddOverflow(X: lhsConst.getValue(), Y: rhsConst.getValue(), Result&: sum)) {
670 return nullptr;
671 }
672 return getAffineConstantExpr(constant: sum, context: lhs.getContext());
673 }
674
675 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
676 // If only one of them is a symbolic expressions, make it the RHS.
677 if (isa<AffineConstantExpr>(Val: lhs) ||
678 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
679 return rhs + lhs;
680 }
681
682 // At this point, if there was a constant, it would be on the right.
683
684 // Addition with a zero is a noop, return the other input.
685 if (rhsConst) {
686 if (rhsConst.getValue() == 0)
687 return lhs;
688 }
689 // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
690 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
691 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
692 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS()))
693 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
694 }
695
696 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
697 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
698 // respective multiplicands.
699 std::optional<int64_t> rLhsConst, rRhsConst;
700 AffineExpr firstExpr, secondExpr;
701 AffineConstantExpr rLhsConstExpr;
702 auto lBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
703 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
704 (rLhsConstExpr = dyn_cast<AffineConstantExpr>(Val: lBinOpExpr.getRHS()))) {
705 rLhsConst = rLhsConstExpr.getValue();
706 firstExpr = lBinOpExpr.getLHS();
707 } else {
708 rLhsConst = 1;
709 firstExpr = lhs;
710 }
711
712 auto rBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: rhs);
713 AffineConstantExpr rRhsConstExpr;
714 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
715 (rRhsConstExpr = dyn_cast<AffineConstantExpr>(Val: rBinOpExpr.getRHS()))) {
716 rRhsConst = rRhsConstExpr.getValue();
717 secondExpr = rBinOpExpr.getLHS();
718 } else {
719 rRhsConst = 1;
720 secondExpr = rhs;
721 }
722
723 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
724 return getAffineBinaryOpExpr(
725 kind: AffineExprKind::Mul, lhs: firstExpr,
726 rhs: getAffineConstantExpr(constant: *rLhsConst + *rRhsConst, context: lhs.getContext()));
727
728 // When doing successive additions, bring constant to the right: turn (d0 + 2)
729 // + d1 into (d0 + d1) + 2.
730 if (lBin && lBin.getKind() == AffineExprKind::Add) {
731 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
732 return lBin.getLHS() + rhs + lrhs;
733 }
734 }
735
736 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
737 // q may be a constant or symbolic expression. This leads to a much more
738 // efficient form when 'c' is a power of two, and in general a more compact
739 // and readable form.
740
741 // Process '(expr floordiv c) * (-c)'.
742 if (!rBinOpExpr)
743 return nullptr;
744
745 auto lrhs = rBinOpExpr.getLHS();
746 auto rrhs = rBinOpExpr.getRHS();
747
748 AffineExpr llrhs, rlrhs;
749
750 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
751 // symbolic expression.
752 auto lrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: lrhs);
753 // Check rrhsConstOpExpr = -1.
754 auto rrhsConstOpExpr = dyn_cast<AffineConstantExpr>(Val&: rrhs);
755 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
756 lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
757 // Check llrhs = expr floordiv q.
758 llrhs = lrhsBinOpExpr.getLHS();
759 // Check rlrhs = q.
760 rlrhs = lrhsBinOpExpr.getRHS();
761 auto llrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: llrhs);
762 if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
763 return nullptr;
764 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
765 return lhs % rlrhs;
766 }
767
768 // Process lrhs, which is 'expr floordiv c'.
769 // expr + (expr // c * -c) = expr % c
770 AffineBinaryOpExpr lrBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: lrhs);
771 if (!lrBinOpExpr || rhs.getKind() != AffineExprKind::Mul ||
772 lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
773 return nullptr;
774
775 llrhs = lrBinOpExpr.getLHS();
776 rlrhs = lrBinOpExpr.getRHS();
777 auto rlrhsConstOpExpr = dyn_cast<AffineConstantExpr>(Val&: rlrhs);
778 // We don't support modulo with a negative RHS.
779 bool isPositiveRhs = rlrhsConstOpExpr && rlrhsConstOpExpr.getValue() > 0;
780
781 if (isPositiveRhs && lhs == llrhs && rlrhs == -rrhs) {
782 return lhs % rlrhs;
783 }
784 return nullptr;
785}
786
787AffineExpr AffineExpr::operator+(int64_t v) const {
788 return *this + getAffineConstantExpr(constant: v, context: getContext());
789}
790AffineExpr AffineExpr::operator+(AffineExpr other) const {
791 if (auto simplified = simplifyAdd(lhs: *this, rhs: other))
792 return simplified;
793
794 StorageUniquer &uniquer = getContext()->getAffineUniquer();
795 return uniquer.get<AffineBinaryOpExprStorage>(
796 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::Add), args: *this, args&: other);
797}
798
799/// Simplify a multiply expression. Return nullptr if it can't be simplified.
800static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
801 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
802 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
803
804 if (lhsConst && rhsConst) {
805 int64_t product;
806 if (llvm::MulOverflow(X: lhsConst.getValue(), Y: rhsConst.getValue(), Result&: product)) {
807 return nullptr;
808 }
809 return getAffineConstantExpr(constant: product, context: lhs.getContext());
810 }
811
812 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
813 return nullptr;
814
815 // Canonicalize the mul expression so that the constant/symbolic term is the
816 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
817 // constant. (Note that a constant is trivially symbolic).
818 if (!rhs.isSymbolicOrConstant() || isa<AffineConstantExpr>(Val: lhs)) {
819 // At least one of them has to be symbolic.
820 return rhs * lhs;
821 }
822
823 // At this point, if there was a constant, it would be on the right.
824
825 // Multiplication with a one is a noop, return the other input.
826 if (rhsConst) {
827 if (rhsConst.getValue() == 1)
828 return lhs;
829 // Multiplication with zero.
830 if (rhsConst.getValue() == 0)
831 return rhsConst;
832 }
833
834 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
835 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
836 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
837 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS()))
838 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
839 }
840
841 // When doing successive multiplication, bring constant to the right: turn (d0
842 // * 2) * d1 into (d0 * d1) * 2.
843 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
844 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
845 return (lBin.getLHS() * rhs) * lrhs;
846 }
847 }
848
849 return nullptr;
850}
851
852AffineExpr AffineExpr::operator*(int64_t v) const {
853 return *this * getAffineConstantExpr(constant: v, context: getContext());
854}
855AffineExpr AffineExpr::operator*(AffineExpr other) const {
856 if (auto simplified = simplifyMul(lhs: *this, rhs: other))
857 return simplified;
858
859 StorageUniquer &uniquer = getContext()->getAffineUniquer();
860 return uniquer.get<AffineBinaryOpExprStorage>(
861 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::Mul), args: *this, args&: other);
862}
863
864// Unary minus, delegate to operator*.
865AffineExpr AffineExpr::operator-() const {
866 return *this * getAffineConstantExpr(constant: -1, context: getContext());
867}
868
869// Delegate to operator+.
870AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
871AffineExpr AffineExpr::operator-(AffineExpr other) const {
872 return *this + (-other);
873}
874
875static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
876 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
877 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
878
879 if (!rhsConst || rhsConst.getValue() == 0)
880 return nullptr;
881
882 if (lhsConst) {
883 if (divideSignedWouldOverflow(Numerator: lhsConst.getValue(), Denominator: rhsConst.getValue()))
884 return nullptr;
885 return getAffineConstantExpr(
886 constant: divideFloorSigned(Numerator: lhsConst.getValue(), Denominator: rhsConst.getValue()),
887 context: lhs.getContext());
888 }
889
890 // Fold floordiv of a multiply with a constant that is a multiple of the
891 // divisor. Eg: (i * 128) floordiv 64 = i * 2.
892 if (rhsConst == 1)
893 return lhs;
894
895 // Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a
896 // multiple of `rhsConst`.
897 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
898 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
899 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
900 // `rhsConst` is known to be a nonzero constant.
901 if (lrhs.getValue() % rhsConst.getValue() == 0)
902 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
903 }
904 }
905
906 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
907 // known to be a multiple of divConst.
908 if (lBin && lBin.getKind() == AffineExprKind::Add) {
909 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
910 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
911 // rhsConst is known to be a nonzero constant.
912 if (llhsDiv % rhsConst.getValue() == 0 ||
913 lrhsDiv % rhsConst.getValue() == 0)
914 return lBin.getLHS().floorDiv(v: rhsConst.getValue()) +
915 lBin.getRHS().floorDiv(v: rhsConst.getValue());
916 }
917
918 return nullptr;
919}
920
921AffineExpr AffineExpr::floorDiv(uint64_t v) const {
922 return floorDiv(other: getAffineConstantExpr(constant: v, context: getContext()));
923}
924AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
925 if (auto simplified = simplifyFloorDiv(lhs: *this, rhs: other))
926 return simplified;
927
928 StorageUniquer &uniquer = getContext()->getAffineUniquer();
929 return uniquer.get<AffineBinaryOpExprStorage>(
930 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::FloorDiv), args: *this,
931 args&: other);
932}
933
934static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
935 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
936 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
937
938 if (!rhsConst || rhsConst.getValue() == 0)
939 return nullptr;
940
941 if (lhsConst) {
942 if (divideSignedWouldOverflow(Numerator: lhsConst.getValue(), Denominator: rhsConst.getValue()))
943 return nullptr;
944 return getAffineConstantExpr(
945 constant: divideCeilSigned(Numerator: lhsConst.getValue(), Denominator: rhsConst.getValue()),
946 context: lhs.getContext());
947 }
948
949 // Fold ceildiv of a multiply with a constant that is a multiple of the
950 // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
951 if (rhsConst.getValue() == 1)
952 return lhs;
953
954 // Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a
955 // multiple of `rhsConst`.
956 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
957 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
958 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
959 // `rhsConst` is known to be a nonzero constant.
960 if (lrhs.getValue() % rhsConst.getValue() == 0)
961 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
962 }
963 }
964
965 return nullptr;
966}
967
968AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
969 return ceilDiv(other: getAffineConstantExpr(constant: v, context: getContext()));
970}
971AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
972 if (auto simplified = simplifyCeilDiv(lhs: *this, rhs: other))
973 return simplified;
974
975 StorageUniquer &uniquer = getContext()->getAffineUniquer();
976 return uniquer.get<AffineBinaryOpExprStorage>(
977 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::CeilDiv), args: *this,
978 args&: other);
979}
980
981static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
982 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
983 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
984
985 // mod w.r.t zero or negative numbers is undefined and preserved as is.
986 if (!rhsConst || rhsConst.getValue() < 1)
987 return nullptr;
988
989 if (lhsConst) {
990 // mod never overflows.
991 return getAffineConstantExpr(constant: mod(Numerator: lhsConst.getValue(), Denominator: rhsConst.getValue()),
992 context: lhs.getContext());
993 }
994
995 // Fold modulo of an expression that is known to be a multiple of a constant
996 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
997 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
998 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
999 return getAffineConstantExpr(constant: 0, context: lhs.getContext());
1000
1001 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
1002 // known to be a multiple of divConst.
1003 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
1004 if (lBin && lBin.getKind() == AffineExprKind::Add) {
1005 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
1006 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
1007 // rhsConst is known to be a positive constant.
1008 if (llhsDiv % rhsConst.getValue() == 0)
1009 return lBin.getRHS() % rhsConst.getValue();
1010 if (lrhsDiv % rhsConst.getValue() == 0)
1011 return lBin.getLHS() % rhsConst.getValue();
1012 }
1013
1014 // Simplify (e % a) % b to e % b when b evenly divides a
1015 if (lBin && lBin.getKind() == AffineExprKind::Mod) {
1016 auto intermediate = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS());
1017 if (intermediate && intermediate.getValue() >= 1 &&
1018 mod(Numerator: intermediate.getValue(), Denominator: rhsConst.getValue()) == 0) {
1019 return lBin.getLHS() % rhsConst.getValue();
1020 }
1021 }
1022
1023 return nullptr;
1024}
1025
1026AffineExpr AffineExpr::operator%(uint64_t v) const {
1027 return *this % getAffineConstantExpr(constant: v, context: getContext());
1028}
1029AffineExpr AffineExpr::operator%(AffineExpr other) const {
1030 if (auto simplified = simplifyMod(lhs: *this, rhs: other))
1031 return simplified;
1032
1033 StorageUniquer &uniquer = getContext()->getAffineUniquer();
1034 return uniquer.get<AffineBinaryOpExprStorage>(
1035 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::Mod), args: *this, args&: other);
1036}
1037
1038AffineExpr AffineExpr::compose(AffineMap map) const {
1039 SmallVector<AffineExpr, 8> dimReplacements(map.getResults());
1040 return replaceDimsAndSymbols(dimReplacements, symReplacements: {});
1041}
1042raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
1043 expr.print(os);
1044 return os;
1045}
1046
1047/// Constructs an affine expression from a flat ArrayRef. If there are local
1048/// identifiers (neither dimensional nor symbolic) that appear in the sum of
1049/// products expression, `localExprs` is expected to have the AffineExpr
1050/// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
1051/// in the format [dims, symbols, locals, constant term].
1052AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1053 unsigned numDims,
1054 unsigned numSymbols,
1055 ArrayRef<AffineExpr> localExprs,
1056 MLIRContext *context) {
1057 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1058 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1059 "unexpected number of local expressions");
1060
1061 auto expr = getAffineConstantExpr(constant: 0, context);
1062 // Dimensions and symbols.
1063 for (unsigned j = 0; j < numDims + numSymbols; j++) {
1064 if (flatExprs[j] == 0)
1065 continue;
1066 auto id = j < numDims ? getAffineDimExpr(position: j, context)
1067 : getAffineSymbolExpr(position: j - numDims, context);
1068 expr = expr + id * flatExprs[j];
1069 }
1070
1071 // Local identifiers.
1072 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1073 j++) {
1074 if (flatExprs[j] == 0)
1075 continue;
1076 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1077 expr = expr + term;
1078 }
1079
1080 // Constant term.
1081 int64_t constTerm = flatExprs[flatExprs.size() - 1];
1082 if (constTerm != 0)
1083 expr = expr + constTerm;
1084 return expr;
1085}
1086
1087/// Constructs a semi-affine expression from a flat ArrayRef. If there are
1088/// local identifiers (neither dimensional nor symbolic) that appear in the sum
1089/// of products expression, `localExprs` is expected to have the AffineExprs for
1090/// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
1091/// the format [dims, symbols, locals, constant term]. The semi-affine
1092/// expression is constructed in the sorted order of dimension and symbol
1093/// position numbers. Note: local expressions/ids are used for mod, div as well
1094/// as symbolic RHS terms for terms that are not pure affine.
1095static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1096 unsigned numDims,
1097 unsigned numSymbols,
1098 ArrayRef<AffineExpr> localExprs,
1099 MLIRContext *context) {
1100 assert(!flatExprs.empty() && "flatExprs cannot be empty");
1101
1102 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1103 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1104 "unexpected number of local expressions");
1105
1106 AffineExpr expr = getAffineConstantExpr(constant: 0, context);
1107
1108 // We design indices as a pair which help us present the semi-affine map as
1109 // sum of product where terms are sorted based on dimension or symbol
1110 // position: <keyA, keyB> for expressions of the form dimension * symbol,
1111 // where keyA is the position number of the dimension and keyB is the
1112 // position number of the symbol. For dimensional expressions we set the index
1113 // as (position number of the dimension, -1), as we want dimensional
1114 // expressions to appear before symbolic and product of dimensional and
1115 // symbolic expressions having the dimension with the same position number.
1116 // For symbolic expression set the index as (position number of the symbol,
1117 // maximum of last dimension and symbol position) number. For example, we want
1118 // the expression we are constructing to look something like: d0 + d0 * s0 +
1119 // s0 + d1*s1 + s1.
1120
1121 // Stores the affine expression corresponding to a given index.
1122 DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap;
1123 // Stores the constant coefficient value corresponding to a given
1124 // dimension, symbol or a non-pure affine expression stored in `localExprs`.
1125 DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
1126 // Stores the indices as defined above, and later sorted to produce
1127 // the semi-affine expression in the desired form.
1128 SmallVector<std::pair<unsigned, signed>, 8> indices;
1129
1130 // Example: expression = d0 + d0 * s0 + 2 * s0.
1131 // indices = [{0,-1}, {0, 0}, {0, 1}]
1132 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
1133 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
1134
1135 // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
1136 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
1137 AffineExpr expr) {
1138 assert(!llvm::is_contained(indices, index) &&
1139 "Key is already present in indices vector and overwriting will "
1140 "happen in `indexToExprMap` and `coefficients`!");
1141
1142 indices.push_back(Elt: index);
1143 coefficients.insert(KV: {index, coefficient});
1144 indexToExprMap.insert(KV: {index, expr});
1145 };
1146
1147 // Design indices for dimensional or symbolic terms, and store the indices,
1148 // constant coefficient corresponding to the indices in `coefficients` map,
1149 // and affine expression corresponding to indices in `indexToExprMap` map.
1150
1151 // Ensure we do not have duplicate keys in `indexToExpr` map.
1152 unsigned offsetSym = 0;
1153 signed offsetDim = -1;
1154 for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1155 if (flatExprs[j] == 0)
1156 continue;
1157 // For symbolic expression set the index as <position number
1158 // of the symbol, max(dimCount, symCount)> number,
1159 // as we want symbolic expressions with the same positional number to
1160 // appear after dimensional expressions having the same positional number.
1161 std::pair<unsigned, signed> indexEntry(
1162 j - numDims, std::max(a: numDims, b: numSymbols) + offsetSym++);
1163 addEntry(indexEntry, flatExprs[j],
1164 getAffineSymbolExpr(position: j - numDims, context));
1165 }
1166
1167 // Denotes semi-affine product, modulo or division terms, which has been added
1168 // to the `indexToExpr` map.
1169 SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1170 false);
1171 unsigned lhsPos, rhsPos;
1172 // Construct indices for product terms involving dimension, symbol or constant
1173 // as lhs/rhs, and store the indices, constant coefficient corresponding to
1174 // the indices in `coefficients` map, and affine expression corresponding to
1175 // in indices in `indexToExprMap` map.
1176 for (const auto &it : llvm::enumerate(First&: localExprs)) {
1177 AffineExpr expr = it.value();
1178 if (flatExprs[numDims + numSymbols + it.index()] == 0)
1179 continue;
1180 AffineExpr lhs = cast<AffineBinaryOpExpr>(Val&: expr).getLHS();
1181 AffineExpr rhs = cast<AffineBinaryOpExpr>(Val&: expr).getRHS();
1182 if (!((isa<AffineDimExpr>(Val: lhs) || isa<AffineSymbolExpr>(Val: lhs)) &&
1183 (isa<AffineDimExpr>(Val: rhs) || isa<AffineSymbolExpr>(Val: rhs) ||
1184 isa<AffineConstantExpr>(Val: rhs)))) {
1185 continue;
1186 }
1187 if (isa<AffineConstantExpr>(Val: rhs)) {
1188 // For product/modulo/division expressions, when rhs of modulo/division
1189 // expression is constant, we put 0 in place of keyB, because we want
1190 // them to appear earlier in the semi-affine expression we are
1191 // constructing. When rhs is constant, we place 0 in place of keyB.
1192 if (isa<AffineDimExpr>(Val: lhs)) {
1193 lhsPos = cast<AffineDimExpr>(Val&: lhs).getPosition();
1194 std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1195 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1196 expr);
1197 } else {
1198 lhsPos = cast<AffineSymbolExpr>(Val&: lhs).getPosition();
1199 std::pair<unsigned, signed> indexEntry(
1200 lhsPos, std::max(a: numDims, b: numSymbols) + offsetSym++);
1201 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1202 expr);
1203 }
1204 } else if (isa<AffineDimExpr>(Val: lhs)) {
1205 // For product/modulo/division expressions having lhs as dimension and rhs
1206 // as symbol, we order the terms in the semi-affine expression based on
1207 // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1208 // where keyA is the position number of the dimension and keyB is the
1209 // position number of the symbol.
1210 lhsPos = cast<AffineDimExpr>(Val&: lhs).getPosition();
1211 rhsPos = cast<AffineSymbolExpr>(Val&: rhs).getPosition();
1212 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1213 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1214 } else {
1215 // For product/modulo/division expressions having both lhs and rhs as
1216 // symbol, we design indices as a pair: <keyA, keyB> for expressions
1217 // of the form dimension * symbol, where keyA is the position number of
1218 // the dimension and keyB is the position number of the symbol.
1219 lhsPos = cast<AffineSymbolExpr>(Val&: lhs).getPosition();
1220 rhsPos = cast<AffineSymbolExpr>(Val&: rhs).getPosition();
1221 std::pair<unsigned, signed> indexEntry(
1222 lhsPos, std::max(a: numDims, b: numSymbols) + offsetSym++);
1223 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1224 }
1225 addedToMap[it.index()] = true;
1226 }
1227
1228 for (unsigned j = 0; j < numDims; ++j) {
1229 if (flatExprs[j] == 0)
1230 continue;
1231 // For dimensional expressions we set the index as <position number of the
1232 // dimension, 0>, as we want dimensional expressions to appear before
1233 // symbolic ones and products of dimensional and symbolic expressions
1234 // having the dimension with the same position number.
1235 std::pair<unsigned, signed> indexEntry(j, offsetDim--);
1236 addEntry(indexEntry, flatExprs[j], getAffineDimExpr(position: j, context));
1237 }
1238
1239 // Constructing the simplified semi-affine sum of product/division/mod
1240 // expression from the flattened form in the desired sorted order of indices
1241 // of the various individual product/division/mod expressions.
1242 llvm::sort(C&: indices);
1243 for (const std::pair<unsigned, unsigned> index : indices) {
1244 assert(indexToExprMap.lookup(index) &&
1245 "cannot find key in `indexToExprMap` map");
1246 expr = expr + indexToExprMap.lookup(Val: index) * coefficients.lookup(Val: index);
1247 }
1248
1249 // Local identifiers.
1250 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1251 j++) {
1252 // If the coefficient of the local expression is 0, continue as we need not
1253 // add it in out final expression.
1254 if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1255 continue;
1256 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1257 expr = expr + term;
1258 }
1259
1260 // Constant term.
1261 int64_t constTerm = flatExprs.back();
1262 if (constTerm != 0)
1263 expr = expr + constTerm;
1264 return expr;
1265}
1266
1267SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
1268 unsigned numSymbols)
1269 : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1270 operandExprStack.reserve(n: 8);
1271}
1272
1273// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1274//
1275// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1276// introduce a local variable p (= expr * symbolic_expr), and the affine
1277// expression expr * symbolic_expr is added to `localExprs`.
1278LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
1279 assert(operandExprStack.size() >= 2);
1280 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1281 operandExprStack.pop_back();
1282 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1283
1284 // Flatten semi-affine multiplication expressions by introducing a local
1285 // variable in place of the product; the affine expression
1286 // corresponding to the quantifier is added to `localExprs`.
1287 if (!isa<AffineConstantExpr>(Val: expr.getRHS())) {
1288 SmallVector<int64_t, 8> mulLhs(lhs);
1289 MLIRContext *context = expr.getContext();
1290 AffineExpr a = getAffineExprFromFlatForm(flatExprs: lhs, numDims, numSymbols,
1291 localExprs, context);
1292 AffineExpr b = getAffineExprFromFlatForm(flatExprs: rhs, numDims, numSymbols,
1293 localExprs, context);
1294 return addLocalVariableSemiAffine(lhs: mulLhs, rhs, localExpr: a * b, result&: lhs, resultSize: lhs.size());
1295 }
1296
1297 // Get the RHS constant.
1298 int64_t rhsConst = rhs[getConstantIndex()];
1299 for (int64_t &lhsElt : lhs)
1300 lhsElt *= rhsConst;
1301
1302 return success();
1303}
1304
1305LogicalResult SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
1306 assert(operandExprStack.size() >= 2);
1307 const auto &rhs = operandExprStack.back();
1308 auto &lhs = operandExprStack[operandExprStack.size() - 2];
1309 assert(lhs.size() == rhs.size());
1310 // Update the LHS in place.
1311 for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1312 lhs[i] += rhs[i];
1313 }
1314 // Pop off the RHS.
1315 operandExprStack.pop_back();
1316 return success();
1317}
1318
1319//
1320// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1321//
1322// A mod expression "expr mod c" is thus flattened by introducing a new local
1323// variable q (= expr floordiv c), such that expr mod c is replaced with
1324// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1325//
1326// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1327// introduce a local variable m (= expr mod symbolic_expr), and the affine
1328// expression expr mod symbolic_expr is added to `localExprs`.
1329LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1330 assert(operandExprStack.size() >= 2);
1331
1332 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1333 operandExprStack.pop_back();
1334 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1335 MLIRContext *context = expr.getContext();
1336
1337 // Flatten semi affine modulo expressions by introducing a local
1338 // variable in place of the modulo value, and the affine expression
1339 // corresponding to the quantifier is added to `localExprs`.
1340 if (!isa<AffineConstantExpr>(Val: expr.getRHS())) {
1341 SmallVector<int64_t, 8> modLhs(lhs);
1342 AffineExpr dividendExpr = getAffineExprFromFlatForm(
1343 flatExprs: lhs, numDims, numSymbols, localExprs, context);
1344 AffineExpr divisorExpr = getAffineExprFromFlatForm(flatExprs: rhs, numDims, numSymbols,
1345 localExprs, context);
1346 AffineExpr modExpr = dividendExpr % divisorExpr;
1347 return addLocalVariableSemiAffine(lhs: modLhs, rhs, localExpr: modExpr, result&: lhs, resultSize: lhs.size());
1348 }
1349
1350 int64_t rhsConst = rhs[getConstantIndex()];
1351 if (rhsConst <= 0)
1352 return failure();
1353
1354 // Check if the LHS expression is a multiple of modulo factor.
1355 unsigned i, e;
1356 for (i = 0, e = lhs.size(); i < e; i++)
1357 if (lhs[i] % rhsConst != 0)
1358 break;
1359 // If yes, modulo expression here simplifies to zero.
1360 if (i == lhs.size()) {
1361 std::fill(first: lhs.begin(), last: lhs.end(), value: 0);
1362 return success();
1363 }
1364
1365 // Add a local variable for the quotient, i.e., expr % c is replaced by
1366 // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1367 // the GCD of expr and c.
1368 SmallVector<int64_t, 8> floorDividend(lhs);
1369 uint64_t gcd = rhsConst;
1370 for (int64_t lhsElt : lhs)
1371 gcd = std::gcd(m: gcd, n: (uint64_t)std::abs(i: lhsElt));
1372 // Simplify the numerator and the denominator.
1373 if (gcd != 1) {
1374 for (int64_t &floorDividendElt : floorDividend)
1375 floorDividendElt = floorDividendElt / static_cast<int64_t>(gcd);
1376 }
1377 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1378
1379 // Construct the AffineExpr form of the floordiv to store in localExprs.
1380
1381 AffineExpr dividendExpr = getAffineExprFromFlatForm(
1382 flatExprs: floorDividend, numDims, numSymbols, localExprs, context);
1383 AffineExpr divisorExpr = getAffineConstantExpr(constant: floorDivisor, context);
1384 AffineExpr floorDivExpr = dividendExpr.floorDiv(other: divisorExpr);
1385 int loc;
1386 if ((loc = findLocalId(localExpr: floorDivExpr)) == -1) {
1387 addLocalFloorDivId(dividend: floorDividend, divisor: floorDivisor, localExpr: floorDivExpr);
1388 // Set result at top of stack to "lhs - rhsConst * q".
1389 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1390 } else {
1391 // Reuse the existing local id.
1392 lhs[getLocalVarStartIndex() + loc] -= rhsConst;
1393 }
1394 return success();
1395}
1396
1397LogicalResult
1398SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
1399 return visitDivExpr(expr, /*isCeil=*/true);
1400}
1401LogicalResult
1402SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
1403 return visitDivExpr(expr, /*isCeil=*/false);
1404}
1405
1406LogicalResult SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
1407 operandExprStack.emplace_back(args: SmallVector<int64_t, 32>(getNumCols(), 0));
1408 auto &eq = operandExprStack.back();
1409 assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1410 eq[getDimStartIndex() + expr.getPosition()] = 1;
1411 return success();
1412}
1413
1414LogicalResult
1415SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
1416 operandExprStack.emplace_back(args: SmallVector<int64_t, 32>(getNumCols(), 0));
1417 auto &eq = operandExprStack.back();
1418 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1419 eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1420 return success();
1421}
1422
1423LogicalResult
1424SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
1425 operandExprStack.emplace_back(args: SmallVector<int64_t, 32>(getNumCols(), 0));
1426 auto &eq = operandExprStack.back();
1427 eq[getConstantIndex()] = expr.getValue();
1428 return success();
1429}
1430
1431LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1432 ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr,
1433 SmallVectorImpl<int64_t> &result, unsigned long resultSize) {
1434 assert(result.size() == resultSize &&
1435 "`result` vector passed is not of correct size");
1436 int loc;
1437 if ((loc = findLocalId(localExpr)) == -1) {
1438 if (failed(Result: addLocalIdSemiAffine(lhs, rhs, localExpr)))
1439 return failure();
1440 }
1441 std::fill(first: result.begin(), last: result.end(), value: 0);
1442 if (loc == -1)
1443 result[getLocalVarStartIndex() + numLocals - 1] = 1;
1444 else
1445 result[getLocalVarStartIndex() + loc] = 1;
1446 return success();
1447}
1448
1449// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1450// A floordiv is thus flattened by introducing a new local variable q, and
1451// replacing that expression with 'q' while adding the constraints
1452// c * q <= expr <= c * q + c - 1 to localVarCst (done by
1453// IntegerRelation::addLocalFloorDiv).
1454//
1455// A ceildiv is similarly flattened:
1456// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1457//
1458// In case of semi affine division expressions, t = expr floordiv symbolic_expr
1459// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1460// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1461// `localExprs`.
1462LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1463 bool isCeil) {
1464 assert(operandExprStack.size() >= 2);
1465
1466 MLIRContext *context = expr.getContext();
1467 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1468 operandExprStack.pop_back();
1469 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1470
1471 // Flatten semi affine division expressions by introducing a local
1472 // variable in place of the quotient, and the affine expression corresponding
1473 // to the quantifier is added to `localExprs`.
1474 if (!isa<AffineConstantExpr>(Val: expr.getRHS())) {
1475 SmallVector<int64_t, 8> divLhs(lhs);
1476 AffineExpr a = getAffineExprFromFlatForm(flatExprs: lhs, numDims, numSymbols,
1477 localExprs, context);
1478 AffineExpr b = getAffineExprFromFlatForm(flatExprs: rhs, numDims, numSymbols,
1479 localExprs, context);
1480 AffineExpr divExpr = isCeil ? a.ceilDiv(other: b) : a.floorDiv(other: b);
1481 return addLocalVariableSemiAffine(lhs: divLhs, rhs, localExpr: divExpr, result&: lhs, resultSize: lhs.size());
1482 }
1483
1484 // This is a pure affine expr; the RHS is a positive constant.
1485 int64_t rhsConst = rhs[getConstantIndex()];
1486 if (rhsConst <= 0)
1487 return failure();
1488
1489 // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1490 // common divisors of the numerator and denominator.
1491 uint64_t gcd = std::abs(i: rhsConst);
1492 for (int64_t lhsElt : lhs)
1493 gcd = std::gcd(m: gcd, n: (uint64_t)std::abs(i: lhsElt));
1494 // Simplify the numerator and the denominator.
1495 if (gcd != 1) {
1496 for (int64_t &lhsElt : lhs)
1497 lhsElt = lhsElt / static_cast<int64_t>(gcd);
1498 }
1499 int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1500 // If the divisor becomes 1, the updated LHS is the result. (The
1501 // divisor can't be negative since rhsConst is positive).
1502 if (divisor == 1)
1503 return success();
1504
1505 // If the divisor cannot be simplified to one, we will have to retain
1506 // the ceil/floor expr (simplified up until here). Add an existential
1507 // quantifier to express its result, i.e., expr1 div expr2 is replaced
1508 // by a new identifier, q.
1509 AffineExpr a =
1510 getAffineExprFromFlatForm(flatExprs: lhs, numDims, numSymbols, localExprs, context);
1511 AffineExpr b = getAffineConstantExpr(constant: divisor, context);
1512
1513 int loc;
1514 AffineExpr divExpr = isCeil ? a.ceilDiv(other: b) : a.floorDiv(other: b);
1515 if ((loc = findLocalId(localExpr: divExpr)) == -1) {
1516 if (!isCeil) {
1517 SmallVector<int64_t, 8> dividend(lhs);
1518 addLocalFloorDivId(dividend, divisor, localExpr: divExpr);
1519 } else {
1520 // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1521 SmallVector<int64_t, 8> dividend(lhs);
1522 dividend.back() += divisor - 1;
1523 addLocalFloorDivId(dividend, divisor, localExpr: divExpr);
1524 }
1525 }
1526 // Set the expression on stack to the local var introduced to capture the
1527 // result of the division (floor or ceil).
1528 std::fill(first: lhs.begin(), last: lhs.end(), value: 0);
1529 if (loc == -1)
1530 lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1531 else
1532 lhs[getLocalVarStartIndex() + loc] = 1;
1533 return success();
1534}
1535
1536// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1537// The local identifier added is always a floordiv of a pure add/mul affine
1538// function of other identifiers, coefficients of which are specified in
1539// dividend and with respect to a positive constant divisor. localExpr is the
1540// simplified tree expression (AffineExpr) corresponding to the quantifier.
1541void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1542 int64_t divisor,
1543 AffineExpr localExpr) {
1544 assert(divisor > 0 && "positive constant divisor expected");
1545 for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1546 subExpr.insert(I: subExpr.begin() + getLocalVarStartIndex() + numLocals, Elt: 0);
1547 localExprs.push_back(Elt: localExpr);
1548 numLocals++;
1549 // dividend and divisor are not used here; an override of this method uses it.
1550}
1551
1552LogicalResult SimpleAffineExprFlattener::addLocalIdSemiAffine(
1553 ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr) {
1554 for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1555 subExpr.insert(I: subExpr.begin() + getLocalVarStartIndex() + numLocals, Elt: 0);
1556 localExprs.push_back(Elt: localExpr);
1557 ++numLocals;
1558 // lhs and rhs are not used here; an override of this method uses them.
1559 return success();
1560}
1561
1562int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1563 SmallVectorImpl<AffineExpr>::iterator it;
1564 if ((it = llvm::find(Range&: localExprs, Val: localExpr)) == localExprs.end())
1565 return -1;
1566 return it - localExprs.begin();
1567}
1568
1569/// Simplify the affine expression by flattening it and reconstructing it.
1570AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1571 unsigned numSymbols) {
1572 // Simplify semi-affine expressions separately.
1573 if (!expr.isPureAffine())
1574 expr = simplifySemiAffine(expr, numDims, numSymbols);
1575
1576 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1577 // has poison expression
1578 if (failed(Result: flattener.walkPostOrder(expr)))
1579 return expr;
1580 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1581 if (!expr.isPureAffine() &&
1582 expr == getAffineExprFromFlatForm(flatExprs: flattenedExpr, numDims, numSymbols,
1583 localExprs: flattener.localExprs,
1584 context: expr.getContext()))
1585 return expr;
1586 AffineExpr simplifiedExpr =
1587 expr.isPureAffine()
1588 ? getAffineExprFromFlatForm(flatExprs: flattenedExpr, numDims, numSymbols,
1589 localExprs: flattener.localExprs, context: expr.getContext())
1590 : getSemiAffineExprFromFlatForm(flatExprs: flattenedExpr, numDims, numSymbols,
1591 localExprs: flattener.localExprs,
1592 context: expr.getContext());
1593
1594 flattener.operandExprStack.pop_back();
1595 assert(flattener.operandExprStack.empty());
1596 return simplifiedExpr;
1597}
1598
1599std::optional<int64_t> mlir::getBoundForAffineExpr(
1600 AffineExpr expr, unsigned numDims, unsigned numSymbols,
1601 ArrayRef<std::optional<int64_t>> constLowerBounds,
1602 ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
1603 // Handle divs and mods.
1604 if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr)) {
1605 // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
1606 // can compute an upper bound.
1607 if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
1608 auto rhsConst = dyn_cast<AffineConstantExpr>(Val: binOpExpr.getRHS());
1609 if (!rhsConst || rhsConst.getValue() < 1)
1610 return std::nullopt;
1611 auto bound =
1612 getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1613 constLowerBounds, constUpperBounds, isUpper);
1614 if (!bound)
1615 return std::nullopt;
1616 return divideFloorSigned(Numerator: *bound, Denominator: rhsConst.getValue());
1617 }
1618 if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
1619 auto rhsConst = dyn_cast<AffineConstantExpr>(Val: binOpExpr.getRHS());
1620 if (rhsConst && rhsConst.getValue() >= 1) {
1621 auto bound =
1622 getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1623 constLowerBounds, constUpperBounds, isUpper);
1624 if (!bound)
1625 return std::nullopt;
1626 return divideCeilSigned(Numerator: *bound, Denominator: rhsConst.getValue());
1627 }
1628 return std::nullopt;
1629 }
1630 if (binOpExpr.getKind() == AffineExprKind::Mod) {
1631 // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
1632 // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
1633 // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
1634 auto rhsConst = dyn_cast<AffineConstantExpr>(Val: binOpExpr.getRHS());
1635 if (rhsConst && rhsConst.getValue() >= 1) {
1636 int64_t rhsConstVal = rhsConst.getValue();
1637 auto lb = getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1638 constLowerBounds, constUpperBounds,
1639 /*isUpper=*/false);
1640 auto ub =
1641 getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1642 constLowerBounds, constUpperBounds, isUpper);
1643 if (ub && lb &&
1644 divideFloorSigned(Numerator: *lb, Denominator: rhsConstVal) ==
1645 divideFloorSigned(Numerator: *ub, Denominator: rhsConstVal))
1646 return isUpper ? mod(Numerator: *ub, Denominator: rhsConstVal) : mod(Numerator: *lb, Denominator: rhsConstVal);
1647 return isUpper ? rhsConstVal - 1 : 0;
1648 }
1649 }
1650 }
1651 // Flatten the expression.
1652 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1653 auto simpleResult = flattener.walkPostOrder(expr);
1654 // has poison expression
1655 if (failed(Result: simpleResult))
1656 return std::nullopt;
1657 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1658 // TODO: Handle local variables. We can get hold of flattener.localExprs and
1659 // get bound on the local expr recursively.
1660 if (flattener.numLocals > 0)
1661 return std::nullopt;
1662 int64_t bound = 0;
1663 // Substitute the constant lower or upper bound for the dimensional or
1664 // symbolic input depending on `isUpper` to determine the bound.
1665 for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
1666 if (flattenedExpr[i] > 0) {
1667 auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
1668 if (!constBound)
1669 return std::nullopt;
1670 bound += *constBound * flattenedExpr[i];
1671 } else if (flattenedExpr[i] < 0) {
1672 auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
1673 if (!constBound)
1674 return std::nullopt;
1675 bound += *constBound * flattenedExpr[i];
1676 }
1677 }
1678 // Constant term.
1679 bound += flattenedExpr.back();
1680 return bound;
1681}
1682

source code of mlir/lib/IR/AffineExpr.cpp