1 | //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Interfaces/ViewLikeInterface.h" |
10 | |
11 | using namespace mlir; |
12 | |
13 | //===----------------------------------------------------------------------===// |
14 | // ViewLike Interfaces |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | /// Include the definitions of the loop-like interfaces. |
18 | #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" |
19 | |
20 | LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op, |
21 | StringRef name, |
22 | unsigned numElements, |
23 | ArrayRef<int64_t> staticVals, |
24 | ValueRange values) { |
25 | // Check static and dynamic offsets/sizes/strides does not overflow type. |
26 | if (staticVals.size() != numElements) |
27 | return op->emitError(message: "expected " ) << numElements << " " << name |
28 | << " values, got " << staticVals.size(); |
29 | unsigned expectedNumDynamicEntries = |
30 | llvm::count_if(staticVals, ShapedType::isDynamic); |
31 | if (values.size() != expectedNumDynamicEntries) |
32 | return op->emitError(message: "expected " ) |
33 | << expectedNumDynamicEntries << " dynamic " << name << " values" ; |
34 | return success(); |
35 | } |
36 | |
37 | SliceBoundsVerificationResult mlir::verifyInBoundsSlice( |
38 | ArrayRef<int64_t> shape, ArrayRef<int64_t> staticOffsets, |
39 | ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides, |
40 | bool generateErrorMessage) { |
41 | SliceBoundsVerificationResult result; |
42 | result.isValid = true; |
43 | for (int64_t i = 0, e = shape.size(); i < e; ++i) { |
44 | // Nothing to verify for dynamic source dims. |
45 | if (ShapedType::isDynamic(shape[i])) |
46 | continue; |
47 | // Nothing to verify if the offset is dynamic. |
48 | if (ShapedType::isDynamic(staticOffsets[i])) |
49 | continue; |
50 | if (staticOffsets[i] >= shape[i]) { |
51 | result.errorMessage = |
52 | std::string("offset " ) + std::to_string(val: i) + |
53 | " is out-of-bounds: " + std::to_string(val: staticOffsets[i]) + |
54 | " >= " + std::to_string(val: shape[i]); |
55 | result.isValid = false; |
56 | return result; |
57 | } |
58 | if (ShapedType::isDynamic(staticSizes[i]) || |
59 | ShapedType::isDynamic(staticStrides[i])) |
60 | continue; |
61 | int64_t lastPos = |
62 | staticOffsets[i] + (staticSizes[i] - 1) * staticStrides[i]; |
63 | if (lastPos >= shape[i]) { |
64 | result.errorMessage = std::string("slice along dimension " ) + |
65 | std::to_string(val: i) + |
66 | " runs out-of-bounds: " + std::to_string(val: lastPos) + |
67 | " >= " + std::to_string(val: shape[i]); |
68 | result.isValid = false; |
69 | return result; |
70 | } |
71 | } |
72 | return result; |
73 | } |
74 | |
75 | SliceBoundsVerificationResult mlir::verifyInBoundsSlice( |
76 | ArrayRef<int64_t> shape, ArrayRef<OpFoldResult> mixedOffsets, |
77 | ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides, |
78 | bool generateErrorMessage) { |
79 | auto getStaticValues = [](ArrayRef<OpFoldResult> ofrs) { |
80 | SmallVector<int64_t> staticValues; |
81 | for (OpFoldResult ofr : ofrs) { |
82 | if (auto attr = dyn_cast<Attribute>(Val&: ofr)) { |
83 | staticValues.push_back(Elt: cast<IntegerAttr>(attr).getInt()); |
84 | } else { |
85 | staticValues.push_back(ShapedType::kDynamic); |
86 | } |
87 | } |
88 | return staticValues; |
89 | }; |
90 | return verifyInBoundsSlice( |
91 | shape, staticOffsets: getStaticValues(mixedOffsets), staticSizes: getStaticValues(mixedSizes), |
92 | staticStrides: getStaticValues(mixedStrides), generateErrorMessage); |
93 | } |
94 | |
95 | LogicalResult |
96 | mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { |
97 | std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks(); |
98 | // Offsets can come in 2 flavors: |
99 | // 1. Either single entry (when maxRanks == 1). |
100 | // 2. Or as an array whose rank must match that of the mixed sizes. |
101 | // So that the result type is well-formed. |
102 | if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT |
103 | op.getMixedOffsets().size() != op.getMixedSizes().size()) |
104 | return op->emitError( |
105 | "expected mixed offsets rank to match mixed sizes rank (" ) |
106 | << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size() |
107 | << ") so the rank of the result type is well-formed." ; |
108 | // Ranks of mixed sizes and strides must always match so the result type is |
109 | // well-formed. |
110 | if (op.getMixedSizes().size() != op.getMixedStrides().size()) |
111 | return op->emitError( |
112 | "expected mixed sizes rank to match mixed strides rank (" ) |
113 | << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() |
114 | << ") so the rank of the result type is well-formed." ; |
115 | |
116 | if (failed(verifyListOfOperandsOrIntegers( |
117 | op, "offset" , maxRanks[0], op.getStaticOffsets(), op.getOffsets()))) |
118 | return failure(); |
119 | if (failed(verifyListOfOperandsOrIntegers( |
120 | op, "size" , maxRanks[1], op.getStaticSizes(), op.getSizes()))) |
121 | return failure(); |
122 | if (failed(verifyListOfOperandsOrIntegers( |
123 | op, "stride" , maxRanks[2], op.getStaticStrides(), op.getStrides()))) |
124 | return failure(); |
125 | |
126 | for (int64_t offset : op.getStaticOffsets()) { |
127 | if (offset < 0 && !ShapedType::isDynamic(offset)) |
128 | return op->emitError("expected offsets to be non-negative, but got " ) |
129 | << offset; |
130 | } |
131 | for (int64_t size : op.getStaticSizes()) { |
132 | if (size < 0 && !ShapedType::isDynamic(size)) |
133 | return op->emitError("expected sizes to be non-negative, but got " ) |
134 | << size; |
135 | } |
136 | return success(); |
137 | } |
138 | |
139 | static char getLeftDelimiter(AsmParser::Delimiter delimiter) { |
140 | switch (delimiter) { |
141 | case AsmParser::Delimiter::Paren: |
142 | return '('; |
143 | case AsmParser::Delimiter::LessGreater: |
144 | return '<'; |
145 | case AsmParser::Delimiter::Square: |
146 | return '['; |
147 | case AsmParser::Delimiter::Braces: |
148 | return '{'; |
149 | default: |
150 | llvm_unreachable("unsupported delimiter" ); |
151 | } |
152 | } |
153 | |
154 | static char getRightDelimiter(AsmParser::Delimiter delimiter) { |
155 | switch (delimiter) { |
156 | case AsmParser::Delimiter::Paren: |
157 | return ')'; |
158 | case AsmParser::Delimiter::LessGreater: |
159 | return '>'; |
160 | case AsmParser::Delimiter::Square: |
161 | return ']'; |
162 | case AsmParser::Delimiter::Braces: |
163 | return '}'; |
164 | default: |
165 | llvm_unreachable("unsupported delimiter" ); |
166 | } |
167 | } |
168 | |
169 | void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, |
170 | OperandRange values, |
171 | ArrayRef<int64_t> integers, |
172 | ArrayRef<bool> scalableFlags, |
173 | TypeRange valueTypes, |
174 | AsmParser::Delimiter delimiter) { |
175 | char leftDelimiter = getLeftDelimiter(delimiter); |
176 | char rightDelimiter = getRightDelimiter(delimiter); |
177 | printer << leftDelimiter; |
178 | if (integers.empty()) { |
179 | printer << rightDelimiter; |
180 | return; |
181 | } |
182 | |
183 | unsigned dynamicValIdx = 0; |
184 | unsigned scalableIndexIdx = 0; |
185 | llvm::interleaveComma(c: integers, os&: printer, each_fn: [&](int64_t integer) { |
186 | if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx]) |
187 | printer << "[" ; |
188 | if (ShapedType::isDynamic(integer)) { |
189 | printer << values[dynamicValIdx]; |
190 | if (!valueTypes.empty()) |
191 | printer << " : " << valueTypes[dynamicValIdx]; |
192 | ++dynamicValIdx; |
193 | } else { |
194 | printer << integer; |
195 | } |
196 | if (!scalableFlags.empty() && scalableFlags[scalableIndexIdx]) |
197 | printer << "]" ; |
198 | |
199 | scalableIndexIdx++; |
200 | }); |
201 | |
202 | printer << rightDelimiter; |
203 | } |
204 | |
205 | ParseResult mlir::parseDynamicIndexList( |
206 | OpAsmParser &parser, |
207 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
208 | DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, |
209 | SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) { |
210 | |
211 | SmallVector<int64_t, 4> integerVals; |
212 | SmallVector<bool, 4> scalableVals; |
213 | auto parseIntegerOrValue = [&]() { |
214 | OpAsmParser::UnresolvedOperand operand; |
215 | auto res = parser.parseOptionalOperand(result&: operand); |
216 | |
217 | // When encountering `[`, assume that this is a scalable index. |
218 | scalableVals.push_back(Elt: parser.parseOptionalLSquare().succeeded()); |
219 | |
220 | if (res.has_value() && succeeded(Result: res.value())) { |
221 | values.push_back(Elt: operand); |
222 | integerVals.push_back(ShapedType::kDynamic); |
223 | if (valueTypes && parser.parseColonType(result&: valueTypes->emplace_back())) |
224 | return failure(); |
225 | } else { |
226 | int64_t integer; |
227 | if (failed(Result: parser.parseInteger(result&: integer))) |
228 | return failure(); |
229 | integerVals.push_back(Elt: integer); |
230 | } |
231 | |
232 | // If this is assumed to be a scalable index, verify that there's a closing |
233 | // `]`. |
234 | if (scalableVals.back() && parser.parseOptionalRSquare().failed()) |
235 | return failure(); |
236 | return success(); |
237 | }; |
238 | if (parser.parseCommaSeparatedList(delimiter, parseElementFn: parseIntegerOrValue, |
239 | contextMessage: " in dynamic index list" )) |
240 | return parser.emitError(loc: parser.getNameLoc()) |
241 | << "expected SSA value or integer" ; |
242 | integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); |
243 | scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); |
244 | return success(); |
245 | } |
246 | |
247 | bool mlir::detail::sameOffsetsSizesAndStrides( |
248 | OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, |
249 | llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) { |
250 | if (a.getStaticOffsets().size() != b.getStaticOffsets().size()) |
251 | return false; |
252 | if (a.getStaticSizes().size() != b.getStaticSizes().size()) |
253 | return false; |
254 | if (a.getStaticStrides().size() != b.getStaticStrides().size()) |
255 | return false; |
256 | for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets())) |
257 | if (!cmp(std::get<0>(it), std::get<1>(it))) |
258 | return false; |
259 | for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes())) |
260 | if (!cmp(std::get<0>(it), std::get<1>(it))) |
261 | return false; |
262 | for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides())) |
263 | if (!cmp(std::get<0>(it), std::get<1>(it))) |
264 | return false; |
265 | return true; |
266 | } |
267 | |
268 | unsigned mlir::detail::getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals, |
269 | unsigned idx) { |
270 | return std::count_if(staticVals.begin(), staticVals.begin() + idx, |
271 | ShapedType::isDynamic); |
272 | } |
273 | |