1//===- AffineExpr.h - MLIR Affine Expr Class --------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// An affine expression is an affine combination of dimension identifiers and
10// symbols, including ceildiv/floordiv/mod by a constant integer.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_IR_AFFINEEXPR_H
15#define MLIR_IR_AFFINEEXPR_H
16
17#include "mlir/IR/Visitors.h"
18#include "mlir/Support/LLVM.h"
19#include "llvm/ADT/DenseMapInfo.h"
20#include "llvm/ADT/Hashing.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/Casting.h"
23#include <functional>
24#include <type_traits>
25
26namespace mlir {
27
28class MLIRContext;
29class AffineMap;
30class IntegerSet;
31
32namespace detail {
33
34struct AffineExprStorage;
35struct AffineBinaryOpExprStorage;
36struct AffineDimExprStorage;
37struct AffineConstantExprStorage;
38
39} // namespace detail
40
41enum class AffineExprKind {
42 Add,
43 /// RHS of mul is always a constant or a symbolic expression.
44 Mul,
45 /// RHS of mod is always a constant or a symbolic expression with a positive
46 /// value.
47 Mod,
48 /// RHS of floordiv is always a constant or a symbolic expression.
49 FloorDiv,
50 /// RHS of ceildiv is always a constant or a symbolic expression.
51 CeilDiv,
52
53 /// This is a marker for the last affine binary op. The range of binary
54 /// op's is expected to be this element and earlier.
55 LAST_AFFINE_BINARY_OP = CeilDiv,
56
57 /// Constant integer.
58 Constant,
59 /// Dimensional identifier.
60 DimId,
61 /// Symbolic identifier.
62 SymbolId,
63};
64
65/// Base type for affine expression.
66/// AffineExpr's are immutable value types with intuitive operators to
67/// operate on chainable, lightweight compositions.
68/// An AffineExpr is an interface to the underlying storage type pointer.
69class AffineExpr {
70public:
71 using ImplType = detail::AffineExprStorage;
72
73 constexpr AffineExpr() {}
74 /* implicit */ AffineExpr(const ImplType *expr)
75 : expr(const_cast<ImplType *>(expr)) {}
76
77 bool operator==(AffineExpr other) const { return expr == other.expr; }
78 bool operator!=(AffineExpr other) const { return !(*this == other); }
79 bool operator==(int64_t v) const;
80 bool operator!=(int64_t v) const { return !(*this == v); }
81 explicit operator bool() const { return expr; }
82
83 bool operator!() const { return expr == nullptr; }
84
85 template <typename U>
86 [[deprecated("Use llvm::isa<U>() instead")]] constexpr bool isa() const;
87
88 template <typename U>
89 [[deprecated("Use llvm::dyn_cast<U>() instead")]] U dyn_cast() const;
90
91 template <typename U>
92 [[deprecated("Use llvm::dyn_cast_or_null<U>() instead")]] U
93 dyn_cast_or_null() const;
94
95 template <typename U>
96 [[deprecated("Use llvm::cast<U>() instead")]] U cast() const;
97
98 MLIRContext *getContext() const;
99
100 /// Return the classification for this type.
101 AffineExprKind getKind() const;
102
103 void print(raw_ostream &os) const;
104 void dump() const;
105
106 /// Returns true if this expression is made out of only symbols and
107 /// constants, i.e., it does not involve dimensional identifiers.
108 bool isSymbolicOrConstant() const;
109
110 /// Returns true if this is a pure affine expression, i.e., multiplication,
111 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
112 bool isPureAffine() const;
113
114 /// Returns the greatest known integral divisor of this affine expression. The
115 /// result is always positive.
116 int64_t getLargestKnownDivisor() const;
117
118 /// Return true if the affine expression is a multiple of 'factor'.
119 bool isMultipleOf(int64_t factor) const;
120
121 /// Return true if the affine expression involves AffineDimExpr `position`.
122 bool isFunctionOfDim(unsigned position) const;
123
124 /// Return true if the affine expression involves AffineSymbolExpr `position`.
125 bool isFunctionOfSymbol(unsigned position) const;
126
127 /// Walk all of the AffineExpr's in this expression in postorder. This allows
128 /// a lambda walk function that can either return `void` or a WalkResult. With
129 /// a WalkResult, interrupting is supported.
130 template <typename FnT, typename RetT = detail::walkResultType<FnT>>
131 RetT walk(FnT &&callback) const {
132 return walk<RetT>(*this, callback);
133 }
134
135 /// This method substitutes any uses of dimensions and symbols (e.g.
136 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
137 /// This is a dense replacement method: a replacement must be specified for
138 /// every single dim and symbol.
139 AffineExpr replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
140 ArrayRef<AffineExpr> symReplacements) const;
141
142 /// Dim-only version of replaceDimsAndSymbols.
143 AffineExpr replaceDims(ArrayRef<AffineExpr> dimReplacements) const;
144
145 /// Symbol-only version of replaceDimsAndSymbols.
146 AffineExpr replaceSymbols(ArrayRef<AffineExpr> symReplacements) const;
147
148 /// Sparse replace method. Replace `expr` by `replacement` and return the
149 /// modified expression tree.
150 AffineExpr replace(AffineExpr expr, AffineExpr replacement) const;
151
152 /// Sparse replace method. If `*this` appears in `map` replaces it by
153 /// `map[*this]` and return the modified expression tree. Otherwise traverse
154 /// `*this` and apply replace with `map` on its subexpressions.
155 AffineExpr replace(const DenseMap<AffineExpr, AffineExpr> &map) const;
156
157 /// Replace dims[offset ... numDims)
158 /// by dims[offset + shift ... shift + numDims).
159 AffineExpr shiftDims(unsigned numDims, unsigned shift,
160 unsigned offset = 0) const;
161
162 /// Replace symbols[offset ... numSymbols)
163 /// by symbols[offset + shift ... shift + numSymbols).
164 AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift,
165 unsigned offset = 0) const;
166
167 AffineExpr operator+(int64_t v) const;
168 AffineExpr operator+(AffineExpr other) const;
169 AffineExpr operator-() const;
170 AffineExpr operator-(int64_t v) const;
171 AffineExpr operator-(AffineExpr other) const;
172 AffineExpr operator*(int64_t v) const;
173 AffineExpr operator*(AffineExpr other) const;
174 AffineExpr floorDiv(uint64_t v) const;
175 AffineExpr floorDiv(AffineExpr other) const;
176 AffineExpr ceilDiv(uint64_t v) const;
177 AffineExpr ceilDiv(AffineExpr other) const;
178 AffineExpr operator%(uint64_t v) const;
179 AffineExpr operator%(AffineExpr other) const;
180
181 /// Compose with an AffineMap.
182 /// Returns the composition of this AffineExpr with `map`.
183 ///
184 /// Prerequisites:
185 /// `this` and `map` are composable, i.e. that the number of AffineDimExpr of
186 /// `this` is smaller than the number of results of `map`. If a result of a
187 /// map does not have a corresponding AffineDimExpr, that result simply does
188 /// not appear in the produced AffineExpr.
189 ///
190 /// Example:
191 /// expr: `d0 + d2`
192 /// map: `(d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)`
193 /// returned expr: `d0 * 2 + d1 + d2 + s1`
194 AffineExpr compose(AffineMap map) const;
195
196 friend ::llvm::hash_code hash_value(AffineExpr arg);
197
198 /// Methods supporting C API.
199 const void *getAsOpaquePointer() const {
200 return static_cast<const void *>(expr);
201 }
202 static AffineExpr getFromOpaquePointer(const void *pointer) {
203 return AffineExpr(
204 reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
205 }
206
207 ImplType *getImpl() const { return expr; }
208
209protected:
210 ImplType *expr{nullptr};
211
212private:
213 /// A trampoline for the templated non-static AffineExpr::walk method to
214 /// dispatch lambda `callback`'s of either a void result type or a
215 /// WalkResult type. Walk all of the AffineExprs in `e` in postorder. Users
216 /// should use the regular (non-static) `walk` method.
217 template <typename WalkRetTy>
218 static WalkRetTy walk(AffineExpr e,
219 function_ref<WalkRetTy(AffineExpr)> callback);
220};
221
222/// Affine binary operation expression. An affine binary operation could be an
223/// add, mul, floordiv, ceildiv, or a modulo operation. (Subtraction is
224/// represented through a multiply by -1 and add.) These expressions are always
225/// constructed in a simplified form. For eg., the LHS and RHS operands can't
226/// both be constants. There are additional canonicalizing rules depending on
227/// the op type: see checks in the constructor.
228class AffineBinaryOpExpr : public AffineExpr {
229public:
230 using ImplType = detail::AffineBinaryOpExprStorage;
231 /* implicit */ AffineBinaryOpExpr(AffineExpr::ImplType *ptr);
232 AffineExpr getLHS() const;
233 AffineExpr getRHS() const;
234};
235
236/// A dimensional identifier appearing in an affine expression.
237class AffineDimExpr : public AffineExpr {
238public:
239 using ImplType = detail::AffineDimExprStorage;
240 /* implicit */ AffineDimExpr(AffineExpr::ImplType *ptr);
241 unsigned getPosition() const;
242};
243
244/// A symbolic identifier appearing in an affine expression.
245class AffineSymbolExpr : public AffineExpr {
246public:
247 using ImplType = detail::AffineDimExprStorage;
248 /* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr);
249 unsigned getPosition() const;
250};
251
252/// An integer constant appearing in affine expression.
253class AffineConstantExpr : public AffineExpr {
254public:
255 using ImplType = detail::AffineConstantExprStorage;
256 /* implicit */ AffineConstantExpr(AffineExpr::ImplType *ptr = nullptr);
257 int64_t getValue() const;
258};
259
260/// Make AffineExpr hashable.
261inline ::llvm::hash_code hash_value(AffineExpr arg) {
262 return ::llvm::hash_value(ptr: arg.expr);
263}
264
265inline AffineExpr operator+(int64_t val, AffineExpr expr) { return expr + val; }
266inline AffineExpr operator*(int64_t val, AffineExpr expr) { return expr * val; }
267inline AffineExpr operator-(int64_t val, AffineExpr expr) {
268 return expr * (-1) + val;
269}
270
271/// These free functions allow clients of the API to not use classes in detail.
272AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context);
273AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context);
274AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context);
275SmallVector<AffineExpr> getAffineConstantExprs(ArrayRef<int64_t> constants,
276 MLIRContext *context);
277AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
278 AffineExpr rhs);
279
280/// Constructs an affine expression from a flat ArrayRef. If there are local
281/// identifiers (neither dimensional nor symbolic) that appear in the sum of
282/// products expression, 'localExprs' is expected to have the AffineExpr
283/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
284/// format [dims, symbols, locals, constant term].
285AffineExpr getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
286 unsigned numDims, unsigned numSymbols,
287 ArrayRef<AffineExpr> localExprs,
288 MLIRContext *context);
289
290raw_ostream &operator<<(raw_ostream &os, AffineExpr expr);
291
292template <typename U>
293constexpr bool AffineExpr::isa() const {
294 if constexpr (std::is_same_v<U, AffineBinaryOpExpr>)
295 return getKind() <= AffineExprKind::LAST_AFFINE_BINARY_OP;
296 if constexpr (std::is_same_v<U, AffineDimExpr>)
297 return getKind() == AffineExprKind::DimId;
298 if constexpr (std::is_same_v<U, AffineSymbolExpr>)
299 return getKind() == AffineExprKind::SymbolId;
300 if constexpr (std::is_same_v<U, AffineConstantExpr>)
301 return getKind() == AffineExprKind::Constant;
302}
303template <typename U>
304U AffineExpr::dyn_cast() const {
305 return llvm::dyn_cast<U>(*this);
306}
307template <typename U>
308U AffineExpr::dyn_cast_or_null() const {
309 return llvm::dyn_cast_or_null<U>(*this);
310}
311template <typename U>
312U AffineExpr::cast() const {
313 return llvm::cast<U>(*this);
314}
315
316/// Simplify an affine expression by flattening and some amount of simple
317/// analysis. This has complexity linear in the number of nodes in 'expr'.
318/// Returns the simplified expression, which is the same as the input expression
319/// if it can't be simplified. When `expr` is semi-affine, a simplified
320/// semi-affine expression is constructed in the sorted order of dimension and
321/// symbol positions.
322AffineExpr simplifyAffineExpr(AffineExpr expr, unsigned numDims,
323 unsigned numSymbols);
324
325namespace detail {
326template <int N>
327void bindDims(MLIRContext *ctx) {}
328
329template <int N, typename AffineExprTy, typename... AffineExprTy2>
330void bindDims(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
331 e = getAffineDimExpr(position: N, context: ctx);
332 bindDims<N + 1, AffineExprTy2 &...>(ctx, exprs...);
333}
334
335template <int N>
336void bindSymbols(MLIRContext *ctx) {}
337
338template <int N, typename AffineExprTy, typename... AffineExprTy2>
339void bindSymbols(MLIRContext *ctx, AffineExprTy &e, AffineExprTy2 &...exprs) {
340 e = getAffineSymbolExpr(position: N, context: ctx);
341 bindSymbols<N + 1, AffineExprTy2 &...>(ctx, exprs...);
342}
343
344} // namespace detail
345
346/// Bind a list of AffineExpr references to DimExpr at positions:
347/// [0 .. sizeof...(exprs)]
348template <typename... AffineExprTy>
349void bindDims(MLIRContext *ctx, AffineExprTy &...exprs) {
350 detail::bindDims<0>(ctx, exprs...);
351}
352
353template <typename AffineExprTy>
354void bindDimsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
355 int idx = 0;
356 for (AffineExprTy &e : exprs)
357 e = getAffineDimExpr(position: idx++, context: ctx);
358}
359
360/// Bind a list of AffineExpr references to SymbolExpr at positions:
361/// [0 .. sizeof...(exprs)]
362template <typename... AffineExprTy>
363void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs) {
364 detail::bindSymbols<0>(ctx, exprs...);
365}
366
367template <typename AffineExprTy>
368void bindSymbolsList(MLIRContext *ctx, MutableArrayRef<AffineExprTy> exprs) {
369 int idx = 0;
370 for (AffineExprTy &e : exprs)
371 e = getAffineSymbolExpr(position: idx++, context: ctx);
372}
373
374/// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
375/// the constant lower and upper bounds for its inputs provided in
376/// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
377/// bound can't be computed. This method only handles simple sum of product
378/// expressions (w.r.t constant coefficients) so as to not depend on anything
379/// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
380/// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
381/// semi-affine ones will lead a none being returned.
382std::optional<int64_t>
383getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
384 ArrayRef<std::optional<int64_t>> constLowerBounds,
385 ArrayRef<std::optional<int64_t>> constUpperBounds,
386 bool isUpper);
387
388} // namespace mlir
389
390namespace llvm {
391
392// AffineExpr hash just like pointers
393template <>
394struct DenseMapInfo<mlir::AffineExpr> {
395 static mlir::AffineExpr getEmptyKey() {
396 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
397 return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
398 }
399 static mlir::AffineExpr getTombstoneKey() {
400 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
401 return mlir::AffineExpr(static_cast<mlir::AffineExpr::ImplType *>(pointer));
402 }
403 static unsigned getHashValue(mlir::AffineExpr val) {
404 return mlir::hash_value(arg: val);
405 }
406 static bool isEqual(mlir::AffineExpr LHS, mlir::AffineExpr RHS) {
407 return LHS == RHS;
408 }
409};
410
411/// Add support for llvm style casts. We provide a cast between To and From if
412/// From is mlir::AffineExpr or derives from it.
413template <typename To, typename From>
414struct CastInfo<To, From,
415 std::enable_if_t<std::is_same_v<mlir::AffineExpr,
416 std::remove_const_t<From>> ||
417 std::is_base_of_v<mlir::AffineExpr, From>>>
418 : NullableValueCastFailed<To>,
419 DefaultDoCastIfPossible<To, From, CastInfo<To, From>> {
420
421 static inline bool isPossible(mlir::AffineExpr expr) {
422 /// Return a constant true instead of a dynamic true when casting to self or
423 /// up the hierarchy.
424 if constexpr (std::is_base_of_v<To, From>) {
425 return true;
426 } else {
427 if constexpr (std::is_same_v<To, ::mlir::AffineBinaryOpExpr>)
428 return expr.getKind() <= ::mlir::AffineExprKind::LAST_AFFINE_BINARY_OP;
429 if constexpr (std::is_same_v<To, ::mlir::AffineDimExpr>)
430 return expr.getKind() == ::mlir::AffineExprKind::DimId;
431 if constexpr (std::is_same_v<To, ::mlir::AffineSymbolExpr>)
432 return expr.getKind() == ::mlir::AffineExprKind::SymbolId;
433 if constexpr (std::is_same_v<To, ::mlir::AffineConstantExpr>)
434 return expr.getKind() == ::mlir::AffineExprKind::Constant;
435 }
436 }
437 static inline To doCast(mlir::AffineExpr expr) { return To(expr.getImpl()); }
438};
439
440} // namespace llvm
441
442#endif // MLIR_IR_AFFINEEXPR_H
443

source code of mlir/include/mlir/IR/AffineExpr.h