1//===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Interfaces/ViewLikeInterface.h"
10
11using namespace mlir;
12
13//===----------------------------------------------------------------------===//
14// ViewLike Interfaces
15//===----------------------------------------------------------------------===//
16
17/// Include the definitions of the loop-like interfaces.
18#include "mlir/Interfaces/ViewLikeInterface.cpp.inc"
19
20LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op,
21 StringRef name,
22 unsigned numElements,
23 ArrayRef<int64_t> staticVals,
24 ValueRange values) {
25 // Check static and dynamic offsets/sizes/strides does not overflow type.
26 if (staticVals.size() != numElements)
27 return op->emitError(message: "expected ") << numElements << " " << name
28 << " values, got " << staticVals.size();
29 unsigned expectedNumDynamicEntries =
30 llvm::count_if(staticVals, ShapedType::isDynamic);
31 if (values.size() != expectedNumDynamicEntries)
32 return op->emitError(message: "expected ")
33 << expectedNumDynamicEntries << " dynamic " << name << " values";
34 return success();
35}
36
37SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
38 ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets,
39 ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides,
40 bool generateErrorMessage) {
41 SliceBoundsVerificationResult result;
42 result.isValid = true;
43 for (int64_t i = 0, e = shape.size(); i < e; ++i) {
44 // Nothing to verify for dynamic source dims.
45 if (ShapedType::isDynamic(shape[i]))
46 continue;
47 // Nothing to verify if the offset is dynamic.
48 if (ShapedType::isDynamic(staticOffsets[i]))
49 continue;
50 if (staticOffsets[i] >= shape[i]) {
51 result.errorMessage =
52 std::string("offset ") + std::to_string(val: i) +
53 " is out-of-bounds: " + std::to_string(val: staticOffsets[i]) +
54 " >= " + std::to_string(val: shape[i]);
55 result.isValid = false;
56 return result;
57 }
58 if (ShapedType::isDynamic(staticSizes[i]) ||
59 ShapedType::isDynamic(staticStrides[i]))
60 continue;
61 int64_t lastPos =
62 staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i];
63 if (lastPos >= shape[i]) {
64 result.errorMessage = std::string("slice along dimension ") +
65 std::to_string(val: i) +
66 " runs out-of-bounds: " + std::to_string(val: lastPos) +
67 " >= " + std::to_string(val: shape[i]);
68 result.isValid = false;
69 return result;
70 }
71 }
72 return result;
73}
74
75SliceBoundsVerificationResult mlir::verifyInBoundsSlice(
76 ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets,
77 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides,
78 bool generateErrorMessage) {
79 auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) {
80 SmallVector<int64_t> staticValues;
81 for (OpFoldResult ofr : ofrs) {
82 if (auto attr = dyn_cast<Attribute>(Val&: ofr)) {
83 staticValues.push_back(Elt: cast<IntegerAttr>(attr).getInt());
84 } else {
85 staticValues.push_back(ShapedType::kDynamic);
86 }
87 }
88 return staticValues;
89 };
90 return verifyInBoundsSlice(
91 shape, staticOffsets: getStaticValues(mixedOffsets), staticSizes: getStaticValues(mixedSizes),
92 staticStrides: getStaticValues(mixedStrides), generateErrorMessage);
93}
94
95LogicalResult
96mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
97 std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks();
98 // Offsets can come in 2 flavors:
99 // 1. Either single entry (when maxRanks == 1).
100 // 2. Or as an array whose rank must match that of the mixed sizes.
101 // So that the result type is well-formed.
102 if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT
103 op.getMixedOffsets().size() != op.getMixedSizes().size())
104 return op->emitError(
105 "expected mixed offsets rank to match mixed sizes rank (")
106 << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size()
107 << ") so the rank of the result type is well-formed.";
108 // Ranks of mixed sizes and strides must always match so the result type is
109 // well-formed.
110 if (op.getMixedSizes().size() != op.getMixedStrides().size())
111 return op->emitError(
112 "expected mixed sizes rank to match mixed strides rank (")
113 << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size()
114 << ") so the rank of the result type is well-formed.";
115
116 if (failed(verifyListOfOperandsOrIntegers(
117 op, "offset", maxRanks[0], op.getStaticOffsets(), op.getOffsets())))
118 return failure();
119 if (failed(verifyListOfOperandsOrIntegers(
120 op, "size", maxRanks[1], op.getStaticSizes(), op.getSizes())))
121 return failure();
122 if (failed(verifyListOfOperandsOrIntegers(
123 op, "stride", maxRanks[2], op.getStaticStrides(), op.getStrides())))
124 return failure();
125
126 for (int64_t offset : op.getStaticOffsets()) {
127 if (offset < 0 && !ShapedType::isDynamic(offset))
128 return op->emitError("expected offsets to be non-negative, but got ")
129 << offset;
130 }
131 for (int64_t size : op.getStaticSizes()) {
132 if (size < 0 && !ShapedType::isDynamic(size))
133 return op->emitError("expected sizes to be non-negative, but got ")
134 << size;
135 }
136 return success();
137}
138
139static char getLeftDelimiter(AsmParser::Delimiter delimiter) {
140 switch (delimiter) {
141 case AsmParser::Delimiter::Paren:
142 return '(';
143 case AsmParser::Delimiter::LessGreater:
144 return '<';
145 case AsmParser::Delimiter::Square:
146 return '[';
147 case AsmParser::Delimiter::Braces:
148 return '{';
149 default:
150 llvm_unreachable("unsupported delimiter");
151 }
152}
153
154static char getRightDelimiter(AsmParser::Delimiter delimiter) {
155 switch (delimiter) {
156 case AsmParser::Delimiter::Paren:
157 return ')';
158 case AsmParser::Delimiter::LessGreater:
159 return '>';
160 case AsmParser::Delimiter::Square:
161 return ']';
162 case AsmParser::Delimiter::Braces:
163 return '}';
164 default:
165 llvm_unreachable("unsupported delimiter");
166 }
167}
168
169void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
170 OperandRange values,
171 ArrayRef<int64_t> integers,
172 ArrayRef<bool> scalableFlags,
173 TypeRange valueTypes,
174 AsmParser::Delimiter delimiter) {
175 char leftDelimiter = getLeftDelimiter(delimiter);
176 char rightDelimiter = getRightDelimiter(delimiter);
177 printer << leftDelimiter;
178 if (integers.empty()) {
179 printer << rightDelimiter;
180 return;
181 }
182
183 unsigned dynamicValIdx = 0;
184 unsigned scalableIndexIdx = 0;
185 llvm::interleaveComma(c: integers, os&: printer, each_fn: [&](int64_t integer) {
186 if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
187 printer << "[";
188 if (ShapedType::isDynamic(integer)) {
189 printer << values[dynamicValIdx];
190 if (!valueTypes.empty())
191 printer << " : " << valueTypes[dynamicValIdx];
192 ++dynamicValIdx;
193 } else {
194 printer << integer;
195 }
196 if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx])
197 printer << "]";
198
199 scalableIndexIdx++;
200 });
201
202 printer << rightDelimiter;
203}
204
205ParseResult mlir::parseDynamicIndexList(
206 OpAsmParser &parser,
207 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
208 DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
209 SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) {
210
211 SmallVector<int64_t, 4> integerVals;
212 SmallVector<bool, 4> scalableVals;
213 auto parseIntegerOrValue = [&]() {
214 OpAsmParser::UnresolvedOperand operand;
215 auto res = parser.parseOptionalOperand(result&: operand);
216
217 // When encountering `[`, assume that this is a scalable index.
218 scalableVals.push_back(Elt: parser.parseOptionalLSquare().succeeded());
219
220 if (res.has_value() && succeeded(Result: res.value())) {
221 values.push_back(Elt: operand);
222 integerVals.push_back(ShapedType::kDynamic);
223 if (valueTypes && parser.parseColonType(result&: valueTypes->emplace_back()))
224 return failure();
225 } else {
226 int64_t integer;
227 if (failed(Result: parser.parseInteger(result&: integer)))
228 return failure();
229 integerVals.push_back(Elt: integer);
230 }
231
232 // If this is assumed to be a scalable index, verify that there's a closing
233 // `]`.
234 if (scalableVals.back() && parser.parseOptionalRSquare().failed())
235 return failure();
236 return success();
237 };
238 if (parser.parseCommaSeparatedList(delimiter, parseElementFn: parseIntegerOrValue,
239 contextMessage: " in dynamic index list"))
240 return parser.emitError(loc: parser.getNameLoc())
241 << "expected SSA value or integer";
242 integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
243 scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
244 return success();
245}
246
247bool mlir::detail::sameOffsetsSizesAndStrides(
248 OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
249 llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
250 if (a.getStaticOffsets().size() != b.getStaticOffsets().size())
251 return false;
252 if (a.getStaticSizes().size() != b.getStaticSizes().size())
253 return false;
254 if (a.getStaticStrides().size() != b.getStaticStrides().size())
255 return false;
256 for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets()))
257 if (!cmp(std::get<0>(it), std::get<1>(it)))
258 return false;
259 for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes()))
260 if (!cmp(std::get<0>(it), std::get<1>(it)))
261 return false;
262 for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides()))
263 if (!cmp(std::get<0>(it), std::get<1>(it)))
264 return false;
265 return true;
266}
267
268unsigned mlir::detail::getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals,
269 unsigned idx) {
270 return std::count_if(staticVals.begin(), staticVals.begin() + idx,
271 ShapedType::isDynamic);
272}
273

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Interfaces/ViewLikeInterface.cpp