1//===- MLIRGen.cpp --------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
10#include "mlir/AsmParser/AsmParser.h"
11#include "mlir/Dialect/PDL/IR/PDL.h"
12#include "mlir/Dialect/PDL/IR/PDLOps.h"
13#include "mlir/Dialect/PDL/IR/PDLTypes.h"
14#include "mlir/IR/Builders.h"
15#include "mlir/IR/BuiltinOps.h"
16#include "mlir/IR/Verifier.h"
17#include "mlir/Tools/PDLL/AST/Context.h"
18#include "mlir/Tools/PDLL/AST/Nodes.h"
19#include "mlir/Tools/PDLL/AST/Types.h"
20#include "mlir/Tools/PDLL/ODS/Context.h"
21#include "mlir/Tools/PDLL/ODS/Operation.h"
22#include "llvm/ADT/ScopedHashTable.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include <optional>
26
27using namespace mlir;
28using namespace mlir::pdll;
29
30//===----------------------------------------------------------------------===//
31// CodeGen
32//===----------------------------------------------------------------------===//
33
34namespace {
35class CodeGen {
36public:
37 CodeGen(MLIRContext *mlirContext, const ast::Context &context,
38 const llvm::SourceMgr &sourceMgr)
39 : builder(mlirContext), odsContext(context.getODSContext()),
40 sourceMgr(sourceMgr) {
41 // Make sure that the PDL dialect is loaded.
42 mlirContext->loadDialect<pdl::PDLDialect>();
43 }
44
45 OwningOpRef<ModuleOp> generate(const ast::Module &module);
46
47private:
48 /// Generate an MLIR location from the given source location.
49 Location genLoc(llvm::SMLoc loc);
50 Location genLoc(llvm::SMRange loc) { return genLoc(loc: loc.Start); }
51
52 /// Generate an MLIR type from the given source type.
53 Type genType(ast::Type type);
54
55 /// Generate MLIR for the given AST node.
56 void gen(const ast::Node *node);
57
58 //===--------------------------------------------------------------------===//
59 // Statements
60 //===--------------------------------------------------------------------===//
61
62 void genImpl(const ast::CompoundStmt *stmt);
63 void genImpl(const ast::EraseStmt *stmt);
64 void genImpl(const ast::LetStmt *stmt);
65 void genImpl(const ast::ReplaceStmt *stmt);
66 void genImpl(const ast::RewriteStmt *stmt);
67 void genImpl(const ast::ReturnStmt *stmt);
68
69 //===--------------------------------------------------------------------===//
70 // Decls
71 //===--------------------------------------------------------------------===//
72
73 void genImpl(const ast::UserConstraintDecl *decl);
74 void genImpl(const ast::UserRewriteDecl *decl);
75 void genImpl(const ast::PatternDecl *decl);
76
77 /// Generate the set of MLIR values defined for the given variable decl, and
78 /// apply any attached constraints.
79 SmallVector<Value> genVar(const ast::VariableDecl *varDecl);
80
81 /// Generate the value for a variable that does not have an initializer
82 /// expression, i.e. create the PDL value based on the type/constraints of the
83 /// variable.
84 Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc);
85
86 /// Apply the constraints of the given variable to `values`, which correspond
87 /// to the MLIR values of the variable.
88 void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values);
89
90 //===--------------------------------------------------------------------===//
91 // Expressions
92 //===--------------------------------------------------------------------===//
93
94 Value genSingleExpr(const ast::Expr *expr);
95 SmallVector<Value> genExpr(const ast::Expr *expr);
96 Value genExprImpl(const ast::AttributeExpr *expr);
97 SmallVector<Value> genExprImpl(const ast::CallExpr *expr);
98 SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr);
99 Value genExprImpl(const ast::MemberAccessExpr *expr);
100 Value genExprImpl(const ast::OperationExpr *expr);
101 Value genExprImpl(const ast::RangeExpr *expr);
102 SmallVector<Value> genExprImpl(const ast::TupleExpr *expr);
103 Value genExprImpl(const ast::TypeExpr *expr);
104
105 SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
106 Location loc, ValueRange inputs,
107 bool isNegated = false);
108 SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
109 Location loc, ValueRange inputs);
110 template <typename PDLOpT, typename T>
111 SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
112 ValueRange inputs,
113 bool isNegated = false);
114
115 //===--------------------------------------------------------------------===//
116 // Fields
117 //===--------------------------------------------------------------------===//
118
119 /// The MLIR builder used for building the resultant IR.
120 OpBuilder builder;
121
122 /// A map from variable declarations to the MLIR equivalent.
123 using VariableMapTy =
124 llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
125 VariableMapTy variables;
126
127 /// A reference to the ODS context.
128 const ods::Context &odsContext;
129
130 /// The source manager of the PDLL ast.
131 const llvm::SourceMgr &sourceMgr;
132};
133} // namespace
134
135OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) {
136 OwningOpRef<ModuleOp> mlirModule =
137 builder.create<ModuleOp>(genLoc(loc: module.getLoc()));
138 builder.setInsertionPointToStart(mlirModule->getBody());
139
140 // Generate code for each of the decls within the module.
141 for (const ast::Decl *decl : module.getChildren())
142 gen(node: decl);
143
144 return mlirModule;
145}
146
147Location CodeGen::genLoc(llvm::SMLoc loc) {
148 unsigned fileID = sourceMgr.FindBufferContainingLoc(Loc: loc);
149
150 // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can
151 // use it here.
152 auto &bufferInfo = sourceMgr.getBufferInfo(i: fileID);
153 unsigned lineNo = bufferInfo.getLineNumber(Ptr: loc.getPointer());
154 unsigned column =
155 (loc.getPointer() - bufferInfo.getPointerForLineNumber(LineNo: lineNo)) + 1;
156 auto *buffer = sourceMgr.getMemoryBuffer(i: fileID);
157
158 return FileLineColLoc::get(builder.getContext(),
159 buffer->getBufferIdentifier(), lineNo, column);
160}
161
162Type CodeGen::genType(ast::Type type) {
163 return TypeSwitch<ast::Type, Type>(type)
164 .Case(caseFn: [&](ast::AttributeType astType) -> Type {
165 return builder.getType<pdl::AttributeType>();
166 })
167 .Case(caseFn: [&](ast::OperationType astType) -> Type {
168 return builder.getType<pdl::OperationType>();
169 })
170 .Case(caseFn: [&](ast::TypeType astType) -> Type {
171 return builder.getType<pdl::TypeType>();
172 })
173 .Case(caseFn: [&](ast::ValueType astType) -> Type {
174 return builder.getType<pdl::ValueType>();
175 })
176 .Case(caseFn: [&](ast::RangeType astType) -> Type {
177 return pdl::RangeType::get(genType(astType.getElementType()));
178 });
179}
180
181void CodeGen::gen(const ast::Node *node) {
182 TypeSwitch<const ast::Node *>(node)
183 .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt,
184 const ast::ReplaceStmt, const ast::RewriteStmt,
185 const ast::ReturnStmt, const ast::UserConstraintDecl,
186 const ast::UserRewriteDecl, const ast::PatternDecl>(
187 caseFn: [&](auto derivedNode) { this->genImpl(derivedNode); })
188 .Case(caseFn: [&](const ast::Expr *expr) { genExpr(expr); });
189}
190
191//===----------------------------------------------------------------------===//
192// CodeGen: Statements
193//===----------------------------------------------------------------------===//
194
195void CodeGen::genImpl(const ast::CompoundStmt *stmt) {
196 VariableMapTy::ScopeTy varScope(variables);
197 for (const ast::Stmt *childStmt : stmt->getChildren())
198 gen(node: childStmt);
199}
200
201/// If the given builder is nested under a PDL PatternOp, build a rewrite
202/// operation and update the builder to nest under it. This is necessary for
203/// PDLL operation rewrite statements that are directly nested within a Pattern.
204static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
205 Location loc) {
206 if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) {
207 pdl::RewriteOp rewrite =
208 builder.create<pdl::RewriteOp>(loc, rootExpr, /*name=*/StringAttr(),
209 /*externalArgs=*/ValueRange());
210 builder.createBlock(&rewrite.getBodyRegion());
211 }
212}
213
214void CodeGen::genImpl(const ast::EraseStmt *stmt) {
215 OpBuilder::InsertionGuard insertGuard(builder);
216 Value rootExpr = genSingleExpr(expr: stmt->getRootOpExpr());
217 Location loc = genLoc(loc: stmt->getLoc());
218
219 // Make sure we are nested in a RewriteOp.
220 OpBuilder::InsertionGuard guard(builder);
221 checkAndNestUnderRewriteOp(builder, rootExpr, loc);
222 builder.create<pdl::EraseOp>(loc, rootExpr);
223}
224
225void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(varDecl: stmt->getVarDecl()); }
226
227void CodeGen::genImpl(const ast::ReplaceStmt *stmt) {
228 OpBuilder::InsertionGuard insertGuard(builder);
229 Value rootExpr = genSingleExpr(expr: stmt->getRootOpExpr());
230 Location loc = genLoc(loc: stmt->getLoc());
231
232 // Make sure we are nested in a RewriteOp.
233 OpBuilder::InsertionGuard guard(builder);
234 checkAndNestUnderRewriteOp(builder, rootExpr, loc);
235
236 SmallVector<Value> replValues;
237 for (ast::Expr *replExpr : stmt->getReplExprs())
238 replValues.push_back(Elt: genSingleExpr(expr: replExpr));
239
240 // Check to see if the statement has a replacement operation, or a range of
241 // replacement values.
242 bool usesReplOperation =
243 replValues.size() == 1 &&
244 isa<pdl::OperationType>(replValues.front().getType());
245 builder.create<pdl::ReplaceOp>(
246 loc, rootExpr, usesReplOperation ? replValues[0] : Value(),
247 usesReplOperation ? ValueRange() : ValueRange(replValues));
248}
249
250void CodeGen::genImpl(const ast::RewriteStmt *stmt) {
251 OpBuilder::InsertionGuard insertGuard(builder);
252 Value rootExpr = genSingleExpr(expr: stmt->getRootOpExpr());
253
254 // Make sure we are nested in a RewriteOp.
255 OpBuilder::InsertionGuard guard(builder);
256 checkAndNestUnderRewriteOp(builder, rootExpr, loc: genLoc(loc: stmt->getLoc()));
257 gen(node: stmt->getRewriteBody());
258}
259
260void CodeGen::genImpl(const ast::ReturnStmt *stmt) {
261 // ReturnStmt generation is handled by the respective constraint or rewrite
262 // parent node.
263}
264
265//===----------------------------------------------------------------------===//
266// CodeGen: Decls
267//===----------------------------------------------------------------------===//
268
269void CodeGen::genImpl(const ast::UserConstraintDecl *decl) {
270 // All PDLL constraints get inlined when called, and the main native
271 // constraint declarations doesn't require any MLIR to be generated, only uses
272 // of it do.
273}
274
275void CodeGen::genImpl(const ast::UserRewriteDecl *decl) {
276 // All PDLL rewrites get inlined when called, and the main native
277 // rewrite declarations doesn't require any MLIR to be generated, only uses
278 // of it do.
279}
280
281void CodeGen::genImpl(const ast::PatternDecl *decl) {
282 const ast::Name *name = decl->getName();
283
284 // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it
285 // here.
286 pdl::PatternOp pattern = builder.create<pdl::PatternOp>(
287 genLoc(decl->getLoc()), decl->getBenefit(),
288 name ? std::optional<StringRef>(name->getName())
289 : std::optional<StringRef>());
290
291 OpBuilder::InsertionGuard savedInsertPoint(builder);
292 builder.setInsertionPointToStart(pattern.getBody());
293 gen(node: decl->getBody());
294}
295
296SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) {
297 auto it = variables.begin(Key: varDecl);
298 if (it != variables.end())
299 return *it;
300
301 // If the variable has an initial value, use that as the base value.
302 // Otherwise, generate a value using the constraint list.
303 SmallVector<Value> values;
304 if (const ast::Expr *initExpr = varDecl->getInitExpr())
305 values = genExpr(expr: initExpr);
306 else
307 values.push_back(Elt: genNonInitializerVar(varDecl, loc: genLoc(loc: varDecl->getLoc())));
308
309 // Apply the constraints of the values of the variable.
310 applyVarConstraints(varDecl, values);
311
312 variables.insert(Key: varDecl, Val: values);
313 return values;
314}
315
316Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
317 Location loc) {
318 // A functor used to generate expressions nested
319 auto getTypeConstraint = [&]() -> Value {
320 for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) {
321 Value typeValue =
322 TypeSwitch<const ast::Node *, Value>(constraint.constraint)
323 .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
324 ast::ValueRangeConstraintDecl>(
325 caseFn: [&, this](auto *cst) -> Value {
326 if (auto *typeConstraintExpr = cst->getTypeExpr())
327 return this->genSingleExpr(expr: typeConstraintExpr);
328 return Value();
329 })
330 .Default(defaultResult: Value());
331 if (typeValue)
332 return typeValue;
333 }
334 return Value();
335 };
336
337 // Generate a value based on the type of the variable.
338 ast::Type type = varDecl->getType();
339 Type mlirType = genType(type);
340 if (type.isa<ast::ValueType>())
341 return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
342 if (type.isa<ast::TypeType>())
343 return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr());
344 if (type.isa<ast::AttributeType>())
345 return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
346 if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) {
347 Value operands = builder.create<pdl::OperandsOp>(
348 loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
349 /*type=*/Value());
350 Value results = builder.create<pdl::TypesOp>(
351 loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()),
352 /*types=*/ArrayAttr());
353 return builder.create<pdl::OperationOp>(
354 loc, opType.getName(), operands, std::nullopt, ValueRange(), results);
355 }
356
357 if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) {
358 ast::Type eleTy = rangeTy.getElementType();
359 if (eleTy.isa<ast::ValueType>())
360 return builder.create<pdl::OperandsOp>(loc, mlirType,
361 getTypeConstraint());
362 if (eleTy.isa<ast::TypeType>())
363 return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr());
364 }
365
366 llvm_unreachable("invalid non-initialized variable type");
367}
368
369void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl,
370 ValueRange values) {
371 // Generate calls to any user constraints that were attached via the
372 // constraint list.
373 for (const ast::ConstraintRef &ref : varDecl->getConstraints())
374 if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(Val: ref.constraint))
375 genConstraintCall(decl: userCst, loc: genLoc(loc: ref.referenceLoc), inputs: values);
376}
377
378//===----------------------------------------------------------------------===//
379// CodeGen: Expressions
380//===----------------------------------------------------------------------===//
381
382Value CodeGen::genSingleExpr(const ast::Expr *expr) {
383 return TypeSwitch<const ast::Expr *, Value>(expr)
384 .Case<const ast::AttributeExpr, const ast::MemberAccessExpr,
385 const ast::OperationExpr, const ast::RangeExpr,
386 const ast::TypeExpr>(
387 caseFn: [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
388 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
389 caseFn: [&](auto derivedNode) {
390 SmallVector<Value> results = this->genExprImpl(derivedNode);
391 assert(results.size() == 1 && "expected single expression result");
392 return results[0];
393 });
394}
395
396SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) {
397 return TypeSwitch<const ast::Expr *, SmallVector<Value>>(expr)
398 .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
399 caseFn: [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
400 .Default(defaultFn: [&](const ast::Expr *expr) -> SmallVector<Value> {
401 return {genSingleExpr(expr)};
402 });
403}
404
405Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) {
406 Attribute attr = parseAttribute(attrStr: expr->getValue(), context: builder.getContext());
407 assert(attr && "invalid MLIR attribute data");
408 return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr);
409}
410
411SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
412 Location loc = genLoc(loc: expr->getLoc());
413 SmallVector<Value> arguments;
414 for (const ast::Expr *arg : expr->getArguments())
415 arguments.push_back(Elt: genSingleExpr(expr: arg));
416
417 // Resolve the callable expression of this call.
418 auto *callableExpr = dyn_cast<ast::DeclRefExpr>(Val: expr->getCallableExpr());
419 assert(callableExpr && "unhandled CallExpr callable");
420
421 // Generate the PDL based on the type of callable.
422 const ast::Decl *callable = callableExpr->getDecl();
423 if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(Val: callable))
424 return genConstraintCall(decl, loc, inputs: arguments, isNegated: expr->getIsNegated());
425 if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(Val: callable))
426 return genRewriteCall(decl, loc, inputs: arguments);
427 llvm_unreachable("unhandled CallExpr callable");
428}
429
430SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) {
431 if (const auto *varDecl = dyn_cast<ast::VariableDecl>(Val: expr->getDecl()))
432 return genVar(varDecl);
433 llvm_unreachable("unknown decl reference expression");
434}
435
436Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
437 Location loc = genLoc(loc: expr->getLoc());
438 StringRef name = expr->getMemberName();
439 SmallVector<Value> parentExprs = genExpr(expr: expr->getParentExpr());
440 ast::Type parentType = expr->getParentExpr()->getType();
441
442 // Handle operation based member access.
443 if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
444 if (isa<ast::AllResultsMemberAccessExpr>(Val: expr)) {
445 Type mlirType = genType(type: expr->getType());
446 if (isa<pdl::ValueType>(mlirType))
447 return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0],
448 builder.getI32IntegerAttr(0));
449 return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
450 }
451
452 const ods::Operation *odsOp = opType.getODSOperation();
453 if (!odsOp) {
454 assert(llvm::isDigit(name[0]) &&
455 "unregistered op only allows numeric indexing");
456 unsigned resultIndex;
457 name.getAsInteger(/*Radix=*/10, Result&: resultIndex);
458 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
459 return builder.create<pdl::ResultOp>(loc, genType(expr->getType()),
460 parentExprs[0], index);
461 }
462
463 // Find the result with the member name or by index.
464 ArrayRef<ods::OperandOrResult> results = odsOp->getResults();
465 unsigned resultIndex = results.size();
466 if (llvm::isDigit(C: name[0])) {
467 name.getAsInteger(/*Radix=*/10, Result&: resultIndex);
468 } else {
469 auto findFn = [&](const ods::OperandOrResult &result) {
470 return result.getName() == name;
471 };
472 resultIndex = llvm::find_if(Range&: results, P: findFn) - results.begin();
473 }
474 assert(resultIndex < results.size() && "invalid result index");
475
476 // Generate the result access.
477 IntegerAttr index = builder.getI32IntegerAttr(resultIndex);
478 return builder.create<pdl::ResultsOp>(loc, genType(expr->getType()),
479 parentExprs[0], index);
480 }
481
482 // Handle tuple based member access.
483 if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
484 auto elementNames = tupleType.getElementNames();
485
486 // The index is either a numeric index, or a name.
487 unsigned index = 0;
488 if (llvm::isDigit(C: name[0]))
489 name.getAsInteger(/*Radix=*/10, Result&: index);
490 else
491 index = llvm::find(Range&: elementNames, Val: name) - elementNames.begin();
492
493 assert(index < parentExprs.size() && "invalid result index");
494 return parentExprs[index];
495 }
496
497 llvm_unreachable("unhandled member access expression");
498}
499
500Value CodeGen::genExprImpl(const ast::OperationExpr *expr) {
501 Location loc = genLoc(loc: expr->getLoc());
502 std::optional<StringRef> opName = expr->getName();
503
504 // Operands.
505 SmallVector<Value> operands;
506 for (const ast::Expr *operand : expr->getOperands())
507 operands.push_back(Elt: genSingleExpr(expr: operand));
508
509 // Attributes.
510 SmallVector<StringRef> attrNames;
511 SmallVector<Value> attrValues;
512 for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) {
513 attrNames.push_back(Elt: attr->getName().getName());
514 attrValues.push_back(Elt: genSingleExpr(expr: attr->getValue()));
515 }
516
517 // Results.
518 SmallVector<Value> results;
519 for (const ast::Expr *result : expr->getResultTypes())
520 results.push_back(Elt: genSingleExpr(expr: result));
521
522 return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
523 attrValues, results);
524}
525
526Value CodeGen::genExprImpl(const ast::RangeExpr *expr) {
527 SmallVector<Value> elements;
528 for (const ast::Expr *element : expr->getElements())
529 llvm::append_range(C&: elements, R: genExpr(expr: element));
530
531 return builder.create<pdl::RangeOp>(genLoc(expr->getLoc()),
532 genType(expr->getType()), elements);
533}
534
535SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) {
536 SmallVector<Value> elements;
537 for (const ast::Expr *element : expr->getElements())
538 elements.push_back(Elt: genSingleExpr(expr: element));
539 return elements;
540}
541
542Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {
543 Type type = parseType(typeStr: expr->getValue(), context: builder.getContext());
544 assert(type && "invalid MLIR type data");
545 return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()),
546 builder.getType<pdl::TypeType>(),
547 TypeAttr::get(type));
548}
549
550SmallVector<Value>
551CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
552 ValueRange inputs, bool isNegated) {
553 // Apply any constraints defined on the arguments to the input values.
554 for (auto it : llvm::zip(t: decl->getInputs(), u&: inputs))
555 applyVarConstraints(varDecl: std::get<0>(t&: it), values: std::get<1>(t&: it));
556
557 // Generate the constraint call.
558 SmallVector<Value> results =
559 genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(
560 decl, loc, inputs, isNegated);
561
562 // Apply any constraints defined on the results of the constraint.
563 for (auto it : llvm::zip(t: decl->getResults(), u&: results))
564 applyVarConstraints(varDecl: std::get<0>(t&: it), values: std::get<1>(t&: it));
565 return results;
566}
567
568SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
569 Location loc, ValueRange inputs) {
570 return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
571 inputs);
572}
573
574template <typename PDLOpT, typename T>
575SmallVector<Value>
576CodeGen::genConstraintOrRewriteCall(const T *decl, Location loc,
577 ValueRange inputs, bool isNegated) {
578 const ast::CompoundStmt *cstBody = decl->getBody();
579
580 // If the decl doesn't have a statement body, it is a native decl.
581 if (!cstBody) {
582 ast::Type declResultType = decl->getResultType();
583 SmallVector<Type> resultTypes;
584 if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) {
585 for (ast::Type type : tupleType.getElementTypes())
586 resultTypes.push_back(Elt: genType(type));
587 } else {
588 resultTypes.push_back(Elt: genType(type: declResultType));
589 }
590 PDLOpT pdlOp = builder.create<PDLOpT>(
591 loc, resultTypes, decl->getName().getName(), inputs);
592 if (isNegated && std::is_same_v<PDLOpT, pdl::ApplyNativeConstraintOp>)
593 cast<pdl::ApplyNativeConstraintOp>(pdlOp).setIsNegated(true);
594 return pdlOp->getResults();
595 }
596
597 // Otherwise, this is a PDLL decl.
598 VariableMapTy::ScopeTy varScope(variables);
599
600 // Map the inputs of the call to the decl arguments.
601 // Note: This is only valid because we do not support recursion, meaning
602 // we don't need to worry about conflicting mappings here.
603 for (auto it : llvm::zip(inputs, decl->getInputs()))
604 variables.insert(Key: std::get<1>(it), Val: {std::get<0>(it)});
605
606 // Visit the body of the call as normal.
607 gen(node: cstBody);
608
609 // If the decl has no results, there is nothing to do.
610 if (cstBody->getChildren().empty())
611 return SmallVector<Value>();
612 auto *returnStmt = dyn_cast<ast::ReturnStmt>(Val: cstBody->getChildren().back());
613 if (!returnStmt)
614 return SmallVector<Value>();
615
616 // Otherwise, grab the results from the return statement.
617 return genExpr(expr: returnStmt->getResultExpr());
618}
619
620//===----------------------------------------------------------------------===//
621// MLIRGen
622//===----------------------------------------------------------------------===//
623
624OwningOpRef<ModuleOp> mlir::pdll::codegenPDLLToMLIR(
625 MLIRContext *mlirContext, const ast::Context &context,
626 const llvm::SourceMgr &sourceMgr, const ast::Module &module) {
627 CodeGen codegen(mlirContext, context, sourceMgr);
628 OwningOpRef<ModuleOp> mlirModule = codegen.generate(module);
629 if (failed(verify(*mlirModule)))
630 return nullptr;
631 return mlirModule;
632}
633

source code of mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp