1//===-- AffineDemotion.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation is a prototype that demote affine dialects operations
10// after optimizations to FIR loops operations.
11// It is used after the AffinePromotion pass.
12// It is not part of the production pipeline and would need more work in order
13// to be used in production.
14// More information can be found in this presentation:
15// https://slides.com/rajanwalia/deck
16//
17//===----------------------------------------------------------------------===//
18
19#include "flang/Optimizer/Dialect/FIRDialect.h"
20#include "flang/Optimizer/Dialect/FIROps.h"
21#include "flang/Optimizer/Dialect/FIRType.h"
22#include "flang/Optimizer/Transforms/Passes.h"
23#include "mlir/Dialect/Affine/IR/AffineOps.h"
24#include "mlir/Dialect/Affine/Utils.h"
25#include "mlir/Dialect/Func/IR/FuncOps.h"
26#include "mlir/Dialect/MemRef/IR/MemRef.h"
27#include "mlir/Dialect/SCF/IR/SCF.h"
28#include "mlir/IR/BuiltinAttributes.h"
29#include "mlir/IR/IntegerSet.h"
30#include "mlir/IR/Visitors.h"
31#include "mlir/Pass/Pass.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "llvm/ADT/DenseMap.h"
34#include "llvm/Support/CommandLine.h"
35#include "llvm/Support/Debug.h"
36
37namespace fir {
38#define GEN_PASS_DEF_AFFINEDIALECTDEMOTION
39#include "flang/Optimizer/Transforms/Passes.h.inc"
40} // namespace fir
41
42#define DEBUG_TYPE "flang-affine-demotion"
43
44using namespace fir;
45using namespace mlir;
46
47namespace {
48
49class AffineLoadConversion
50 : public OpConversionPattern<mlir::affine::AffineLoadOp> {
51public:
52 using OpConversionPattern<mlir::affine::AffineLoadOp>::OpConversionPattern;
53
54 LogicalResult
55 matchAndRewrite(mlir::affine::AffineLoadOp op, OpAdaptor adaptor,
56 ConversionPatternRewriter &rewriter) const override {
57 SmallVector<Value> indices(adaptor.getIndices());
58 auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(),
59 op.getAffineMap(), indices);
60 if (!maybeExpandedMap)
61 return failure();
62
63 auto coorOp = rewriter.create<fir::CoordinateOp>(
64 op.getLoc(), fir::ReferenceType::get(op.getResult().getType()),
65 adaptor.getMemref(), *maybeExpandedMap);
66
67 rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult());
68 return success();
69 }
70};
71
72class AffineStoreConversion
73 : public OpConversionPattern<mlir::affine::AffineStoreOp> {
74public:
75 using OpConversionPattern<mlir::affine::AffineStoreOp>::OpConversionPattern;
76
77 LogicalResult
78 matchAndRewrite(mlir::affine::AffineStoreOp op, OpAdaptor adaptor,
79 ConversionPatternRewriter &rewriter) const override {
80 SmallVector<Value> indices(op.getIndices());
81 auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(),
82 op.getAffineMap(), indices);
83 if (!maybeExpandedMap)
84 return failure();
85
86 auto coorOp = rewriter.create<fir::CoordinateOp>(
87 op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()),
88 adaptor.getMemref(), *maybeExpandedMap);
89 rewriter.replaceOpWithNewOp<fir::StoreOp>(op, adaptor.getValue(),
90 coorOp.getResult());
91 return success();
92 }
93};
94
95class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> {
96public:
97 using OpRewritePattern::OpRewritePattern;
98 mlir::LogicalResult
99 matchAndRewrite(fir::ConvertOp op,
100 mlir::PatternRewriter &rewriter) const override {
101 if (op.getRes().getType().isa<mlir::MemRefType>()) {
102 // due to index calculation moving to affine maps we still need to
103 // add converts for sequence types this has a side effect of losing
104 // some information about arrays with known dimensions by creating:
105 // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) ->
106 // !fir.ref<!fir.array<?xi32>>
107 if (auto refTy = op.getValue().getType().dyn_cast<fir::ReferenceType>())
108 if (auto arrTy = refTy.getEleTy().dyn_cast<fir::SequenceType>()) {
109 fir::SequenceType::Shape flatShape = {
110 fir::SequenceType::getUnknownExtent()};
111 auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy());
112 auto flatTy = fir::ReferenceType::get(flatArrTy);
113 rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy,
114 op.getValue());
115 return success();
116 }
117 rewriter.startOpModification(op->getParentOp());
118 op.getResult().replaceAllUsesWith(op.getValue());
119 rewriter.finalizeOpModification(op->getParentOp());
120 rewriter.eraseOp(op);
121 }
122 return success();
123 }
124};
125
126mlir::Type convertMemRef(mlir::MemRefType type) {
127 return fir::SequenceType::get(
128 SmallVector<int64_t>(type.getShape().begin(), type.getShape().end()),
129 type.getElementType());
130}
131
132class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> {
133public:
134 using OpRewritePattern::OpRewritePattern;
135 mlir::LogicalResult
136 matchAndRewrite(memref::AllocOp op,
137 mlir::PatternRewriter &rewriter) const override {
138 rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()),
139 op.getMemref());
140 return success();
141 }
142};
143
144class AffineDialectDemotion
145 : public fir::impl::AffineDialectDemotionBase<AffineDialectDemotion> {
146public:
147 void runOnOperation() override {
148 auto *context = &getContext();
149 auto function = getOperation();
150 LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n";
151 function.print(llvm::dbgs()););
152
153 mlir::RewritePatternSet patterns(context);
154 patterns.insert<ConvertConversion>(context);
155 patterns.insert<AffineLoadConversion>(context);
156 patterns.insert<AffineStoreConversion>(context);
157 patterns.insert<StdAllocConversion>(context);
158 mlir::ConversionTarget target(*context);
159 target.addIllegalOp<memref::AllocOp>();
160 target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) {
161 if (op.getRes().getType().isa<mlir::MemRefType>())
162 return false;
163 return true;
164 });
165 target
166 .addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
167 mlir::arith::ArithDialect, mlir::func::FuncDialect>();
168
169 if (mlir::failed(mlir::applyPartialConversion(function, target,
170 std::move(patterns)))) {
171 mlir::emitError(mlir::UnknownLoc::get(context),
172 "error in converting affine dialect\n");
173 signalPassFailure();
174 }
175 }
176};
177
178} // namespace
179
180std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() {
181 return std::make_unique<AffineDialectDemotion>();
182}
183

source code of flang/lib/Optimizer/Transforms/AffineDemotion.cpp