1 | //===- ViewLikeInterface.cpp - View-like operations in MLIR ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Interfaces/ViewLikeInterface.h" |
10 | |
11 | using namespace mlir; |
12 | |
13 | //===----------------------------------------------------------------------===// |
14 | // ViewLike Interfaces |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | /// Include the definitions of the loop-like interfaces. |
18 | #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" |
19 | |
20 | LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op, |
21 | StringRef name, |
22 | unsigned numElements, |
23 | ArrayRef<int64_t> staticVals, |
24 | ValueRange values) { |
25 | // Check static and dynamic offsets/sizes/strides does not overflow type. |
26 | if (staticVals.size() != numElements) |
27 | return op->emitError(message: "expected " ) << numElements << " " << name |
28 | << " values, got " << staticVals.size(); |
29 | unsigned expectedNumDynamicEntries = |
30 | llvm::count_if(staticVals, [](int64_t staticVal) { |
31 | return ShapedType::isDynamic(staticVal); |
32 | }); |
33 | if (values.size() != expectedNumDynamicEntries) |
34 | return op->emitError(message: "expected " ) |
35 | << expectedNumDynamicEntries << " dynamic " << name << " values" ; |
36 | return success(); |
37 | } |
38 | |
39 | LogicalResult |
40 | mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) { |
41 | std::array<unsigned, 3> maxRanks = op.getArrayAttrMaxRanks(); |
42 | // Offsets can come in 2 flavors: |
43 | // 1. Either single entry (when maxRanks == 1). |
44 | // 2. Or as an array whose rank must match that of the mixed sizes. |
45 | // So that the result type is well-formed. |
46 | if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && // NOLINT |
47 | op.getMixedOffsets().size() != op.getMixedSizes().size()) |
48 | return op->emitError( |
49 | "expected mixed offsets rank to match mixed sizes rank (" ) |
50 | << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size() |
51 | << ") so the rank of the result type is well-formed." ; |
52 | // Ranks of mixed sizes and strides must always match so the result type is |
53 | // well-formed. |
54 | if (op.getMixedSizes().size() != op.getMixedStrides().size()) |
55 | return op->emitError( |
56 | "expected mixed sizes rank to match mixed strides rank (" ) |
57 | << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() |
58 | << ") so the rank of the result type is well-formed." ; |
59 | |
60 | if (failed(verifyListOfOperandsOrIntegers( |
61 | op, "offset" , maxRanks[0], op.getStaticOffsets(), op.getOffsets()))) |
62 | return failure(); |
63 | if (failed(verifyListOfOperandsOrIntegers( |
64 | op, "size" , maxRanks[1], op.getStaticSizes(), op.getSizes()))) |
65 | return failure(); |
66 | if (failed(verifyListOfOperandsOrIntegers( |
67 | op, "stride" , maxRanks[2], op.getStaticStrides(), op.getStrides()))) |
68 | return failure(); |
69 | |
70 | for (int64_t offset : op.getStaticOffsets()) { |
71 | if (offset < 0 && !ShapedType::isDynamic(offset)) |
72 | return op->emitError("expected offsets to be non-negative, but got " ) |
73 | << offset; |
74 | } |
75 | for (int64_t size : op.getStaticSizes()) { |
76 | if (size < 0 && !ShapedType::isDynamic(size)) |
77 | return op->emitError("expected sizes to be non-negative, but got " ) |
78 | << size; |
79 | } |
80 | return success(); |
81 | } |
82 | |
83 | static char getLeftDelimiter(AsmParser::Delimiter delimiter) { |
84 | switch (delimiter) { |
85 | case AsmParser::Delimiter::Paren: |
86 | return '('; |
87 | case AsmParser::Delimiter::LessGreater: |
88 | return '<'; |
89 | case AsmParser::Delimiter::Square: |
90 | return '['; |
91 | case AsmParser::Delimiter::Braces: |
92 | return '{'; |
93 | default: |
94 | llvm_unreachable("unsupported delimiter" ); |
95 | } |
96 | } |
97 | |
98 | static char getRightDelimiter(AsmParser::Delimiter delimiter) { |
99 | switch (delimiter) { |
100 | case AsmParser::Delimiter::Paren: |
101 | return ')'; |
102 | case AsmParser::Delimiter::LessGreater: |
103 | return '>'; |
104 | case AsmParser::Delimiter::Square: |
105 | return ']'; |
106 | case AsmParser::Delimiter::Braces: |
107 | return '}'; |
108 | default: |
109 | llvm_unreachable("unsupported delimiter" ); |
110 | } |
111 | } |
112 | |
113 | void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, |
114 | OperandRange values, |
115 | ArrayRef<int64_t> integers, |
116 | TypeRange valueTypes, ArrayRef<bool> scalables, |
117 | AsmParser::Delimiter delimiter) { |
118 | char leftDelimiter = getLeftDelimiter(delimiter); |
119 | char rightDelimiter = getRightDelimiter(delimiter); |
120 | printer << leftDelimiter; |
121 | if (integers.empty()) { |
122 | printer << rightDelimiter; |
123 | return; |
124 | } |
125 | |
126 | unsigned dynamicValIdx = 0; |
127 | unsigned scalableIndexIdx = 0; |
128 | llvm::interleaveComma(c: integers, os&: printer, each_fn: [&](int64_t integer) { |
129 | if (!scalables.empty() && scalables[scalableIndexIdx]) |
130 | printer << "[" ; |
131 | if (ShapedType::isDynamic(integer)) { |
132 | printer << values[dynamicValIdx]; |
133 | if (!valueTypes.empty()) |
134 | printer << " : " << valueTypes[dynamicValIdx]; |
135 | ++dynamicValIdx; |
136 | } else { |
137 | printer << integer; |
138 | } |
139 | if (!scalables.empty() && scalables[scalableIndexIdx]) |
140 | printer << "]" ; |
141 | |
142 | scalableIndexIdx++; |
143 | }); |
144 | |
145 | printer << rightDelimiter; |
146 | } |
147 | |
148 | ParseResult mlir::parseDynamicIndexList( |
149 | OpAsmParser &parser, |
150 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values, |
151 | DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables, |
152 | SmallVectorImpl<Type> *valueTypes, AsmParser::Delimiter delimiter) { |
153 | |
154 | SmallVector<int64_t, 4> integerVals; |
155 | SmallVector<bool, 4> scalableVals; |
156 | auto parseIntegerOrValue = [&]() { |
157 | OpAsmParser::UnresolvedOperand operand; |
158 | auto res = parser.parseOptionalOperand(result&: operand); |
159 | |
160 | // When encountering `[`, assume that this is a scalable index. |
161 | scalableVals.push_back(Elt: parser.parseOptionalLSquare().succeeded()); |
162 | |
163 | if (res.has_value() && succeeded(result: res.value())) { |
164 | values.push_back(Elt: operand); |
165 | integerVals.push_back(ShapedType::kDynamic); |
166 | if (valueTypes && parser.parseColonType(result&: valueTypes->emplace_back())) |
167 | return failure(); |
168 | } else { |
169 | int64_t integer; |
170 | if (failed(result: parser.parseInteger(result&: integer))) |
171 | return failure(); |
172 | integerVals.push_back(Elt: integer); |
173 | } |
174 | |
175 | // If this is assumed to be a scalable index, verify that there's a closing |
176 | // `]`. |
177 | if (scalableVals.back() && parser.parseOptionalRSquare().failed()) |
178 | return failure(); |
179 | return success(); |
180 | }; |
181 | if (parser.parseCommaSeparatedList(delimiter, parseElementFn: parseIntegerOrValue, |
182 | contextMessage: " in dynamic index list" )) |
183 | return parser.emitError(loc: parser.getNameLoc()) |
184 | << "expected SSA value or integer" ; |
185 | integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); |
186 | scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); |
187 | return success(); |
188 | } |
189 | |
190 | bool mlir::detail::sameOffsetsSizesAndStrides( |
191 | OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, |
192 | llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) { |
193 | if (a.getStaticOffsets().size() != b.getStaticOffsets().size()) |
194 | return false; |
195 | if (a.getStaticSizes().size() != b.getStaticSizes().size()) |
196 | return false; |
197 | if (a.getStaticStrides().size() != b.getStaticStrides().size()) |
198 | return false; |
199 | for (auto it : llvm::zip(a.getMixedOffsets(), b.getMixedOffsets())) |
200 | if (!cmp(std::get<0>(it), std::get<1>(it))) |
201 | return false; |
202 | for (auto it : llvm::zip(a.getMixedSizes(), b.getMixedSizes())) |
203 | if (!cmp(std::get<0>(it), std::get<1>(it))) |
204 | return false; |
205 | for (auto it : llvm::zip(a.getMixedStrides(), b.getMixedStrides())) |
206 | if (!cmp(std::get<0>(it), std::get<1>(it))) |
207 | return false; |
208 | return true; |
209 | } |
210 | |
211 | unsigned mlir::detail::getNumDynamicEntriesUpToIdx(ArrayRef<int64_t> staticVals, |
212 | unsigned idx) { |
213 | return std::count_if(first: staticVals.begin(), last: staticVals.begin() + idx, |
214 | pred: [&](int64_t val) { return ShapedType::isDynamic(val); }); |
215 | } |
216 | |