1//===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Interfaces/ViewLikeInterface.h"
10
11using namespace mlir;
12
13//===----------------------------------------------------------------------===//
14// ViewLike Interfaces
15//===----------------------------------------------------------------------===//
16
17/// Include the definitions of the loop-like interfaces.
18#include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19
20LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
21 StringRef name,
22 unsigned numElements,
23 ArrayRef<int64_t> staticVals,
24 ValueRange values) {
25 // Check static and dynamic offsets/sizes/strides does not overflow type.
26 if (staticVals.size() != numElements)
27 return op->emitError(message: "expected ") << numElements << " " << name
28 << " values, got " << staticVals.size();
29 unsigned expectedNumDynamicEntries =
30 llvm::count_if(staticVals, [](int64_t staticVal) {
31 return ShapedType::isDynamic(staticVal);
32 });
33 if (values.size() != expectedNumDynamicEntries)
34 return op->emitError(message: "expected ")
35 << expectedNumDynamicEntries << " dynamic " << name << " values";
36 return success();
37}
38
39LogicalResult
40mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
41 std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
42 // Offsets can come in 2 flavors:
43 // 1. Either single entry (when maxRanks == 1).
44 // 2. Or as an array whose rank must match that of the mixed sizes.
45 // So that the result type is well-formed.
46 if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
47 op.getMixedOffsets().size() != op.getMixedSizes().size())
48 return op->emitError(
49 "expected mixed offsets rank to match mixed sizes rank (")
50 << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
51 << ") so the rank of the result type is well-formed.";
52 // Ranks of mixed sizes and strides must always match so the result type is
53 // well-formed.
54 if (op.getMixedSizes().size() != op.getMixedStrides().size())
55 return op->emitError(
56 "expected mixed sizes rank to match mixed strides rank (")
57 << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
58 << ") so the rank of the result type is well-formed.";
59
60 if (failed(verifyListOfOperandsOrIntegers(
61 op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets())))
62 return failure();
63 if (failed(verifyListOfOperandsOrIntegers(
64 op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes())))
65 return failure();
66 if (failed(verifyListOfOperandsOrIntegers(
67 op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides())))
68 return failure();
69
70 for (int64_t offset : op.getStaticOffsets()) {
71 if (offset < 0 && !ShapedType::isDynamic(offset))
72 return op->emitError("expected offsets to be non-negative, but got ")
73 << offset;
74 }
75 for (int64_t size : op.getStaticSizes()) {
76 if (size < 0 && !ShapedType::isDynamic(size))
77 return op->emitError("expected sizes to be non-negative, but got ")
78 << size;
79 }
80 return success();
81}
82
83static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
84 switch (delimiter) {
85 case AsmParser::Delimiter::Paren:
86 return '(';
87 case AsmParser::Delimiter::LessGreater:
88 return '<';
89 case AsmParser::Delimiter::Square:
90 return '[';
91 case AsmParser::Delimiter::Braces:
92 return '{';
93 default:
94 llvm_unreachable("unsupported delimiter");
95 }
96}
97
98static char getRightDelimiter(AsmParser::Delimiter delimiter) {
99 switch (delimiter) {
100 case AsmParser::Delimiter::Paren:
101 return ')';
102 case AsmParser::Delimiter::LessGreater:
103 return '>';
104 case AsmParser::Delimiter::Square:
105 return ']';
106 case AsmParser::Delimiter::Braces:
107 return '}';
108 default:
109 llvm_unreachable("unsupported delimiter");
110 }
111}
112
113void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
114 OperandRange values,
115 ArrayRef<int64_t> integers,
116 TypeRange valueTypes, ArrayRef<bool> scalables,
117 AsmParser::Delimiter delimiter) {
118 char leftDelimiter = getLeftDelimiter(delimiter);
119 char rightDelimiter = getRightDelimiter(delimiter);
120 printer << leftDelimiter;
121 if (integers.empty()) {
122 printer << rightDelimiter;
123 return;
124 }
125
126 unsigned dynamicValIdx = 0;
127 unsigned scalableIndexIdx = 0;
128 llvm::interleaveComma(c: integers, os&: printer, each_fn: [&](int64_t integer) {
129 if (!scalables.empty() && scalables[scalableIndexIdx])
130 printer << "[";
131 if (ShapedType::isDynamic(integer)) {
132 printer << values[dynamicValIdx];
133 if (!valueTypes.empty())
134 printer << " : " << valueTypes[dynamicValIdx];
135 ++dynamicValIdx;
136 } else {
137 printer << integer;
138 }
139 if (!scalables.empty() && scalables[scalableIndexIdx])
140 printer << "]";
141
142 scalableIndexIdx++;
143 });
144
145 printer << rightDelimiter;
146}
147
148ParseResult mlir::parseDynamicIndexList(
149 OpAsmParser &parser,
150 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
151 DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables,
152 SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
153
154 SmallVector<int64_t, 4> integerVals;
155 SmallVector<bool, 4> scalableVals;
156 auto parseIntegerOrValue = [&]() {
157 OpAsmParser::UnresolvedOperand operand;
158 auto res = parser.parseOptionalOperand(result&: operand);
159
160 // When encountering `[`, assume that this is a scalable index.
161 scalableVals.push_back(Elt: parser.parseOptionalLSquare().succeeded());
162
163 if (res.has_value() && succeeded(result: res.value())) {
164 values.push_back(Elt: operand);
165 integerVals.push_back(ShapedType::kDynamic);
166 if (valueTypes && parser.parseColonType(result&: valueTypes->emplace_back()))
167 return failure();
168 } else {
169 int64_t integer;
170 if (failed(result: parser.parseInteger(result&: integer)))
171 return failure();
172 integerVals.push_back(Elt: integer);
173 }
174
175 // If this is assumed to be a scalable index, verify that there's a closing
176 // `]`.
177 if (scalableVals.back() && parser.parseOptionalRSquare().failed())
178 return failure();
179 return success();
180 };
181 if (parser.parseCommaSeparatedList(delimiter, parseElementFn: parseIntegerOrValue,
182 contextMessage: " in dynamic index list"))
183 return parser.emitError(loc: parser.getNameLoc())
184 << "expected SSA value or integer";
185 integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
186 scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
187 return success();
188}
189
190bool mlir::detail::sameOffsetsSizesAndStrides(
191 OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
192 llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
193 if (a.getStaticOffsets().size() != b.getStaticOffsets().size())
194 return false;
195 if (a.getStaticSizes().size() != b.getStaticSizes().size())
196 return false;
197 if (a.getStaticStrides().size() != b.getStaticStrides().size())
198 return false;
199 for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
200 if (!cmp(std::get<0>(it), std::get<1>(it)))
201 return false;
202 for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
203 if (!cmp(std::get<0>(it), std::get<1>(it)))
204 return false;
205 for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
206 if (!cmp(std::get<0>(it), std::get<1>(it)))
207 return false;
208 return true;
209}
210
211unsigned mlir::detail::getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
212 unsigned idx) {
213 return std::count_if(first: staticVals.begin(), last: staticVals.begin() + idx,
214 pred: [&](int64_t val) { return ShapedType::isDynamic(val); });
215}
216

source code of mlir/lib/Interfaces/ViewLikeInterface.cpp