1 | //===- BuiltinAttributeInterfaces.cpp -------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
10 | #include "mlir/IR/BuiltinTypes.h" |
11 | #include "mlir/IR/Diagnostics.h" |
12 | #include "llvm/ADT/Sequence.h" |
13 | |
14 | using namespace mlir; |
15 | using namespace mlir::detail; |
16 | |
17 | //===----------------------------------------------------------------------===// |
18 | /// Tablegen Interface Definitions |
19 | //===----------------------------------------------------------------------===// |
20 | |
21 | #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc" |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // ElementsAttr |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | Type ElementsAttr::getElementType(ElementsAttr elementsAttr) { |
28 | return elementsAttr.getShapedType().getElementType(); |
29 | } |
30 | |
31 | int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) { |
32 | return elementsAttr.getShapedType().getNumElements(); |
33 | } |
34 | |
35 | bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) { |
36 | // Verify that the rank of the indices matches the held type. |
37 | int64_t rank = type.getRank(); |
38 | if (rank == 0 && index.size() == 1 && index[0] == 0) |
39 | return true; |
40 | if (rank != static_cast<int64_t>(index.size())) |
41 | return false; |
42 | |
43 | // Verify that all of the indices are within the shape dimensions. |
44 | ArrayRef<int64_t> shape = type.getShape(); |
45 | return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { |
46 | int64_t dim = static_cast<int64_t>(index[i]); |
47 | return 0 <= dim && dim < shape[i]; |
48 | }); |
49 | } |
50 | bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr, |
51 | ArrayRef<uint64_t> index) { |
52 | return isValidIndex(elementsAttr.getShapedType(), index); |
53 | } |
54 | |
55 | uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) { |
56 | ShapedType shapeType = llvm::cast<ShapedType>(type); |
57 | assert(isValidIndex(shapeType, index) && |
58 | "expected valid multi-dimensional index" ); |
59 | |
60 | // Reduce the provided multidimensional index into a flattended 1D row-major |
61 | // index. |
62 | auto rank = shapeType.getRank(); |
63 | ArrayRef<int64_t> shape = shapeType.getShape(); |
64 | uint64_t valueIndex = 0; |
65 | uint64_t dimMultiplier = 1; |
66 | for (int i = rank - 1; i >= 0; --i) { |
67 | valueIndex += index[i] * dimMultiplier; |
68 | dimMultiplier *= shape[i]; |
69 | } |
70 | return valueIndex; |
71 | } |
72 | |
73 | //===----------------------------------------------------------------------===// |
74 | // MemRefLayoutAttrInterface |
75 | //===----------------------------------------------------------------------===// |
76 | |
77 | LogicalResult mlir::detail::verifyAffineMapAsLayout( |
78 | AffineMap m, ArrayRef<int64_t> shape, |
79 | function_ref<InFlightDiagnostic()> emitError) { |
80 | if (m.getNumDims() != shape.size()) |
81 | return emitError() << "memref layout mismatch between rank and affine map: " |
82 | << shape.size() << " != " << m.getNumDims(); |
83 | |
84 | return success(); |
85 | } |
86 | |
87 | // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( |
88 | // i.e. single term). Accumulate the AffineExpr into the existing one. |
89 | static void (AffineExpr e, |
90 | AffineExpr multiplicativeFactor, |
91 | MutableArrayRef<AffineExpr> strides, |
92 | AffineExpr &offset) { |
93 | if (auto dim = dyn_cast<AffineDimExpr>(Val&: e)) |
94 | strides[dim.getPosition()] = |
95 | strides[dim.getPosition()] + multiplicativeFactor; |
96 | else |
97 | offset = offset + e * multiplicativeFactor; |
98 | } |
99 | |
100 | /// Takes a single AffineExpr `e` and populates the `strides` array with the |
101 | /// strides expressions for each dim position. |
102 | /// The convention is that the strides for dimensions d0, .. dn appear in |
103 | /// order to make indexing intuitive into the result. |
104 | static LogicalResult (AffineExpr e, |
105 | AffineExpr multiplicativeFactor, |
106 | MutableArrayRef<AffineExpr> strides, |
107 | AffineExpr &offset) { |
108 | auto bin = dyn_cast<AffineBinaryOpExpr>(Val&: e); |
109 | if (!bin) { |
110 | extractStridesFromTerm(e, multiplicativeFactor, strides, offset); |
111 | return success(); |
112 | } |
113 | |
114 | if (bin.getKind() == AffineExprKind::CeilDiv || |
115 | bin.getKind() == AffineExprKind::FloorDiv || |
116 | bin.getKind() == AffineExprKind::Mod) |
117 | return failure(); |
118 | |
119 | if (bin.getKind() == AffineExprKind::Mul) { |
120 | auto dim = dyn_cast<AffineDimExpr>(Val: bin.getLHS()); |
121 | if (dim) { |
122 | strides[dim.getPosition()] = |
123 | strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; |
124 | return success(); |
125 | } |
126 | // LHS and RHS may both contain complex expressions of dims. Try one path |
127 | // and if it fails try the other. This is guaranteed to succeed because |
128 | // only one path may have a `dim`, otherwise this is not an AffineExpr in |
129 | // the first place. |
130 | if (bin.getLHS().isSymbolicOrConstant()) |
131 | return extractStrides(e: bin.getRHS(), multiplicativeFactor: multiplicativeFactor * bin.getLHS(), |
132 | strides, offset); |
133 | return extractStrides(e: bin.getLHS(), multiplicativeFactor: multiplicativeFactor * bin.getRHS(), |
134 | strides, offset); |
135 | } |
136 | |
137 | if (bin.getKind() == AffineExprKind::Add) { |
138 | auto res1 = |
139 | extractStrides(e: bin.getLHS(), multiplicativeFactor, strides, offset); |
140 | auto res2 = |
141 | extractStrides(e: bin.getRHS(), multiplicativeFactor, strides, offset); |
142 | return success(IsSuccess: succeeded(Result: res1) && succeeded(Result: res2)); |
143 | } |
144 | |
145 | llvm_unreachable("unexpected binary operation" ); |
146 | } |
147 | |
148 | /// A stride specification is a list of integer values that are either static |
149 | /// or dynamic (encoded with ShapedType::kDynamic). Strides encode |
150 | /// the distance in the number of elements between successive entries along a |
151 | /// particular dimension. |
152 | /// |
153 | /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a |
154 | /// non-contiguous memory region of `42` by `16` `f32` elements in which the |
155 | /// distance between two consecutive elements along the outer dimension is `1` |
156 | /// and the distance between two consecutive elements along the inner dimension |
157 | /// is `64`. |
158 | /// |
159 | /// The convention is that the strides for dimensions d0, .. dn appear in |
160 | /// order to make indexing intuitive into the result. |
161 | static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef<int64_t> shape, |
162 | SmallVectorImpl<AffineExpr> &strides, |
163 | AffineExpr &offset) { |
164 | if (m.getNumResults() != 1 && !m.isIdentity()) |
165 | return failure(); |
166 | |
167 | auto zero = getAffineConstantExpr(constant: 0, context: m.getContext()); |
168 | auto one = getAffineConstantExpr(constant: 1, context: m.getContext()); |
169 | offset = zero; |
170 | strides.assign(NumElts: shape.size(), Elt: zero); |
171 | |
172 | // Canonical case for empty map. |
173 | if (m.isIdentity()) { |
174 | // 0-D corner case, offset is already 0. |
175 | if (shape.empty()) |
176 | return success(); |
177 | auto stridedExpr = makeCanonicalStridedLayoutExpr(sizes: shape, context: m.getContext()); |
178 | if (succeeded(Result: extractStrides(e: stridedExpr, multiplicativeFactor: one, strides, offset))) |
179 | return success(); |
180 | assert(false && "unexpected failure: extract strides in canonical layout" ); |
181 | } |
182 | |
183 | // Non-canonical case requires more work. |
184 | auto stridedExpr = |
185 | simplifyAffineExpr(expr: m.getResult(idx: 0), numDims: m.getNumDims(), numSymbols: m.getNumSymbols()); |
186 | if (failed(Result: extractStrides(e: stridedExpr, multiplicativeFactor: one, strides, offset))) { |
187 | offset = AffineExpr(); |
188 | strides.clear(); |
189 | return failure(); |
190 | } |
191 | |
192 | // Simplify results to allow folding to constants and simple checks. |
193 | unsigned numDims = m.getNumDims(); |
194 | unsigned numSymbols = m.getNumSymbols(); |
195 | offset = simplifyAffineExpr(expr: offset, numDims, numSymbols); |
196 | for (auto &stride : strides) |
197 | stride = simplifyAffineExpr(expr: stride, numDims, numSymbols); |
198 | |
199 | return success(); |
200 | } |
201 | |
202 | LogicalResult mlir::detail::getAffineMapStridesAndOffset( |
203 | AffineMap map, ArrayRef<int64_t> shape, SmallVectorImpl<int64_t> &strides, |
204 | int64_t &offset) { |
205 | AffineExpr offsetExpr; |
206 | SmallVector<AffineExpr, 4> strideExprs; |
207 | if (failed(Result: ::getStridesAndOffset(m: map, shape, strides&: strideExprs, offset&: offsetExpr))) |
208 | return failure(); |
209 | if (auto cst = llvm::dyn_cast<AffineConstantExpr>(Val&: offsetExpr)) |
210 | offset = cst.getValue(); |
211 | else |
212 | offset = ShapedType::kDynamic; |
213 | for (auto e : strideExprs) { |
214 | if (auto c = llvm::dyn_cast<AffineConstantExpr>(Val&: e)) |
215 | strides.push_back(Elt: c.getValue()); |
216 | else |
217 | strides.push_back(ShapedType::kDynamic); |
218 | } |
219 | return success(); |
220 | } |
221 | |