1//===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the MemRef dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/Vector/IR/VectorOps.h"
18
19namespace mlir {
20namespace memref {
21
22bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
23 if (!type.hasStaticShape())
24 return false;
25
26 SmallVector<int64_t> strides;
27 int64_t offset;
28 if (failed(getStridesAndOffset(type, strides, offset)))
29 return false;
30
31 // MemRef is contiguous if outer dimensions are size-1 and inner
32 // dimensions have unit strides.
33 int64_t runningStride = 1;
34 int64_t curDim = strides.size() - 1;
35 // Finds all inner dimensions with unit strides.
36 while (curDim >= 0 && strides[curDim] == runningStride) {
37 runningStride *= type.getDimSize(curDim);
38 --curDim;
39 }
40
41 // Check if other dimensions are size-1.
42 while (curDim >= 0 && type.getDimSize(curDim) == 1) {
43 --curDim;
44 }
45
46 // All dims are unit-strided or size-1.
47 return curDim < 0;
48}
49
50std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
51 OpBuilder &builder, Location loc, int srcBits, int dstBits,
52 OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
53 ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices) {
54 unsigned sourceRank = sizes.size();
55 assert(sizes.size() == strides.size() &&
56 "expected as many sizes as strides for a memref");
57 SmallVector<OpFoldResult> indicesVec = llvm::to_vector(Range&: indices);
58 if (indices.empty())
59 indicesVec.resize(sourceRank, builder.getIndexAttr(0));
60 assert(indicesVec.size() == strides.size() &&
61 "expected as many indices as rank of memref");
62
63 // Create the affine symbols and values for linearization.
64 SmallVector<AffineExpr> symbols(2 * sourceRank);
65 bindSymbolsList(ctx: builder.getContext(), exprs: MutableArrayRef{symbols});
66 AffineExpr addMulMap = builder.getAffineConstantExpr(constant: 0);
67 AffineExpr mulMap = builder.getAffineConstantExpr(constant: 1);
68
69 SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
70 SmallVector<OpFoldResult> sizeValues(sourceRank);
71
72 for (unsigned i = 0; i < sourceRank; ++i) {
73 unsigned offsetIdx = 2 * i;
74 addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
75 offsetValues[offsetIdx] = indicesVec[i];
76 offsetValues[offsetIdx + 1] = strides[i];
77
78 mulMap = mulMap * symbols[i];
79 }
80
81 // Adjust linearizedIndices, size and offset by the scale factor (dstBits /
82 // srcBits).
83 int64_t scaler = dstBits / srcBits;
84 addMulMap = addMulMap.floorDiv(v: scaler);
85 mulMap = mulMap.floorDiv(v: scaler);
86
87 OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
88 b&: builder, loc, expr: addMulMap, operands: offsetValues);
89 OpFoldResult linearizedSize =
90 affine::makeComposedFoldedAffineApply(b&: builder, loc, expr: mulMap, operands: sizes);
91
92 // Adjust baseOffset by the scale factor (dstBits / srcBits).
93 AffineExpr s0;
94 bindSymbols(ctx: builder.getContext(), exprs&: s0);
95 OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
96 b&: builder, loc, expr: s0.floorDiv(v: scaler), operands: {offset});
97
98 return {{.linearizedOffset: adjustBaseOffset, .linearizedSize: linearizedSize}, linearizedIndices};
99}
100
101LinearizedMemRefInfo
102getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
103 int dstBits, OpFoldResult offset,
104 ArrayRef<OpFoldResult> sizes) {
105 SmallVector<OpFoldResult> strides(sizes.size());
106 if (!sizes.empty()) {
107 strides.back() = builder.getIndexAttr(1);
108 AffineExpr s0, s1;
109 bindSymbols(ctx: builder.getContext(), exprs&: s0, exprs&: s1);
110 for (int index = sizes.size() - 1; index > 0; --index) {
111 strides[index - 1] = affine::makeComposedFoldedAffineApply(
112 b&: builder, loc, expr: s0 * s1,
113 operands: ArrayRef<OpFoldResult>{strides[index], sizes[index]});
114 }
115 }
116
117 LinearizedMemRefInfo linearizedMemRefInfo;
118 std::tie(args&: linearizedMemRefInfo, args: std::ignore) =
119 getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset,
120 sizes, strides);
121 return linearizedMemRefInfo;
122}
123
124/// Returns true if all the uses of op are not read/load.
125/// There can be SubviewOp users as long as all its users are also
126/// StoreOp/transfer_write. If return true it also fills out the uses, if it
127/// returns false uses is unchanged.
128static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
129 std::vector<Operation *> opUses;
130 for (OpOperand &use : op->getUses()) {
131 Operation *useOp = use.getOwner();
132 if (isa<memref::DeallocOp>(useOp) ||
133 (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
134 !mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
135 (isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) {
136 opUses.push_back(x: useOp);
137 continue;
138 }
139 return false;
140 }
141 uses.insert(position: uses.end(), first: opUses.begin(), last: opUses.end());
142 return true;
143}
144
145void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) {
146 std::vector<Operation *> opToErase;
147 parentOp->walk(callback: [&](memref::AllocOp op) {
148 std::vector<Operation *> candidates;
149 if (resultIsNotRead(op, candidates)) {
150 opToErase.insert(position: opToErase.end(), first: candidates.begin(), last: candidates.end());
151 opToErase.push_back(op.getOperation());
152 }
153 });
154 for (Operation *op : opToErase)
155 rewriter.eraseOp(op);
156}
157
158} // namespace memref
159} // namespace mlir
160

source code of mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp