1 | //===- InlineHLFIRAssign.cpp - Inline hlfir.assign ops --------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // Transform hlfir.assign array operations into loop nests performing element |
9 | // per element assignments. The inlining is done for trivial data types always, |
10 | // though, we may add performance/code-size heuristics in future. |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/Analysis/AliasAnalysis.h" |
14 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
15 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
16 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
17 | #include "flang/Optimizer/HLFIR/Passes.h" |
18 | #include "flang/Optimizer/OpenMP/Passes.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | #include "mlir/Support/LLVM.h" |
22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
23 | |
24 | namespace hlfir { |
25 | #define GEN_PASS_DEF_INLINEHLFIRASSIGN |
26 | #include "flang/Optimizer/HLFIR/Passes.h.inc" |
27 | } // namespace hlfir |
28 | |
29 | #define DEBUG_TYPE "inline-hlfir-assign" |
30 | |
31 | namespace { |
32 | /// Expand hlfir.assign of array RHS to array LHS into a loop nest |
33 | /// of element-by-element assignments: |
34 | /// hlfir.assign %4 to %5 : !fir.ref<!fir.array<3x3xf32>>, |
35 | /// !fir.ref<!fir.array<3x3xf32>> |
36 | /// into: |
37 | /// fir.do_loop %arg1 = %c1 to %c3 step %c1 unordered { |
38 | /// fir.do_loop %arg2 = %c1 to %c3 step %c1 unordered { |
39 | /// %6 = hlfir.designate %4 (%arg2, %arg1) : |
40 | /// (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32> |
41 | /// %7 = fir.load %6 : !fir.ref<f32> |
42 | /// %8 = hlfir.designate %5 (%arg2, %arg1) : |
43 | /// (!fir.ref<!fir.array<3x3xf32>>, index, index) -> !fir.ref<f32> |
44 | /// hlfir.assign %7 to %8 : f32, !fir.ref<f32> |
45 | /// } |
46 | /// } |
47 | /// |
48 | /// The transformation is correct only when LHS and RHS do not alias. |
49 | /// When RHS is an array expression, then there is no aliasing. |
50 | /// This transformation does not support runtime checking for |
51 | /// non-conforming LHS/RHS arrays' shapes currently. |
52 | class InlineHLFIRAssignConversion |
53 | : public mlir::OpRewritePattern<hlfir::AssignOp> { |
54 | public: |
55 | using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern; |
56 | |
57 | llvm::LogicalResult |
58 | matchAndRewrite(hlfir::AssignOp assign, |
59 | mlir::PatternRewriter &rewriter) const override { |
60 | if (assign.isAllocatableAssignment()) |
61 | return rewriter.notifyMatchFailure(assign, |
62 | "AssignOp may imply allocation" ); |
63 | |
64 | hlfir::Entity rhs{assign.getRhs()}; |
65 | |
66 | if (!rhs.isArray()) |
67 | return rewriter.notifyMatchFailure(assign, |
68 | "AssignOp's RHS is not an array" ); |
69 | |
70 | mlir::Type rhsEleTy = rhs.getFortranElementType(); |
71 | if (!fir::isa_trivial(rhsEleTy)) |
72 | return rewriter.notifyMatchFailure( |
73 | assign, "AssignOp's RHS data type is not trivial" ); |
74 | |
75 | hlfir::Entity lhs{assign.getLhs()}; |
76 | if (!lhs.isArray()) |
77 | return rewriter.notifyMatchFailure(assign, |
78 | "AssignOp's LHS is not an array" ); |
79 | |
80 | mlir::Type lhsEleTy = lhs.getFortranElementType(); |
81 | if (!fir::isa_trivial(lhsEleTy)) |
82 | return rewriter.notifyMatchFailure( |
83 | assign, "AssignOp's LHS data type is not trivial" ); |
84 | |
85 | if (lhsEleTy != rhsEleTy) |
86 | return rewriter.notifyMatchFailure(assign, |
87 | "RHS/LHS element types mismatch" ); |
88 | |
89 | if (!mlir::isa<hlfir::ExprType>(rhs.getType())) { |
90 | // If RHS is not an hlfir.expr, then we should prove that |
91 | // LHS and RHS do not alias. |
92 | // TODO: if they may alias, we can insert hlfir.as_expr for RHS, |
93 | // and proceed with the inlining. |
94 | fir::AliasAnalysis aliasAnalysis; |
95 | mlir::AliasResult aliasRes = aliasAnalysis.alias(lhs, rhs); |
96 | // TODO: use areIdenticalOrDisjointSlices() from |
97 | // OptimizedBufferization.cpp to check if we can still do the expansion. |
98 | if (!aliasRes.isNo()) { |
99 | LLVM_DEBUG(llvm::dbgs() << "InlineHLFIRAssign:\n" |
100 | << "\tLHS: " << lhs << "\n" |
101 | << "\tRHS: " << rhs << "\n" |
102 | << "\tALIAS: " << aliasRes << "\n" ); |
103 | return rewriter.notifyMatchFailure(assign, "RHS/LHS may alias" ); |
104 | } |
105 | } |
106 | |
107 | mlir::Location loc = assign->getLoc(); |
108 | fir::FirOpBuilder builder(rewriter, assign.getOperation()); |
109 | builder.setInsertionPoint(assign); |
110 | rhs = hlfir::derefPointersAndAllocatables(loc, builder, rhs); |
111 | lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs); |
112 | mlir::Value lhsShape = hlfir::genShape(loc, builder, lhs); |
113 | llvm::SmallVector<mlir::Value> lhsExtents = |
114 | hlfir::getIndexExtents(loc, builder, lhsShape); |
115 | mlir::Value rhsShape = hlfir::genShape(loc, builder, rhs); |
116 | llvm::SmallVector<mlir::Value> rhsExtents = |
117 | hlfir::getIndexExtents(loc, builder, rhsShape); |
118 | llvm::SmallVector<mlir::Value> extents = |
119 | fir::factory::deduceOptimalExtents(lhsExtents, rhsExtents); |
120 | hlfir::LoopNest loopNest = |
121 | hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true, |
122 | flangomp::shouldUseWorkshareLowering(assign)); |
123 | builder.setInsertionPointToStart(loopNest.body); |
124 | auto rhsArrayElement = |
125 | hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices); |
126 | rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement); |
127 | auto lhsArrayElement = |
128 | hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); |
129 | builder.create<hlfir::AssignOp>(loc, rhsArrayElement, lhsArrayElement); |
130 | rewriter.eraseOp(assign); |
131 | return mlir::success(); |
132 | } |
133 | }; |
134 | |
135 | class InlineHLFIRAssignPass |
136 | : public hlfir::impl::InlineHLFIRAssignBase<InlineHLFIRAssignPass> { |
137 | public: |
138 | void runOnOperation() override { |
139 | mlir::MLIRContext *context = &getContext(); |
140 | |
141 | mlir::GreedyRewriteConfig config; |
142 | // Prevent the pattern driver from merging blocks. |
143 | config.setRegionSimplificationLevel( |
144 | mlir::GreedySimplifyRegionLevel::Disabled); |
145 | |
146 | mlir::RewritePatternSet patterns(context); |
147 | patterns.insert<InlineHLFIRAssignConversion>(context); |
148 | |
149 | if (mlir::failed(mlir::applyPatternsGreedily( |
150 | getOperation(), std::move(patterns), config))) { |
151 | mlir::emitError(getOperation()->getLoc(), |
152 | "failure in hlfir.assign inlining" ); |
153 | signalPassFailure(); |
154 | } |
155 | } |
156 | }; |
157 | } // namespace |
158 | |