1 | //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" |
10 | #include "mlir/Dialect/PDL/IR/PDLTypes.h" |
11 | #include "mlir/IR/BuiltinTypes.h" |
12 | #include "mlir/IR/DialectImplementation.h" |
13 | #include "mlir/Interfaces/FunctionImplementation.h" |
14 | |
15 | using namespace mlir; |
16 | using namespace mlir::pdl_interp; |
17 | |
18 | #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc" |
19 | |
20 | //===----------------------------------------------------------------------===// |
21 | // PDLInterp Dialect |
22 | //===----------------------------------------------------------------------===// |
23 | |
24 | void PDLInterpDialect::initialize() { |
25 | addOperations< |
26 | #define GET_OP_LIST |
27 | #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" |
28 | >(); |
29 | } |
30 | |
31 | template <typename OpT> |
32 | static LogicalResult verifySwitchOp(OpT op) { |
33 | // Verify that the number of case destinations matches the number of case |
34 | // values. |
35 | size_t numDests = op.getCases().size(); |
36 | size_t numValues = op.getCaseValues().size(); |
37 | if (numDests != numValues) { |
38 | return op.emitOpError( |
39 | "expected number of cases to match the number of case " |
40 | "values, got " ) |
41 | << numDests << " but expected " << numValues; |
42 | } |
43 | return success(); |
44 | } |
45 | |
46 | //===----------------------------------------------------------------------===// |
47 | // pdl_interp::CreateOperationOp |
48 | //===----------------------------------------------------------------------===// |
49 | |
50 | LogicalResult CreateOperationOp::verify() { |
51 | if (!getInferredResultTypes()) |
52 | return success(); |
53 | if (!getInputResultTypes().empty()) { |
54 | return emitOpError("with inferred results cannot also have " |
55 | "explicit result types" ); |
56 | } |
57 | OperationName opName(getName(), getContext()); |
58 | if (!opName.hasInterface<InferTypeOpInterface>()) { |
59 | return emitOpError() |
60 | << "has inferred results, but the created operation '" << opName |
61 | << "' does not support result type inference (or is not " |
62 | "registered)" ; |
63 | } |
64 | return success(); |
65 | } |
66 | |
67 | static ParseResult parseCreateOperationOpAttributes( |
68 | OpAsmParser &p, |
69 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands, |
70 | ArrayAttr &attrNamesAttr) { |
71 | Builder &builder = p.getBuilder(); |
72 | SmallVector<Attribute, 4> attrNames; |
73 | if (succeeded(result: p.parseOptionalLBrace())) { |
74 | auto parseOperands = [&]() { |
75 | StringAttr nameAttr; |
76 | OpAsmParser::UnresolvedOperand operand; |
77 | if (p.parseAttribute(nameAttr) || p.parseEqual() || |
78 | p.parseOperand(result&: operand)) |
79 | return failure(); |
80 | attrNames.push_back(Elt: nameAttr); |
81 | attrOperands.push_back(Elt: operand); |
82 | return success(); |
83 | }; |
84 | if (p.parseCommaSeparatedList(parseElementFn: parseOperands) || p.parseRBrace()) |
85 | return failure(); |
86 | } |
87 | attrNamesAttr = builder.getArrayAttr(attrNames); |
88 | return success(); |
89 | } |
90 | |
91 | static void printCreateOperationOpAttributes(OpAsmPrinter &p, |
92 | CreateOperationOp op, |
93 | OperandRange attrArgs, |
94 | ArrayAttr attrNames) { |
95 | if (attrNames.empty()) |
96 | return; |
97 | p << " {" ; |
98 | interleaveComma(llvm::seq<int>(0, attrNames.size()), p, |
99 | [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); |
100 | p << '}'; |
101 | } |
102 | |
103 | static ParseResult parseCreateOperationOpResults( |
104 | OpAsmParser &p, |
105 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &resultOperands, |
106 | SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) { |
107 | if (failed(result: p.parseOptionalArrow())) |
108 | return success(); |
109 | |
110 | // Handle the case of inferred results. |
111 | if (succeeded(result: p.parseOptionalLess())) { |
112 | if (p.parseKeyword(keyword: "inferred" ) || p.parseGreater()) |
113 | return failure(); |
114 | inferredResultTypes = p.getBuilder().getUnitAttr(); |
115 | return success(); |
116 | } |
117 | |
118 | // Otherwise, parse the explicit results. |
119 | return failure(isFailure: p.parseLParen() || p.parseOperandList(result&: resultOperands) || |
120 | p.parseColonTypeList(result&: resultTypes) || p.parseRParen()); |
121 | } |
122 | |
123 | static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op, |
124 | OperandRange resultOperands, |
125 | TypeRange resultTypes, |
126 | UnitAttr inferredResultTypes) { |
127 | // Handle the case of inferred results. |
128 | if (inferredResultTypes) { |
129 | p << " -> <inferred>" ; |
130 | return; |
131 | } |
132 | |
133 | // Otherwise, handle the explicit results. |
134 | if (!resultTypes.empty()) |
135 | p << " -> (" << resultOperands << " : " << resultTypes << ")" ; |
136 | } |
137 | |
138 | //===----------------------------------------------------------------------===// |
139 | // pdl_interp::ForEachOp |
140 | //===----------------------------------------------------------------------===// |
141 | |
142 | void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, |
143 | Value range, Block *successor, bool initLoop) { |
144 | build(builder, state, range, successor); |
145 | if (initLoop) { |
146 | // Create the block and the loop variable. |
147 | // FIXME: Allow passing in a proper location for the loop variable. |
148 | auto rangeType = llvm::cast<pdl::RangeType>(range.getType()); |
149 | state.regions.front()->emplaceBlock(); |
150 | state.regions.front()->addArgument(rangeType.getElementType(), |
151 | state.location); |
152 | } |
153 | } |
154 | |
155 | ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) { |
156 | // Parse the loop variable followed by type. |
157 | OpAsmParser::Argument loopVariable; |
158 | OpAsmParser::UnresolvedOperand operandInfo; |
159 | if (parser.parseArgument(loopVariable, /*allowType=*/true) || |
160 | parser.parseKeyword("in" , " after loop variable" ) || |
161 | // Parse the operand (value range). |
162 | parser.parseOperand(operandInfo)) |
163 | return failure(); |
164 | |
165 | // Resolve the operand. |
166 | Type rangeType = pdl::RangeType::get(loopVariable.type); |
167 | if (parser.resolveOperand(operandInfo, rangeType, result.operands)) |
168 | return failure(); |
169 | |
170 | // Parse the body region. |
171 | Region *body = result.addRegion(); |
172 | Block *successor; |
173 | if (parser.parseRegion(*body, loopVariable) || |
174 | parser.parseOptionalAttrDict(result.attributes) || |
175 | // Parse the successor. |
176 | parser.parseArrow() || parser.parseSuccessor(successor)) |
177 | return failure(); |
178 | |
179 | result.addSuccessors(successor); |
180 | return success(); |
181 | } |
182 | |
183 | void ForEachOp::print(OpAsmPrinter &p) { |
184 | BlockArgument arg = getLoopVariable(); |
185 | p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' '; |
186 | p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
187 | p.printOptionalAttrDict((*this)->getAttrs()); |
188 | p << " -> " ; |
189 | p.printSuccessor(getSuccessor()); |
190 | } |
191 | |
192 | LogicalResult ForEachOp::verify() { |
193 | // Verify that the operation has exactly one argument. |
194 | if (getRegion().getNumArguments() != 1) |
195 | return emitOpError("requires exactly one argument" ); |
196 | |
197 | // Verify that the loop variable and the operand (value range) |
198 | // have compatible types. |
199 | BlockArgument arg = getLoopVariable(); |
200 | Type rangeType = pdl::RangeType::get(arg.getType()); |
201 | if (rangeType != getValues().getType()) |
202 | return emitOpError("operand must be a range of loop variable type" ); |
203 | |
204 | return success(); |
205 | } |
206 | |
207 | //===----------------------------------------------------------------------===// |
208 | // pdl_interp::FuncOp |
209 | //===----------------------------------------------------------------------===// |
210 | |
211 | void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, |
212 | FunctionType type, ArrayRef<NamedAttribute> attrs) { |
213 | buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); |
214 | } |
215 | |
216 | ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { |
217 | auto buildFuncType = |
218 | [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, |
219 | function_interface_impl::VariadicFlag, |
220 | std::string &) { return builder.getFunctionType(argTypes, results); }; |
221 | |
222 | return function_interface_impl::parseFunctionOp( |
223 | parser, result, /*allowVariadic=*/false, |
224 | getFunctionTypeAttrName(result.name), buildFuncType, |
225 | getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); |
226 | } |
227 | |
228 | void FuncOp::print(OpAsmPrinter &p) { |
229 | function_interface_impl::printFunctionOp( |
230 | p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), |
231 | getArgAttrsAttrName(), getResAttrsAttrName()); |
232 | } |
233 | |
234 | //===----------------------------------------------------------------------===// |
235 | // pdl_interp::GetValueTypeOp |
236 | //===----------------------------------------------------------------------===// |
237 | |
238 | /// Given the result type of a `GetValueTypeOp`, return the expected input type. |
239 | static Type getGetValueTypeOpValueType(Type type) { |
240 | Type valueTy = pdl::ValueType::get(type.getContext()); |
241 | return llvm::isa<pdl::RangeType>(type) ? pdl::RangeType::get(valueTy) |
242 | : valueTy; |
243 | } |
244 | |
245 | //===----------------------------------------------------------------------===// |
246 | // pdl::CreateRangeOp |
247 | //===----------------------------------------------------------------------===// |
248 | |
249 | static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, |
250 | Type &resultType) { |
251 | // If arguments were provided, infer the result type from the argument list. |
252 | if (!argumentTypes.empty()) { |
253 | resultType = |
254 | pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0])); |
255 | return success(); |
256 | } |
257 | // Otherwise, parse the type as a trailing type. |
258 | return p.parseColonType(result&: resultType); |
259 | } |
260 | |
261 | static void printRangeType(OpAsmPrinter &p, CreateRangeOp op, |
262 | TypeRange argumentTypes, Type resultType) { |
263 | if (argumentTypes.empty()) |
264 | p << ": " << resultType; |
265 | } |
266 | |
267 | LogicalResult CreateRangeOp::verify() { |
268 | Type elementType = getType().getElementType(); |
269 | for (Type operandType : getOperandTypes()) { |
270 | Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType); |
271 | if (operandElementType != elementType) { |
272 | return emitOpError("expected operand to have element type " ) |
273 | << elementType << ", but got " << operandElementType; |
274 | } |
275 | } |
276 | return success(); |
277 | } |
278 | |
279 | //===----------------------------------------------------------------------===// |
280 | // pdl_interp::SwitchAttributeOp |
281 | //===----------------------------------------------------------------------===// |
282 | |
283 | LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); } |
284 | |
285 | //===----------------------------------------------------------------------===// |
286 | // pdl_interp::SwitchOperandCountOp |
287 | //===----------------------------------------------------------------------===// |
288 | |
289 | LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); } |
290 | |
291 | //===----------------------------------------------------------------------===// |
292 | // pdl_interp::SwitchOperationNameOp |
293 | //===----------------------------------------------------------------------===// |
294 | |
295 | LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); } |
296 | |
297 | //===----------------------------------------------------------------------===// |
298 | // pdl_interp::SwitchResultCountOp |
299 | //===----------------------------------------------------------------------===// |
300 | |
301 | LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); } |
302 | |
303 | //===----------------------------------------------------------------------===// |
304 | // pdl_interp::SwitchTypeOp |
305 | //===----------------------------------------------------------------------===// |
306 | |
307 | LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); } |
308 | |
309 | //===----------------------------------------------------------------------===// |
310 | // pdl_interp::SwitchTypesOp |
311 | //===----------------------------------------------------------------------===// |
312 | |
313 | LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); } |
314 | |
315 | //===----------------------------------------------------------------------===// |
316 | // TableGen Auto-Generated Op and Interface Definitions |
317 | //===----------------------------------------------------------------------===// |
318 | |
319 | #define GET_OP_CLASSES |
320 | #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc" |
321 | |