1//===- TestFormatUtils.cpp - MLIR Test Dialect Assembly Format Utilities --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestFormatUtils.h"
10#include "mlir/IR/Builders.h"
11
12using namespace mlir;
13using namespace test;
14
15//===----------------------------------------------------------------------===//
16// CustomDirectiveOperands
17//===----------------------------------------------------------------------===//
18
19ParseResult test::parseCustomDirectiveOperands(
20 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
21 std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
22 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
23 if (parser.parseOperand(result&: operand))
24 return failure();
25 if (succeeded(result: parser.parseOptionalComma())) {
26 optOperand.emplace();
27 if (parser.parseOperand(result&: *optOperand))
28 return failure();
29 }
30 if (parser.parseArrow() || parser.parseLParen() ||
31 parser.parseOperandList(result&: varOperands) || parser.parseRParen())
32 return failure();
33 return success();
34}
35
36void test::printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
37 Value operand, Value optOperand,
38 OperandRange varOperands) {
39 printer << operand;
40 if (optOperand)
41 printer << ", " << optOperand;
42 printer << " -> (" << varOperands << ")";
43}
44
45//===----------------------------------------------------------------------===//
46// CustomDirectiveResults
47//===----------------------------------------------------------------------===//
48
49ParseResult
50test::parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
51 Type &optOperandType,
52 SmallVectorImpl<Type> &varOperandTypes) {
53 if (parser.parseColon())
54 return failure();
55
56 if (parser.parseType(result&: operandType))
57 return failure();
58 if (succeeded(result: parser.parseOptionalComma()))
59 if (parser.parseType(result&: optOperandType))
60 return failure();
61 if (parser.parseArrow() || parser.parseLParen() ||
62 parser.parseTypeList(result&: varOperandTypes) || parser.parseRParen())
63 return failure();
64 return success();
65}
66
67void test::printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
68 Type operandType, Type optOperandType,
69 TypeRange varOperandTypes) {
70 printer << " : " << operandType;
71 if (optOperandType)
72 printer << ", " << optOperandType;
73 printer << " -> (" << varOperandTypes << ")";
74}
75
76//===----------------------------------------------------------------------===//
77// CustomDirectiveWithTypeRefs
78//===----------------------------------------------------------------------===//
79
80ParseResult test::parseCustomDirectiveWithTypeRefs(
81 OpAsmParser &parser, Type operandType, Type optOperandType,
82 const SmallVectorImpl<Type> &varOperandTypes) {
83 if (parser.parseKeyword(keyword: "type_refs_capture"))
84 return failure();
85
86 Type operandType2, optOperandType2;
87 SmallVector<Type, 1> varOperandTypes2;
88 if (parseCustomDirectiveResults(parser, operandType&: operandType2, optOperandType&: optOperandType2,
89 varOperandTypes&: varOperandTypes2))
90 return failure();
91
92 if (operandType != operandType2 || optOperandType != optOperandType2 ||
93 varOperandTypes != varOperandTypes2)
94 return failure();
95
96 return success();
97}
98
99void test::printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
100 Operation *op, Type operandType,
101 Type optOperandType,
102 TypeRange varOperandTypes) {
103 printer << " type_refs_capture ";
104 printCustomDirectiveResults(printer, op, operandType, optOperandType,
105 varOperandTypes);
106}
107
108//===----------------------------------------------------------------------===//
109// CustomDirectiveOperandsAndTypes
110//===----------------------------------------------------------------------===//
111
112ParseResult test::parseCustomDirectiveOperandsAndTypes(
113 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
114 std::optional<OpAsmParser::UnresolvedOperand> &optOperand,
115 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
116 Type &operandType, Type &optOperandType,
117 SmallVectorImpl<Type> &varOperandTypes) {
118 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
119 parseCustomDirectiveResults(parser, operandType, optOperandType,
120 varOperandTypes))
121 return failure();
122 return success();
123}
124
125void test::printCustomDirectiveOperandsAndTypes(
126 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
127 OperandRange varOperands, Type operandType, Type optOperandType,
128 TypeRange varOperandTypes) {
129 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
130 printCustomDirectiveResults(printer, op, operandType, optOperandType,
131 varOperandTypes);
132}
133
134//===----------------------------------------------------------------------===//
135// CustomDirectiveRegions
136//===----------------------------------------------------------------------===//
137
138ParseResult test::parseCustomDirectiveRegions(
139 OpAsmParser &parser, Region &region,
140 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
141 if (parser.parseRegion(region))
142 return failure();
143 if (failed(result: parser.parseOptionalComma()))
144 return success();
145 std::unique_ptr<Region> varRegion = std::make_unique<Region>();
146 if (parser.parseRegion(region&: *varRegion))
147 return failure();
148 varRegions.emplace_back(Args: std::move(varRegion));
149 return success();
150}
151
152void test::printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
153 Region &region,
154 MutableArrayRef<Region> varRegions) {
155 printer.printRegion(blocks&: region);
156 if (!varRegions.empty()) {
157 printer << ", ";
158 for (Region &region : varRegions)
159 printer.printRegion(blocks&: region);
160 }
161}
162
163//===----------------------------------------------------------------------===//
164// CustomDirectiveSuccessors
165//===----------------------------------------------------------------------===//
166
167ParseResult
168test::parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
169 SmallVectorImpl<Block *> &varSuccessors) {
170 if (parser.parseSuccessor(dest&: successor))
171 return failure();
172 if (failed(result: parser.parseOptionalComma()))
173 return success();
174 Block *varSuccessor;
175 if (parser.parseSuccessor(dest&: varSuccessor))
176 return failure();
177 varSuccessors.append(NumInputs: 2, Elt: varSuccessor);
178 return success();
179}
180
181void test::printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
182 Block *successor,
183 SuccessorRange varSuccessors) {
184 printer << successor;
185 if (!varSuccessors.empty())
186 printer << ", " << varSuccessors.front();
187}
188
189//===----------------------------------------------------------------------===//
190// CustomDirectiveAttributes
191//===----------------------------------------------------------------------===//
192
193ParseResult test::parseCustomDirectiveAttributes(OpAsmParser &parser,
194 IntegerAttr &attr,
195 IntegerAttr &optAttr) {
196 if (parser.parseAttribute(result&: attr))
197 return failure();
198 if (succeeded(result: parser.parseOptionalComma())) {
199 if (parser.parseAttribute(result&: optAttr))
200 return failure();
201 }
202 return success();
203}
204
205void test::printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
206 Attribute attribute,
207 Attribute optAttribute) {
208 printer << attribute;
209 if (optAttribute)
210 printer << ", " << optAttribute;
211}
212
213//===----------------------------------------------------------------------===//
214// CustomDirectiveAttrDict
215//===----------------------------------------------------------------------===//
216
217ParseResult test::parseCustomDirectiveAttrDict(OpAsmParser &parser,
218 NamedAttrList &attrs) {
219 return parser.parseOptionalAttrDict(result&: attrs);
220}
221
222void test::printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
223 DictionaryAttr attrs) {
224 printer.printOptionalAttrDict(attrs: attrs.getValue());
225}
226
227//===----------------------------------------------------------------------===//
228// CustomDirectiveOptionalOperandRef
229//===----------------------------------------------------------------------===//
230
231ParseResult test::parseCustomDirectiveOptionalOperandRef(
232 OpAsmParser &parser,
233 std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
234 int64_t operandCount = 0;
235 if (parser.parseInteger(result&: operandCount))
236 return failure();
237 bool expectedOptionalOperand = operandCount == 0;
238 return success(isSuccess: expectedOptionalOperand != !!optOperand);
239}
240
241void test::printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
242 Operation *op,
243 Value optOperand) {
244 printer << (optOperand ? "1" : "0");
245}
246
247//===----------------------------------------------------------------------===//
248// CustomDirectiveOptionalOperand
249//===----------------------------------------------------------------------===//
250
251ParseResult test::parseCustomOptionalOperand(
252 OpAsmParser &parser,
253 std::optional<OpAsmParser::UnresolvedOperand> &optOperand) {
254 if (succeeded(result: parser.parseOptionalLParen())) {
255 optOperand.emplace();
256 if (parser.parseOperand(result&: *optOperand) || parser.parseRParen())
257 return failure();
258 }
259 return success();
260}
261
262void test::printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
263 Value optOperand) {
264 if (optOperand)
265 printer << "(" << optOperand << ") ";
266}
267
268//===----------------------------------------------------------------------===//
269// CustomDirectiveSwitchCases
270//===----------------------------------------------------------------------===//
271
272ParseResult
273test::parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases,
274 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
275 SmallVector<int64_t> caseValues;
276 while (succeeded(result: p.parseOptionalKeyword(keyword: "case"))) {
277 int64_t value;
278 Region &region = *caseRegions.emplace_back(Args: std::make_unique<Region>());
279 if (p.parseInteger(result&: value) || p.parseRegion(region, /*arguments=*/{}))
280 return failure();
281 caseValues.push_back(Elt: value);
282 }
283 cases = p.getBuilder().getDenseI64ArrayAttr(caseValues);
284 return success();
285}
286
287void test::printSwitchCases(OpAsmPrinter &p, Operation *op,
288 DenseI64ArrayAttr cases, RegionRange caseRegions) {
289 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
290 p.printNewline();
291 p << "case " << value << ' ';
292 p.printRegion(*region, /*printEntryBlockArgs=*/false);
293 }
294}
295
296//===----------------------------------------------------------------------===//
297// CustomUsingPropertyInCustom
298//===----------------------------------------------------------------------===//
299
300bool test::parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) {
301 return parser.parseLSquare() || parser.parseInteger(result&: value[0]) ||
302 parser.parseComma() || parser.parseInteger(result&: value[1]) ||
303 parser.parseComma() || parser.parseInteger(result&: value[2]) ||
304 parser.parseRSquare();
305}
306
307void test::printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op,
308 ArrayRef<int64_t> value) {
309 printer << '[' << value << ']';
310}
311
312//===----------------------------------------------------------------------===//
313// CustomDirectiveIntProperty
314//===----------------------------------------------------------------------===//
315
316bool test::parseIntProperty(OpAsmParser &parser, int64_t &value) {
317 return failed(result: parser.parseInteger(result&: value));
318}
319
320void test::printIntProperty(OpAsmPrinter &printer, Operation *op,
321 int64_t value) {
322 printer << value;
323}
324
325//===----------------------------------------------------------------------===//
326// CustomDirectiveSumProperty
327//===----------------------------------------------------------------------===//
328
329bool test::parseSumProperty(OpAsmParser &parser, int64_t &second,
330 int64_t first) {
331 int64_t sum;
332 auto loc = parser.getCurrentLocation();
333 if (parser.parseInteger(result&: second) || parser.parseEqual() ||
334 parser.parseInteger(result&: sum))
335 return true;
336 if (sum != second + first) {
337 parser.emitError(loc, message: "Expected sum to equal first + second");
338 return true;
339 }
340 return false;
341}
342
343void test::printSumProperty(OpAsmPrinter &printer, Operation *op,
344 int64_t second, int64_t first) {
345 printer << second << " = " << (second + first);
346}
347
348//===----------------------------------------------------------------------===//
349// CustomDirectiveOptionalCustomParser
350//===----------------------------------------------------------------------===//
351
352OptionalParseResult test::parseOptionalCustomParser(AsmParser &p,
353 IntegerAttr &result) {
354 if (succeeded(result: p.parseOptionalKeyword(keyword: "foo")))
355 return p.parseAttribute(result);
356 return {};
357}
358
359void test::printOptionalCustomParser(AsmPrinter &p, Operation *,
360 IntegerAttr result) {
361 p << "foo ";
362 p.printAttribute(attr: result);
363}
364
365//===----------------------------------------------------------------------===//
366// CustomDirectiveAttrElideType
367//===----------------------------------------------------------------------===//
368
369ParseResult test::parseAttrElideType(AsmParser &parser, TypeAttr type,
370 Attribute &attr) {
371 return parser.parseAttribute(attr, type.getValue());
372}
373
374void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type,
375 Attribute attr) {
376 printer.printAttributeWithoutType(attr);
377}
378

source code of mlir/test/lib/Dialect/Test/TestFormatUtils.cpp