1//===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the MemRef dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
14#include "mlir/Dialect/Affine/IR/AffineOps.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/Vector/IR/VectorOps.h"
18#include "mlir/Interfaces/ViewLikeInterface.h"
19#include "llvm/ADT/STLExtras.h"
20
21namespace mlir {
22namespace memref {
23
24bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
25 if (!type.hasStaticShape())
26 return false;
27
28 SmallVector<int64_t> strides;
29 int64_t offset;
30 if (failed(type.getStridesAndOffset(strides, offset)))
31 return false;
32
33 // MemRef is contiguous if outer dimensions are size-1 and inner
34 // dimensions have unit strides.
35 int64_t runningStride = 1;
36 int64_t curDim = strides.size() - 1;
37 // Finds all inner dimensions with unit strides.
38 while (curDim >= 0 && strides[curDim] == runningStride) {
39 runningStride *= type.getDimSize(curDim);
40 --curDim;
41 }
42
43 // Check if other dimensions are size-1.
44 while (curDim >= 0 && type.getDimSize(curDim) == 1) {
45 --curDim;
46 }
47
48 // All dims are unit-strided or size-1.
49 return curDim < 0;
50}
51
52std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
53 OpBuilder &builder, Location loc, int srcBits, int dstBits,
54 OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
55 ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices) {
56 unsigned sourceRank = sizes.size();
57 assert(sizes.size() == strides.size() &&
58 "expected as many sizes as strides for a memref");
59 SmallVector<OpFoldResult> indicesVec = llvm::to_vector(Range&: indices);
60 if (indices.empty())
61 indicesVec.resize(sourceRank, builder.getIndexAttr(0));
62 assert(indicesVec.size() == strides.size() &&
63 "expected as many indices as rank of memref");
64
65 // Create the affine symbols and values for linearization.
66 SmallVector<AffineExpr> symbols(2 * sourceRank);
67 bindSymbolsList(ctx: builder.getContext(), exprs: MutableArrayRef{symbols});
68 AffineExpr addMulMap = builder.getAffineConstantExpr(constant: 0);
69
70 SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
71
72 for (unsigned i = 0; i < sourceRank; ++i) {
73 unsigned offsetIdx = 2 * i;
74 addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
75 offsetValues[offsetIdx] = indicesVec[i];
76 offsetValues[offsetIdx + 1] = strides[i];
77 }
78 // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
79 int64_t scaler = dstBits / srcBits;
80 OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
81 b&: builder, loc, expr: addMulMap.floorDiv(v: scaler), operands: offsetValues);
82
83 size_t symbolIndex = 0;
84 SmallVector<OpFoldResult> values;
85 SmallVector<AffineExpr> productExpressions;
86 for (unsigned i = 0; i < sourceRank; ++i) {
87 AffineExpr strideExpr = symbols[symbolIndex++];
88 values.push_back(Elt: strides[i]);
89 AffineExpr sizeExpr = symbols[symbolIndex++];
90 values.push_back(Elt: sizes[i]);
91
92 productExpressions.push_back(Elt: (strideExpr * sizeExpr).floorDiv(v: scaler));
93 }
94 AffineMap maxMap = AffineMap::get(
95 /*dimCount=*/0, /*symbolCount=*/symbolIndex, results: productExpressions,
96 context: builder.getContext());
97 OpFoldResult linearizedSize =
98 affine::makeComposedFoldedAffineMax(b&: builder, loc, map: maxMap, operands: values);
99
100 // Adjust baseOffset by the scale factor (dstBits / srcBits).
101 AffineExpr s0;
102 bindSymbols(ctx: builder.getContext(), exprs&: s0);
103 OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
104 b&: builder, loc, expr: s0.floorDiv(v: scaler), operands: {offset});
105
106 OpFoldResult intraVectorOffset = affine::makeComposedFoldedAffineApply(
107 b&: builder, loc, expr: addMulMap % scaler, operands: offsetValues);
108
109 return {{.linearizedOffset: adjustBaseOffset, .linearizedSize: linearizedSize, .intraDataOffset: intraVectorOffset},
110 linearizedIndices};
111}
112
113LinearizedMemRefInfo
114getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
115 int dstBits, OpFoldResult offset,
116 ArrayRef<OpFoldResult> sizes) {
117 SmallVector<OpFoldResult> strides(sizes.size());
118 if (!sizes.empty()) {
119 strides.back() = builder.getIndexAttr(1);
120 AffineExpr s0, s1;
121 bindSymbols(ctx: builder.getContext(), exprs&: s0, exprs&: s1);
122 for (int index = sizes.size() - 1; index > 0; --index) {
123 strides[index - 1] = affine::makeComposedFoldedAffineApply(
124 b&: builder, loc, expr: s0 * s1,
125 operands: ArrayRef<OpFoldResult>{strides[index], sizes[index]});
126 }
127 }
128
129 LinearizedMemRefInfo linearizedMemRefInfo;
130 std::tie(args&: linearizedMemRefInfo, args: std::ignore) =
131 getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset,
132 sizes, strides);
133 return linearizedMemRefInfo;
134}
135
136/// Returns true if all the uses of op are not read/load.
137/// There can be SubviewOp users as long as all its users are also
138/// StoreOp/transfer_write. If return true it also fills out the uses, if it
139/// returns false uses is unchanged.
140static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) {
141 std::vector<Operation *> opUses;
142 for (OpOperand &use : op->getUses()) {
143 Operation *useOp = use.getOwner();
144 if (isa<memref::DeallocOp>(useOp) ||
145 (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 &&
146 !mlir::hasEffect<MemoryEffects::Read>(useOp)) ||
147 (isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) {
148 opUses.push_back(x: useOp);
149 continue;
150 }
151 return false;
152 }
153 llvm::append_range(C&: uses, R&: opUses);
154 return true;
155}
156
157void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) {
158 std::vector<Operation *> opToErase;
159 parentOp->walk(callback: [&](memref::AllocOp op) {
160 std::vector<Operation *> candidates;
161 if (resultIsNotRead(op, candidates)) {
162 llvm::append_range(C&: opToErase, R&: candidates);
163 opToErase.push_back(op.getOperation());
164 }
165 });
166 for (Operation *op : opToErase)
167 rewriter.eraseOp(op);
168}
169
170static SmallVector<OpFoldResult>
171computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder,
172 ArrayRef<OpFoldResult> sizes,
173 OpFoldResult unit) {
174 SmallVector<OpFoldResult> strides(sizes.size(), unit);
175 AffineExpr s0, s1;
176 bindSymbols(ctx: builder.getContext(), exprs&: s0, exprs&: s1);
177
178 for (int64_t r = strides.size() - 1; r > 0; --r) {
179 strides[r - 1] = affine::makeComposedFoldedAffineApply(
180 b&: builder, loc, expr: s0 * s1, operands: {strides[r], sizes[r]});
181 }
182 return strides;
183}
184
185SmallVector<OpFoldResult>
186computeSuffixProductIRBlock(Location loc, OpBuilder &builder,
187 ArrayRef<OpFoldResult> sizes) {
188 OpFoldResult unit = builder.getIndexAttr(1);
189 return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
190}
191
192MemrefValue skipFullyAliasingOperations(MemrefValue source) {
193 while (auto op = source.getDefiningOp()) {
194 if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
195 subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
196 // A `memref.subview` with an all zero offset, and all unit strides, still
197 // points to the same memory.
198 source = cast<MemrefValue>(subViewOp.getSource());
199 } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
200 // A `memref.cast` still points to the same memory.
201 source = castOp.getSource();
202 } else {
203 return source;
204 }
205 }
206 return source;
207}
208
209MemrefValue skipViewLikeOps(MemrefValue source) {
210 while (auto op = source.getDefiningOp()) {
211 if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
212 source = cast<MemrefValue>(viewLike.getViewSource());
213 continue;
214 }
215 return source;
216 }
217 return source;
218}
219
220} // namespace memref
221} // namespace mlir
222

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp