1//===- RaiseMemrefDialect.cpp - raise memref.store and load to affine ops -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements functionality to convert memref load and store ops to
10// the corresponding affine ops, inferring the affine map as needed.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/Analysis/Utils.h"
15#include "mlir/Dialect/Affine/Passes.h"
16#include "mlir/Dialect/Affine/Transforms/Transforms.h"
17#include "mlir/Dialect/Affine/Utils.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/IR/AffineExpr.h"
21#include "mlir/IR/Matchers.h"
22#include "mlir/IR/Operation.h"
23#include "mlir/Pass/Pass.h"
24#include "llvm/Support/Casting.h"
25#include "llvm/Support/Debug.h"
26
27namespace mlir {
28namespace affine {
29#define GEN_PASS_DEF_RAISEMEMREFDIALECT
30#include "mlir/Dialect/Affine/Passes.h.inc"
31} // namespace affine
32} // namespace mlir
33
34#define DEBUG_TYPE "raise-memref-to-affine"
35
36using namespace mlir;
37using namespace mlir::affine;
38
39namespace {
40
41/// Find the index of the given value in the `dims` list,
42/// and append it if it was not already in the list. The
43/// dims list is a list of symbols or dimensions of the
44/// affine map. Within the results of an affine map, they
45/// are identified by their index, which is why we need
46/// this function.
47static std::optional<size_t>
48findInListOrAdd(Value value, llvm::SmallVectorImpl<Value> &dims,
49 function_ref<bool(Value)> isValidElement) {
50
51 Value *loopIV = llvm::find(dims, value);
52 if (loopIV != dims.end()) {
53 // We found an IV that already has an index, return that index.
54 return {std::distance(dims.begin(), loopIV)};
55 }
56 if (isValidElement(value)) {
57 // This is a valid element for the dim/symbol list, push this as a
58 // parameter.
59 size_t idx = dims.size();
60 dims.push_back(Elt: value);
61 return idx;
62 }
63 return std::nullopt;
64}
65
66/// Convert a value to an affine expr if possible. Adds dims and symbols
67/// if needed.
68static AffineExpr toAffineExpr(Value value,
69 llvm::SmallVectorImpl<Value> &affineDims,
70 llvm::SmallVectorImpl<Value> &affineSymbols) {
71 using namespace matchers;
72 IntegerAttr::ValueType cst;
73 if (matchPattern(value, m_ConstantInt(&cst))) {
74 return getAffineConstantExpr(cst.getSExtValue(), value.getContext());
75 }
76
77 Operation *definingOp = value.getDefiningOp();
78 if (llvm::isa_and_nonnull<arith::AddIOp>(definingOp) ||
79 llvm::isa_and_nonnull<arith::MulIOp>(definingOp)) {
80 // TODO: replace recursion with explicit stack.
81 // For the moment this can be tolerated as we only recurse on
82 // arith.addi and arith.muli, so there cannot be any infinite
83 // recursion. The depth of these expressions should be in most
84 // cases very manageable, as affine expressions should be as
85 // simple as `a + b * c`.
86 AffineExpr lhsE =
87 toAffineExpr(value: definingOp->getOperand(idx: 0), affineDims, affineSymbols);
88 AffineExpr rhsE =
89 toAffineExpr(value: definingOp->getOperand(idx: 1), affineDims, affineSymbols);
90
91 if (lhsE && rhsE) {
92 AffineExprKind kind;
93 if (isa<arith::AddIOp>(definingOp)) {
94 kind = mlir::AffineExprKind::Add;
95 } else {
96 kind = mlir::AffineExprKind::Mul;
97
98 if (!lhsE.isSymbolicOrConstant() && !rhsE.isSymbolicOrConstant()) {
99 // This is not an affine expression, give up.
100 return {};
101 }
102 }
103 return getAffineBinaryOpExpr(kind, lhs: lhsE, rhs: rhsE);
104 }
105 return {};
106 }
107
108 if (auto dimIx = findInListOrAdd(value, affineSymbols, [](Value v) {
109 return affine::isValidSymbol(v);
110 })) {
111 return getAffineSymbolExpr(*dimIx, value.getContext());
112 }
113
114 if (auto dimIx = findInListOrAdd(
115 value, affineDims, [](Value v) { return affine::isValidDim(v); })) {
116
117 return getAffineDimExpr(*dimIx, value.getContext());
118 }
119
120 return {};
121}
122
123static LogicalResult
124computeAffineMapAndArgs(MLIRContext *ctx, ValueRange indices, AffineMap &map,
125 llvm::SmallVectorImpl<Value> &mapArgs) {
126 SmallVector<AffineExpr> results;
127 SmallVector<Value> symbols;
128 SmallVector<Value> dims;
129
130 for (Value indexExpr : indices) {
131 AffineExpr res = toAffineExpr(indexExpr, dims, symbols);
132 if (!res) {
133 return failure();
134 }
135 results.push_back(res);
136 }
137
138 map = AffineMap::get(dims.size(), symbols.size(), results, ctx);
139
140 dims.append(symbols);
141 mapArgs.swap(RHS&: dims);
142 return success();
143}
144
145struct RaiseMemrefDialect
146 : public affine::impl::RaiseMemrefDialectBase<RaiseMemrefDialect> {
147
148 void runOnOperation() override {
149 auto *ctx = &getContext();
150 Operation *op = getOperation();
151 IRRewriter rewriter(ctx);
152 AffineMap map;
153 SmallVector<Value> mapArgs;
154 op->walk([&](Operation *op) {
155 rewriter.setInsertionPoint(op);
156 if (auto store = llvm::dyn_cast_or_null<memref::StoreOp>(op)) {
157
158 if (succeeded(computeAffineMapAndArgs(ctx, store.getIndices(), map,
159 mapArgs))) {
160 rewriter.replaceOpWithNewOp<AffineStoreOp>(
161 op, store.getValueToStore(), store.getMemRef(), map, mapArgs);
162 return;
163 }
164
165 LLVM_DEBUG(llvm::dbgs()
166 << "[affine] Cannot raise memref op: " << op << "\n");
167
168 } else if (auto load = llvm::dyn_cast_or_null<memref::LoadOp>(op)) {
169 if (succeeded(computeAffineMapAndArgs(ctx, load.getIndices(), map,
170 mapArgs))) {
171 rewriter.replaceOpWithNewOp<AffineLoadOp>(op, load.getMemRef(), map,
172 mapArgs);
173 return;
174 }
175 LLVM_DEBUG(llvm::dbgs()
176 << "[affine] Cannot raise memref op: " << op << "\n");
177 }
178 });
179 }
180};
181
182} // namespace
183
184std::unique_ptr<OperationPass<func::FuncOp>>
185mlir::affine::createRaiseMemrefToAffine() {
186 return std::make_unique<RaiseMemrefDialect>();
187}
188

source code of mlir/lib/Dialect/Affine/Transforms/RaiseMemrefDialect.cpp