1//===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the MemRef dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Interfaces/ViewLikeInterface.h"
17#include "llvm/ADT/STLExtras.h"
18
19namespace mlir {
20namespace memref {
21
22bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
23 if (!type.hasStaticShape())
24 return false;
25
26 SmallVector<int64_t> strides;
27 int64_t offset;
28 if (failed(Result: type.getStridesAndOffset(strides, offset)))
29 return false;
30
31 // MemRef is contiguous if outer dimensions are size-1 and inner
32 // dimensions have unit strides.
33 int64_t runningStride = 1;
34 int64_t curDim = strides.size() - 1;
35 // Finds all inner dimensions with unit strides.
36 while (curDim >= 0 && strides[curDim] == runningStride) {
37 runningStride *= type.getDimSize(idx: curDim);
38 --curDim;
39 }
40
41 // Check if other dimensions are size-1.
42 while (curDim >= 0 && type.getDimSize(idx: curDim) == 1) {
43 --curDim;
44 }
45
46 // All dims are unit-strided or size-1.
47 return curDim < 0;
48}
49
50std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
51 OpBuilder &builder, Location loc, int srcBits, int dstBits,
52 OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
53 ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices) {
54 unsigned sourceRank = sizes.size();
55 assert(sizes.size() == strides.size() &&
56 "expected as many sizes as strides for a memref");
57 SmallVector<OpFoldResult> indicesVec = llvm::to_vector(Range&: indices);
58 if (indices.empty())
59 indicesVec.resize(N: sourceRank, NV: builder.getIndexAttr(value: 0));
60 assert(indicesVec.size() == strides.size() &&
61 "expected as many indices as rank of memref");
62
63 // Create the affine symbols and values for linearization.
64 SmallVector<AffineExpr> symbols(2 * sourceRank);
65 bindSymbolsList(ctx: builder.getContext(), exprs: MutableArrayRef{symbols});
66 AffineExpr addMulMap = builder.getAffineConstantExpr(constant: 0);
67
68 SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
69
70 for (unsigned i = 0; i < sourceRank; ++i) {
71 unsigned offsetIdx = 2 * i;
72 addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
73 offsetValues[offsetIdx] = indicesVec[i];
74 offsetValues[offsetIdx + 1] = strides[i];
75 }
76 // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
77 int64_t scaler = dstBits / srcBits;
78 OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
79 b&: builder, loc, expr: addMulMap.floorDiv(v: scaler), operands: offsetValues);
80
81 size_t symbolIndex = 0;
82 SmallVector<OpFoldResult> values;
83 SmallVector<AffineExpr> productExpressions;
84 for (unsigned i = 0; i < sourceRank; ++i) {
85 AffineExpr strideExpr = symbols[symbolIndex++];
86 values.push_back(Elt: strides[i]);
87 AffineExpr sizeExpr = symbols[symbolIndex++];
88 values.push_back(Elt: sizes[i]);
89
90 productExpressions.push_back(Elt: (strideExpr * sizeExpr).floorDiv(v: scaler));
91 }
92 AffineMap maxMap = AffineMap::get(
93 /*dimCount=*/0, /*symbolCount=*/symbolIndex, results: productExpressions,
94 context: builder.getContext());
95 OpFoldResult linearizedSize =
96 affine::makeComposedFoldedAffineMax(b&: builder, loc, map: maxMap, operands: values);
97
98 // Adjust baseOffset by the scale factor (dstBits / srcBits).
99 AffineExpr s0;
100 bindSymbols(ctx: builder.getContext(), exprs&: s0);
101 OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
102 b&: builder, loc, expr: s0.floorDiv(v: scaler), operands: {offset});
103
104 OpFoldResult intraVectorOffset = affine::makeComposedFoldedAffineApply(
105 b&: builder, loc, expr: addMulMap % scaler, operands: offsetValues);
106
107 return {{.linearizedOffset: adjustBaseOffset, .linearizedSize: linearizedSize, .intraDataOffset: intraVectorOffset},
108 linearizedIndices};
109}
110
111LinearizedMemRefInfo
112getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
113 int dstBits, OpFoldResult offset,
114 ArrayRef<OpFoldResult> sizes) {
115 SmallVector<OpFoldResult> strides(sizes.size());
116 if (!sizes.empty()) {
117 strides.back() = builder.getIndexAttr(value: 1);
118 AffineExpr s0, s1;
119 bindSymbols(ctx: builder.getContext(), exprs&: s0, exprs&: s1);
120 for (int index = sizes.size() - 1; index > 0; --index) {
121 strides[index - 1] = affine::makeComposedFoldedAffineApply(
122 b&: builder, loc, expr: s0 * s1,
123 operands: ArrayRef<OpFoldResult>{strides[index], sizes[index]});
124 }
125 }
126
127 LinearizedMemRefInfo linearizedMemRefInfo;
128 std::tie(args&: linearizedMemRefInfo, args: std::ignore) =
129 getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset,
130 sizes, strides);
131 return linearizedMemRefInfo;
132}
133
134/// Returns true if all the uses of op are not read/load.
135/// There can be SubviewOp users as long as all its users are also
136/// StoreOp/transfer_write. If return true it also fills out the uses, if it
137/// returns false uses is unchanged.
138static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
139 std::vector<Operation *> opUses;
140 for (OpOperand &use : op->getUses()) {
141 Operation *useOp = use.getOwner();
142 if (isa<memref::DeallocOp>(Val: useOp) ||
143 (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
144 !mlir::hasEffect<MemoryEffects::Read>(op: useOp)) ||
145 (isa<memref::SubViewOp>(Val: useOp) && resultIsNotRead(op: useOp, uses&: opUses))) {
146 opUses.push_back(x: useOp);
147 continue;
148 }
149 return false;
150 }
151 llvm::append_range(C&: uses, R&: opUses);
152 return true;
153}
154
155void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) {
156 std::vector<Operation *> opToErase;
157 parentOp->walk(callback: [&](Operation *op) {
158 std::vector<Operation *> candidates;
159 if (isa<memref::AllocOp, memref::AllocaOp>(Val: op) &&
160 resultIsNotRead(op, uses&: candidates)) {
161 llvm::append_range(C&: opToErase, R&: candidates);
162 opToErase.push_back(x: op);
163 }
164 });
165
166 for (Operation *op : opToErase)
167 rewriter.eraseOp(op);
168}
169
170static SmallVector<OpFoldResult>
171computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder,
172 ArrayRef<OpFoldResult> sizes,
173 OpFoldResult unit) {
174 SmallVector<OpFoldResult> strides(sizes.size(), unit);
175 AffineExpr s0, s1;
176 bindSymbols(ctx: builder.getContext(), exprs&: s0, exprs&: s1);
177
178 for (int64_t r = strides.size() - 1; r > 0; --r) {
179 strides[r - 1] = affine::makeComposedFoldedAffineApply(
180 b&: builder, loc, expr: s0 * s1, operands: {strides[r], sizes[r]});
181 }
182 return strides;
183}
184
185SmallVector<OpFoldResult>
186computeSuffixProductIRBlock(Location loc, OpBuilder &builder,
187 ArrayRef<OpFoldResult> sizes) {
188 OpFoldResult unit = builder.getIndexAttr(value: 1);
189 return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
190}
191
192MemrefValue skipFullyAliasingOperations(MemrefValue source) {
193 while (auto op = source.getDefiningOp()) {
194 if (auto subViewOp = dyn_cast<memref::SubViewOp>(Val: op);
195 subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
196 // A `memref.subview` with an all zero offset, and all unit strides, still
197 // points to the same memory.
198 source = cast<MemrefValue>(Val: subViewOp.getSource());
199 } else if (auto castOp = dyn_cast<memref::CastOp>(Val: op)) {
200 // A `memref.cast` still points to the same memory.
201 source = castOp.getSource();
202 } else {
203 return source;
204 }
205 }
206 return source;
207}
208
209MemrefValue skipViewLikeOps(MemrefValue source) {
210 while (auto op = source.getDefiningOp()) {
211 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(Val: op)) {
212 source = cast<MemrefValue>(Val: viewLike.getViewSource());
213 continue;
214 }
215 return source;
216 }
217 return source;
218}
219
220} // namespace memref
221} // namespace mlir
222

source code of mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp