1 | //===- SimplifyHLFIRIntrinsics.cpp - Simplify HLFIR Intrinsics ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // Normally transformational intrinsics are lowered to calls to runtime |
9 | // functions. However, some cases of the intrinsics are faster when inlined |
10 | // into the calling function. |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
14 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
15 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
16 | #include "flang/Optimizer/HLFIR/HLFIRDialect.h" |
17 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
18 | #include "flang/Optimizer/HLFIR/Passes.h" |
19 | #include "mlir/Dialect/Arith/IR/Arith.h" |
20 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
21 | #include "mlir/IR/BuiltinDialect.h" |
22 | #include "mlir/IR/Location.h" |
23 | #include "mlir/Pass/Pass.h" |
24 | #include "mlir/Transforms/DialectConversion.h" |
25 | |
26 | namespace hlfir { |
27 | #define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS |
28 | #include "flang/Optimizer/HLFIR/Passes.h.inc" |
29 | } // namespace hlfir |
30 | |
31 | namespace { |
32 | |
33 | class TransposeAsElementalConversion |
34 | : public mlir::OpRewritePattern<hlfir::TransposeOp> { |
35 | public: |
36 | using mlir::OpRewritePattern<hlfir::TransposeOp>::OpRewritePattern; |
37 | |
38 | mlir::LogicalResult |
39 | matchAndRewrite(hlfir::TransposeOp transpose, |
40 | mlir::PatternRewriter &rewriter) const override { |
41 | mlir::Location loc = transpose.getLoc(); |
42 | fir::FirOpBuilder builder{rewriter, transpose.getOperation()}; |
43 | hlfir::ExprType expr = transpose.getType(); |
44 | mlir::Type elementType = expr.getElementType(); |
45 | hlfir::Entity array = hlfir::Entity{transpose.getArray()}; |
46 | mlir::Value resultShape = genResultShape(loc, builder, array); |
47 | llvm::SmallVector<mlir::Value, 1> typeParams; |
48 | hlfir::genLengthParameters(loc, builder, array, typeParams); |
49 | |
50 | auto genKernel = [&array](mlir::Location loc, fir::FirOpBuilder &builder, |
51 | mlir::ValueRange inputIndices) -> hlfir::Entity { |
52 | assert(inputIndices.size() == 2 && "checked in TransposeOp::validate" ); |
53 | const std::initializer_list<mlir::Value> initList = {inputIndices[1], |
54 | inputIndices[0]}; |
55 | mlir::ValueRange transposedIndices(initList); |
56 | hlfir::Entity element = |
57 | hlfir::getElementAt(loc, builder, array, transposedIndices); |
58 | hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, element); |
59 | return val; |
60 | }; |
61 | hlfir::ElementalOp elementalOp = hlfir::genElementalOp( |
62 | loc, builder, elementType, resultShape, typeParams, genKernel, |
63 | /*isUnordered=*/true, /*polymorphicMold=*/nullptr, |
64 | transpose.getResult().getType()); |
65 | |
66 | // it wouldn't be safe to replace block arguments with a different |
67 | // hlfir.expr type. Types can differ due to differing amounts of shape |
68 | // information |
69 | assert(elementalOp.getResult().getType() == |
70 | transpose.getResult().getType()); |
71 | |
72 | rewriter.replaceOp(transpose, elementalOp); |
73 | return mlir::success(); |
74 | } |
75 | |
76 | private: |
77 | static mlir::Value genResultShape(mlir::Location loc, |
78 | fir::FirOpBuilder &builder, |
79 | hlfir::Entity array) { |
80 | mlir::Value inShape = hlfir::genShape(loc, builder, array); |
81 | llvm::SmallVector<mlir::Value> inExtents = |
82 | hlfir::getExplicitExtentsFromShape(inShape, builder); |
83 | if (inShape.getUses().empty()) |
84 | inShape.getDefiningOp()->erase(); |
85 | |
86 | // transpose indices |
87 | assert(inExtents.size() == 2 && "checked in TransposeOp::validate" ); |
88 | return builder.create<fir::ShapeOp>( |
89 | loc, mlir::ValueRange{inExtents[1], inExtents[0]}); |
90 | } |
91 | }; |
92 | |
93 | class SimplifyHLFIRIntrinsics |
94 | : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> { |
95 | public: |
96 | void runOnOperation() override { |
97 | mlir::func::FuncOp func = this->getOperation(); |
98 | mlir::MLIRContext *context = &getContext(); |
99 | mlir::RewritePatternSet patterns(context); |
100 | patterns.insert<TransposeAsElementalConversion>(context); |
101 | mlir::ConversionTarget target(*context); |
102 | // don't transform transpose of polymorphic arrays (not currently supported |
103 | // by hlfir.elemental) |
104 | target.addDynamicallyLegalOp<hlfir::TransposeOp>( |
105 | [](hlfir::TransposeOp transpose) { |
106 | return transpose.getType().cast<hlfir::ExprType>().isPolymorphic(); |
107 | }); |
108 | target.markUnknownOpDynamicallyLegal( |
109 | [](mlir::Operation *) { return true; }); |
110 | if (mlir::failed( |
111 | mlir::applyFullConversion(func, target, std::move(patterns)))) { |
112 | mlir::emitError(func->getLoc(), |
113 | "failure in HLFIR intrinsic simplification" ); |
114 | signalPassFailure(); |
115 | } |
116 | } |
117 | }; |
118 | } // namespace |
119 | |
120 | std::unique_ptr<mlir::Pass> hlfir::createSimplifyHLFIRIntrinsicsPass() { |
121 | return std::make_unique<SimplifyHLFIRIntrinsics>(); |
122 | } |
123 | |