1//===- InlineElementals.cpp - Inline chained hlfir.elemental ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// Chained elemental operations like a + b + c can inline the first elemental
9// at the hlfir.apply in the body of the second one (as described in
10// docs/HighLevelFIR.md). This has to be done in a pass rather than in lowering
11// so that it happens after the HLFIR intrinsic simplification pass.
12//===----------------------------------------------------------------------===//
13
14#include "flang/Optimizer/Builder/FIRBuilder.h"
15#include "flang/Optimizer/Builder/HLFIRTools.h"
16#include "flang/Optimizer/Dialect/Support/FIRContext.h"
17#include "flang/Optimizer/HLFIR/HLFIROps.h"
18#include "flang/Optimizer/HLFIR/Passes.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Support/LLVM.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include <iterator>
28
29namespace hlfir {
30#define GEN_PASS_DEF_INLINEELEMENTALS
31#include "flang/Optimizer/HLFIR/Passes.h.inc"
32} // namespace hlfir
33
34/// If the elemental has only two uses and those two are an apply operation and
35/// a destory operation, return those two, otherwise return {}
36static std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>>
37getTwoUses(hlfir::ElementalOp elemental) {
38 mlir::Operation::user_range users = elemental->getUsers();
39 // don't inline anything with more than one use (plus hfir.destroy)
40 if (std::distance(users.begin(), users.end()) != 2) {
41 return std::nullopt;
42 }
43
44 // If the ElementalOp must produce a temporary (e.g. for
45 // finalization purposes), then we cannot inline it.
46 if (hlfir::elementalOpMustProduceTemp(elemental))
47 return std::nullopt;
48
49 hlfir::ApplyOp apply;
50 hlfir::DestroyOp destroy;
51 for (mlir::Operation *user : users)
52 mlir::TypeSwitch<mlir::Operation *, void>(user)
53 .Case([&](hlfir::ApplyOp op) { apply = op; })
54 .Case([&](hlfir::DestroyOp op) { destroy = op; });
55
56 if (!apply || !destroy)
57 return std::nullopt;
58
59 // we can't inline if the return type of the yield doesn't match the return
60 // type of the apply
61 auto yield = mlir::dyn_cast_or_null<hlfir::YieldElementOp>(
62 elemental.getRegion().back().back());
63 assert(yield && "hlfir.elemental should always end with a yield");
64 if (apply.getResult().getType() != yield.getElementValue().getType())
65 return std::nullopt;
66
67 return std::pair{apply, destroy};
68}
69
70namespace {
71class InlineElementalConversion
72 : public mlir::OpRewritePattern<hlfir::ElementalOp> {
73public:
74 using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern;
75
76 mlir::LogicalResult
77 matchAndRewrite(hlfir::ElementalOp elemental,
78 mlir::PatternRewriter &rewriter) const override {
79 std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>> maybeTuple =
80 getTwoUses(elemental);
81 if (!maybeTuple)
82 return rewriter.notifyMatchFailure(
83 elemental, "hlfir.elemental does not have two uses");
84
85 if (elemental.isOrdered()) {
86 // We can only inline the ordered elemental into a loop-like
87 // construct that processes the indices in-order and does not
88 // have the side effects itself. Adhere to conservative behavior
89 // for the time being.
90 return rewriter.notifyMatchFailure(elemental,
91 "hlfir.elemental is ordered");
92 }
93 auto [apply, destroy] = *maybeTuple;
94
95 assert(elemental.getRegion().hasOneBlock() &&
96 "expect elemental region to have one block");
97
98 fir::FirOpBuilder builder{rewriter, elemental.getOperation()};
99 builder.setInsertionPointAfter(apply);
100 hlfir::YieldElementOp yield = hlfir::inlineElementalOp(
101 elemental.getLoc(), builder, elemental, apply.getIndices());
102
103 // remove the old elemental and all of the bookkeeping
104 rewriter.replaceAllUsesWith(apply.getResult(), yield.getElementValue());
105 rewriter.eraseOp(yield);
106 rewriter.eraseOp(apply);
107 rewriter.eraseOp(destroy);
108 rewriter.eraseOp(elemental);
109
110 return mlir::success();
111 }
112};
113
114class InlineElementalsPass
115 : public hlfir::impl::InlineElementalsBase<InlineElementalsPass> {
116public:
117 void runOnOperation() override {
118 mlir::func::FuncOp func = getOperation();
119 mlir::MLIRContext *context = &getContext();
120
121 mlir::GreedyRewriteConfig config;
122 // Prevent the pattern driver from merging blocks.
123 config.enableRegionSimplification = false;
124
125 mlir::RewritePatternSet patterns(context);
126 patterns.insert<InlineElementalConversion>(context);
127
128 if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
129 func, std::move(patterns), config))) {
130 mlir::emitError(func->getLoc(), "failure in HLFIR elemental inlining");
131 signalPassFailure();
132 }
133 }
134};
135} // namespace
136
137std::unique_ptr<mlir::Pass> hlfir::createInlineElementalsPass() {
138 return std::make_unique<InlineElementalsPass>();
139}
140

source code of flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp