1 | //===- TestFormatUtils.cpp - MLIR Test Dialect Assembly Format Utilities --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestFormatUtils.h" |
10 | #include "mlir/IR/Builders.h" |
11 | |
12 | using namespace mlir; |
13 | using namespace test; |
14 | |
15 | //===----------------------------------------------------------------------===// |
16 | // CustomDirectiveOperands |
17 | //===----------------------------------------------------------------------===// |
18 | |
19 | ParseResult test::parseCustomDirectiveOperands( |
20 | OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, |
21 | std::optional<OpAsmParser::UnresolvedOperand> &optOperand, |
22 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) { |
23 | if (parser.parseOperand(result&: operand)) |
24 | return failure(); |
25 | if (succeeded(result: parser.parseOptionalComma())) { |
26 | optOperand.emplace(); |
27 | if (parser.parseOperand(result&: *optOperand)) |
28 | return failure(); |
29 | } |
30 | if (parser.parseArrow() || parser.parseLParen() || |
31 | parser.parseOperandList(result&: varOperands) || parser.parseRParen()) |
32 | return failure(); |
33 | return success(); |
34 | } |
35 | |
36 | void test::printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, |
37 | Value operand, Value optOperand, |
38 | OperandRange varOperands) { |
39 | printer << operand; |
40 | if (optOperand) |
41 | printer << ", " << optOperand; |
42 | printer << " -> (" << varOperands << ")" ; |
43 | } |
44 | |
45 | //===----------------------------------------------------------------------===// |
46 | // CustomDirectiveResults |
47 | //===----------------------------------------------------------------------===// |
48 | |
49 | ParseResult |
50 | test::parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, |
51 | Type &optOperandType, |
52 | SmallVectorImpl<Type> &varOperandTypes) { |
53 | if (parser.parseColon()) |
54 | return failure(); |
55 | |
56 | if (parser.parseType(result&: operandType)) |
57 | return failure(); |
58 | if (succeeded(result: parser.parseOptionalComma())) |
59 | if (parser.parseType(result&: optOperandType)) |
60 | return failure(); |
61 | if (parser.parseArrow() || parser.parseLParen() || |
62 | parser.parseTypeList(result&: varOperandTypes) || parser.parseRParen()) |
63 | return failure(); |
64 | return success(); |
65 | } |
66 | |
67 | void test::printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, |
68 | Type operandType, Type optOperandType, |
69 | TypeRange varOperandTypes) { |
70 | printer << " : " << operandType; |
71 | if (optOperandType) |
72 | printer << ", " << optOperandType; |
73 | printer << " -> (" << varOperandTypes << ")" ; |
74 | } |
75 | |
76 | //===----------------------------------------------------------------------===// |
77 | // CustomDirectiveWithTypeRefs |
78 | //===----------------------------------------------------------------------===// |
79 | |
80 | ParseResult test::parseCustomDirectiveWithTypeRefs( |
81 | OpAsmParser &parser, Type operandType, Type optOperandType, |
82 | const SmallVectorImpl<Type> &varOperandTypes) { |
83 | if (parser.parseKeyword(keyword: "type_refs_capture" )) |
84 | return failure(); |
85 | |
86 | Type operandType2, optOperandType2; |
87 | SmallVector<Type, 1> varOperandTypes2; |
88 | if (parseCustomDirectiveResults(parser, operandType&: operandType2, optOperandType&: optOperandType2, |
89 | varOperandTypes&: varOperandTypes2)) |
90 | return failure(); |
91 | |
92 | if (operandType != operandType2 || optOperandType != optOperandType2 || |
93 | varOperandTypes != varOperandTypes2) |
94 | return failure(); |
95 | |
96 | return success(); |
97 | } |
98 | |
99 | void test::printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, |
100 | Operation *op, Type operandType, |
101 | Type optOperandType, |
102 | TypeRange varOperandTypes) { |
103 | printer << " type_refs_capture " ; |
104 | printCustomDirectiveResults(printer, op, operandType, optOperandType, |
105 | varOperandTypes); |
106 | } |
107 | |
108 | //===----------------------------------------------------------------------===// |
109 | // CustomDirectiveOperandsAndTypes |
110 | //===----------------------------------------------------------------------===// |
111 | |
112 | ParseResult test::parseCustomDirectiveOperandsAndTypes( |
113 | OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, |
114 | std::optional<OpAsmParser::UnresolvedOperand> &optOperand, |
115 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands, |
116 | Type &operandType, Type &optOperandType, |
117 | SmallVectorImpl<Type> &varOperandTypes) { |
118 | if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || |
119 | parseCustomDirectiveResults(parser, operandType, optOperandType, |
120 | varOperandTypes)) |
121 | return failure(); |
122 | return success(); |
123 | } |
124 | |
125 | void test::printCustomDirectiveOperandsAndTypes( |
126 | OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, |
127 | OperandRange varOperands, Type operandType, Type optOperandType, |
128 | TypeRange varOperandTypes) { |
129 | printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); |
130 | printCustomDirectiveResults(printer, op, operandType, optOperandType, |
131 | varOperandTypes); |
132 | } |
133 | |
134 | //===----------------------------------------------------------------------===// |
135 | // CustomDirectiveRegions |
136 | //===----------------------------------------------------------------------===// |
137 | |
138 | ParseResult test::parseCustomDirectiveRegions( |
139 | OpAsmParser &parser, Region ®ion, |
140 | SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { |
141 | if (parser.parseRegion(region)) |
142 | return failure(); |
143 | if (failed(result: parser.parseOptionalComma())) |
144 | return success(); |
145 | std::unique_ptr<Region> varRegion = std::make_unique<Region>(); |
146 | if (parser.parseRegion(region&: *varRegion)) |
147 | return failure(); |
148 | varRegions.emplace_back(Args: std::move(varRegion)); |
149 | return success(); |
150 | } |
151 | |
152 | void test::printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, |
153 | Region ®ion, |
154 | MutableArrayRef<Region> varRegions) { |
155 | printer.printRegion(blocks&: region); |
156 | if (!varRegions.empty()) { |
157 | printer << ", " ; |
158 | for (Region ®ion : varRegions) |
159 | printer.printRegion(blocks&: region); |
160 | } |
161 | } |
162 | |
163 | //===----------------------------------------------------------------------===// |
164 | // CustomDirectiveSuccessors |
165 | //===----------------------------------------------------------------------===// |
166 | |
167 | ParseResult |
168 | test::parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, |
169 | SmallVectorImpl<Block *> &varSuccessors) { |
170 | if (parser.parseSuccessor(dest&: successor)) |
171 | return failure(); |
172 | if (failed(result: parser.parseOptionalComma())) |
173 | return success(); |
174 | Block *varSuccessor; |
175 | if (parser.parseSuccessor(dest&: varSuccessor)) |
176 | return failure(); |
177 | varSuccessors.append(NumInputs: 2, Elt: varSuccessor); |
178 | return success(); |
179 | } |
180 | |
181 | void test::printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, |
182 | Block *successor, |
183 | SuccessorRange varSuccessors) { |
184 | printer << successor; |
185 | if (!varSuccessors.empty()) |
186 | printer << ", " << varSuccessors.front(); |
187 | } |
188 | |
189 | //===----------------------------------------------------------------------===// |
190 | // CustomDirectiveAttributes |
191 | //===----------------------------------------------------------------------===// |
192 | |
193 | ParseResult test::parseCustomDirectiveAttributes(OpAsmParser &parser, |
194 | IntegerAttr &attr, |
195 | IntegerAttr &optAttr) { |
196 | if (parser.parseAttribute(result&: attr)) |
197 | return failure(); |
198 | if (succeeded(result: parser.parseOptionalComma())) { |
199 | if (parser.parseAttribute(result&: optAttr)) |
200 | return failure(); |
201 | } |
202 | return success(); |
203 | } |
204 | |
205 | void test::printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, |
206 | Attribute attribute, |
207 | Attribute optAttribute) { |
208 | printer << attribute; |
209 | if (optAttribute) |
210 | printer << ", " << optAttribute; |
211 | } |
212 | |
213 | //===----------------------------------------------------------------------===// |
214 | // CustomDirectiveAttrDict |
215 | //===----------------------------------------------------------------------===// |
216 | |
217 | ParseResult test::parseCustomDirectiveAttrDict(OpAsmParser &parser, |
218 | NamedAttrList &attrs) { |
219 | return parser.parseOptionalAttrDict(result&: attrs); |
220 | } |
221 | |
222 | void test::printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, |
223 | DictionaryAttr attrs) { |
224 | printer.printOptionalAttrDict(attrs: attrs.getValue()); |
225 | } |
226 | |
227 | //===----------------------------------------------------------------------===// |
228 | // CustomDirectiveOptionalOperandRef |
229 | //===----------------------------------------------------------------------===// |
230 | |
231 | ParseResult test::parseCustomDirectiveOptionalOperandRef( |
232 | OpAsmParser &parser, |
233 | std::optional<OpAsmParser::UnresolvedOperand> &optOperand) { |
234 | int64_t operandCount = 0; |
235 | if (parser.parseInteger(result&: operandCount)) |
236 | return failure(); |
237 | bool expectedOptionalOperand = operandCount == 0; |
238 | return success(isSuccess: expectedOptionalOperand != !!optOperand); |
239 | } |
240 | |
241 | void test::printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, |
242 | Operation *op, |
243 | Value optOperand) { |
244 | printer << (optOperand ? "1" : "0" ); |
245 | } |
246 | |
247 | //===----------------------------------------------------------------------===// |
248 | // CustomDirectiveOptionalOperand |
249 | //===----------------------------------------------------------------------===// |
250 | |
251 | ParseResult test::parseCustomOptionalOperand( |
252 | OpAsmParser &parser, |
253 | std::optional<OpAsmParser::UnresolvedOperand> &optOperand) { |
254 | if (succeeded(result: parser.parseOptionalLParen())) { |
255 | optOperand.emplace(); |
256 | if (parser.parseOperand(result&: *optOperand) || parser.parseRParen()) |
257 | return failure(); |
258 | } |
259 | return success(); |
260 | } |
261 | |
262 | void test::printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, |
263 | Value optOperand) { |
264 | if (optOperand) |
265 | printer << "(" << optOperand << ") " ; |
266 | } |
267 | |
268 | //===----------------------------------------------------------------------===// |
269 | // CustomDirectiveSwitchCases |
270 | //===----------------------------------------------------------------------===// |
271 | |
272 | ParseResult |
273 | test::parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, |
274 | SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) { |
275 | SmallVector<int64_t> caseValues; |
276 | while (succeeded(result: p.parseOptionalKeyword(keyword: "case" ))) { |
277 | int64_t value; |
278 | Region ®ion = *caseRegions.emplace_back(Args: std::make_unique<Region>()); |
279 | if (p.parseInteger(result&: value) || p.parseRegion(region, /*arguments=*/{})) |
280 | return failure(); |
281 | caseValues.push_back(Elt: value); |
282 | } |
283 | cases = p.getBuilder().getDenseI64ArrayAttr(caseValues); |
284 | return success(); |
285 | } |
286 | |
287 | void test::printSwitchCases(OpAsmPrinter &p, Operation *op, |
288 | DenseI64ArrayAttr cases, RegionRange caseRegions) { |
289 | for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { |
290 | p.printNewline(); |
291 | p << "case " << value << ' '; |
292 | p.printRegion(*region, /*printEntryBlockArgs=*/false); |
293 | } |
294 | } |
295 | |
296 | //===----------------------------------------------------------------------===// |
297 | // CustomUsingPropertyInCustom |
298 | //===----------------------------------------------------------------------===// |
299 | |
300 | bool test::parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) { |
301 | return parser.parseLSquare() || parser.parseInteger(result&: value[0]) || |
302 | parser.parseComma() || parser.parseInteger(result&: value[1]) || |
303 | parser.parseComma() || parser.parseInteger(result&: value[2]) || |
304 | parser.parseRSquare(); |
305 | } |
306 | |
307 | void test::printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op, |
308 | ArrayRef<int64_t> value) { |
309 | printer << '[' << value << ']'; |
310 | } |
311 | |
312 | //===----------------------------------------------------------------------===// |
313 | // CustomDirectiveIntProperty |
314 | //===----------------------------------------------------------------------===// |
315 | |
316 | bool test::parseIntProperty(OpAsmParser &parser, int64_t &value) { |
317 | return failed(result: parser.parseInteger(result&: value)); |
318 | } |
319 | |
320 | void test::printIntProperty(OpAsmPrinter &printer, Operation *op, |
321 | int64_t value) { |
322 | printer << value; |
323 | } |
324 | |
325 | //===----------------------------------------------------------------------===// |
326 | // CustomDirectiveSumProperty |
327 | //===----------------------------------------------------------------------===// |
328 | |
329 | bool test::parseSumProperty(OpAsmParser &parser, int64_t &second, |
330 | int64_t first) { |
331 | int64_t sum; |
332 | auto loc = parser.getCurrentLocation(); |
333 | if (parser.parseInteger(result&: second) || parser.parseEqual() || |
334 | parser.parseInteger(result&: sum)) |
335 | return true; |
336 | if (sum != second + first) { |
337 | parser.emitError(loc, message: "Expected sum to equal first + second" ); |
338 | return true; |
339 | } |
340 | return false; |
341 | } |
342 | |
343 | void test::printSumProperty(OpAsmPrinter &printer, Operation *op, |
344 | int64_t second, int64_t first) { |
345 | printer << second << " = " << (second + first); |
346 | } |
347 | |
348 | //===----------------------------------------------------------------------===// |
349 | // CustomDirectiveOptionalCustomParser |
350 | //===----------------------------------------------------------------------===// |
351 | |
352 | OptionalParseResult test::parseOptionalCustomParser(AsmParser &p, |
353 | IntegerAttr &result) { |
354 | if (succeeded(result: p.parseOptionalKeyword(keyword: "foo" ))) |
355 | return p.parseAttribute(result); |
356 | return {}; |
357 | } |
358 | |
359 | void test::printOptionalCustomParser(AsmPrinter &p, Operation *, |
360 | IntegerAttr result) { |
361 | p << "foo " ; |
362 | p.printAttribute(attr: result); |
363 | } |
364 | |
365 | //===----------------------------------------------------------------------===// |
366 | // CustomDirectiveAttrElideType |
367 | //===----------------------------------------------------------------------===// |
368 | |
369 | ParseResult test::parseAttrElideType(AsmParser &parser, TypeAttr type, |
370 | Attribute &attr) { |
371 | return parser.parseAttribute(attr, type.getValue()); |
372 | } |
373 | |
374 | void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type, |
375 | Attribute attr) { |
376 | printer.printAttributeWithoutType(attr); |
377 | } |
378 | |