1//===- TestOpsSyntax.cpp - Operations for testing syntax ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestOpsSyntax.h"
10#include "TestDialect.h"
11#include "TestOps.h"
12#include "mlir/IR/OpImplementation.h"
13#include "llvm/Support/Base64.h"
14
15using namespace mlir;
16using namespace test;
17
18//===----------------------------------------------------------------------===//
19// Test Format* operations
20//===----------------------------------------------------------------------===//
21
22//===----------------------------------------------------------------------===//
23// Parsing
24//===----------------------------------------------------------------------===//
25
26static ParseResult parseCustomOptionalOperand(
27 OpAsmParser &parser,
28 std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
29 if (succeeded(Result: parser.parseOptionalLParen())) {
30 optOperand.emplace();
31 if (parser.parseOperand(result&: *optOperand) || parser.parseRParen())
32 return failure();
33 }
34 return success();
35}
36
37static ParseResult parseCustomDirectiveOperands(
38 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
39 std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
40 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
41 if (parser.parseOperand(result&: operand))
42 return failure();
43 if (succeeded(Result: parser.parseOptionalComma())) {
44 optOperand.emplace();
45 if (parser.parseOperand(result&: *optOperand))
46 return failure();
47 }
48 if (parser.parseArrow() || parser.parseLParen() ||
49 parser.parseOperandList(result&: varOperands) || parser.parseRParen())
50 return failure();
51 return success();
52}
53static ParseResult
54parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
55 Type &optOperandType,
56 SmallVectorImpl<Type> &varOperandTypes) {
57 if (parser.parseColon())
58 return failure();
59
60 if (parser.parseType(result&: operandType))
61 return failure();
62 if (succeeded(Result: parser.parseOptionalComma())) {
63 if (parser.parseType(result&: optOperandType))
64 return failure();
65 }
66 if (parser.parseArrow() || parser.parseLParen() ||
67 parser.parseTypeList(result&: varOperandTypes) || parser.parseRParen())
68 return failure();
69 return success();
70}
71static ParseResult
72parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
73 Type optOperandType,
74 const SmallVectorImpl<Type> &varOperandTypes) {
75 if (parser.parseKeyword(keyword: "type_refs_capture"))
76 return failure();
77
78 Type operandType2, optOperandType2;
79 SmallVector<Type, 1> varOperandTypes2;
80 if (parseCustomDirectiveResults(parser, operandType&: operandType2, optOperandType&: optOperandType2,
81 varOperandTypes&: varOperandTypes2))
82 return failure();
83
84 if (operandType != operandType2 || optOperandType != optOperandType2 ||
85 varOperandTypes != varOperandTypes2)
86 return failure();
87
88 return success();
89}
90static ParseResult parseCustomDirectiveOperandsAndTypes(
91 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
92 std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
93 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
94 Type &operandType, Type &optOperandType,
95 SmallVectorImpl<Type> &varOperandTypes) {
96 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
97 parseCustomDirectiveResults(parser, operandType, optOperandType,
98 varOperandTypes))
99 return failure();
100 return success();
101}
102static ParseResult parseCustomDirectiveRegions(
103 OpAsmParser &parser, Region &region,
104 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
105 if (parser.parseRegion(region))
106 return failure();
107 if (failed(Result: parser.parseOptionalComma()))
108 return success();
109 std::unique_ptr<Region> varRegion = std::make_unique<Region>();
110 if (parser.parseRegion(region&: *varRegion))
111 return failure();
112 varRegions.emplace_back(Args: std::move(varRegion));
113 return success();
114}
115static ParseResult
116parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
117 SmallVectorImpl<Block *> &varSuccessors) {
118 if (parser.parseSuccessor(dest&: successor))
119 return failure();
120 if (failed(Result: parser.parseOptionalComma()))
121 return success();
122 Block *varSuccessor;
123 if (parser.parseSuccessor(dest&: varSuccessor))
124 return failure();
125 varSuccessors.append(NumInputs: 2, Elt: varSuccessor);
126 return success();
127}
128static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
129 IntegerAttr &attr,
130 IntegerAttr &optAttr) {
131 if (parser.parseAttribute(result&: attr))
132 return failure();
133 if (succeeded(Result: parser.parseOptionalComma())) {
134 if (parser.parseAttribute(result&: optAttr))
135 return failure();
136 }
137 return success();
138}
139static ParseResult parseCustomDirectiveSpacing(OpAsmParser &parser,
140 mlir::StringAttr &attr) {
141 return parser.parseAttribute(result&: attr);
142}
143static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
144 NamedAttrList &attrs) {
145 return parser.parseOptionalAttrDict(result&: attrs);
146}
147static ParseResult parseCustomDirectiveOptionalOperandRef(
148 OpAsmParser &parser,
149 std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
150 int64_t operandCount = 0;
151 if (parser.parseInteger(result&: operandCount))
152 return failure();
153 bool expectedOptionalOperand = operandCount == 0;
154 return success(IsSuccess: expectedOptionalOperand != optOperand.has_value());
155}
156
157//===----------------------------------------------------------------------===//
158// Printing
159//===----------------------------------------------------------------------===//
160
161static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
162 Value optOperand) {
163 if (optOperand)
164 printer << "(" << optOperand << ") ";
165}
166
167static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
168 Value operand, Value optOperand,
169 OperandRange varOperands) {
170 printer << operand;
171 if (optOperand)
172 printer << ", " << optOperand;
173 printer << " -> (" << varOperands << ")";
174}
175static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
176 Type operandType, Type optOperandType,
177 TypeRange varOperandTypes) {
178 printer << " : " << operandType;
179 if (optOperandType)
180 printer << ", " << optOperandType;
181 printer << " -> (" << varOperandTypes << ")";
182}
183static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
184 Operation *op, Type operandType,
185 Type optOperandType,
186 TypeRange varOperandTypes) {
187 printer << " type_refs_capture ";
188 printCustomDirectiveResults(printer, op, operandType, optOperandType,
189 varOperandTypes);
190}
191static void printCustomDirectiveOperandsAndTypes(
192 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
193 OperandRange varOperands, Type operandType, Type optOperandType,
194 TypeRange varOperandTypes) {
195 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
196 printCustomDirectiveResults(printer, op, operandType, optOperandType,
197 varOperandTypes);
198}
199static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
200 Region &region,
201 MutableArrayRef<Region> varRegions) {
202 printer.printRegion(blocks&: region);
203 if (!varRegions.empty()) {
204 printer << ", ";
205 for (Region &region : varRegions)
206 printer.printRegion(blocks&: region);
207 }
208}
209static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
210 Block *successor,
211 SuccessorRange varSuccessors) {
212 printer << successor;
213 if (!varSuccessors.empty())
214 printer << ", " << varSuccessors.front();
215}
216static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
217 Attribute attribute,
218 Attribute optAttribute) {
219 printer << attribute;
220 if (optAttribute)
221 printer << ", " << optAttribute;
222}
223static void printCustomDirectiveSpacing(OpAsmPrinter &printer, Operation *op,
224 Attribute attribute) {
225 printer << attribute;
226}
227static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
228 DictionaryAttr attrs) {
229 printer.printOptionalAttrDict(attrs: attrs.getValue());
230}
231
232static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
233 Operation *op,
234 Value optOperand) {
235 printer << (optOperand ? "1" : "0");
236}
237//===----------------------------------------------------------------------===//
238// Test parser.
239//===----------------------------------------------------------------------===//
240
241ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
242 OperationState &result) {
243 if (parser.parseOptionalColon())
244 return success();
245 uint64_t numResults;
246 if (parser.parseInteger(numResults))
247 return failure();
248
249 IndexType type = parser.getBuilder().getIndexType();
250 for (unsigned i = 0; i < numResults; ++i)
251 result.addTypes(type);
252 return success();
253}
254
255void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
256 if (unsigned numResults = getNumResults())
257 p << " : " << numResults;
258}
259
260ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
261 OperationState &result) {
262 StringRef keyword;
263 if (parser.parseKeyword(&keyword))
264 return failure();
265 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
266 return success();
267}
268
269void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
270
271ParseResult ParseB64BytesOp::parse(OpAsmParser &parser,
272 OperationState &result) {
273 std::vector<char> bytes;
274 if (parser.parseBase64Bytes(&bytes))
275 return failure();
276 result.addAttribute("b64", parser.getBuilder().getStringAttr(
277 StringRef(&bytes.front(), bytes.size())));
278 return success();
279}
280
281void ParseB64BytesOp::print(OpAsmPrinter &p) {
282 p << " \"" << llvm::encodeBase64(getB64()) << "\"";
283}
284
285::llvm::LogicalResult FormatInferType2Op::inferReturnTypes(
286 ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
287 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
288 OpaqueProperties properties, ::mlir::RegionRange regions,
289 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
290 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
291 return ::mlir::success();
292}
293
294//===----------------------------------------------------------------------===//
295// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
296//===----------------------------------------------------------------------===//
297
298ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
299 OperationState &result) {
300 if (parser.parseKeyword("wraps"))
301 return failure();
302
303 // Parse the wrapped op in a region
304 Region &body = *result.addRegion();
305 body.push_back(new Block);
306 Block &block = body.back();
307 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
308 if (!wrappedOp)
309 return failure();
310
311 // Create a return terminator in the inner region, pass as operand to the
312 // terminator the returned values from the wrapped operation.
313 SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
314 OpBuilder builder(parser.getContext());
315 builder.setInsertionPointToEnd(&block);
316 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
317
318 // Get the results type for the wrapping op from the terminator operands.
319 Operation &returnOp = body.back().back();
320 result.types.append(returnOp.operand_type_begin(),
321 returnOp.operand_type_end());
322
323 // Use the location of the wrapped op for the "test.wrapping_region" op.
324 result.location = wrappedOp->getLoc();
325
326 return success();
327}
328
329void WrappingRegionOp::print(OpAsmPrinter &p) {
330 p << " wraps ";
331 p.printGenericOp(&getRegion().front().front());
332}
333
334//===----------------------------------------------------------------------===//
335// Test PrettyPrintedRegionOp - exercising the following parser APIs
336// parseGenericOperationAfterOpName
337// parseCustomOperationName
338//===----------------------------------------------------------------------===//
339
340ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
341 OperationState &result) {
342
343 SMLoc loc = parser.getCurrentLocation();
344 Location currLocation = parser.getEncodedSourceLoc(loc);
345
346 // Parse the operands.
347 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
348 if (parser.parseOperandList(operands))
349 return failure();
350
351 // Check if we are parsing the pretty-printed version
352 // test.pretty_printed_region start <inner-op> end : <functional-type>
353 // Else fallback to parsing the "non pretty-printed" version.
354 if (!succeeded(parser.parseOptionalKeyword("start")))
355 return parser.parseGenericOperationAfterOpName(result,
356 llvm::ArrayRef(operands));
357
358 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
359 if (failed(parseOpNameInfo))
360 return failure();
361
362 StringAttr innerOpName = parseOpNameInfo->getIdentifier();
363
364 FunctionType opFntype;
365 std::optional<Location> explicitLoc;
366 if (parser.parseKeyword("end") || parser.parseColon() ||
367 parser.parseType(opFntype) ||
368 parser.parseOptionalLocationSpecifier(explicitLoc))
369 return failure();
370
371 // If location of the op is explicitly provided, then use it; Else use
372 // the parser's current location.
373 Location opLoc = explicitLoc.value_or(currLocation);
374
375 // Derive the SSA-values for op's operands.
376 if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
377 result.operands))
378 return failure();
379
380 // Add a region for op.
381 Region &region = *result.addRegion();
382
383 // Create a basic-block inside op's region.
384 Block &block = region.emplaceBlock();
385
386 // Create and insert an "inner-op" operation in the block.
387 // Just for testing purposes, we can assume that inner op is a binary op with
388 // result and operand types all same as the test-op's first operand.
389 Type innerOpType = opFntype.getInput(0);
390 Value lhs = block.addArgument(innerOpType, opLoc);
391 Value rhs = block.addArgument(innerOpType, opLoc);
392
393 OpBuilder builder(parser.getBuilder().getContext());
394 builder.setInsertionPointToStart(&block);
395
396 Operation *innerOp =
397 builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
398
399 // Insert a return statement in the block returning the inner-op's result.
400 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
401
402 // Populate the op operation-state with result-type and location.
403 result.addTypes(opFntype.getResults());
404 result.location = innerOp->getLoc();
405
406 return success();
407}
408
409void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
410 p << ' ';
411 p.printOperands(getOperands());
412
413 Operation &innerOp = getRegion().front().front();
414 // Assuming that region has a single non-terminator inner-op, if the inner-op
415 // meets some criteria (which in this case is a simple one based on the name
416 // of inner-op), then we can print the entire region in a succinct way.
417 // Here we assume that the prototype of "test.special.op" can be trivially
418 // derived while parsing it back.
419 if (innerOp.getName().getStringRef() == "test.special.op") {
420 p << " start test.special.op end";
421 } else {
422 p << " (";
423 p.printRegion(getRegion());
424 p << ")";
425 }
426
427 p << " : ";
428 p.printFunctionalType(*this);
429}
430
431//===----------------------------------------------------------------------===//
432// Test PolyForOp - parse list of region arguments.
433//===----------------------------------------------------------------------===//
434
435ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
436 SmallVector<OpAsmParser::Argument, 4> ivsInfo;
437 // Parse list of region arguments without a delimiter.
438 if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
439 return failure();
440
441 // Parse the body region.
442 Region *body = result.addRegion();
443 for (auto &iv : ivsInfo)
444 iv.type = parser.getBuilder().getIndexType();
445 return parser.parseRegion(*body, ivsInfo);
446}
447
448void PolyForOp::print(OpAsmPrinter &p) {
449 p << " ";
450 llvm::interleaveComma(getRegion().getArguments(), p, [&](auto arg) {
451 p.printRegionArgument(arg, /*argAttrs =*/{}, /*omitType=*/true);
452 });
453 p << " ";
454 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
455}
456
457void PolyForOp::getAsmBlockArgumentNames(Region &region,
458 OpAsmSetValueNameFn setNameFn) {
459 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
460 if (!arrayAttr)
461 return;
462 auto args = getRegion().front().getArguments();
463 auto e = std::min(arrayAttr.size(), args.size());
464 for (unsigned i = 0; i < e; ++i) {
465 if (auto strAttr = dyn_cast<StringAttr>(arrayAttr[i]))
466 setNameFn(args[i], strAttr.getValue());
467 }
468}
469
470//===----------------------------------------------------------------------===//
471// TestAttrWithLoc - parse/printOptionalLocationSpecifier
472//===----------------------------------------------------------------------===//
473
474static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) {
475 std::optional<Location> result;
476 SMLoc sourceLoc = p.getCurrentLocation();
477 if (p.parseOptionalLocationSpecifier(result))
478 return failure();
479 if (result)
480 loc = *result;
481 else
482 loc = p.getEncodedSourceLoc(loc: sourceLoc);
483 return success();
484}
485
486static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) {
487 p.printOptionalLocationSpecifier(loc: cast<LocationAttr>(Val&: loc));
488}
489
490//===----------------------------------------------------------------------===//
491// ParseCustomOperationNameAPI
492//===----------------------------------------------------------------------===//
493
494static ParseResult parseCustomOperationNameEntry(OpAsmParser &p,
495 Attribute &name) {
496 FailureOr<OperationName> opName = p.parseCustomOperationName();
497 if (failed(Result: opName))
498 return ParseResult::failure();
499
500 name = p.getBuilder().getStringAttr(opName->getStringRef());
501 return ParseResult::success();
502}
503
504static void printCustomOperationNameEntry(OpAsmPrinter &p, Operation *op,
505 Attribute name) {
506 p << cast<StringAttr>(name).getValue();
507}
508
509#define GET_OP_CLASSES
510#include "TestOpsSyntax.cpp.inc"
511
512void TestDialect::registerOpsSyntax() {
513 addOperations<
514#define GET_OP_LIST
515#include "TestOpsSyntax.cpp.inc"
516 >();
517}
518

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp