1//===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinAttributeInterfaces.h"
10#include "mlir/IR/BuiltinTypes.h"
11#include "mlir/IR/Diagnostics.h"
12#include "llvm/ADT/Sequence.h"
13
14using namespace mlir;
15using namespace mlir::detail;
16
17//===----------------------------------------------------------------------===//
18/// Tablegen Interface Definitions
19//===----------------------------------------------------------------------===//
20
21#include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
22
23//===----------------------------------------------------------------------===//
24// ElementsAttr
25//===----------------------------------------------------------------------===//
26
27Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
28 return elementsAttr.getShapedType().getElementType();
29}
30
31int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
32 return elementsAttr.getShapedType().getNumElements();
33}
34
35bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
36 // Verify that the rank of the indices matches the held type.
37 int64_t rank = type.getRank();
38 if (rank == 0 && index.size() == 1 && index[0] == 0)
39 return true;
40 if (rank != static_cast<int64_t>(index.size()))
41 return false;
42
43 // Verify that all of the indices are within the shape dimensions.
44 ArrayRef<int64_t> shape = type.getShape();
45 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
46 int64_t dim = static_cast<int64_t>(index[i]);
47 return 0 <= dim && dim < shape[i];
48 });
49}
50bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
51 ArrayRef<uint64_t> index) {
52 return isValidIndex(elementsAttr.getShapedType(), index);
53}
54
55uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
56 ShapedType shapeType = llvm::cast<ShapedType>(type);
57 assert(isValidIndex(shapeType, index) &&
58 "expected valid multi-dimensional index");
59
60 // Reduce the provided multidimensional index into a flattended 1D row-major
61 // index.
62 auto rank = shapeType.getRank();
63 ArrayRef<int64_t> shape = shapeType.getShape();
64 uint64_t valueIndex = 0;
65 uint64_t dimMultiplier = 1;
66 for (int i = rank - 1; i >= 0; --i) {
67 valueIndex += index[i] * dimMultiplier;
68 dimMultiplier *= shape[i];
69 }
70 return valueIndex;
71}
72
73//===----------------------------------------------------------------------===//
74// MemRefLayoutAttrInterface
75//===----------------------------------------------------------------------===//
76
77LogicalResult mlir::detail::verifyAffineMapAsLayout(
78 AffineMap m, ArrayRef<int64_t> shape,
79 function_ref<InFlightDiagnostic()> emitError) {
80 if (m.getNumDims() != shape.size())
81 return emitError() << "memref layout mismatch between rank and affine map: "
82 << shape.size() << " != " << m.getNumDims();
83
84 return success();
85}
86
87// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
88// i.e. single term). Accumulate the AffineExpr into the existing one.
89static void extractStridesFromTerm(AffineExpr e,
90 AffineExpr multiplicativeFactor,
91 MutableArrayRef<AffineExpr> strides,
92 AffineExpr &offset) {
93 if (auto dim = dyn_cast<AffineDimExpr>(Val&: e))
94 strides[dim.getPosition()] =
95 strides[dim.getPosition()] + multiplicativeFactor;
96 else
97 offset = offset + e * multiplicativeFactor;
98}
99
100/// Takes a single AffineExpr `e` and populates the `strides` array with the
101/// strides expressions for each dim position.
102/// The convention is that the strides for dimensions d0, .. dn appear in
103/// order to make indexing intuitive into the result.
104static LogicalResult extractStrides(AffineExpr e,
105 AffineExpr multiplicativeFactor,
106 MutableArrayRef<AffineExpr> strides,
107 AffineExpr &offset) {
108 auto bin = dyn_cast<AffineBinaryOpExpr>(Val&: e);
109 if (!bin) {
110 extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
111 return success();
112 }
113
114 if (bin.getKind() == AffineExprKind::CeilDiv ||
115 bin.getKind() == AffineExprKind::FloorDiv ||
116 bin.getKind() == AffineExprKind::Mod)
117 return failure();
118
119 if (bin.getKind() == AffineExprKind::Mul) {
120 auto dim = dyn_cast<AffineDimExpr>(Val: bin.getLHS());
121 if (dim) {
122 strides[dim.getPosition()] =
123 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
124 return success();
125 }
126 // LHS and RHS may both contain complex expressions of dims. Try one path
127 // and if it fails try the other. This is guaranteed to succeed because
128 // only one path may have a `dim`, otherwise this is not an AffineExpr in
129 // the first place.
130 if (bin.getLHS().isSymbolicOrConstant())
131 return extractStrides(e: bin.getRHS(), multiplicativeFactor: multiplicativeFactor * bin.getLHS(),
132 strides, offset);
133 return extractStrides(e: bin.getLHS(), multiplicativeFactor: multiplicativeFactor * bin.getRHS(),
134 strides, offset);
135 }
136
137 if (bin.getKind() == AffineExprKind::Add) {
138 auto res1 =
139 extractStrides(e: bin.getLHS(), multiplicativeFactor, strides, offset);
140 auto res2 =
141 extractStrides(e: bin.getRHS(), multiplicativeFactor, strides, offset);
142 return success(IsSuccess: succeeded(Result: res1) && succeeded(Result: res2));
143 }
144
145 llvm_unreachable("unexpected binary operation");
146}
147
148/// A stride specification is a list of integer values that are either static
149/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
150/// the distance in the number of elements between successive entries along a
151/// particular dimension.
152///
153/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
154/// non-contiguous memory region of `42` by `16` `f32` elements in which the
155/// distance between two consecutive elements along the outer dimension is `1`
156/// and the distance between two consecutive elements along the inner dimension
157/// is `64`.
158///
159/// The convention is that the strides for dimensions d0, .. dn appear in
160/// order to make indexing intuitive into the result.
161static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef<int64_t> shape,
162 SmallVectorImpl<AffineExpr> &strides,
163 AffineExpr &offset) {
164 if (m.getNumResults() != 1 && !m.isIdentity())
165 return failure();
166
167 auto zero = getAffineConstantExpr(constant: 0, context: m.getContext());
168 auto one = getAffineConstantExpr(constant: 1, context: m.getContext());
169 offset = zero;
170 strides.assign(NumElts: shape.size(), Elt: zero);
171
172 // Canonical case for empty map.
173 if (m.isIdentity()) {
174 // 0-D corner case, offset is already 0.
175 if (shape.empty())
176 return success();
177 auto stridedExpr = makeCanonicalStridedLayoutExpr(sizes: shape, context: m.getContext());
178 if (succeeded(Result: extractStrides(e: stridedExpr, multiplicativeFactor: one, strides, offset)))
179 return success();
180 assert(false && "unexpected failure: extract strides in canonical layout");
181 }
182
183 // Non-canonical case requires more work.
184 auto stridedExpr =
185 simplifyAffineExpr(expr: m.getResult(idx: 0), numDims: m.getNumDims(), numSymbols: m.getNumSymbols());
186 if (failed(Result: extractStrides(e: stridedExpr, multiplicativeFactor: one, strides, offset))) {
187 offset = AffineExpr();
188 strides.clear();
189 return failure();
190 }
191
192 // Simplify results to allow folding to constants and simple checks.
193 unsigned numDims = m.getNumDims();
194 unsigned numSymbols = m.getNumSymbols();
195 offset = simplifyAffineExpr(expr: offset, numDims, numSymbols);
196 for (auto &stride : strides)
197 stride = simplifyAffineExpr(expr: stride, numDims, numSymbols);
198
199 return success();
200}
201
202LogicalResult mlir::detail::getAffineMapStridesAndOffset(
203 AffineMap map, ArrayRef<int64_t> shape, SmallVectorImpl<int64_t> &strides,
204 int64_t &offset) {
205 AffineExpr offsetExpr;
206 SmallVector<AffineExpr, 4> strideExprs;
207 if (failed(Result: ::getStridesAndOffset(m: map, shape, strides&: strideExprs, offset&: offsetExpr)))
208 return failure();
209 if (auto cst = llvm::dyn_cast<AffineConstantExpr>(Val&: offsetExpr))
210 offset = cst.getValue();
211 else
212 offset = ShapedType::kDynamic;
213 for (auto e : strideExprs) {
214 if (auto c = llvm::dyn_cast<AffineConstantExpr>(Val&: e))
215 strides.push_back(Elt: c.getValue());
216 else
217 strides.push_back(ShapedType::kDynamic);
218 }
219 return success();
220}
221

source code of mlir/lib/IR/BuiltinAttributeInterfaces.cpp