1//===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinAttributeInterfaces.h"
10#include "mlir/IR/BuiltinTypes.h"
11#include "mlir/IR/Diagnostics.h"
12#include "llvm/ADT/Sequence.h"
13
14using namespace mlir;
15using namespace mlir::detail;
16
17//===----------------------------------------------------------------------===//
18/// Tablegen Interface Definitions
19//===----------------------------------------------------------------------===//
20
21#include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
22
23//===----------------------------------------------------------------------===//
24// ElementsAttr
25//===----------------------------------------------------------------------===//
26
27Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
28 return elementsAttr.getShapedType().getElementType();
29}
30
31int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
32 return elementsAttr.getShapedType().getNumElements();
33}
34
35bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
36 // Verify that the rank of the indices matches the held type.
37 int64_t rank = type.getRank();
38 if (rank == 0 && index.size() == 1 && index[0] == 0)
39 return true;
40 if (rank != static_cast<int64_t>(index.size()))
41 return false;
42
43 // Verify that all of the indices are within the shape dimensions.
44 ArrayRef<int64_t> shape = type.getShape();
45 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
46 int64_t dim = static_cast<int64_t>(index[i]);
47 return 0 <= dim && dim < shape[i];
48 });
49}
50bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
51 ArrayRef<uint64_t> index) {
52 return isValidIndex(elementsAttr.getShapedType(), index);
53}
54
55uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
56 ShapedType shapeType = llvm::cast<ShapedType>(type);
57 assert(isValidIndex(shapeType, index) &&
58 "expected valid multi-dimensional index");
59
60 // Reduce the provided multidimensional index into a flattended 1D row-major
61 // index.
62 auto rank = shapeType.getRank();
63 ArrayRef<int64_t> shape = shapeType.getShape();
64 uint64_t valueIndex = 0;
65 uint64_t dimMultiplier = 1;
66 for (int i = rank - 1; i >= 0; --i) {
67 valueIndex += index[i] * dimMultiplier;
68 dimMultiplier *= shape[i];
69 }
70 return valueIndex;
71}
72
73//===----------------------------------------------------------------------===//
74// MemRefLayoutAttrInterface
75//===----------------------------------------------------------------------===//
76
77LogicalResult mlir::detail::verifyAffineMapAsLayout(
78 AffineMap m, ArrayRef<int64_t> shape,
79 function_ref<InFlightDiagnostic()> emitError) {
80 if (m.getNumDims() != shape.size())
81 return emitError() << "memref layout mismatch between rank and affine map: "
82 << shape.size() << " != " << m.getNumDims();
83
84 return success();
85}
86

source code of mlir/lib/IR/BuiltinAttributeInterfaces.cpp