1//===- Syntax.cpp - Custom syntax for Linalg transform ops ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
10#include "mlir/IR/OpImplementation.h"
11
12using namespace mlir;
13
14ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
15 Type &resultType) {
16 argumentType = resultType = nullptr;
17 bool hasLParen = parser.parseOptionalLParen().succeeded();
18 if (parser.parseType(result&: argumentType).failed())
19 return failure();
20 if (!hasLParen)
21 return success();
22
23 return failure(isFailure: parser.parseRParen().failed() ||
24 parser.parseArrow().failed() ||
25 parser.parseType(result&: resultType).failed());
26}
27
28ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
29 SmallVectorImpl<Type> &resultTypes) {
30 argumentType = nullptr;
31 bool hasLParen = parser.parseOptionalLParen().succeeded();
32 if (parser.parseType(result&: argumentType).failed())
33 return failure();
34 if (!hasLParen)
35 return success();
36
37 if (parser.parseRParen().failed() || parser.parseArrow().failed())
38 return failure();
39
40 if (parser.parseOptionalLParen().failed()) {
41 Type type;
42 if (parser.parseType(result&: type).failed())
43 return failure();
44 resultTypes.push_back(Elt: type);
45 return success();
46 }
47 if (parser.parseTypeList(result&: resultTypes).failed() ||
48 parser.parseRParen().failed()) {
49 resultTypes.clear();
50 return failure();
51 }
52 return success();
53}
54
55void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
56 Type argumentType, TypeRange resultType) {
57 if (!resultType.empty())
58 printer << "(";
59 printer << argumentType;
60 if (resultType.empty())
61 return;
62 printer << ") -> ";
63
64 if (resultType.size() > 1)
65 printer << "(";
66 llvm::interleaveComma(c: resultType, os&: printer.getStream());
67 if (resultType.size() > 1)
68 printer << ")";
69}
70
71void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
72 Type argumentType, Type resultType) {
73 return printSemiFunctionType(printer, op, argumentType,
74 resultType: resultType ? TypeRange(resultType)
75 : TypeRange());
76}
77

source code of mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp