1//===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cmath>
10#include <cstdint>
11#include <limits>
12#include <utility>
13
14#include "AffineExprDetail.h"
15#include "mlir/IR/AffineExpr.h"
16#include "mlir/IR/AffineExprVisitor.h"
17#include "mlir/IR/AffineMap.h"
18#include "mlir/IR/IntegerSet.h"
19#include "mlir/Support/TypeID.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/MathExtras.h"
22#include <numeric>
23#include <optional>
24
25using namespace mlir;
26using namespace mlir::detail;
27
28using llvm::divideCeilSigned;
29using llvm::divideFloorSigned;
30using llvm::divideSignedWouldOverflow;
31using llvm::mod;
32
33MLIRContext *AffineExpr::getContext() const { return expr->context; }
34
35AffineExprKind AffineExpr::getKind() const { return expr->kind; }
36
37/// Walk all of the AffineExprs in `e` in postorder. This is a private factory
38/// method to help handle lambda walk functions. Users should use the regular
39/// (non-static) `walk` method.
40template <typename WalkRetTy>
41WalkRetTy mlir::AffineExpr::walk(AffineExpr e,
42 function_ref<WalkRetTy(AffineExpr)> callback) {
43 struct AffineExprWalker
44 : public AffineExprVisitor<AffineExprWalker, WalkRetTy> {
45 function_ref<WalkRetTy(AffineExpr)> callback;
46
47 AffineExprWalker(function_ref<WalkRetTy(AffineExpr)> callback)
48 : callback(callback) {}
49
50 WalkRetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
51 return callback(expr);
52 }
53 WalkRetTy visitConstantExpr(AffineConstantExpr expr) {
54 return callback(expr);
55 }
56 WalkRetTy visitDimExpr(AffineDimExpr expr) { return callback(expr); }
57 WalkRetTy visitSymbolExpr(AffineSymbolExpr expr) { return callback(expr); }
58 };
59
60 return AffineExprWalker(callback).walkPostOrder(e);
61}
62// Explicitly instantiate for the two supported return types.
63template void mlir::AffineExpr::walk(AffineExpr e,
64 function_ref<void(AffineExpr)> callback);
65template WalkResult
66mlir::AffineExpr::walk(AffineExpr e,
67 function_ref<WalkResult(AffineExpr)> callback);
68
69// Dispatch affine expression construction based on kind.
70AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
71 AffineExpr rhs) {
72 if (kind == AffineExprKind::Add)
73 return lhs + rhs;
74 if (kind == AffineExprKind::Mul)
75 return lhs * rhs;
76 if (kind == AffineExprKind::FloorDiv)
77 return lhs.floorDiv(other: rhs);
78 if (kind == AffineExprKind::CeilDiv)
79 return lhs.ceilDiv(other: rhs);
80 if (kind == AffineExprKind::Mod)
81 return lhs % rhs;
82
83 llvm_unreachable("unknown binary operation on affine expressions");
84}
85
86/// This method substitutes any uses of dimensions and symbols (e.g.
87/// dim#0 with dimReplacements[0]) and returns the modified expression tree.
88AffineExpr
89AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
90 ArrayRef<AffineExpr> symReplacements) const {
91 switch (getKind()) {
92 case AffineExprKind::Constant:
93 return *this;
94 case AffineExprKind::DimId: {
95 unsigned dimId = llvm::cast<AffineDimExpr>(Val: *this).getPosition();
96 if (dimId >= dimReplacements.size())
97 return *this;
98 return dimReplacements[dimId];
99 }
100 case AffineExprKind::SymbolId: {
101 unsigned symId = llvm::cast<AffineSymbolExpr>(Val: *this).getPosition();
102 if (symId >= symReplacements.size())
103 return *this;
104 return symReplacements[symId];
105 }
106 case AffineExprKind::Add:
107 case AffineExprKind::Mul:
108 case AffineExprKind::FloorDiv:
109 case AffineExprKind::CeilDiv:
110 case AffineExprKind::Mod:
111 auto binOp = llvm::cast<AffineBinaryOpExpr>(Val: *this);
112 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
113 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
114 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
115 if (newLHS == lhs && newRHS == rhs)
116 return *this;
117 return getAffineBinaryOpExpr(kind: getKind(), lhs: newLHS, rhs: newRHS);
118 }
119 llvm_unreachable("Unknown AffineExpr");
120}
121
122AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
123 return replaceDimsAndSymbols(dimReplacements, symReplacements: {});
124}
125
126AffineExpr
127AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
128 return replaceDimsAndSymbols(dimReplacements: {}, symReplacements);
129}
130
131/// Replace dims[offset ... numDims)
132/// by dims[offset + shift ... shift + numDims).
133AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
134 unsigned offset) const {
135 SmallVector<AffineExpr, 4> dims;
136 for (unsigned idx = 0; idx < offset; ++idx)
137 dims.push_back(Elt: getAffineDimExpr(position: idx, context: getContext()));
138 for (unsigned idx = offset; idx < numDims; ++idx)
139 dims.push_back(Elt: getAffineDimExpr(position: idx + shift, context: getContext()));
140 return replaceDimsAndSymbols(dimReplacements: dims, symReplacements: {});
141}
142
143/// Replace symbols[offset ... numSymbols)
144/// by symbols[offset + shift ... shift + numSymbols).
145AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
146 unsigned offset) const {
147 SmallVector<AffineExpr, 4> symbols;
148 for (unsigned idx = 0; idx < offset; ++idx)
149 symbols.push_back(Elt: getAffineSymbolExpr(position: idx, context: getContext()));
150 for (unsigned idx = offset; idx < numSymbols; ++idx)
151 symbols.push_back(Elt: getAffineSymbolExpr(position: idx + shift, context: getContext()));
152 return replaceDimsAndSymbols(dimReplacements: {}, symReplacements: symbols);
153}
154
155/// Sparse replace method. Return the modified expression tree.
156AffineExpr
157AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
158 auto it = map.find(Val: *this);
159 if (it != map.end())
160 return it->second;
161 switch (getKind()) {
162 default:
163 return *this;
164 case AffineExprKind::Add:
165 case AffineExprKind::Mul:
166 case AffineExprKind::FloorDiv:
167 case AffineExprKind::CeilDiv:
168 case AffineExprKind::Mod:
169 auto binOp = llvm::cast<AffineBinaryOpExpr>(Val: *this);
170 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
171 auto newLHS = lhs.replace(map);
172 auto newRHS = rhs.replace(map);
173 if (newLHS == lhs && newRHS == rhs)
174 return *this;
175 return getAffineBinaryOpExpr(kind: getKind(), lhs: newLHS, rhs: newRHS);
176 }
177 llvm_unreachable("Unknown AffineExpr");
178}
179
180/// Sparse replace method. Return the modified expression tree.
181AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
182 DenseMap<AffineExpr, AffineExpr> map;
183 map.insert(KV: std::make_pair(x&: expr, y&: replacement));
184 return replace(map);
185}
186/// Returns true if this expression is made out of only symbols and
187/// constants (no dimensional identifiers).
188bool AffineExpr::isSymbolicOrConstant() const {
189 switch (getKind()) {
190 case AffineExprKind::Constant:
191 return true;
192 case AffineExprKind::DimId:
193 return false;
194 case AffineExprKind::SymbolId:
195 return true;
196
197 case AffineExprKind::Add:
198 case AffineExprKind::Mul:
199 case AffineExprKind::FloorDiv:
200 case AffineExprKind::CeilDiv:
201 case AffineExprKind::Mod: {
202 auto expr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
203 return expr.getLHS().isSymbolicOrConstant() &&
204 expr.getRHS().isSymbolicOrConstant();
205 }
206 }
207 llvm_unreachable("Unknown AffineExpr");
208}
209
210/// Returns true if this is a pure affine expression, i.e., multiplication,
211/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
212bool AffineExpr::isPureAffine() const {
213 switch (getKind()) {
214 case AffineExprKind::SymbolId:
215 case AffineExprKind::DimId:
216 case AffineExprKind::Constant:
217 return true;
218 case AffineExprKind::Add: {
219 auto op = llvm::cast<AffineBinaryOpExpr>(Val: *this);
220 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
221 }
222
223 case AffineExprKind::Mul: {
224 // TODO: Canonicalize the constants in binary operators to the RHS when
225 // possible, allowing this to merge into the next case.
226 auto op = llvm::cast<AffineBinaryOpExpr>(Val: *this);
227 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
228 (llvm::isa<AffineConstantExpr>(Val: op.getLHS()) ||
229 llvm::isa<AffineConstantExpr>(Val: op.getRHS()));
230 }
231 case AffineExprKind::FloorDiv:
232 case AffineExprKind::CeilDiv:
233 case AffineExprKind::Mod: {
234 auto op = llvm::cast<AffineBinaryOpExpr>(Val: *this);
235 return op.getLHS().isPureAffine() &&
236 llvm::isa<AffineConstantExpr>(Val: op.getRHS());
237 }
238 }
239 llvm_unreachable("Unknown AffineExpr");
240}
241
242// Returns the greatest known integral divisor of this affine expression.
243int64_t AffineExpr::getLargestKnownDivisor() const {
244 AffineBinaryOpExpr binExpr(nullptr);
245 switch (getKind()) {
246 case AffineExprKind::DimId:
247 [[fallthrough]];
248 case AffineExprKind::SymbolId:
249 return 1;
250 case AffineExprKind::CeilDiv:
251 [[fallthrough]];
252 case AffineExprKind::FloorDiv: {
253 // If the RHS is a constant and divides the known divisor on the LHS, the
254 // quotient is a known divisor of the expression.
255 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
256 auto rhs = llvm::dyn_cast<AffineConstantExpr>(Val: binExpr.getRHS());
257 // Leave alone undefined expressions.
258 if (rhs && rhs.getValue() != 0) {
259 int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor();
260 if (lhsDiv % rhs.getValue() == 0)
261 return std::abs(i: lhsDiv / rhs.getValue());
262 }
263 return 1;
264 }
265 case AffineExprKind::Constant:
266 return std::abs(i: llvm::cast<AffineConstantExpr>(Val: *this).getValue());
267 case AffineExprKind::Mul: {
268 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
269 return binExpr.getLHS().getLargestKnownDivisor() *
270 binExpr.getRHS().getLargestKnownDivisor();
271 }
272 case AffineExprKind::Add:
273 [[fallthrough]];
274 case AffineExprKind::Mod: {
275 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
276 return std::gcd(m: (uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
277 n: (uint64_t)binExpr.getRHS().getLargestKnownDivisor());
278 }
279 }
280 llvm_unreachable("Unknown AffineExpr");
281}
282
283bool AffineExpr::isMultipleOf(int64_t factor) const {
284 AffineBinaryOpExpr binExpr(nullptr);
285 uint64_t l, u;
286 switch (getKind()) {
287 case AffineExprKind::SymbolId:
288 [[fallthrough]];
289 case AffineExprKind::DimId:
290 return factor * factor == 1;
291 case AffineExprKind::Constant:
292 return llvm::cast<AffineConstantExpr>(Val: *this).getValue() % factor == 0;
293 case AffineExprKind::Mul: {
294 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
295 // It's probably not worth optimizing this further (to not traverse the
296 // whole sub-tree under - it that would require a version of isMultipleOf
297 // that on a 'false' return also returns the largest known divisor).
298 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
299 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
300 (l * u) % factor == 0;
301 }
302 case AffineExprKind::Add:
303 case AffineExprKind::FloorDiv:
304 case AffineExprKind::CeilDiv:
305 case AffineExprKind::Mod: {
306 binExpr = llvm::cast<AffineBinaryOpExpr>(Val: *this);
307 return std::gcd(m: (uint64_t)binExpr.getLHS().getLargestKnownDivisor(),
308 n: (uint64_t)binExpr.getRHS().getLargestKnownDivisor()) %
309 factor ==
310 0;
311 }
312 }
313 llvm_unreachable("Unknown AffineExpr");
314}
315
316bool AffineExpr::isFunctionOfDim(unsigned position) const {
317 if (getKind() == AffineExprKind::DimId) {
318 return *this == mlir::getAffineDimExpr(position, context: getContext());
319 }
320 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(Val: *this)) {
321 return expr.getLHS().isFunctionOfDim(position) ||
322 expr.getRHS().isFunctionOfDim(position);
323 }
324 return false;
325}
326
327bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
328 if (getKind() == AffineExprKind::SymbolId) {
329 return *this == mlir::getAffineSymbolExpr(position, context: getContext());
330 }
331 if (auto expr = llvm::dyn_cast<AffineBinaryOpExpr>(Val: *this)) {
332 return expr.getLHS().isFunctionOfSymbol(position) ||
333 expr.getRHS().isFunctionOfSymbol(position);
334 }
335 return false;
336}
337
338AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
339 : AffineExpr(ptr) {}
340AffineExpr AffineBinaryOpExpr::getLHS() const {
341 return static_cast<ImplType *>(expr)->lhs;
342}
343AffineExpr AffineBinaryOpExpr::getRHS() const {
344 return static_cast<ImplType *>(expr)->rhs;
345}
346
347AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
348unsigned AffineDimExpr::getPosition() const {
349 return static_cast<ImplType *>(expr)->position;
350}
351
352/// Returns true if the expression is divisible by the given symbol with
353/// position `symbolPos`. The argument `opKind` specifies here what kind of
354/// division or mod operation called this division. It helps in implementing the
355/// commutative property of the floordiv and ceildiv operations. If the argument
356///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
357/// operation, then the commutative property can be used otherwise, the floordiv
358/// operation is not divisible. The same argument holds for ceildiv operation.
359static bool canSimplifyDivisionBySymbol(AffineExpr expr, unsigned symbolPos,
360 AffineExprKind opKind,
361 bool fromMul = false) {
362 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
363 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
364 opKind == AffineExprKind::CeilDiv) &&
365 "unexpected opKind");
366 switch (expr.getKind()) {
367 case AffineExprKind::Constant:
368 return cast<AffineConstantExpr>(Val&: expr).getValue() == 0;
369 case AffineExprKind::DimId:
370 return false;
371 case AffineExprKind::SymbolId:
372 return (cast<AffineSymbolExpr>(Val&: expr).getPosition() == symbolPos);
373 // Checks divisibility by the given symbol for both operands.
374 case AffineExprKind::Add: {
375 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
376 return canSimplifyDivisionBySymbol(expr: binaryExpr.getLHS(), symbolPos,
377 opKind) &&
378 canSimplifyDivisionBySymbol(expr: binaryExpr.getRHS(), symbolPos, opKind);
379 }
380 // Checks divisibility by the given symbol for both operands. Consider the
381 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
382 // this is a division by s1 and both the operands of modulo are divisible by
383 // s1 but it is not divisible by s1 always. The third argument is
384 // `AffineExprKind::Mod` for this reason.
385 case AffineExprKind::Mod: {
386 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
387 return canSimplifyDivisionBySymbol(expr: binaryExpr.getLHS(), symbolPos,
388 opKind: AffineExprKind::Mod) &&
389 canSimplifyDivisionBySymbol(expr: binaryExpr.getRHS(), symbolPos,
390 opKind: AffineExprKind::Mod);
391 }
392 // Checks if any of the operand divisible by the given symbol.
393 case AffineExprKind::Mul: {
394 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
395 return canSimplifyDivisionBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind,
396 fromMul: true) ||
397 canSimplifyDivisionBySymbol(expr: binaryExpr.getRHS(), symbolPos, opKind,
398 fromMul: true);
399 }
400 // Floordiv and ceildiv are divisible by the given symbol when the first
401 // operand is divisible, and the affine expression kind of the argument expr
402 // is same as the argument `opKind`. This can be inferred from commutative
403 // property of floordiv and ceildiv operations and are as follow:
404 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
405 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
406 // It will fail 1.if operations are not same. For example:
407 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified. 2.if there is a
408 // multiplication operation in the expression. For example:
409 // (exps1 ceildiv exp2) mul exp3 ceildiv exp4 can not be simplified.
410 case AffineExprKind::FloorDiv:
411 case AffineExprKind::CeilDiv: {
412 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
413 if (opKind != expr.getKind())
414 return false;
415 if (fromMul)
416 return false;
417 return canSimplifyDivisionBySymbol(expr: binaryExpr.getLHS(), symbolPos,
418 opKind: expr.getKind());
419 }
420 }
421 llvm_unreachable("Unknown AffineExpr");
422}
423
424/// Divides the given expression by the given symbol at position `symbolPos`. It
425/// considers the divisibility condition is checked before calling itself. A
426/// null expression is returned whenever the divisibility condition fails.
427static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
428 AffineExprKind opKind) {
429 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
430 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
431 opKind == AffineExprKind::CeilDiv) &&
432 "unexpected opKind");
433 switch (expr.getKind()) {
434 case AffineExprKind::Constant:
435 if (cast<AffineConstantExpr>(Val&: expr).getValue() != 0)
436 return nullptr;
437 return getAffineConstantExpr(constant: 0, context: expr.getContext());
438 case AffineExprKind::DimId:
439 return nullptr;
440 case AffineExprKind::SymbolId:
441 return getAffineConstantExpr(constant: 1, context: expr.getContext());
442 // Dividing both operands by the given symbol.
443 case AffineExprKind::Add: {
444 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
445 return getAffineBinaryOpExpr(
446 kind: expr.getKind(), lhs: symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind),
447 rhs: symbolicDivide(expr: binaryExpr.getRHS(), symbolPos, opKind));
448 }
449 // Dividing both operands by the given symbol.
450 case AffineExprKind::Mod: {
451 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
452 return getAffineBinaryOpExpr(
453 kind: expr.getKind(),
454 lhs: symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind: expr.getKind()),
455 rhs: symbolicDivide(expr: binaryExpr.getRHS(), symbolPos, opKind: expr.getKind()));
456 }
457 // Dividing any of the operand by the given symbol.
458 case AffineExprKind::Mul: {
459 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
460 if (!canSimplifyDivisionBySymbol(expr: binaryExpr.getLHS(), symbolPos, opKind))
461 return binaryExpr.getLHS() *
462 symbolicDivide(expr: binaryExpr.getRHS(), symbolPos, opKind);
463 return symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind) *
464 binaryExpr.getRHS();
465 }
466 // Dividing first operand only by the given symbol.
467 case AffineExprKind::FloorDiv:
468 case AffineExprKind::CeilDiv: {
469 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
470 return getAffineBinaryOpExpr(
471 kind: expr.getKind(),
472 lhs: symbolicDivide(expr: binaryExpr.getLHS(), symbolPos, opKind: expr.getKind()),
473 rhs: binaryExpr.getRHS());
474 }
475 }
476 llvm_unreachable("Unknown AffineExpr");
477}
478
479/// Populate `result` with all summand operands of given (potentially nested)
480/// addition. If the given expression is not an addition, just populate the
481/// expression itself.
482/// Example: Add(Add(7, 8), Mul(9, 10)) will return [7, 8, Mul(9, 10)].
483static void getSummandExprs(AffineExpr expr, SmallVector<AffineExpr> &result) {
484 auto addExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr);
485 if (!addExpr || addExpr.getKind() != AffineExprKind::Add) {
486 result.push_back(Elt: expr);
487 return;
488 }
489 getSummandExprs(expr: addExpr.getLHS(), result);
490 getSummandExprs(expr: addExpr.getRHS(), result);
491}
492
493/// Return "true" if `candidate` is a negated expression, i.e., Mul(-1, expr).
494/// If so, also return the non-negated expression via `expr`.
495static bool isNegatedAffineExpr(AffineExpr candidate, AffineExpr &expr) {
496 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(Val&: candidate);
497 if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
498 return false;
499 if (auto lhs = dyn_cast<AffineConstantExpr>(Val: mulExpr.getLHS())) {
500 if (lhs.getValue() == -1) {
501 expr = mulExpr.getRHS();
502 return true;
503 }
504 }
505 if (auto rhs = dyn_cast<AffineConstantExpr>(Val: mulExpr.getRHS())) {
506 if (rhs.getValue() == -1) {
507 expr = mulExpr.getLHS();
508 return true;
509 }
510 }
511 return false;
512}
513
514/// Return "true" if `lhs` % `rhs` is guaranteed to evaluate to zero based on
515/// the fact that `lhs` contains another modulo expression that ensures that
516/// `lhs` is divisible by `rhs`. This is a common pattern in the resulting IR
517/// after loop peeling.
518///
519/// Example: lhs = ub - ub % step
520/// rhs = step
521/// => (ub - ub % step) % step is guaranteed to evaluate to 0.
522static bool isModOfModSubtraction(AffineExpr lhs, AffineExpr rhs,
523 unsigned numDims, unsigned numSymbols) {
524 // TODO: Try to unify this function with `getBoundForAffineExpr`.
525 // Collect all summands in lhs.
526 SmallVector<AffineExpr> summands;
527 getSummandExprs(expr: lhs, result&: summands);
528 // Look for Mul(-1, Mod(x, rhs)) among the summands. If x matches the
529 // remaining summands, then lhs % rhs is guaranteed to evaluate to 0.
530 for (int64_t i = 0, e = summands.size(); i < e; ++i) {
531 AffineExpr current = summands[i];
532 AffineExpr beforeNegation;
533 if (!isNegatedAffineExpr(candidate: current, expr&: beforeNegation))
534 continue;
535 AffineBinaryOpExpr innerMod = dyn_cast<AffineBinaryOpExpr>(Val&: beforeNegation);
536 if (!innerMod || innerMod.getKind() != AffineExprKind::Mod)
537 continue;
538 if (innerMod.getRHS() != rhs)
539 continue;
540 // Sum all remaining summands and subtract x. If that expression can be
541 // simplified to zero, then the remaining summands and x are equal.
542 AffineExpr diff = getAffineConstantExpr(constant: 0, context: lhs.getContext());
543 for (int64_t j = 0; j < e; ++j)
544 if (i != j)
545 diff = diff + summands[j];
546 diff = diff - innerMod.getLHS();
547 diff = simplifyAffineExpr(expr: diff, numDims, numSymbols);
548 auto constExpr = dyn_cast<AffineConstantExpr>(Val&: diff);
549 if (constExpr && constExpr.getValue() == 0)
550 return true;
551 }
552 return false;
553}
554
555/// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
556/// operations when the second operand simplifies to a symbol and the first
557/// operand is divisible by that symbol. It can be applied to any semi-affine
558/// expression. Returned expression can either be a semi-affine or pure affine
559/// expression.
560static AffineExpr simplifySemiAffine(AffineExpr expr, unsigned numDims,
561 unsigned numSymbols) {
562 switch (expr.getKind()) {
563 case AffineExprKind::Constant:
564 case AffineExprKind::DimId:
565 case AffineExprKind::SymbolId:
566 return expr;
567 case AffineExprKind::Add:
568 case AffineExprKind::Mul: {
569 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
570 return getAffineBinaryOpExpr(
571 kind: expr.getKind(),
572 lhs: simplifySemiAffine(expr: binaryExpr.getLHS(), numDims, numSymbols),
573 rhs: simplifySemiAffine(expr: binaryExpr.getRHS(), numDims, numSymbols));
574 }
575 // Check if the simplification of the second operand is a symbol, and the
576 // first operand is divisible by it. If the operation is a modulo, a constant
577 // zero expression is returned. In the case of floordiv and ceildiv, the
578 // symbol from the simplification of the second operand divides the first
579 // operand. Otherwise, simplification is not possible.
580 case AffineExprKind::FloorDiv:
581 case AffineExprKind::CeilDiv:
582 case AffineExprKind::Mod: {
583 AffineBinaryOpExpr binaryExpr = cast<AffineBinaryOpExpr>(Val&: expr);
584 AffineExpr sLHS =
585 simplifySemiAffine(expr: binaryExpr.getLHS(), numDims, numSymbols);
586 AffineExpr sRHS =
587 simplifySemiAffine(expr: binaryExpr.getRHS(), numDims, numSymbols);
588 if (isModOfModSubtraction(lhs: sLHS, rhs: sRHS, numDims, numSymbols))
589 return getAffineConstantExpr(constant: 0, context: expr.getContext());
590 AffineSymbolExpr symbolExpr = dyn_cast<AffineSymbolExpr>(
591 Val: simplifySemiAffine(expr: binaryExpr.getRHS(), numDims, numSymbols));
592 if (!symbolExpr)
593 return getAffineBinaryOpExpr(kind: expr.getKind(), lhs: sLHS, rhs: sRHS);
594 unsigned symbolPos = symbolExpr.getPosition();
595 if (!canSimplifyDivisionBySymbol(expr: binaryExpr.getLHS(), symbolPos,
596 opKind: expr.getKind()))
597 return getAffineBinaryOpExpr(kind: expr.getKind(), lhs: sLHS, rhs: sRHS);
598 if (expr.getKind() == AffineExprKind::Mod)
599 return getAffineConstantExpr(constant: 0, context: expr.getContext());
600 AffineExpr simplifiedQuotient =
601 symbolicDivide(expr: sLHS, symbolPos, opKind: expr.getKind());
602 return simplifiedQuotient
603 ? simplifiedQuotient
604 : getAffineBinaryOpExpr(kind: expr.getKind(), lhs: sLHS, rhs: sRHS);
605 }
606 }
607 llvm_unreachable("Unknown AffineExpr");
608}
609
610static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
611 MLIRContext *context) {
612 auto assignCtx = [context](AffineDimExprStorage *storage) {
613 storage->context = context;
614 };
615
616 StorageUniquer &uniquer = context->getAffineUniquer();
617 return uniquer.get<AffineDimExprStorage>(
618 initFn: assignCtx, args: static_cast<unsigned>(kind), args&: position);
619}
620
621AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
622 return getAffineDimOrSymbol(kind: AffineExprKind::DimId, position, context);
623}
624
625AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
626 : AffineExpr(ptr) {}
627unsigned AffineSymbolExpr::getPosition() const {
628 return static_cast<ImplType *>(expr)->position;
629}
630
631AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
632 return getAffineDimOrSymbol(kind: AffineExprKind::SymbolId, position, context);
633}
634
635AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
636 : AffineExpr(ptr) {}
637int64_t AffineConstantExpr::getValue() const {
638 return static_cast<ImplType *>(expr)->constant;
639}
640
641bool AffineExpr::operator==(int64_t v) const {
642 return *this == getAffineConstantExpr(constant: v, context: getContext());
643}
644
645AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
646 auto assignCtx = [context](AffineConstantExprStorage *storage) {
647 storage->context = context;
648 };
649
650 StorageUniquer &uniquer = context->getAffineUniquer();
651 return uniquer.get<AffineConstantExprStorage>(initFn: assignCtx, args&: constant);
652}
653
654SmallVector<AffineExpr>
655mlir::getAffineConstantExprs(ArrayRef<int64_t> constants,
656 MLIRContext *context) {
657 return llvm::to_vector(Range: llvm::map_range(C&: constants, F: [&](int64_t constant) {
658 return getAffineConstantExpr(constant, context);
659 }));
660}
661
662/// Simplify add expression. Return nullptr if it can't be simplified.
663static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
664 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
665 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
666 // Fold if both LHS, RHS are a constant and the sum does not overflow.
667 if (lhsConst && rhsConst) {
668 int64_t sum;
669 if (llvm::AddOverflow(X: lhsConst.getValue(), Y: rhsConst.getValue(), Result&: sum)) {
670 return nullptr;
671 }
672 return getAffineConstantExpr(constant: sum, context: lhs.getContext());
673 }
674
675 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
676 // If only one of them is a symbolic expressions, make it the RHS.
677 if (isa<AffineConstantExpr>(Val: lhs) ||
678 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
679 return rhs + lhs;
680 }
681
682 // At this point, if there was a constant, it would be on the right.
683
684 // Addition with a zero is a noop, return the other input.
685 if (rhsConst) {
686 if (rhsConst.getValue() == 0)
687 return lhs;
688 }
689 // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
690 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
691 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
692 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS()))
693 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
694 }
695
696 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
697 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
698 // respective multiplicands.
699 std::optional<int64_t> rLhsConst, rRhsConst;
700 AffineExpr firstExpr, secondExpr;
701 AffineConstantExpr rLhsConstExpr;
702 auto lBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
703 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
704 (rLhsConstExpr = dyn_cast<AffineConstantExpr>(Val: lBinOpExpr.getRHS()))) {
705 rLhsConst = rLhsConstExpr.getValue();
706 firstExpr = lBinOpExpr.getLHS();
707 } else {
708 rLhsConst = 1;
709 firstExpr = lhs;
710 }
711
712 auto rBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: rhs);
713 AffineConstantExpr rRhsConstExpr;
714 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
715 (rRhsConstExpr = dyn_cast<AffineConstantExpr>(Val: rBinOpExpr.getRHS()))) {
716 rRhsConst = rRhsConstExpr.getValue();
717 secondExpr = rBinOpExpr.getLHS();
718 } else {
719 rRhsConst = 1;
720 secondExpr = rhs;
721 }
722
723 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
724 return getAffineBinaryOpExpr(
725 kind: AffineExprKind::Mul, lhs: firstExpr,
726 rhs: getAffineConstantExpr(constant: *rLhsConst + *rRhsConst, context: lhs.getContext()));
727
728 // When doing successive additions, bring constant to the right: turn (d0 + 2)
729 // + d1 into (d0 + d1) + 2.
730 if (lBin && lBin.getKind() == AffineExprKind::Add) {
731 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
732 return lBin.getLHS() + rhs + lrhs;
733 }
734 }
735
736 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
737 // q may be a constant or symbolic expression. This leads to a much more
738 // efficient form when 'c' is a power of two, and in general a more compact
739 // and readable form.
740
741 // Process '(expr floordiv c) * (-c)'.
742 if (!rBinOpExpr)
743 return nullptr;
744
745 auto lrhs = rBinOpExpr.getLHS();
746 auto rrhs = rBinOpExpr.getRHS();
747
748 AffineExpr llrhs, rlrhs;
749
750 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
751 // symbolic expression.
752 auto lrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: lrhs);
753 // Check rrhsConstOpExpr = -1.
754 auto rrhsConstOpExpr = dyn_cast<AffineConstantExpr>(Val&: rrhs);
755 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
756 lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
757 // Check llrhs = expr floordiv q.
758 llrhs = lrhsBinOpExpr.getLHS();
759 // Check rlrhs = q.
760 rlrhs = lrhsBinOpExpr.getRHS();
761 auto llrhsBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: llrhs);
762 if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
763 return nullptr;
764 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
765 return lhs % rlrhs;
766 }
767
768 // Process lrhs, which is 'expr floordiv c'.
769 // expr + (expr // c * -c) = expr % c
770 AffineBinaryOpExpr lrBinOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: lrhs);
771 if (!lrBinOpExpr || rhs.getKind() != AffineExprKind::Mul ||
772 lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
773 return nullptr;
774
775 llrhs = lrBinOpExpr.getLHS();
776 rlrhs = lrBinOpExpr.getRHS();
777 auto rlrhsConstOpExpr = dyn_cast<AffineConstantExpr>(Val&: rlrhs);
778 // We don't support modulo with a negative RHS.
779 bool isPositiveRhs = rlrhsConstOpExpr && rlrhsConstOpExpr.getValue() > 0;
780
781 if (isPositiveRhs && lhs == llrhs && rlrhs == -rrhs) {
782 return lhs % rlrhs;
783 }
784
785 // Try simplify lhs's last operand with rhs. e.g:
786 // (s0 * 64 + s1) + (s1 // c * -c) --->
787 // s0 * 64 + (s1 + s1 // c * -c) -->
788 // s0 * 64 + s1 % c
789 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Add) {
790 if (auto simplified = simplifyAdd(lhs: lBinOpExpr.getRHS(), rhs))
791 return lBinOpExpr.getLHS() + simplified;
792 }
793 return nullptr;
794}
795
796/// Get the canonical order of two commutative exprs arguments.
797static std::pair<AffineExpr, AffineExpr>
798orderCommutativeArgs(AffineExpr expr1, AffineExpr expr2) {
799 auto sym1 = dyn_cast<AffineSymbolExpr>(Val&: expr1);
800 auto sym2 = dyn_cast<AffineSymbolExpr>(Val&: expr2);
801 // Try to order by symbol/dim position first.
802 if (sym1 && sym2)
803 return sym1.getPosition() < sym2.getPosition() ? std::pair{expr1, expr2}
804 : std::pair{expr2, expr1};
805
806 auto dim1 = dyn_cast<AffineDimExpr>(Val&: expr1);
807 auto dim2 = dyn_cast<AffineDimExpr>(Val&: expr2);
808 if (dim1 && dim2)
809 return dim1.getPosition() < dim2.getPosition() ? std::pair{expr1, expr2}
810 : std::pair{expr2, expr1};
811
812 // Put dims before symbols.
813 if (dim1 && sym2)
814 return {dim1, sym2};
815
816 if (sym1 && dim2)
817 return {dim2, sym1};
818
819 // Otherwise, keep original order.
820 return {expr1, expr2};
821}
822
823AffineExpr AffineExpr::operator+(int64_t v) const {
824 return *this + getAffineConstantExpr(constant: v, context: getContext());
825}
826AffineExpr AffineExpr::operator+(AffineExpr other) const {
827 if (auto simplified = simplifyAdd(lhs: *this, rhs: other))
828 return simplified;
829
830 auto [lhs, rhs] = orderCommutativeArgs(expr1: *this, expr2: other);
831
832 StorageUniquer &uniquer = getContext()->getAffineUniquer();
833 return uniquer.get<AffineBinaryOpExprStorage>(
834 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::Add), args&: lhs, args&: rhs);
835}
836
837/// Simplify a multiply expression. Return nullptr if it can't be simplified.
838static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
839 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
840 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
841
842 if (lhsConst && rhsConst) {
843 int64_t product;
844 if (llvm::MulOverflow(X: lhsConst.getValue(), Y: rhsConst.getValue(), Result&: product)) {
845 return nullptr;
846 }
847 return getAffineConstantExpr(constant: product, context: lhs.getContext());
848 }
849
850 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())
851 return nullptr;
852
853 // Canonicalize the mul expression so that the constant/symbolic term is the
854 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
855 // constant. (Note that a constant is trivially symbolic).
856 if (!rhs.isSymbolicOrConstant() || isa<AffineConstantExpr>(Val: lhs)) {
857 // At least one of them has to be symbolic.
858 return rhs * lhs;
859 }
860
861 // At this point, if there was a constant, it would be on the right.
862
863 // Multiplication with a one is a noop, return the other input.
864 if (rhsConst) {
865 if (rhsConst.getValue() == 1)
866 return lhs;
867 // Multiplication with zero.
868 if (rhsConst.getValue() == 0)
869 return rhsConst;
870 }
871
872 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
873 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
874 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
875 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS()))
876 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
877 }
878
879 // When doing successive multiplication, bring constant to the right: turn (d0
880 // * 2) * d1 into (d0 * d1) * 2.
881 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
882 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
883 return (lBin.getLHS() * rhs) * lrhs;
884 }
885 }
886
887 return nullptr;
888}
889
890AffineExpr AffineExpr::operator*(int64_t v) const {
891 return *this * getAffineConstantExpr(constant: v, context: getContext());
892}
893AffineExpr AffineExpr::operator*(AffineExpr other) const {
894 if (auto simplified = simplifyMul(lhs: *this, rhs: other))
895 return simplified;
896
897 auto [lhs, rhs] = orderCommutativeArgs(expr1: *this, expr2: other);
898
899 StorageUniquer &uniquer = getContext()->getAffineUniquer();
900 return uniquer.get<AffineBinaryOpExprStorage>(
901 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::Mul), args&: lhs, args&: rhs);
902}
903
904// Unary minus, delegate to operator*.
905AffineExpr AffineExpr::operator-() const {
906 return *this * getAffineConstantExpr(constant: -1, context: getContext());
907}
908
909// Delegate to operator+.
910AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
911AffineExpr AffineExpr::operator-(AffineExpr other) const {
912 return *this + (-other);
913}
914
915static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
916 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
917 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
918
919 if (!rhsConst || rhsConst.getValue() == 0)
920 return nullptr;
921
922 if (lhsConst) {
923 if (divideSignedWouldOverflow(Numerator: lhsConst.getValue(), Denominator: rhsConst.getValue()))
924 return nullptr;
925 return getAffineConstantExpr(
926 constant: divideFloorSigned(Numerator: lhsConst.getValue(), Denominator: rhsConst.getValue()),
927 context: lhs.getContext());
928 }
929
930 // Fold floordiv of a multiply with a constant that is a multiple of the
931 // divisor. Eg: (i * 128) floordiv 64 = i * 2.
932 if (rhsConst == 1)
933 return lhs;
934
935 // Simplify `(expr * lrhs) floordiv rhsConst` when `lrhs` is known to be a
936 // multiple of `rhsConst`.
937 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
938 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
939 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
940 // `rhsConst` is known to be a nonzero constant.
941 if (lrhs.getValue() % rhsConst.getValue() == 0)
942 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
943 }
944 }
945
946 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
947 // known to be a multiple of divConst.
948 if (lBin && lBin.getKind() == AffineExprKind::Add) {
949 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
950 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
951 // rhsConst is known to be a nonzero constant.
952 if (llhsDiv % rhsConst.getValue() == 0 ||
953 lrhsDiv % rhsConst.getValue() == 0)
954 return lBin.getLHS().floorDiv(v: rhsConst.getValue()) +
955 lBin.getRHS().floorDiv(v: rhsConst.getValue());
956 }
957
958 return nullptr;
959}
960
961AffineExpr AffineExpr::floorDiv(uint64_t v) const {
962 return floorDiv(other: getAffineConstantExpr(constant: v, context: getContext()));
963}
964AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
965 if (auto simplified = simplifyFloorDiv(lhs: *this, rhs: other))
966 return simplified;
967
968 StorageUniquer &uniquer = getContext()->getAffineUniquer();
969 return uniquer.get<AffineBinaryOpExprStorage>(
970 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::FloorDiv), args: *this,
971 args&: other);
972}
973
974static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
975 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
976 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
977
978 if (!rhsConst || rhsConst.getValue() == 0)
979 return nullptr;
980
981 if (lhsConst) {
982 if (divideSignedWouldOverflow(Numerator: lhsConst.getValue(), Denominator: rhsConst.getValue()))
983 return nullptr;
984 return getAffineConstantExpr(
985 constant: divideCeilSigned(Numerator: lhsConst.getValue(), Denominator: rhsConst.getValue()),
986 context: lhs.getContext());
987 }
988
989 // Fold ceildiv of a multiply with a constant that is a multiple of the
990 // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
991 if (rhsConst.getValue() == 1)
992 return lhs;
993
994 // Simplify `(expr * lrhs) ceildiv rhsConst` when `lrhs` is known to be a
995 // multiple of `rhsConst`.
996 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
997 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
998 if (auto lrhs = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS())) {
999 // `rhsConst` is known to be a nonzero constant.
1000 if (lrhs.getValue() % rhsConst.getValue() == 0)
1001 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
1002 }
1003 }
1004
1005 return nullptr;
1006}
1007
1008AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
1009 return ceilDiv(other: getAffineConstantExpr(constant: v, context: getContext()));
1010}
1011AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
1012 if (auto simplified = simplifyCeilDiv(lhs: *this, rhs: other))
1013 return simplified;
1014
1015 StorageUniquer &uniquer = getContext()->getAffineUniquer();
1016 return uniquer.get<AffineBinaryOpExprStorage>(
1017 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::CeilDiv), args: *this,
1018 args&: other);
1019}
1020
1021static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
1022 auto lhsConst = dyn_cast<AffineConstantExpr>(Val&: lhs);
1023 auto rhsConst = dyn_cast<AffineConstantExpr>(Val&: rhs);
1024
1025 // mod w.r.t zero or negative numbers is undefined and preserved as is.
1026 if (!rhsConst || rhsConst.getValue() < 1)
1027 return nullptr;
1028
1029 if (lhsConst) {
1030 // mod never overflows.
1031 return getAffineConstantExpr(constant: mod(Numerator: lhsConst.getValue(), Denominator: rhsConst.getValue()),
1032 context: lhs.getContext());
1033 }
1034
1035 // Fold modulo of an expression that is known to be a multiple of a constant
1036 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
1037 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
1038 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
1039 return getAffineConstantExpr(constant: 0, context: lhs.getContext());
1040
1041 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
1042 // known to be a multiple of divConst.
1043 auto lBin = dyn_cast<AffineBinaryOpExpr>(Val&: lhs);
1044 if (lBin && lBin.getKind() == AffineExprKind::Add) {
1045 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
1046 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
1047 // rhsConst is known to be a positive constant.
1048 if (llhsDiv % rhsConst.getValue() == 0)
1049 return lBin.getRHS() % rhsConst.getValue();
1050 if (lrhsDiv % rhsConst.getValue() == 0)
1051 return lBin.getLHS() % rhsConst.getValue();
1052 }
1053
1054 // Simplify (e % a) % b to e % b when b evenly divides a
1055 if (lBin && lBin.getKind() == AffineExprKind::Mod) {
1056 auto intermediate = dyn_cast<AffineConstantExpr>(Val: lBin.getRHS());
1057 if (intermediate && intermediate.getValue() >= 1 &&
1058 mod(Numerator: intermediate.getValue(), Denominator: rhsConst.getValue()) == 0) {
1059 return lBin.getLHS() % rhsConst.getValue();
1060 }
1061 }
1062
1063 return nullptr;
1064}
1065
1066AffineExpr AffineExpr::operator%(uint64_t v) const {
1067 return *this % getAffineConstantExpr(constant: v, context: getContext());
1068}
1069AffineExpr AffineExpr::operator%(AffineExpr other) const {
1070 if (auto simplified = simplifyMod(lhs: *this, rhs: other))
1071 return simplified;
1072
1073 StorageUniquer &uniquer = getContext()->getAffineUniquer();
1074 return uniquer.get<AffineBinaryOpExprStorage>(
1075 /*initFn=*/{}, args: static_cast<unsigned>(AffineExprKind::Mod), args: *this, args&: other);
1076}
1077
1078AffineExpr AffineExpr::compose(AffineMap map) const {
1079 SmallVector<AffineExpr, 8> dimReplacements(map.getResults());
1080 return replaceDimsAndSymbols(dimReplacements, symReplacements: {});
1081}
1082raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
1083 expr.print(os);
1084 return os;
1085}
1086
1087/// Constructs an affine expression from a flat ArrayRef. If there are local
1088/// identifiers (neither dimensional nor symbolic) that appear in the sum of
1089/// products expression, `localExprs` is expected to have the AffineExpr
1090/// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
1091/// in the format [dims, symbols, locals, constant term].
1092AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1093 unsigned numDims,
1094 unsigned numSymbols,
1095 ArrayRef<AffineExpr> localExprs,
1096 MLIRContext *context) {
1097 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1098 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1099 "unexpected number of local expressions");
1100
1101 auto expr = getAffineConstantExpr(constant: 0, context);
1102 // Dimensions and symbols.
1103 for (unsigned j = 0; j < numDims + numSymbols; j++) {
1104 if (flatExprs[j] == 0)
1105 continue;
1106 auto id = j < numDims ? getAffineDimExpr(position: j, context)
1107 : getAffineSymbolExpr(position: j - numDims, context);
1108 expr = expr + id * flatExprs[j];
1109 }
1110
1111 // Local identifiers.
1112 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1113 j++) {
1114 if (flatExprs[j] == 0)
1115 continue;
1116 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1117 expr = expr + term;
1118 }
1119
1120 // Constant term.
1121 int64_t constTerm = flatExprs[flatExprs.size() - 1];
1122 if (constTerm != 0)
1123 expr = expr + constTerm;
1124 return expr;
1125}
1126
1127/// Constructs a semi-affine expression from a flat ArrayRef. If there are
1128/// local identifiers (neither dimensional nor symbolic) that appear in the sum
1129/// of products expression, `localExprs` is expected to have the AffineExprs for
1130/// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
1131/// the format [dims, symbols, locals, constant term]. The semi-affine
1132/// expression is constructed in the sorted order of dimension and symbol
1133/// position numbers. Note: local expressions/ids are used for mod, div as well
1134/// as symbolic RHS terms for terms that are not pure affine.
1135static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
1136 unsigned numDims,
1137 unsigned numSymbols,
1138 ArrayRef<AffineExpr> localExprs,
1139 MLIRContext *context) {
1140 assert(!flatExprs.empty() && "flatExprs cannot be empty");
1141
1142 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
1143 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
1144 "unexpected number of local expressions");
1145
1146 AffineExpr expr = getAffineConstantExpr(constant: 0, context);
1147
1148 // We design indices as a pair which help us present the semi-affine map as
1149 // sum of product where terms are sorted based on dimension or symbol
1150 // position: <keyA, keyB> for expressions of the form dimension * symbol,
1151 // where keyA is the position number of the dimension and keyB is the
1152 // position number of the symbol. For dimensional expressions we set the index
1153 // as (position number of the dimension, -1), as we want dimensional
1154 // expressions to appear before symbolic and product of dimensional and
1155 // symbolic expressions having the dimension with the same position number.
1156 // For symbolic expression set the index as (position number of the symbol,
1157 // maximum of last dimension and symbol position) number. For example, we want
1158 // the expression we are constructing to look something like: d0 + d0 * s0 +
1159 // s0 + d1*s1 + s1.
1160
1161 // Stores the affine expression corresponding to a given index.
1162 DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap;
1163 // Stores the constant coefficient value corresponding to a given
1164 // dimension, symbol or a non-pure affine expression stored in `localExprs`.
1165 DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
1166 // Stores the indices as defined above, and later sorted to produce
1167 // the semi-affine expression in the desired form.
1168 SmallVector<std::pair<unsigned, signed>, 8> indices;
1169
1170 // Example: expression = d0 + d0 * s0 + 2 * s0.
1171 // indices = [{0,-1}, {0, 0}, {0, 1}]
1172 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
1173 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
1174
1175 // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
1176 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
1177 AffineExpr expr) {
1178 assert(!llvm::is_contained(indices, index) &&
1179 "Key is already present in indices vector and overwriting will "
1180 "happen in `indexToExprMap` and `coefficients`!");
1181
1182 indices.push_back(Elt: index);
1183 coefficients.insert(KV: {index, coefficient});
1184 indexToExprMap.insert(KV: {index, expr});
1185 };
1186
1187 // Design indices for dimensional or symbolic terms, and store the indices,
1188 // constant coefficient corresponding to the indices in `coefficients` map,
1189 // and affine expression corresponding to indices in `indexToExprMap` map.
1190
1191 // Ensure we do not have duplicate keys in `indexToExpr` map.
1192 unsigned offsetSym = 0;
1193 signed offsetDim = -1;
1194 for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1195 if (flatExprs[j] == 0)
1196 continue;
1197 // For symbolic expression set the index as <position number
1198 // of the symbol, max(dimCount, symCount)> number,
1199 // as we want symbolic expressions with the same positional number to
1200 // appear after dimensional expressions having the same positional number.
1201 std::pair<unsigned, signed> indexEntry(
1202 j - numDims, std::max(a: numDims, b: numSymbols) + offsetSym++);
1203 addEntry(indexEntry, flatExprs[j],
1204 getAffineSymbolExpr(position: j - numDims, context));
1205 }
1206
1207 // Denotes semi-affine product, modulo or division terms, which has been added
1208 // to the `indexToExpr` map.
1209 SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1210 false);
1211 unsigned lhsPos, rhsPos;
1212 // Construct indices for product terms involving dimension, symbol or constant
1213 // as lhs/rhs, and store the indices, constant coefficient corresponding to
1214 // the indices in `coefficients` map, and affine expression corresponding to
1215 // in indices in `indexToExprMap` map.
1216 for (const auto &it : llvm::enumerate(First&: localExprs)) {
1217 if (flatExprs[numDims + numSymbols + it.index()] == 0)
1218 continue;
1219 AffineExpr expr = it.value();
1220 auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr);
1221 if (!binaryExpr)
1222 continue;
1223
1224 AffineExpr lhs = binaryExpr.getLHS();
1225 AffineExpr rhs = binaryExpr.getRHS();
1226 if (!((isa<AffineDimExpr>(Val: lhs) || isa<AffineSymbolExpr>(Val: lhs)) &&
1227 (isa<AffineDimExpr>(Val: rhs) || isa<AffineSymbolExpr>(Val: rhs) ||
1228 isa<AffineConstantExpr>(Val: rhs)))) {
1229 continue;
1230 }
1231 if (isa<AffineConstantExpr>(Val: rhs)) {
1232 // For product/modulo/division expressions, when rhs of modulo/division
1233 // expression is constant, we put 0 in place of keyB, because we want
1234 // them to appear earlier in the semi-affine expression we are
1235 // constructing. When rhs is constant, we place 0 in place of keyB.
1236 if (isa<AffineDimExpr>(Val: lhs)) {
1237 lhsPos = cast<AffineDimExpr>(Val&: lhs).getPosition();
1238 std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
1239 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1240 expr);
1241 } else {
1242 lhsPos = cast<AffineSymbolExpr>(Val&: lhs).getPosition();
1243 std::pair<unsigned, signed> indexEntry(
1244 lhsPos, std::max(a: numDims, b: numSymbols) + offsetSym++);
1245 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1246 expr);
1247 }
1248 } else if (isa<AffineDimExpr>(Val: lhs)) {
1249 // For product/modulo/division expressions having lhs as dimension and rhs
1250 // as symbol, we order the terms in the semi-affine expression based on
1251 // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1252 // where keyA is the position number of the dimension and keyB is the
1253 // position number of the symbol.
1254 lhsPos = cast<AffineDimExpr>(Val&: lhs).getPosition();
1255 rhsPos = cast<AffineSymbolExpr>(Val&: rhs).getPosition();
1256 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1257 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1258 } else {
1259 // For product/modulo/division expressions having both lhs and rhs as
1260 // symbol, we design indices as a pair: <keyA, keyB> for expressions
1261 // of the form dimension * symbol, where keyA is the position number of
1262 // the dimension and keyB is the position number of the symbol.
1263 lhsPos = cast<AffineSymbolExpr>(Val&: lhs).getPosition();
1264 rhsPos = cast<AffineSymbolExpr>(Val&: rhs).getPosition();
1265 std::pair<unsigned, signed> indexEntry(
1266 lhsPos, std::max(a: numDims, b: numSymbols) + offsetSym++);
1267 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1268 }
1269 addedToMap[it.index()] = true;
1270 }
1271
1272 for (unsigned j = 0; j < numDims; ++j) {
1273 if (flatExprs[j] == 0)
1274 continue;
1275 // For dimensional expressions we set the index as <position number of the
1276 // dimension, 0>, as we want dimensional expressions to appear before
1277 // symbolic ones and products of dimensional and symbolic expressions
1278 // having the dimension with the same position number.
1279 std::pair<unsigned, signed> indexEntry(j, offsetDim--);
1280 addEntry(indexEntry, flatExprs[j], getAffineDimExpr(position: j, context));
1281 }
1282
1283 // Constructing the simplified semi-affine sum of product/division/mod
1284 // expression from the flattened form in the desired sorted order of indices
1285 // of the various individual product/division/mod expressions.
1286 llvm::sort(C&: indices);
1287 for (const std::pair<unsigned, unsigned> index : indices) {
1288 assert(indexToExprMap.lookup(index) &&
1289 "cannot find key in `indexToExprMap` map");
1290 expr = expr + indexToExprMap.lookup(Val: index) * coefficients.lookup(Val: index);
1291 }
1292
1293 // Local identifiers.
1294 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1295 j++) {
1296 // If the coefficient of the local expression is 0, continue as we need not
1297 // add it in out final expression.
1298 if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1299 continue;
1300 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1301 expr = expr + term;
1302 }
1303
1304 // Constant term.
1305 int64_t constTerm = flatExprs.back();
1306 if (constTerm != 0)
1307 expr = expr + constTerm;
1308 return expr;
1309}
1310
1311SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
1312 unsigned numSymbols)
1313 : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1314 operandExprStack.reserve(n: 8);
1315}
1316
1317// In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1318//
1319// In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1320// introduce a local variable p (= expr * symbolic_expr), and the affine
1321// expression expr * symbolic_expr is added to `localExprs`.
1322LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
1323 assert(operandExprStack.size() >= 2);
1324 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1325 operandExprStack.pop_back();
1326 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1327
1328 // Flatten semi-affine multiplication expressions by introducing a local
1329 // variable in place of the product; the affine expression
1330 // corresponding to the quantifier is added to `localExprs`.
1331 if (!isa<AffineConstantExpr>(Val: expr.getRHS())) {
1332 SmallVector<int64_t, 8> mulLhs(lhs);
1333 MLIRContext *context = expr.getContext();
1334 AffineExpr a = getAffineExprFromFlatForm(flatExprs: lhs, numDims, numSymbols,
1335 localExprs, context);
1336 AffineExpr b = getAffineExprFromFlatForm(flatExprs: rhs, numDims, numSymbols,
1337 localExprs, context);
1338 return addLocalVariableSemiAffine(lhs: mulLhs, rhs, localExpr: a * b, result&: lhs, resultSize: lhs.size());
1339 }
1340
1341 // Get the RHS constant.
1342 int64_t rhsConst = rhs[getConstantIndex()];
1343 for (int64_t &lhsElt : lhs)
1344 lhsElt *= rhsConst;
1345
1346 return success();
1347}
1348
1349LogicalResult SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
1350 assert(operandExprStack.size() >= 2);
1351 const auto &rhs = operandExprStack.back();
1352 auto &lhs = operandExprStack[operandExprStack.size() - 2];
1353 assert(lhs.size() == rhs.size());
1354 // Update the LHS in place.
1355 for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1356 lhs[i] += rhs[i];
1357 }
1358 // Pop off the RHS.
1359 operandExprStack.pop_back();
1360 return success();
1361}
1362
1363//
1364// t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1365//
1366// A mod expression "expr mod c" is thus flattened by introducing a new local
1367// variable q (= expr floordiv c), such that expr mod c is replaced with
1368// 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1369//
1370// In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1371// introduce a local variable m (= expr mod symbolic_expr), and the affine
1372// expression expr mod symbolic_expr is added to `localExprs`.
1373LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1374 assert(operandExprStack.size() >= 2);
1375
1376 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1377 operandExprStack.pop_back();
1378 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1379 MLIRContext *context = expr.getContext();
1380
1381 // Flatten semi affine modulo expressions by introducing a local
1382 // variable in place of the modulo value, and the affine expression
1383 // corresponding to the quantifier is added to `localExprs`.
1384 if (!isa<AffineConstantExpr>(Val: expr.getRHS())) {
1385 SmallVector<int64_t, 8> modLhs(lhs);
1386 AffineExpr dividendExpr = getAffineExprFromFlatForm(
1387 flatExprs: lhs, numDims, numSymbols, localExprs, context);
1388 AffineExpr divisorExpr = getAffineExprFromFlatForm(flatExprs: rhs, numDims, numSymbols,
1389 localExprs, context);
1390 AffineExpr modExpr = dividendExpr % divisorExpr;
1391 return addLocalVariableSemiAffine(lhs: modLhs, rhs, localExpr: modExpr, result&: lhs, resultSize: lhs.size());
1392 }
1393
1394 int64_t rhsConst = rhs[getConstantIndex()];
1395 if (rhsConst <= 0)
1396 return failure();
1397
1398 // Check if the LHS expression is a multiple of modulo factor.
1399 unsigned i, e;
1400 for (i = 0, e = lhs.size(); i < e; i++)
1401 if (lhs[i] % rhsConst != 0)
1402 break;
1403 // If yes, modulo expression here simplifies to zero.
1404 if (i == lhs.size()) {
1405 llvm::fill(Range&: lhs, Value: 0);
1406 return success();
1407 }
1408
1409 // Add a local variable for the quotient, i.e., expr % c is replaced by
1410 // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1411 // the GCD of expr and c.
1412 SmallVector<int64_t, 8> floorDividend(lhs);
1413 uint64_t gcd = rhsConst;
1414 for (int64_t lhsElt : lhs)
1415 gcd = std::gcd(m: gcd, n: (uint64_t)std::abs(i: lhsElt));
1416 // Simplify the numerator and the denominator.
1417 if (gcd != 1) {
1418 for (int64_t &floorDividendElt : floorDividend)
1419 floorDividendElt = floorDividendElt / static_cast<int64_t>(gcd);
1420 }
1421 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1422
1423 // Construct the AffineExpr form of the floordiv to store in localExprs.
1424
1425 AffineExpr dividendExpr = getAffineExprFromFlatForm(
1426 flatExprs: floorDividend, numDims, numSymbols, localExprs, context);
1427 AffineExpr divisorExpr = getAffineConstantExpr(constant: floorDivisor, context);
1428 AffineExpr floorDivExpr = dividendExpr.floorDiv(other: divisorExpr);
1429 int loc;
1430 if ((loc = findLocalId(localExpr: floorDivExpr)) == -1) {
1431 addLocalFloorDivId(dividend: floorDividend, divisor: floorDivisor, localExpr: floorDivExpr);
1432 // Set result at top of stack to "lhs - rhsConst * q".
1433 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1434 } else {
1435 // Reuse the existing local id.
1436 lhs[getLocalVarStartIndex() + loc] -= rhsConst;
1437 }
1438 return success();
1439}
1440
1441LogicalResult
1442SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
1443 return visitDivExpr(expr, /*isCeil=*/true);
1444}
1445LogicalResult
1446SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
1447 return visitDivExpr(expr, /*isCeil=*/false);
1448}
1449
1450LogicalResult SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
1451 operandExprStack.emplace_back(args: SmallVector<int64_t, 32>(getNumCols(), 0));
1452 auto &eq = operandExprStack.back();
1453 assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1454 eq[getDimStartIndex() + expr.getPosition()] = 1;
1455 return success();
1456}
1457
1458LogicalResult
1459SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
1460 operandExprStack.emplace_back(args: SmallVector<int64_t, 32>(getNumCols(), 0));
1461 auto &eq = operandExprStack.back();
1462 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1463 eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1464 return success();
1465}
1466
1467LogicalResult
1468SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
1469 operandExprStack.emplace_back(args: SmallVector<int64_t, 32>(getNumCols(), 0));
1470 auto &eq = operandExprStack.back();
1471 eq[getConstantIndex()] = expr.getValue();
1472 return success();
1473}
1474
1475LogicalResult SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1476 ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr,
1477 SmallVectorImpl<int64_t> &result, unsigned long resultSize) {
1478 assert(result.size() == resultSize &&
1479 "`result` vector passed is not of correct size");
1480 int loc;
1481 if ((loc = findLocalId(localExpr)) == -1) {
1482 if (failed(Result: addLocalIdSemiAffine(lhs, rhs, localExpr)))
1483 return failure();
1484 }
1485 llvm::fill(Range&: result, Value: 0);
1486 if (loc == -1)
1487 result[getLocalVarStartIndex() + numLocals - 1] = 1;
1488 else
1489 result[getLocalVarStartIndex() + loc] = 1;
1490 return success();
1491}
1492
1493// t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1494// A floordiv is thus flattened by introducing a new local variable q, and
1495// replacing that expression with 'q' while adding the constraints
1496// c * q <= expr <= c * q + c - 1 to localVarCst (done by
1497// IntegerRelation::addLocalFloorDiv).
1498//
1499// A ceildiv is similarly flattened:
1500// t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1501//
1502// In case of semi affine division expressions, t = expr floordiv symbolic_expr
1503// or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1504// floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1505// `localExprs`.
1506LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1507 bool isCeil) {
1508 assert(operandExprStack.size() >= 2);
1509
1510 MLIRContext *context = expr.getContext();
1511 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1512 operandExprStack.pop_back();
1513 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1514
1515 // Flatten semi affine division expressions by introducing a local
1516 // variable in place of the quotient, and the affine expression corresponding
1517 // to the quantifier is added to `localExprs`.
1518 if (!isa<AffineConstantExpr>(Val: expr.getRHS())) {
1519 SmallVector<int64_t, 8> divLhs(lhs);
1520 AffineExpr a = getAffineExprFromFlatForm(flatExprs: lhs, numDims, numSymbols,
1521 localExprs, context);
1522 AffineExpr b = getAffineExprFromFlatForm(flatExprs: rhs, numDims, numSymbols,
1523 localExprs, context);
1524 AffineExpr divExpr = isCeil ? a.ceilDiv(other: b) : a.floorDiv(other: b);
1525 return addLocalVariableSemiAffine(lhs: divLhs, rhs, localExpr: divExpr, result&: lhs, resultSize: lhs.size());
1526 }
1527
1528 // This is a pure affine expr; the RHS is a positive constant.
1529 int64_t rhsConst = rhs[getConstantIndex()];
1530 if (rhsConst <= 0)
1531 return failure();
1532
1533 // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1534 // common divisors of the numerator and denominator.
1535 uint64_t gcd = std::abs(i: rhsConst);
1536 for (int64_t lhsElt : lhs)
1537 gcd = std::gcd(m: gcd, n: (uint64_t)std::abs(i: lhsElt));
1538 // Simplify the numerator and the denominator.
1539 if (gcd != 1) {
1540 for (int64_t &lhsElt : lhs)
1541 lhsElt = lhsElt / static_cast<int64_t>(gcd);
1542 }
1543 int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1544 // If the divisor becomes 1, the updated LHS is the result. (The
1545 // divisor can't be negative since rhsConst is positive).
1546 if (divisor == 1)
1547 return success();
1548
1549 // If the divisor cannot be simplified to one, we will have to retain
1550 // the ceil/floor expr (simplified up until here). Add an existential
1551 // quantifier to express its result, i.e., expr1 div expr2 is replaced
1552 // by a new identifier, q.
1553 AffineExpr a =
1554 getAffineExprFromFlatForm(flatExprs: lhs, numDims, numSymbols, localExprs, context);
1555 AffineExpr b = getAffineConstantExpr(constant: divisor, context);
1556
1557 int loc;
1558 AffineExpr divExpr = isCeil ? a.ceilDiv(other: b) : a.floorDiv(other: b);
1559 if ((loc = findLocalId(localExpr: divExpr)) == -1) {
1560 if (!isCeil) {
1561 SmallVector<int64_t, 8> dividend(lhs);
1562 addLocalFloorDivId(dividend, divisor, localExpr: divExpr);
1563 } else {
1564 // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1565 SmallVector<int64_t, 8> dividend(lhs);
1566 dividend.back() += divisor - 1;
1567 addLocalFloorDivId(dividend, divisor, localExpr: divExpr);
1568 }
1569 }
1570 // Set the expression on stack to the local var introduced to capture the
1571 // result of the division (floor or ceil).
1572 llvm::fill(Range&: lhs, Value: 0);
1573 if (loc == -1)
1574 lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1575 else
1576 lhs[getLocalVarStartIndex() + loc] = 1;
1577 return success();
1578}
1579
1580// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1581// The local identifier added is always a floordiv of a pure add/mul affine
1582// function of other identifiers, coefficients of which are specified in
1583// dividend and with respect to a positive constant divisor. localExpr is the
1584// simplified tree expression (AffineExpr) corresponding to the quantifier.
1585void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1586 int64_t divisor,
1587 AffineExpr localExpr) {
1588 assert(divisor > 0 && "positive constant divisor expected");
1589 for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1590 subExpr.insert(I: subExpr.begin() + getLocalVarStartIndex() + numLocals, Elt: 0);
1591 localExprs.push_back(Elt: localExpr);
1592 numLocals++;
1593 // dividend and divisor are not used here; an override of this method uses it.
1594}
1595
1596LogicalResult SimpleAffineExprFlattener::addLocalIdSemiAffine(
1597 ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr) {
1598 for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1599 subExpr.insert(I: subExpr.begin() + getLocalVarStartIndex() + numLocals, Elt: 0);
1600 localExprs.push_back(Elt: localExpr);
1601 ++numLocals;
1602 // lhs and rhs are not used here; an override of this method uses them.
1603 return success();
1604}
1605
1606int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1607 SmallVectorImpl<AffineExpr>::iterator it;
1608 if ((it = llvm::find(Range&: localExprs, Val: localExpr)) == localExprs.end())
1609 return -1;
1610 return it - localExprs.begin();
1611}
1612
1613/// Simplify the affine expression by flattening it and reconstructing it.
1614AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1615 unsigned numSymbols) {
1616 // Simplify semi-affine expressions separately.
1617 if (!expr.isPureAffine())
1618 expr = simplifySemiAffine(expr, numDims, numSymbols);
1619
1620 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1621 // has poison expression
1622 if (failed(Result: flattener.walkPostOrder(expr)))
1623 return expr;
1624 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1625 if (!expr.isPureAffine() &&
1626 expr == getAffineExprFromFlatForm(flatExprs: flattenedExpr, numDims, numSymbols,
1627 localExprs: flattener.localExprs,
1628 context: expr.getContext()))
1629 return expr;
1630 AffineExpr simplifiedExpr =
1631 expr.isPureAffine()
1632 ? getAffineExprFromFlatForm(flatExprs: flattenedExpr, numDims, numSymbols,
1633 localExprs: flattener.localExprs, context: expr.getContext())
1634 : getSemiAffineExprFromFlatForm(flatExprs: flattenedExpr, numDims, numSymbols,
1635 localExprs: flattener.localExprs,
1636 context: expr.getContext());
1637
1638 flattener.operandExprStack.pop_back();
1639 assert(flattener.operandExprStack.empty());
1640 return simplifiedExpr;
1641}
1642
1643std::optional<int64_t> mlir::getBoundForAffineExpr(
1644 AffineExpr expr, unsigned numDims, unsigned numSymbols,
1645 ArrayRef<std::optional<int64_t>> constLowerBounds,
1646 ArrayRef<std::optional<int64_t>> constUpperBounds, bool isUpper) {
1647 // Handle divs and mods.
1648 if (auto binOpExpr = dyn_cast<AffineBinaryOpExpr>(Val&: expr)) {
1649 // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
1650 // can compute an upper bound.
1651 if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
1652 auto rhsConst = dyn_cast<AffineConstantExpr>(Val: binOpExpr.getRHS());
1653 if (!rhsConst || rhsConst.getValue() < 1)
1654 return std::nullopt;
1655 auto bound =
1656 getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1657 constLowerBounds, constUpperBounds, isUpper);
1658 if (!bound)
1659 return std::nullopt;
1660 return divideFloorSigned(Numerator: *bound, Denominator: rhsConst.getValue());
1661 }
1662 if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
1663 auto rhsConst = dyn_cast<AffineConstantExpr>(Val: binOpExpr.getRHS());
1664 if (rhsConst && rhsConst.getValue() >= 1) {
1665 auto bound =
1666 getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1667 constLowerBounds, constUpperBounds, isUpper);
1668 if (!bound)
1669 return std::nullopt;
1670 return divideCeilSigned(Numerator: *bound, Denominator: rhsConst.getValue());
1671 }
1672 return std::nullopt;
1673 }
1674 if (binOpExpr.getKind() == AffineExprKind::Mod) {
1675 // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
1676 // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
1677 // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
1678 auto rhsConst = dyn_cast<AffineConstantExpr>(Val: binOpExpr.getRHS());
1679 if (rhsConst && rhsConst.getValue() >= 1) {
1680 int64_t rhsConstVal = rhsConst.getValue();
1681 auto lb = getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1682 constLowerBounds, constUpperBounds,
1683 /*isUpper=*/false);
1684 auto ub =
1685 getBoundForAffineExpr(expr: binOpExpr.getLHS(), numDims, numSymbols,
1686 constLowerBounds, constUpperBounds, isUpper);
1687 if (ub && lb &&
1688 divideFloorSigned(Numerator: *lb, Denominator: rhsConstVal) ==
1689 divideFloorSigned(Numerator: *ub, Denominator: rhsConstVal))
1690 return isUpper ? mod(Numerator: *ub, Denominator: rhsConstVal) : mod(Numerator: *lb, Denominator: rhsConstVal);
1691 return isUpper ? rhsConstVal - 1 : 0;
1692 }
1693 }
1694 }
1695 // Flatten the expression.
1696 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1697 auto simpleResult = flattener.walkPostOrder(expr);
1698 // has poison expression
1699 if (failed(Result: simpleResult))
1700 return std::nullopt;
1701 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1702 // TODO: Handle local variables. We can get hold of flattener.localExprs and
1703 // get bound on the local expr recursively.
1704 if (flattener.numLocals > 0)
1705 return std::nullopt;
1706 int64_t bound = 0;
1707 // Substitute the constant lower or upper bound for the dimensional or
1708 // symbolic input depending on `isUpper` to determine the bound.
1709 for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
1710 if (flattenedExpr[i] > 0) {
1711 auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
1712 if (!constBound)
1713 return std::nullopt;
1714 bound += *constBound * flattenedExpr[i];
1715 } else if (flattenedExpr[i] < 0) {
1716 auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
1717 if (!constBound)
1718 return std::nullopt;
1719 bound += *constBound * flattenedExpr[i];
1720 }
1721 }
1722 // Constant term.
1723 bound += flattenedExpr.back();
1724 return bound;
1725}
1726

source code of mlir/lib/IR/AffineExpr.cpp