1//===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10#include "mlir/Dialect/PDL/IR/PDLTypes.h"
11#include "mlir/IR/BuiltinTypes.h"
12#include "mlir/IR/DialectImplementation.h"
13#include "mlir/Interfaces/FunctionImplementation.h"
14
15using namespace mlir;
16using namespace mlir::pdl_interp;
17
18#include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
19
20//===----------------------------------------------------------------------===//
21// PDLInterp Dialect
22//===----------------------------------------------------------------------===//
23
24void PDLInterpDialect::initialize() {
25 addOperations<
26#define GET_OP_LIST
27#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
28 >();
29}
30
31template <typename OpT>
32static LogicalResult verifySwitchOp(OpT op) {
33 // Verify that the number of case destinations matches the number of case
34 // values.
35 size_t numDests = op.getCases().size();
36 size_t numValues = op.getCaseValues().size();
37 if (numDests != numValues) {
38 return op.emitOpError(
39 "expected number of cases to match the number of case "
40 "values, got ")
41 << numDests << " but expected " << numValues;
42 }
43 return success();
44}
45
46//===----------------------------------------------------------------------===//
47// pdl_interp::CreateOperationOp
48//===----------------------------------------------------------------------===//
49
50LogicalResult CreateOperationOp::verify() {
51 if (!getInferredResultTypes())
52 return success();
53 if (!getInputResultTypes().empty()) {
54 return emitOpError("with inferred results cannot also have "
55 "explicit result types");
56 }
57 OperationName opName(getName(), getContext());
58 if (!opName.hasInterface<InferTypeOpInterface>()) {
59 return emitOpError()
60 << "has inferred results, but the created operation '" << opName
61 << "' does not support result type inference (or is not "
62 "registered)";
63 }
64 return success();
65}
66
67static ParseResult parseCreateOperationOpAttributes(
68 OpAsmParser &p,
69 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
70 ArrayAttr &attrNamesAttr) {
71 Builder &builder = p.getBuilder();
72 SmallVector<Attribute, 4> attrNames;
73 if (succeeded(result: p.parseOptionalLBrace())) {
74 auto parseOperands = [&]() {
75 StringAttr nameAttr;
76 OpAsmParser::UnresolvedOperand operand;
77 if (p.parseAttribute(nameAttr) || p.parseEqual() ||
78 p.parseOperand(result&: operand))
79 return failure();
80 attrNames.push_back(Elt: nameAttr);
81 attrOperands.push_back(Elt: operand);
82 return success();
83 };
84 if (p.parseCommaSeparatedList(parseElementFn: parseOperands) || p.parseRBrace())
85 return failure();
86 }
87 attrNamesAttr = builder.getArrayAttr(attrNames);
88 return success();
89}
90
91static void printCreateOperationOpAttributes(OpAsmPrinter &p,
92 CreateOperationOp op,
93 OperandRange attrArgs,
94 ArrayAttr attrNames) {
95 if (attrNames.empty())
96 return;
97 p << " {";
98 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
99 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
100 p << '}';
101}
102
103static ParseResult parseCreateOperationOpResults(
104 OpAsmParser &p,
105 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultOperands,
106 SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) {
107 if (failed(result: p.parseOptionalArrow()))
108 return success();
109
110 // Handle the case of inferred results.
111 if (succeeded(result: p.parseOptionalLess())) {
112 if (p.parseKeyword(keyword: "inferred") || p.parseGreater())
113 return failure();
114 inferredResultTypes = p.getBuilder().getUnitAttr();
115 return success();
116 }
117
118 // Otherwise, parse the explicit results.
119 return failure(isFailure: p.parseLParen() || p.parseOperandList(result&: resultOperands) ||
120 p.parseColonTypeList(result&: resultTypes) || p.parseRParen());
121}
122
123static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op,
124 OperandRange resultOperands,
125 TypeRange resultTypes,
126 UnitAttr inferredResultTypes) {
127 // Handle the case of inferred results.
128 if (inferredResultTypes) {
129 p << " -> <inferred>";
130 return;
131 }
132
133 // Otherwise, handle the explicit results.
134 if (!resultTypes.empty())
135 p << " -> (" << resultOperands << " : " << resultTypes << ")";
136}
137
138//===----------------------------------------------------------------------===//
139// pdl_interp::ForEachOp
140//===----------------------------------------------------------------------===//
141
142void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
143 Value range, Block *successor, bool initLoop) {
144 build(builder, state, range, successor);
145 if (initLoop) {
146 // Create the block and the loop variable.
147 // FIXME: Allow passing in a proper location for the loop variable.
148 auto rangeType = llvm::cast<pdl::RangeType>(range.getType());
149 state.regions.front()->emplaceBlock();
150 state.regions.front()->addArgument(rangeType.getElementType(),
151 state.location);
152 }
153}
154
155ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
156 // Parse the loop variable followed by type.
157 OpAsmParser::Argument loopVariable;
158 OpAsmParser::UnresolvedOperand operandInfo;
159 if (parser.parseArgument(loopVariable, /*allowType=*/true) ||
160 parser.parseKeyword("in", " after loop variable") ||
161 // Parse the operand (value range).
162 parser.parseOperand(operandInfo))
163 return failure();
164
165 // Resolve the operand.
166 Type rangeType = pdl::RangeType::get(loopVariable.type);
167 if (parser.resolveOperand(operandInfo, rangeType, result.operands))
168 return failure();
169
170 // Parse the body region.
171 Region *body = result.addRegion();
172 Block *successor;
173 if (parser.parseRegion(*body, loopVariable) ||
174 parser.parseOptionalAttrDict(result.attributes) ||
175 // Parse the successor.
176 parser.parseArrow() || parser.parseSuccessor(successor))
177 return failure();
178
179 result.addSuccessors(successor);
180 return success();
181}
182
183void ForEachOp::print(OpAsmPrinter &p) {
184 BlockArgument arg = getLoopVariable();
185 p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
186 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
187 p.printOptionalAttrDict((*this)->getAttrs());
188 p << " -> ";
189 p.printSuccessor(getSuccessor());
190}
191
192LogicalResult ForEachOp::verify() {
193 // Verify that the operation has exactly one argument.
194 if (getRegion().getNumArguments() != 1)
195 return emitOpError("requires exactly one argument");
196
197 // Verify that the loop variable and the operand (value range)
198 // have compatible types.
199 BlockArgument arg = getLoopVariable();
200 Type rangeType = pdl::RangeType::get(arg.getType());
201 if (rangeType != getValues().getType())
202 return emitOpError("operand must be a range of loop variable type");
203
204 return success();
205}
206
207//===----------------------------------------------------------------------===//
208// pdl_interp::FuncOp
209//===----------------------------------------------------------------------===//
210
211void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
212 FunctionType type, ArrayRef<NamedAttribute> attrs) {
213 buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
214}
215
216ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
217 auto buildFuncType =
218 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
219 function_interface_impl::VariadicFlag,
220 std::string &) { return builder.getFunctionType(argTypes, results); };
221
222 return function_interface_impl::parseFunctionOp(
223 parser, result, /*allowVariadic=*/false,
224 getFunctionTypeAttrName(result.name), buildFuncType,
225 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
226}
227
228void FuncOp::print(OpAsmPrinter &p) {
229 function_interface_impl::printFunctionOp(
230 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
231 getArgAttrsAttrName(), getResAttrsAttrName());
232}
233
234//===----------------------------------------------------------------------===//
235// pdl_interp::GetValueTypeOp
236//===----------------------------------------------------------------------===//
237
238/// Given the result type of a `GetValueTypeOp`, return the expected input type.
239static Type getGetValueTypeOpValueType(Type type) {
240 Type valueTy = pdl::ValueType::get(type.getContext());
241 return llvm::isa<pdl::RangeType>(type) ? pdl::RangeType::get(valueTy)
242 : valueTy;
243}
244
245//===----------------------------------------------------------------------===//
246// pdl::CreateRangeOp
247//===----------------------------------------------------------------------===//
248
249static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
250 Type &resultType) {
251 // If arguments were provided, infer the result type from the argument list.
252 if (!argumentTypes.empty()) {
253 resultType =
254 pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0]));
255 return success();
256 }
257 // Otherwise, parse the type as a trailing type.
258 return p.parseColonType(result&: resultType);
259}
260
261static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
262 TypeRange argumentTypes, Type resultType) {
263 if (argumentTypes.empty())
264 p << ": " << resultType;
265}
266
267LogicalResult CreateRangeOp::verify() {
268 Type elementType = getType().getElementType();
269 for (Type operandType : getOperandTypes()) {
270 Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType);
271 if (operandElementType != elementType) {
272 return emitOpError("expected operand to have element type ")
273 << elementType << ", but got " << operandElementType;
274 }
275 }
276 return success();
277}
278
279//===----------------------------------------------------------------------===//
280// pdl_interp::SwitchAttributeOp
281//===----------------------------------------------------------------------===//
282
283LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); }
284
285//===----------------------------------------------------------------------===//
286// pdl_interp::SwitchOperandCountOp
287//===----------------------------------------------------------------------===//
288
289LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }
290
291//===----------------------------------------------------------------------===//
292// pdl_interp::SwitchOperationNameOp
293//===----------------------------------------------------------------------===//
294
295LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }
296
297//===----------------------------------------------------------------------===//
298// pdl_interp::SwitchResultCountOp
299//===----------------------------------------------------------------------===//
300
301LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); }
302
303//===----------------------------------------------------------------------===//
304// pdl_interp::SwitchTypeOp
305//===----------------------------------------------------------------------===//
306
307LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); }
308
309//===----------------------------------------------------------------------===//
310// pdl_interp::SwitchTypesOp
311//===----------------------------------------------------------------------===//
312
313LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); }
314
315//===----------------------------------------------------------------------===//
316// TableGen Auto-Generated Op and Interface Definitions
317//===----------------------------------------------------------------------===//
318
319#define GET_OP_CLASSES
320#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
321

source code of mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp