1 | //===- BuiltinAttributeInterfaces.cpp -------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/BuiltinAttributeInterfaces.h" |
10 | #include "mlir/IR/BuiltinTypes.h" |
11 | #include "mlir/IR/Diagnostics.h" |
12 | #include "llvm/ADT/Sequence.h" |
13 | |
14 | using namespace mlir; |
15 | using namespace mlir::detail; |
16 | |
17 | //===----------------------------------------------------------------------===// |
18 | /// Tablegen Interface Definitions |
19 | //===----------------------------------------------------------------------===// |
20 | |
21 | #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc" |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // ElementsAttr |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | Type ElementsAttr::getElementType(ElementsAttr elementsAttr) { |
28 | return elementsAttr.getShapedType().getElementType(); |
29 | } |
30 | |
31 | int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) { |
32 | return elementsAttr.getShapedType().getNumElements(); |
33 | } |
34 | |
35 | bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) { |
36 | // Verify that the rank of the indices matches the held type. |
37 | int64_t rank = type.getRank(); |
38 | if (rank == 0 && index.size() == 1 && index[0] == 0) |
39 | return true; |
40 | if (rank != static_cast<int64_t>(index.size())) |
41 | return false; |
42 | |
43 | // Verify that all of the indices are within the shape dimensions. |
44 | ArrayRef<int64_t> shape = type.getShape(); |
45 | return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { |
46 | int64_t dim = static_cast<int64_t>(index[i]); |
47 | return 0 <= dim && dim < shape[i]; |
48 | }); |
49 | } |
50 | bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr, |
51 | ArrayRef<uint64_t> index) { |
52 | return isValidIndex(elementsAttr.getShapedType(), index); |
53 | } |
54 | |
55 | uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) { |
56 | ShapedType shapeType = llvm::cast<ShapedType>(type); |
57 | assert(isValidIndex(shapeType, index) && |
58 | "expected valid multi-dimensional index" ); |
59 | |
60 | // Reduce the provided multidimensional index into a flattended 1D row-major |
61 | // index. |
62 | auto rank = shapeType.getRank(); |
63 | ArrayRef<int64_t> shape = shapeType.getShape(); |
64 | uint64_t valueIndex = 0; |
65 | uint64_t dimMultiplier = 1; |
66 | for (int i = rank - 1; i >= 0; --i) { |
67 | valueIndex += index[i] * dimMultiplier; |
68 | dimMultiplier *= shape[i]; |
69 | } |
70 | return valueIndex; |
71 | } |
72 | |
73 | //===----------------------------------------------------------------------===// |
74 | // MemRefLayoutAttrInterface |
75 | //===----------------------------------------------------------------------===// |
76 | |
77 | LogicalResult mlir::detail::verifyAffineMapAsLayout( |
78 | AffineMap m, ArrayRef<int64_t> shape, |
79 | function_ref<InFlightDiagnostic()> emitError) { |
80 | if (m.getNumDims() != shape.size()) |
81 | return emitError() << "memref layout mismatch between rank and affine map: " |
82 | << shape.size() << " != " << m.getNumDims(); |
83 | |
84 | return success(); |
85 | } |
86 | |