1//===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
10#include "mlir/Dialect/EmitC/IR/EmitC.h"
11#include "mlir/IR/IRMapping.h"
12#include "mlir/IR/PatternMatch.h"
13#include "llvm/Support/Debug.h"
14
15namespace mlir {
16namespace emitc {
17
18ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
19 assert(op->hasTrait<OpTrait::emitc::CExpression>() &&
20 "Expected a C expression");
21
22 // Create an expression yielding the value returned by op.
23 assert(op->getNumResults() == 1 && "Expected exactly one result");
24 Value result = op->getResult(idx: 0);
25 Type resultType = result.getType();
26 Location loc = op->getLoc();
27
28 builder.setInsertionPointAfter(op);
29 auto expressionOp = builder.create<emitc::ExpressionOp>(loc, resultType);
30
31 // Replace all op's uses with the new expression's result.
32 result.replaceAllUsesWith(newValue: expressionOp.getResult());
33
34 // Create an op to yield op's value.
35 Region &region = expressionOp.getRegion();
36 Block &block = region.emplaceBlock();
37 builder.setInsertionPointToEnd(&block);
38 auto yieldOp = builder.create<emitc::YieldOp>(loc, result);
39
40 // Move op into the new expression.
41 op->moveBefore(yieldOp);
42
43 return expressionOp;
44}
45
46} // namespace emitc
47} // namespace mlir
48
49using namespace mlir;
50using namespace mlir::emitc;
51
52namespace {
53
54struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
55 using OpRewritePattern<ExpressionOp>::OpRewritePattern;
56 LogicalResult matchAndRewrite(ExpressionOp expressionOp,
57 PatternRewriter &rewriter) const override {
58 bool anythingFolded = false;
59 for (Operation &op : llvm::make_early_inc_range(
60 expressionOp.getBody()->without_terminator())) {
61 // Don't fold expressions whose result value has its address taken.
62 auto applyOp = dyn_cast<emitc::ApplyOp>(op);
63 if (applyOp && applyOp.getApplicableOperator() == "&")
64 continue;
65
66 for (Value operand : op.getOperands()) {
67 auto usedExpression =
68 dyn_cast_if_present<ExpressionOp>(operand.getDefiningOp());
69
70 if (!usedExpression)
71 continue;
72
73 // Don't fold expressions with multiple users: assume any
74 // re-materialization was done separately.
75 if (!usedExpression.getResult().hasOneUse())
76 continue;
77
78 // Don't fold expressions with side effects.
79 if (usedExpression.hasSideEffects())
80 continue;
81
82 // Fold the used expression into this expression by cloning all
83 // instructions in the used expression just before the operation using
84 // its value.
85 rewriter.setInsertionPoint(&op);
86 IRMapping mapper;
87 for (Operation &opToClone :
88 usedExpression.getBody()->without_terminator()) {
89 Operation *clone = rewriter.clone(opToClone, mapper);
90 mapper.map(&opToClone, clone);
91 }
92
93 Operation *expressionRoot = usedExpression.getRootOp();
94 Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
95 assert(clonedExpressionRootOp &&
96 "Expected cloned expression root to be in mapper");
97 assert(clonedExpressionRootOp->getNumResults() == 1 &&
98 "Expected cloned root to have a single result");
99
100 rewriter.replaceOp(usedExpression, clonedExpressionRootOp);
101 anythingFolded = true;
102 }
103 }
104 return anythingFolded ? success() : failure();
105 }
106};
107
108} // namespace
109
110void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) {
111 patterns.add<FoldExpressionOp>(arg: patterns.getContext());
112}
113

source code of mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp