1//===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10#include "mlir/Dialect/PDL/IR/PDLTypes.h"
11#include "mlir/IR/BuiltinTypes.h"
12#include "mlir/Interfaces/FunctionImplementation.h"
13
14using namespace mlir;
15using namespace mlir::pdl_interp;
16
17#include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
18
19//===----------------------------------------------------------------------===//
20// PDLInterp Dialect
21//===----------------------------------------------------------------------===//
22
23void PDLInterpDialect::initialize() {
24 addOperations<
25#define GET_OP_LIST
26#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
27 >();
28}
29
30template <typename OpT>
31static LogicalResult verifySwitchOp(OpT op) {
32 // Verify that the number of case destinations matches the number of case
33 // values.
34 size_t numDests = op.getCases().size();
35 size_t numValues = op.getCaseValues().size();
36 if (numDests != numValues) {
37 return op.emitOpError(
38 "expected number of cases to match the number of case "
39 "values, got ")
40 << numDests << " but expected " << numValues;
41 }
42 return success();
43}
44
45//===----------------------------------------------------------------------===//
46// pdl_interp::CreateOperationOp
47//===----------------------------------------------------------------------===//
48
49LogicalResult CreateOperationOp::verify() {
50 if (!getInferredResultTypes())
51 return success();
52 if (!getInputResultTypes().empty()) {
53 return emitOpError(message: "with inferred results cannot also have "
54 "explicit result types");
55 }
56 OperationName opName(getName(), getContext());
57 if (!opName.hasInterface<InferTypeOpInterface>()) {
58 return emitOpError()
59 << "has inferred results, but the created operation '" << opName
60 << "' does not support result type inference (or is not "
61 "registered)";
62 }
63 return success();
64}
65
66static ParseResult parseCreateOperationOpAttributes(
67 OpAsmParser &p,
68 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
69 ArrayAttr &attrNamesAttr) {
70 Builder &builder = p.getBuilder();
71 SmallVector<Attribute, 4> attrNames;
72 if (succeeded(Result: p.parseOptionalLBrace())) {
73 auto parseOperands = [&]() {
74 StringAttr nameAttr;
75 OpAsmParser::UnresolvedOperand operand;
76 if (p.parseAttribute(result&: nameAttr) || p.parseEqual() ||
77 p.parseOperand(result&: operand))
78 return failure();
79 attrNames.push_back(Elt: nameAttr);
80 attrOperands.push_back(Elt: operand);
81 return success();
82 };
83 if (p.parseCommaSeparatedList(parseElementFn: parseOperands) || p.parseRBrace())
84 return failure();
85 }
86 attrNamesAttr = builder.getArrayAttr(value: attrNames);
87 return success();
88}
89
90static void printCreateOperationOpAttributes(OpAsmPrinter &p,
91 CreateOperationOp op,
92 OperandRange attrArgs,
93 ArrayAttr attrNames) {
94 if (attrNames.empty())
95 return;
96 p << " {";
97 interleaveComma(c: llvm::seq<int>(Begin: 0, End: attrNames.size()), os&: p,
98 each_fn: [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
99 p << '}';
100}
101
102static ParseResult parseCreateOperationOpResults(
103 OpAsmParser &p,
104 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultOperands,
105 SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) {
106 if (failed(Result: p.parseOptionalArrow()))
107 return success();
108
109 // Handle the case of inferred results.
110 if (succeeded(Result: p.parseOptionalLess())) {
111 if (p.parseKeyword(keyword: "inferred") || p.parseGreater())
112 return failure();
113 inferredResultTypes = p.getBuilder().getUnitAttr();
114 return success();
115 }
116
117 // Otherwise, parse the explicit results.
118 return failure(IsFailure: p.parseLParen() || p.parseOperandList(result&: resultOperands) ||
119 p.parseColonTypeList(result&: resultTypes) || p.parseRParen());
120}
121
122static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op,
123 OperandRange resultOperands,
124 TypeRange resultTypes,
125 UnitAttr inferredResultTypes) {
126 // Handle the case of inferred results.
127 if (inferredResultTypes) {
128 p << " -> <inferred>";
129 return;
130 }
131
132 // Otherwise, handle the explicit results.
133 if (!resultTypes.empty())
134 p << " -> (" << resultOperands << " : " << resultTypes << ")";
135}
136
137//===----------------------------------------------------------------------===//
138// pdl_interp::ForEachOp
139//===----------------------------------------------------------------------===//
140
141void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
142 Value range, Block *successor, bool initLoop) {
143 build(odsBuilder&: builder, odsState&: state, values: range, successor);
144 if (initLoop) {
145 // Create the block and the loop variable.
146 // FIXME: Allow passing in a proper location for the loop variable.
147 auto rangeType = llvm::cast<pdl::RangeType>(Val: range.getType());
148 state.regions.front()->emplaceBlock();
149 state.regions.front()->addArgument(type: rangeType.getElementType(),
150 loc: state.location);
151 }
152}
153
154ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
155 // Parse the loop variable followed by type.
156 OpAsmParser::Argument loopVariable;
157 OpAsmParser::UnresolvedOperand operandInfo;
158 if (parser.parseArgument(result&: loopVariable, /*allowType=*/true) ||
159 parser.parseKeyword(keyword: "in", msg: " after loop variable") ||
160 // Parse the operand (value range).
161 parser.parseOperand(result&: operandInfo))
162 return failure();
163
164 // Resolve the operand.
165 Type rangeType = pdl::RangeType::get(elementType: loopVariable.type);
166 if (parser.resolveOperand(operand: operandInfo, type: rangeType, result&: result.operands))
167 return failure();
168
169 // Parse the body region.
170 Region *body = result.addRegion();
171 Block *successor;
172 if (parser.parseRegion(region&: *body, arguments: loopVariable) ||
173 parser.parseOptionalAttrDict(result&: result.attributes) ||
174 // Parse the successor.
175 parser.parseArrow() || parser.parseSuccessor(dest&: successor))
176 return failure();
177
178 result.addSuccessors(successor);
179 return success();
180}
181
182void ForEachOp::print(OpAsmPrinter &p) {
183 BlockArgument arg = getLoopVariable();
184 p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
185 p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false);
186 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
187 p << " -> ";
188 p.printSuccessor(successor: getSuccessor());
189}
190
191LogicalResult ForEachOp::verify() {
192 // Verify that the operation has exactly one argument.
193 if (getRegion().getNumArguments() != 1)
194 return emitOpError(message: "requires exactly one argument");
195
196 // Verify that the loop variable and the operand (value range)
197 // have compatible types.
198 BlockArgument arg = getLoopVariable();
199 Type rangeType = pdl::RangeType::get(elementType: arg.getType());
200 if (rangeType != getValues().getType())
201 return emitOpError(message: "operand must be a range of loop variable type");
202
203 return success();
204}
205
206//===----------------------------------------------------------------------===//
207// pdl_interp::FuncOp
208//===----------------------------------------------------------------------===//
209
210void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
211 FunctionType type, ArrayRef<NamedAttribute> attrs) {
212 buildWithEntryBlock(builder, state, name, type, attrs, inputTypes: type.getInputs());
213}
214
215ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
216 auto buildFuncType =
217 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
218 function_interface_impl::VariadicFlag,
219 std::string &) { return builder.getFunctionType(inputs: argTypes, results); };
220
221 return function_interface_impl::parseFunctionOp(
222 parser, result, /*allowVariadic=*/false,
223 typeAttrName: getFunctionTypeAttrName(name: result.name), funcTypeBuilder: buildFuncType,
224 argAttrsName: getArgAttrsAttrName(name: result.name), resAttrsName: getResAttrsAttrName(name: result.name));
225}
226
227void FuncOp::print(OpAsmPrinter &p) {
228 function_interface_impl::printFunctionOp(
229 p, op: *this, /*isVariadic=*/false, typeAttrName: getFunctionTypeAttrName(),
230 argAttrsName: getArgAttrsAttrName(), resAttrsName: getResAttrsAttrName());
231}
232
233//===----------------------------------------------------------------------===//
234// pdl_interp::GetValueTypeOp
235//===----------------------------------------------------------------------===//
236
237/// Given the result type of a `GetValueTypeOp`, return the expected input type.
238static Type getGetValueTypeOpValueType(Type type) {
239 Type valueTy = pdl::ValueType::get(ctx: type.getContext());
240 return llvm::isa<pdl::RangeType>(Val: type) ? pdl::RangeType::get(elementType: valueTy)
241 : valueTy;
242}
243
244//===----------------------------------------------------------------------===//
245// pdl::CreateRangeOp
246//===----------------------------------------------------------------------===//
247
248static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
249 Type &resultType) {
250 // If arguments were provided, infer the result type from the argument list.
251 if (!argumentTypes.empty()) {
252 resultType =
253 pdl::RangeType::get(elementType: pdl::getRangeElementTypeOrSelf(type: argumentTypes[0]));
254 return success();
255 }
256 // Otherwise, parse the type as a trailing type.
257 return p.parseColonType(result&: resultType);
258}
259
260static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
261 TypeRange argumentTypes, Type resultType) {
262 if (argumentTypes.empty())
263 p << ": " << resultType;
264}
265
266LogicalResult CreateRangeOp::verify() {
267 Type elementType = getType().getElementType();
268 for (Type operandType : getOperandTypes()) {
269 Type operandElementType = pdl::getRangeElementTypeOrSelf(type: operandType);
270 if (operandElementType != elementType) {
271 return emitOpError(message: "expected operand to have element type ")
272 << elementType << ", but got " << operandElementType;
273 }
274 }
275 return success();
276}
277
278//===----------------------------------------------------------------------===//
279// pdl_interp::SwitchAttributeOp
280//===----------------------------------------------------------------------===//
281
282LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(op: *this); }
283
284//===----------------------------------------------------------------------===//
285// pdl_interp::SwitchOperandCountOp
286//===----------------------------------------------------------------------===//
287
288LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(op: *this); }
289
290//===----------------------------------------------------------------------===//
291// pdl_interp::SwitchOperationNameOp
292//===----------------------------------------------------------------------===//
293
294LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(op: *this); }
295
296//===----------------------------------------------------------------------===//
297// pdl_interp::SwitchResultCountOp
298//===----------------------------------------------------------------------===//
299
300LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(op: *this); }
301
302//===----------------------------------------------------------------------===//
303// pdl_interp::SwitchTypeOp
304//===----------------------------------------------------------------------===//
305
306LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(op: *this); }
307
308//===----------------------------------------------------------------------===//
309// pdl_interp::SwitchTypesOp
310//===----------------------------------------------------------------------===//
311
312LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(op: *this); }
313
314//===----------------------------------------------------------------------===//
315// TableGen Auto-Generated Op and Interface Definitions
316//===----------------------------------------------------------------------===//
317
318#define GET_OP_CLASSES
319#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
320

source code of mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp