| 1 | //===- BuiltinAttributeInterfaces.cpp -------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
| 10 | #include "mlir/IR/BuiltinTypes.h" |
| 11 | #include "mlir/IR/Diagnostics.h" |
| 12 | #include "llvm/ADT/Sequence.h" |
| 13 | |
| 14 | using namespace mlir; |
| 15 | using namespace mlir::detail; |
| 16 | |
| 17 | //===----------------------------------------------------------------------===// |
| 18 | /// Tablegen Interface Definitions |
| 19 | //===----------------------------------------------------------------------===// |
| 20 | |
| 21 | #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc" |
| 22 | |
| 23 | //===----------------------------------------------------------------------===// |
| 24 | // ElementsAttr |
| 25 | //===----------------------------------------------------------------------===// |
| 26 | |
| 27 | Type ElementsAttr::getElementType(ElementsAttr elementsAttr) { |
| 28 | return elementsAttr.getShapedType().getElementType(); |
| 29 | } |
| 30 | |
| 31 | int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) { |
| 32 | return elementsAttr.getShapedType().getNumElements(); |
| 33 | } |
| 34 | |
| 35 | bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) { |
| 36 | // Verify that the rank of the indices matches the held type. |
| 37 | int64_t rank = type.getRank(); |
| 38 | if (rank == 0 && index.size() == 1 && index[0] == 0) |
| 39 | return true; |
| 40 | if (rank != static_cast<int64_t>(index.size())) |
| 41 | return false; |
| 42 | |
| 43 | // Verify that all of the indices are within the shape dimensions. |
| 44 | ArrayRef<int64_t> shape = type.getShape(); |
| 45 | return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { |
| 46 | int64_t dim = static_cast<int64_t>(index[i]); |
| 47 | return 0 <= dim && dim < shape[i]; |
| 48 | }); |
| 49 | } |
| 50 | bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr, |
| 51 | ArrayRef<uint64_t> index) { |
| 52 | return isValidIndex(elementsAttr.getShapedType(), index); |
| 53 | } |
| 54 | |
| 55 | uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) { |
| 56 | ShapedType shapeType = llvm::cast<ShapedType>(type); |
| 57 | assert(isValidIndex(shapeType, index) && |
| 58 | "expected valid multi-dimensional index" ); |
| 59 | |
| 60 | // Reduce the provided multidimensional index into a flattended 1D row-major |
| 61 | // index. |
| 62 | auto rank = shapeType.getRank(); |
| 63 | ArrayRef<int64_t> shape = shapeType.getShape(); |
| 64 | uint64_t valueIndex = 0; |
| 65 | uint64_t dimMultiplier = 1; |
| 66 | for (int i = rank - 1; i >= 0; --i) { |
| 67 | valueIndex += index[i] * dimMultiplier; |
| 68 | dimMultiplier *= shape[i]; |
| 69 | } |
| 70 | return valueIndex; |
| 71 | } |
| 72 | |
| 73 | //===----------------------------------------------------------------------===// |
| 74 | // MemRefLayoutAttrInterface |
| 75 | //===----------------------------------------------------------------------===// |
| 76 | |
| 77 | LogicalResult mlir::detail::verifyAffineMapAsLayout( |
| 78 | AffineMap m, ArrayRef<int64_t> shape, |
| 79 | function_ref<InFlightDiagnostic()> emitError) { |
| 80 | if (m.getNumDims() != shape.size()) |
| 81 | return emitError() << "memref layout mismatch between rank and affine map: " |
| 82 | << shape.size() << " != " << m.getNumDims(); |
| 83 | |
| 84 | return success(); |
| 85 | } |
| 86 | |
| 87 | // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( |
| 88 | // i.e. single term). Accumulate the AffineExpr into the existing one. |
| 89 | static void (AffineExpr e, |
| 90 | AffineExpr multiplicativeFactor, |
| 91 | MutableArrayRef<AffineExpr> strides, |
| 92 | AffineExpr &offset) { |
| 93 | if (auto dim = dyn_cast<AffineDimExpr>(Val&: e)) |
| 94 | strides[dim.getPosition()] = |
| 95 | strides[dim.getPosition()] + multiplicativeFactor; |
| 96 | else |
| 97 | offset = offset + e * multiplicativeFactor; |
| 98 | } |
| 99 | |
| 100 | /// Takes a single AffineExpr `e` and populates the `strides` array with the |
| 101 | /// strides expressions for each dim position. |
| 102 | /// The convention is that the strides for dimensions d0, .. dn appear in |
| 103 | /// order to make indexing intuitive into the result. |
| 104 | static LogicalResult (AffineExpr e, |
| 105 | AffineExpr multiplicativeFactor, |
| 106 | MutableArrayRef<AffineExpr> strides, |
| 107 | AffineExpr &offset) { |
| 108 | auto bin = dyn_cast<AffineBinaryOpExpr>(Val&: e); |
| 109 | if (!bin) { |
| 110 | extractStridesFromTerm(e, multiplicativeFactor, strides, offset); |
| 111 | return success(); |
| 112 | } |
| 113 | |
| 114 | if (bin.getKind() == AffineExprKind::CeilDiv || |
| 115 | bin.getKind() == AffineExprKind::FloorDiv || |
| 116 | bin.getKind() == AffineExprKind::Mod) |
| 117 | return failure(); |
| 118 | |
| 119 | if (bin.getKind() == AffineExprKind::Mul) { |
| 120 | auto dim = dyn_cast<AffineDimExpr>(Val: bin.getLHS()); |
| 121 | if (dim) { |
| 122 | strides[dim.getPosition()] = |
| 123 | strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; |
| 124 | return success(); |
| 125 | } |
| 126 | // LHS and RHS may both contain complex expressions of dims. Try one path |
| 127 | // and if it fails try the other. This is guaranteed to succeed because |
| 128 | // only one path may have a `dim`, otherwise this is not an AffineExpr in |
| 129 | // the first place. |
| 130 | if (bin.getLHS().isSymbolicOrConstant()) |
| 131 | return extractStrides(e: bin.getRHS(), multiplicativeFactor: multiplicativeFactor * bin.getLHS(), |
| 132 | strides, offset); |
| 133 | return extractStrides(e: bin.getLHS(), multiplicativeFactor: multiplicativeFactor * bin.getRHS(), |
| 134 | strides, offset); |
| 135 | } |
| 136 | |
| 137 | if (bin.getKind() == AffineExprKind::Add) { |
| 138 | auto res1 = |
| 139 | extractStrides(e: bin.getLHS(), multiplicativeFactor, strides, offset); |
| 140 | auto res2 = |
| 141 | extractStrides(e: bin.getRHS(), multiplicativeFactor, strides, offset); |
| 142 | return success(IsSuccess: succeeded(Result: res1) && succeeded(Result: res2)); |
| 143 | } |
| 144 | |
| 145 | llvm_unreachable("unexpected binary operation" ); |
| 146 | } |
| 147 | |
| 148 | /// A stride specification is a list of integer values that are either static |
| 149 | /// or dynamic (encoded with ShapedType::kDynamic). Strides encode |
| 150 | /// the distance in the number of elements between successive entries along a |
| 151 | /// particular dimension. |
| 152 | /// |
| 153 | /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a |
| 154 | /// non-contiguous memory region of `42` by `16` `f32` elements in which the |
| 155 | /// distance between two consecutive elements along the outer dimension is `1` |
| 156 | /// and the distance between two consecutive elements along the inner dimension |
| 157 | /// is `64`. |
| 158 | /// |
| 159 | /// The convention is that the strides for dimensions d0, .. dn appear in |
| 160 | /// order to make indexing intuitive into the result. |
| 161 | static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef<int64_t> shape, |
| 162 | SmallVectorImpl<AffineExpr> &strides, |
| 163 | AffineExpr &offset) { |
| 164 | if (m.getNumResults() != 1 && !m.isIdentity()) |
| 165 | return failure(); |
| 166 | |
| 167 | auto zero = getAffineConstantExpr(constant: 0, context: m.getContext()); |
| 168 | auto one = getAffineConstantExpr(constant: 1, context: m.getContext()); |
| 169 | offset = zero; |
| 170 | strides.assign(NumElts: shape.size(), Elt: zero); |
| 171 | |
| 172 | // Canonical case for empty map. |
| 173 | if (m.isIdentity()) { |
| 174 | // 0-D corner case, offset is already 0. |
| 175 | if (shape.empty()) |
| 176 | return success(); |
| 177 | auto stridedExpr = makeCanonicalStridedLayoutExpr(sizes: shape, context: m.getContext()); |
| 178 | if (succeeded(Result: extractStrides(e: stridedExpr, multiplicativeFactor: one, strides, offset))) |
| 179 | return success(); |
| 180 | assert(false && "unexpected failure: extract strides in canonical layout" ); |
| 181 | } |
| 182 | |
| 183 | // Non-canonical case requires more work. |
| 184 | auto stridedExpr = |
| 185 | simplifyAffineExpr(expr: m.getResult(idx: 0), numDims: m.getNumDims(), numSymbols: m.getNumSymbols()); |
| 186 | if (failed(Result: extractStrides(e: stridedExpr, multiplicativeFactor: one, strides, offset))) { |
| 187 | offset = AffineExpr(); |
| 188 | strides.clear(); |
| 189 | return failure(); |
| 190 | } |
| 191 | |
| 192 | // Simplify results to allow folding to constants and simple checks. |
| 193 | unsigned numDims = m.getNumDims(); |
| 194 | unsigned numSymbols = m.getNumSymbols(); |
| 195 | offset = simplifyAffineExpr(expr: offset, numDims, numSymbols); |
| 196 | for (auto &stride : strides) |
| 197 | stride = simplifyAffineExpr(expr: stride, numDims, numSymbols); |
| 198 | |
| 199 | return success(); |
| 200 | } |
| 201 | |
| 202 | LogicalResult mlir::detail::getAffineMapStridesAndOffset( |
| 203 | AffineMap map, ArrayRef<int64_t> shape, SmallVectorImpl<int64_t> &strides, |
| 204 | int64_t &offset) { |
| 205 | AffineExpr offsetExpr; |
| 206 | SmallVector<AffineExpr, 4> strideExprs; |
| 207 | if (failed(Result: ::getStridesAndOffset(m: map, shape, strides&: strideExprs, offset&: offsetExpr))) |
| 208 | return failure(); |
| 209 | if (auto cst = llvm::dyn_cast<AffineConstantExpr>(Val&: offsetExpr)) |
| 210 | offset = cst.getValue(); |
| 211 | else |
| 212 | offset = ShapedType::kDynamic; |
| 213 | for (auto e : strideExprs) { |
| 214 | if (auto c = llvm::dyn_cast<AffineConstantExpr>(Val&: e)) |
| 215 | strides.push_back(Elt: c.getValue()); |
| 216 | else |
| 217 | strides.push_back(ShapedType::kDynamic); |
| 218 | } |
| 219 | return success(); |
| 220 | } |
| 221 | |