1 | //===- TestOpsSyntax.cpp - Operations for testing syntax ------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestOpsSyntax.h" |
10 | #include "TestDialect.h" |
11 | #include "TestOps.h" |
12 | #include "mlir/IR/OpImplementation.h" |
13 | #include "llvm/Support/Base64.h" |
14 | |
15 | using namespace mlir; |
16 | using namespace test; |
17 | |
18 | //===----------------------------------------------------------------------===// |
19 | // Test Format* operations |
20 | //===----------------------------------------------------------------------===// |
21 | |
22 | //===----------------------------------------------------------------------===// |
23 | // Parsing |
24 | //===----------------------------------------------------------------------===// |
25 | |
26 | static ParseResult parseCustomOptionalOperand( |
27 | OpAsmParser &parser, |
28 | std::optional<OpAsmParser::UnresolvedOperand> &optOperand) { |
29 | if (succeeded(Result: parser.parseOptionalLParen())) { |
30 | optOperand.emplace(); |
31 | if (parser.parseOperand(result&: *optOperand) || parser.parseRParen()) |
32 | return failure(); |
33 | } |
34 | return success(); |
35 | } |
36 | |
37 | static ParseResult parseCustomDirectiveOperands( |
38 | OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, |
39 | std::optional<OpAsmParser::UnresolvedOperand> &optOperand, |
40 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) { |
41 | if (parser.parseOperand(result&: operand)) |
42 | return failure(); |
43 | if (succeeded(Result: parser.parseOptionalComma())) { |
44 | optOperand.emplace(); |
45 | if (parser.parseOperand(result&: *optOperand)) |
46 | return failure(); |
47 | } |
48 | if (parser.parseArrow() || parser.parseLParen() || |
49 | parser.parseOperandList(result&: varOperands) || parser.parseRParen()) |
50 | return failure(); |
51 | return success(); |
52 | } |
53 | static ParseResult |
54 | parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, |
55 | Type &optOperandType, |
56 | SmallVectorImpl<Type> &varOperandTypes) { |
57 | if (parser.parseColon()) |
58 | return failure(); |
59 | |
60 | if (parser.parseType(result&: operandType)) |
61 | return failure(); |
62 | if (succeeded(Result: parser.parseOptionalComma())) { |
63 | if (parser.parseType(result&: optOperandType)) |
64 | return failure(); |
65 | } |
66 | if (parser.parseArrow() || parser.parseLParen() || |
67 | parser.parseTypeList(result&: varOperandTypes) || parser.parseRParen()) |
68 | return failure(); |
69 | return success(); |
70 | } |
71 | static ParseResult |
72 | parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, |
73 | Type optOperandType, |
74 | const SmallVectorImpl<Type> &varOperandTypes) { |
75 | if (parser.parseKeyword(keyword: "type_refs_capture")) |
76 | return failure(); |
77 | |
78 | Type operandType2, optOperandType2; |
79 | SmallVector<Type, 1> varOperandTypes2; |
80 | if (parseCustomDirectiveResults(parser, operandType&: operandType2, optOperandType&: optOperandType2, |
81 | varOperandTypes&: varOperandTypes2)) |
82 | return failure(); |
83 | |
84 | if (operandType != operandType2 || optOperandType != optOperandType2 || |
85 | varOperandTypes != varOperandTypes2) |
86 | return failure(); |
87 | |
88 | return success(); |
89 | } |
90 | static ParseResult parseCustomDirectiveOperandsAndTypes( |
91 | OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, |
92 | std::optional<OpAsmParser::UnresolvedOperand> &optOperand, |
93 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands, |
94 | Type &operandType, Type &optOperandType, |
95 | SmallVectorImpl<Type> &varOperandTypes) { |
96 | if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || |
97 | parseCustomDirectiveResults(parser, operandType, optOperandType, |
98 | varOperandTypes)) |
99 | return failure(); |
100 | return success(); |
101 | } |
102 | static ParseResult parseCustomDirectiveRegions( |
103 | OpAsmParser &parser, Region ®ion, |
104 | SmallVectorImpl<std::unique_ptr<Region>> &varRegions) { |
105 | if (parser.parseRegion(region)) |
106 | return failure(); |
107 | if (failed(Result: parser.parseOptionalComma())) |
108 | return success(); |
109 | std::unique_ptr<Region> varRegion = std::make_unique<Region>(); |
110 | if (parser.parseRegion(region&: *varRegion)) |
111 | return failure(); |
112 | varRegions.emplace_back(Args: std::move(varRegion)); |
113 | return success(); |
114 | } |
115 | static ParseResult |
116 | parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, |
117 | SmallVectorImpl<Block *> &varSuccessors) { |
118 | if (parser.parseSuccessor(dest&: successor)) |
119 | return failure(); |
120 | if (failed(Result: parser.parseOptionalComma())) |
121 | return success(); |
122 | Block *varSuccessor; |
123 | if (parser.parseSuccessor(dest&: varSuccessor)) |
124 | return failure(); |
125 | varSuccessors.append(NumInputs: 2, Elt: varSuccessor); |
126 | return success(); |
127 | } |
128 | static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, |
129 | IntegerAttr &attr, |
130 | IntegerAttr &optAttr) { |
131 | if (parser.parseAttribute(result&: attr)) |
132 | return failure(); |
133 | if (succeeded(Result: parser.parseOptionalComma())) { |
134 | if (parser.parseAttribute(result&: optAttr)) |
135 | return failure(); |
136 | } |
137 | return success(); |
138 | } |
139 | static ParseResult parseCustomDirectiveSpacing(OpAsmParser &parser, |
140 | mlir::StringAttr &attr) { |
141 | return parser.parseAttribute(result&: attr); |
142 | } |
143 | static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, |
144 | NamedAttrList &attrs) { |
145 | return parser.parseOptionalAttrDict(result&: attrs); |
146 | } |
147 | static ParseResult parseCustomDirectiveOptionalOperandRef( |
148 | OpAsmParser &parser, |
149 | std::optional<OpAsmParser::UnresolvedOperand> &optOperand) { |
150 | int64_t operandCount = 0; |
151 | if (parser.parseInteger(result&: operandCount)) |
152 | return failure(); |
153 | bool expectedOptionalOperand = operandCount == 0; |
154 | return success(IsSuccess: expectedOptionalOperand != optOperand.has_value()); |
155 | } |
156 | |
157 | //===----------------------------------------------------------------------===// |
158 | // Printing |
159 | //===----------------------------------------------------------------------===// |
160 | |
161 | static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, |
162 | Value optOperand) { |
163 | if (optOperand) |
164 | printer << "("<< optOperand << ") "; |
165 | } |
166 | |
167 | static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, |
168 | Value operand, Value optOperand, |
169 | OperandRange varOperands) { |
170 | printer << operand; |
171 | if (optOperand) |
172 | printer << ", "<< optOperand; |
173 | printer << " -> ("<< varOperands << ")"; |
174 | } |
175 | static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, |
176 | Type operandType, Type optOperandType, |
177 | TypeRange varOperandTypes) { |
178 | printer << " : "<< operandType; |
179 | if (optOperandType) |
180 | printer << ", "<< optOperandType; |
181 | printer << " -> ("<< varOperandTypes << ")"; |
182 | } |
183 | static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, |
184 | Operation *op, Type operandType, |
185 | Type optOperandType, |
186 | TypeRange varOperandTypes) { |
187 | printer << " type_refs_capture "; |
188 | printCustomDirectiveResults(printer, op, operandType, optOperandType, |
189 | varOperandTypes); |
190 | } |
191 | static void printCustomDirectiveOperandsAndTypes( |
192 | OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, |
193 | OperandRange varOperands, Type operandType, Type optOperandType, |
194 | TypeRange varOperandTypes) { |
195 | printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); |
196 | printCustomDirectiveResults(printer, op, operandType, optOperandType, |
197 | varOperandTypes); |
198 | } |
199 | static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, |
200 | Region ®ion, |
201 | MutableArrayRef<Region> varRegions) { |
202 | printer.printRegion(blocks&: region); |
203 | if (!varRegions.empty()) { |
204 | printer << ", "; |
205 | for (Region ®ion : varRegions) |
206 | printer.printRegion(blocks&: region); |
207 | } |
208 | } |
209 | static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, |
210 | Block *successor, |
211 | SuccessorRange varSuccessors) { |
212 | printer << successor; |
213 | if (!varSuccessors.empty()) |
214 | printer << ", "<< varSuccessors.front(); |
215 | } |
216 | static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, |
217 | Attribute attribute, |
218 | Attribute optAttribute) { |
219 | printer << attribute; |
220 | if (optAttribute) |
221 | printer << ", "<< optAttribute; |
222 | } |
223 | static void printCustomDirectiveSpacing(OpAsmPrinter &printer, Operation *op, |
224 | Attribute attribute) { |
225 | printer << attribute; |
226 | } |
227 | static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, |
228 | DictionaryAttr attrs) { |
229 | printer.printOptionalAttrDict(attrs: attrs.getValue()); |
230 | } |
231 | |
232 | static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, |
233 | Operation *op, |
234 | Value optOperand) { |
235 | printer << (optOperand ? "1": "0"); |
236 | } |
237 | //===----------------------------------------------------------------------===// |
238 | // Test parser. |
239 | //===----------------------------------------------------------------------===// |
240 | |
241 | ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, |
242 | OperationState &result) { |
243 | if (parser.parseOptionalColon()) |
244 | return success(); |
245 | uint64_t numResults; |
246 | if (parser.parseInteger(numResults)) |
247 | return failure(); |
248 | |
249 | IndexType type = parser.getBuilder().getIndexType(); |
250 | for (unsigned i = 0; i < numResults; ++i) |
251 | result.addTypes(type); |
252 | return success(); |
253 | } |
254 | |
255 | void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { |
256 | if (unsigned numResults = getNumResults()) |
257 | p << " : "<< numResults; |
258 | } |
259 | |
260 | ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, |
261 | OperationState &result) { |
262 | StringRef keyword; |
263 | if (parser.parseKeyword(&keyword)) |
264 | return failure(); |
265 | result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); |
266 | return success(); |
267 | } |
268 | |
269 | void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " "<< getKeyword(); } |
270 | |
271 | ParseResult ParseB64BytesOp::parse(OpAsmParser &parser, |
272 | OperationState &result) { |
273 | std::vector<char> bytes; |
274 | if (parser.parseBase64Bytes(&bytes)) |
275 | return failure(); |
276 | result.addAttribute("b64", parser.getBuilder().getStringAttr( |
277 | StringRef(&bytes.front(), bytes.size()))); |
278 | return success(); |
279 | } |
280 | |
281 | void ParseB64BytesOp::print(OpAsmPrinter &p) { |
282 | p << " \""<< llvm::encodeBase64(getB64()) << "\""; |
283 | } |
284 | |
285 | ::llvm::LogicalResult FormatInferType2Op::inferReturnTypes( |
286 | ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, |
287 | ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, |
288 | OpaqueProperties properties, ::mlir::RegionRange regions, |
289 | ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { |
290 | inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); |
291 | return ::mlir::success(); |
292 | } |
293 | |
294 | //===----------------------------------------------------------------------===// |
295 | // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. |
296 | //===----------------------------------------------------------------------===// |
297 | |
298 | ParseResult WrappingRegionOp::parse(OpAsmParser &parser, |
299 | OperationState &result) { |
300 | if (parser.parseKeyword("wraps")) |
301 | return failure(); |
302 | |
303 | // Parse the wrapped op in a region |
304 | Region &body = *result.addRegion(); |
305 | body.push_back(new Block); |
306 | Block &block = body.back(); |
307 | Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); |
308 | if (!wrappedOp) |
309 | return failure(); |
310 | |
311 | // Create a return terminator in the inner region, pass as operand to the |
312 | // terminator the returned values from the wrapped operation. |
313 | SmallVector<Value, 8> returnOperands(wrappedOp->getResults()); |
314 | OpBuilder builder(parser.getContext()); |
315 | builder.setInsertionPointToEnd(&block); |
316 | builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands); |
317 | |
318 | // Get the results type for the wrapping op from the terminator operands. |
319 | Operation &returnOp = body.back().back(); |
320 | result.types.append(returnOp.operand_type_begin(), |
321 | returnOp.operand_type_end()); |
322 | |
323 | // Use the location of the wrapped op for the "test.wrapping_region" op. |
324 | result.location = wrappedOp->getLoc(); |
325 | |
326 | return success(); |
327 | } |
328 | |
329 | void WrappingRegionOp::print(OpAsmPrinter &p) { |
330 | p << " wraps "; |
331 | p.printGenericOp(&getRegion().front().front()); |
332 | } |
333 | |
334 | //===----------------------------------------------------------------------===// |
335 | // Test PrettyPrintedRegionOp - exercising the following parser APIs |
336 | // parseGenericOperationAfterOpName |
337 | // parseCustomOperationName |
338 | //===----------------------------------------------------------------------===// |
339 | |
340 | ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, |
341 | OperationState &result) { |
342 | |
343 | SMLoc loc = parser.getCurrentLocation(); |
344 | Location currLocation = parser.getEncodedSourceLoc(loc); |
345 | |
346 | // Parse the operands. |
347 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; |
348 | if (parser.parseOperandList(operands)) |
349 | return failure(); |
350 | |
351 | // Check if we are parsing the pretty-printed version |
352 | // test.pretty_printed_region start <inner-op> end : <functional-type> |
353 | // Else fallback to parsing the "non pretty-printed" version. |
354 | if (!succeeded(parser.parseOptionalKeyword("start"))) |
355 | return parser.parseGenericOperationAfterOpName(result, |
356 | llvm::ArrayRef(operands)); |
357 | |
358 | FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName(); |
359 | if (failed(parseOpNameInfo)) |
360 | return failure(); |
361 | |
362 | StringAttr innerOpName = parseOpNameInfo->getIdentifier(); |
363 | |
364 | FunctionType opFntype; |
365 | std::optional<Location> explicitLoc; |
366 | if (parser.parseKeyword("end") || parser.parseColon() || |
367 | parser.parseType(opFntype) || |
368 | parser.parseOptionalLocationSpecifier(explicitLoc)) |
369 | return failure(); |
370 | |
371 | // If location of the op is explicitly provided, then use it; Else use |
372 | // the parser's current location. |
373 | Location opLoc = explicitLoc.value_or(currLocation); |
374 | |
375 | // Derive the SSA-values for op's operands. |
376 | if (parser.resolveOperands(operands, opFntype.getInputs(), loc, |
377 | result.operands)) |
378 | return failure(); |
379 | |
380 | // Add a region for op. |
381 | Region ®ion = *result.addRegion(); |
382 | |
383 | // Create a basic-block inside op's region. |
384 | Block &block = region.emplaceBlock(); |
385 | |
386 | // Create and insert an "inner-op" operation in the block. |
387 | // Just for testing purposes, we can assume that inner op is a binary op with |
388 | // result and operand types all same as the test-op's first operand. |
389 | Type innerOpType = opFntype.getInput(0); |
390 | Value lhs = block.addArgument(innerOpType, opLoc); |
391 | Value rhs = block.addArgument(innerOpType, opLoc); |
392 | |
393 | OpBuilder builder(parser.getBuilder().getContext()); |
394 | builder.setInsertionPointToStart(&block); |
395 | |
396 | Operation *innerOp = |
397 | builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); |
398 | |
399 | // Insert a return statement in the block returning the inner-op's result. |
400 | builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults()); |
401 | |
402 | // Populate the op operation-state with result-type and location. |
403 | result.addTypes(opFntype.getResults()); |
404 | result.location = innerOp->getLoc(); |
405 | |
406 | return success(); |
407 | } |
408 | |
409 | void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { |
410 | p << ' '; |
411 | p.printOperands(getOperands()); |
412 | |
413 | Operation &innerOp = getRegion().front().front(); |
414 | // Assuming that region has a single non-terminator inner-op, if the inner-op |
415 | // meets some criteria (which in this case is a simple one based on the name |
416 | // of inner-op), then we can print the entire region in a succinct way. |
417 | // Here we assume that the prototype of "test.special.op" can be trivially |
418 | // derived while parsing it back. |
419 | if (innerOp.getName().getStringRef() == "test.special.op") { |
420 | p << " start test.special.op end"; |
421 | } else { |
422 | p << " ("; |
423 | p.printRegion(getRegion()); |
424 | p << ")"; |
425 | } |
426 | |
427 | p << " : "; |
428 | p.printFunctionalType(*this); |
429 | } |
430 | |
431 | //===----------------------------------------------------------------------===// |
432 | // Test PolyForOp - parse list of region arguments. |
433 | //===----------------------------------------------------------------------===// |
434 | |
435 | ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { |
436 | SmallVector<OpAsmParser::Argument, 4> ivsInfo; |
437 | // Parse list of region arguments without a delimiter. |
438 | if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None)) |
439 | return failure(); |
440 | |
441 | // Parse the body region. |
442 | Region *body = result.addRegion(); |
443 | for (auto &iv : ivsInfo) |
444 | iv.type = parser.getBuilder().getIndexType(); |
445 | return parser.parseRegion(*body, ivsInfo); |
446 | } |
447 | |
448 | void PolyForOp::print(OpAsmPrinter &p) { |
449 | p << " "; |
450 | llvm::interleaveComma(getRegion().getArguments(), p, [&](auto arg) { |
451 | p.printRegionArgument(arg, /*argAttrs =*/{}, /*omitType=*/true); |
452 | }); |
453 | p << " "; |
454 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
455 | } |
456 | |
457 | void PolyForOp::getAsmBlockArgumentNames(Region ®ion, |
458 | OpAsmSetValueNameFn setNameFn) { |
459 | auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names"); |
460 | if (!arrayAttr) |
461 | return; |
462 | auto args = getRegion().front().getArguments(); |
463 | auto e = std::min(arrayAttr.size(), args.size()); |
464 | for (unsigned i = 0; i < e; ++i) { |
465 | if (auto strAttr = dyn_cast<StringAttr>(arrayAttr[i])) |
466 | setNameFn(args[i], strAttr.getValue()); |
467 | } |
468 | } |
469 | |
470 | //===----------------------------------------------------------------------===// |
471 | // TestAttrWithLoc - parse/printOptionalLocationSpecifier |
472 | //===----------------------------------------------------------------------===// |
473 | |
474 | static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) { |
475 | std::optional<Location> result; |
476 | SMLoc sourceLoc = p.getCurrentLocation(); |
477 | if (p.parseOptionalLocationSpecifier(result)) |
478 | return failure(); |
479 | if (result) |
480 | loc = *result; |
481 | else |
482 | loc = p.getEncodedSourceLoc(loc: sourceLoc); |
483 | return success(); |
484 | } |
485 | |
486 | static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) { |
487 | p.printOptionalLocationSpecifier(loc: cast<LocationAttr>(Val&: loc)); |
488 | } |
489 | |
490 | //===----------------------------------------------------------------------===// |
491 | // ParseCustomOperationNameAPI |
492 | //===----------------------------------------------------------------------===// |
493 | |
494 | static ParseResult parseCustomOperationNameEntry(OpAsmParser &p, |
495 | Attribute &name) { |
496 | FailureOr<OperationName> opName = p.parseCustomOperationName(); |
497 | if (failed(Result: opName)) |
498 | return ParseResult::failure(); |
499 | |
500 | name = p.getBuilder().getStringAttr(opName->getStringRef()); |
501 | return ParseResult::success(); |
502 | } |
503 | |
504 | static void printCustomOperationNameEntry(OpAsmPrinter &p, Operation *op, |
505 | Attribute name) { |
506 | p << cast<StringAttr>(name).getValue(); |
507 | } |
508 | |
509 | #define GET_OP_CLASSES |
510 | #include "TestOpsSyntax.cpp.inc" |
511 | |
512 | void TestDialect::registerOpsSyntax() { |
513 | addOperations< |
514 | #define GET_OP_LIST |
515 | #include "TestOpsSyntax.cpp.inc" |
516 | >(); |
517 | } |
518 |
Definitions
- parseCustomOptionalOperand
- parseCustomDirectiveOperands
- parseCustomDirectiveResults
- parseCustomDirectiveWithTypeRefs
- parseCustomDirectiveOperandsAndTypes
- parseCustomDirectiveRegions
- parseCustomDirectiveSuccessors
- parseCustomDirectiveAttributes
- parseCustomDirectiveSpacing
- parseCustomDirectiveAttrDict
- parseCustomDirectiveOptionalOperandRef
- printCustomOptionalOperand
- printCustomDirectiveOperands
- printCustomDirectiveResults
- printCustomDirectiveWithTypeRefs
- printCustomDirectiveOperandsAndTypes
- printCustomDirectiveRegions
- printCustomDirectiveSuccessors
- printCustomDirectiveAttributes
- printCustomDirectiveSpacing
- printCustomDirectiveAttrDict
- printCustomDirectiveOptionalOperandRef
- parseOptionalLoc
- printOptionalLoc
- parseCustomOperationNameEntry
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more