1//===- SimplifyHLFIRIntrinsics.cpp - Simplify HLFIR Intrinsics ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// Normally transformational intrinsics are lowered to calls to runtime
9// functions. However, some cases of the intrinsics are faster when inlined
10// into the calling function.
11//===----------------------------------------------------------------------===//
12
13#include "flang/Optimizer/Builder/FIRBuilder.h"
14#include "flang/Optimizer/Builder/HLFIRTools.h"
15#include "flang/Optimizer/Dialect/FIRDialect.h"
16#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
17#include "flang/Optimizer/HLFIR/HLFIROps.h"
18#include "flang/Optimizer/HLFIR/Passes.h"
19#include "mlir/Dialect/Arith/IR/Arith.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/IR/BuiltinDialect.h"
22#include "mlir/IR/Location.h"
23#include "mlir/Pass/Pass.h"
24#include "mlir/Transforms/DialectConversion.h"
25
26namespace hlfir {
27#define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
28#include "flang/Optimizer/HLFIR/Passes.h.inc"
29} // namespace hlfir
30
31namespace {
32
33class TransposeAsElementalConversion
34 : public mlir::OpRewritePattern<hlfir::TransposeOp> {
35public:
36 using mlir::OpRewritePattern<hlfir::TransposeOp>::OpRewritePattern;
37
38 mlir::LogicalResult
39 matchAndRewrite(hlfir::TransposeOp transpose,
40 mlir::PatternRewriter &rewriter) const override {
41 mlir::Location loc = transpose.getLoc();
42 fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
43 hlfir::ExprType expr = transpose.getType();
44 mlir::Type elementType = expr.getElementType();
45 hlfir::Entity array = hlfir::Entity{transpose.getArray()};
46 mlir::Value resultShape = genResultShape(loc, builder, array);
47 llvm::SmallVector<mlir::Value, 1> typeParams;
48 hlfir::genLengthParameters(loc, builder, array, typeParams);
49
50 auto genKernel = [&array](mlir::Location loc, fir::FirOpBuilder &builder,
51 mlir::ValueRange inputIndices) -> hlfir::Entity {
52 assert(inputIndices.size() == 2 && "checked in TransposeOp::validate");
53 const std::initializer_list<mlir::Value> initList = {inputIndices[1],
54 inputIndices[0]};
55 mlir::ValueRange transposedIndices(initList);
56 hlfir::Entity element =
57 hlfir::getElementAt(loc, builder, array, transposedIndices);
58 hlfir::Entity val = hlfir::loadTrivialScalar(loc, builder, element);
59 return val;
60 };
61 hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
62 loc, builder, elementType, resultShape, typeParams, genKernel,
63 /*isUnordered=*/true, /*polymorphicMold=*/nullptr,
64 transpose.getResult().getType());
65
66 // it wouldn't be safe to replace block arguments with a different
67 // hlfir.expr type. Types can differ due to differing amounts of shape
68 // information
69 assert(elementalOp.getResult().getType() ==
70 transpose.getResult().getType());
71
72 rewriter.replaceOp(transpose, elementalOp);
73 return mlir::success();
74 }
75
76private:
77 static mlir::Value genResultShape(mlir::Location loc,
78 fir::FirOpBuilder &builder,
79 hlfir::Entity array) {
80 mlir::Value inShape = hlfir::genShape(loc, builder, array);
81 llvm::SmallVector<mlir::Value> inExtents =
82 hlfir::getExplicitExtentsFromShape(inShape, builder);
83 if (inShape.getUses().empty())
84 inShape.getDefiningOp()->erase();
85
86 // transpose indices
87 assert(inExtents.size() == 2 && "checked in TransposeOp::validate");
88 return builder.create<fir::ShapeOp>(
89 loc, mlir::ValueRange{inExtents[1], inExtents[0]});
90 }
91};
92
93class SimplifyHLFIRIntrinsics
94 : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
95public:
96 void runOnOperation() override {
97 mlir::func::FuncOp func = this->getOperation();
98 mlir::MLIRContext *context = &getContext();
99 mlir::RewritePatternSet patterns(context);
100 patterns.insert<TransposeAsElementalConversion>(context);
101 mlir::ConversionTarget target(*context);
102 // don't transform transpose of polymorphic arrays (not currently supported
103 // by hlfir.elemental)
104 target.addDynamicallyLegalOp<hlfir::TransposeOp>(
105 [](hlfir::TransposeOp transpose) {
106 return transpose.getType().cast<hlfir::ExprType>().isPolymorphic();
107 });
108 target.markUnknownOpDynamicallyLegal(
109 [](mlir::Operation *) { return true; });
110 if (mlir::failed(
111 mlir::applyFullConversion(func, target, std::move(patterns)))) {
112 mlir::emitError(func->getLoc(),
113 "failure in HLFIR intrinsic simplification");
114 signalPassFailure();
115 }
116 }
117};
118} // namespace
119
120std::unique_ptr<mlir::Pass> hlfir::createSimplifyHLFIRIntrinsicsPass() {
121 return std::make_unique<SimplifyHLFIRIntrinsics>();
122}
123

source code of flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp