1//===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
10#include "mlir/Dialect/EmitC/IR/EmitC.h"
11#include "mlir/IR/IRMapping.h"
12#include "mlir/IR/PatternMatch.h"
13
14namespace mlir {
15namespace emitc {
16
17ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
18 assert(isa<emitc::CExpressionInterface>(op) && "Expected a C expression");
19
20 // Create an expression yielding the value returned by op.
21 assert(op->getNumResults() == 1 && "Expected exactly one result");
22 Value result = op->getResult(idx: 0);
23 Type resultType = result.getType();
24 Location loc = op->getLoc();
25
26 builder.setInsertionPointAfter(op);
27 auto expressionOp = builder.create<emitc::ExpressionOp>(location: loc, args&: resultType);
28
29 // Replace all op's uses with the new expression's result.
30 result.replaceAllUsesWith(newValue: expressionOp.getResult());
31
32 // Create an op to yield op's value.
33 Region &region = expressionOp.getRegion();
34 Block &block = region.emplaceBlock();
35 builder.setInsertionPointToEnd(&block);
36 auto yieldOp = builder.create<emitc::YieldOp>(location: loc, args&: result);
37
38 // Move op into the new expression.
39 op->moveBefore(existingOp: yieldOp);
40
41 return expressionOp;
42}
43
44} // namespace emitc
45} // namespace mlir
46
47using namespace mlir;
48using namespace mlir::emitc;
49
50namespace {
51
52struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
53 using OpRewritePattern<ExpressionOp>::OpRewritePattern;
54 LogicalResult matchAndRewrite(ExpressionOp expressionOp,
55 PatternRewriter &rewriter) const override {
56 bool anythingFolded = false;
57 for (Operation &op : llvm::make_early_inc_range(
58 Range: expressionOp.getBody()->without_terminator())) {
59 // Don't fold expressions whose result value has its address taken.
60 auto applyOp = dyn_cast<emitc::ApplyOp>(Val&: op);
61 if (applyOp && applyOp.getApplicableOperator() == "&")
62 continue;
63
64 for (Value operand : op.getOperands()) {
65 auto usedExpression =
66 dyn_cast_if_present<ExpressionOp>(Val: operand.getDefiningOp());
67
68 if (!usedExpression)
69 continue;
70
71 // Don't fold expressions with multiple users: assume any
72 // re-materialization was done separately.
73 if (!usedExpression.getResult().hasOneUse())
74 continue;
75
76 // Don't fold expressions with side effects.
77 if (usedExpression.hasSideEffects())
78 continue;
79
80 // Fold the used expression into this expression by cloning all
81 // instructions in the used expression just before the operation using
82 // its value.
83 rewriter.setInsertionPoint(&op);
84 IRMapping mapper;
85 for (Operation &opToClone :
86 usedExpression.getBody()->without_terminator()) {
87 Operation *clone = rewriter.clone(op&: opToClone, mapper);
88 mapper.map(from: &opToClone, to: clone);
89 }
90
91 Operation *expressionRoot = usedExpression.getRootOp();
92 Operation *clonedExpressionRootOp = mapper.lookup(from: expressionRoot);
93 assert(clonedExpressionRootOp &&
94 "Expected cloned expression root to be in mapper");
95 assert(clonedExpressionRootOp->getNumResults() == 1 &&
96 "Expected cloned root to have a single result");
97
98 rewriter.replaceOp(op: usedExpression, newOp: clonedExpressionRootOp);
99 anythingFolded = true;
100 }
101 }
102 return anythingFolded ? success() : failure();
103 }
104};
105
106} // namespace
107
108void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) {
109 patterns.add<FoldExpressionOp>(arg: patterns.getContext());
110}
111

source code of mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp