1 | //===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for the MemRef dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
18 | |
19 | namespace mlir { |
20 | namespace memref { |
21 | |
22 | bool isStaticShapeAndContiguousRowMajor(MemRefType type) { |
23 | if (!type.hasStaticShape()) |
24 | return false; |
25 | |
26 | SmallVector<int64_t> strides; |
27 | int64_t offset; |
28 | if (failed(getStridesAndOffset(type, strides, offset))) |
29 | return false; |
30 | |
31 | // MemRef is contiguous if outer dimensions are size-1 and inner |
32 | // dimensions have unit strides. |
33 | int64_t runningStride = 1; |
34 | int64_t curDim = strides.size() - 1; |
35 | // Finds all inner dimensions with unit strides. |
36 | while (curDim >= 0 && strides[curDim] == runningStride) { |
37 | runningStride *= type.getDimSize(curDim); |
38 | --curDim; |
39 | } |
40 | |
41 | // Check if other dimensions are size-1. |
42 | while (curDim >= 0 && type.getDimSize(curDim) == 1) { |
43 | --curDim; |
44 | } |
45 | |
46 | // All dims are unit-strided or size-1. |
47 | return curDim < 0; |
48 | } |
49 | |
50 | std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize( |
51 | OpBuilder &builder, Location loc, int srcBits, int dstBits, |
52 | OpFoldResult offset, ArrayRef<OpFoldResult> sizes, |
53 | ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices) { |
54 | unsigned sourceRank = sizes.size(); |
55 | assert(sizes.size() == strides.size() && |
56 | "expected as many sizes as strides for a memref" ); |
57 | SmallVector<OpFoldResult> indicesVec = llvm::to_vector(Range&: indices); |
58 | if (indices.empty()) |
59 | indicesVec.resize(sourceRank, builder.getIndexAttr(0)); |
60 | assert(indicesVec.size() == strides.size() && |
61 | "expected as many indices as rank of memref" ); |
62 | |
63 | // Create the affine symbols and values for linearization. |
64 | SmallVector<AffineExpr> symbols(2 * sourceRank); |
65 | bindSymbolsList(ctx: builder.getContext(), exprs: MutableArrayRef{symbols}); |
66 | AffineExpr addMulMap = builder.getAffineConstantExpr(constant: 0); |
67 | AffineExpr mulMap = builder.getAffineConstantExpr(constant: 1); |
68 | |
69 | SmallVector<OpFoldResult> offsetValues(2 * sourceRank); |
70 | SmallVector<OpFoldResult> sizeValues(sourceRank); |
71 | |
72 | for (unsigned i = 0; i < sourceRank; ++i) { |
73 | unsigned offsetIdx = 2 * i; |
74 | addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1]; |
75 | offsetValues[offsetIdx] = indicesVec[i]; |
76 | offsetValues[offsetIdx + 1] = strides[i]; |
77 | |
78 | mulMap = mulMap * symbols[i]; |
79 | } |
80 | |
81 | // Adjust linearizedIndices, size and offset by the scale factor (dstBits / |
82 | // srcBits). |
83 | int64_t scaler = dstBits / srcBits; |
84 | addMulMap = addMulMap.floorDiv(v: scaler); |
85 | mulMap = mulMap.floorDiv(v: scaler); |
86 | |
87 | OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply( |
88 | b&: builder, loc, expr: addMulMap, operands: offsetValues); |
89 | OpFoldResult linearizedSize = |
90 | affine::makeComposedFoldedAffineApply(b&: builder, loc, expr: mulMap, operands: sizes); |
91 | |
92 | // Adjust baseOffset by the scale factor (dstBits / srcBits). |
93 | AffineExpr s0; |
94 | bindSymbols(ctx: builder.getContext(), exprs&: s0); |
95 | OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply( |
96 | b&: builder, loc, expr: s0.floorDiv(v: scaler), operands: {offset}); |
97 | |
98 | return {{.linearizedOffset: adjustBaseOffset, .linearizedSize: linearizedSize}, linearizedIndices}; |
99 | } |
100 | |
101 | LinearizedMemRefInfo |
102 | getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, |
103 | int dstBits, OpFoldResult offset, |
104 | ArrayRef<OpFoldResult> sizes) { |
105 | SmallVector<OpFoldResult> strides(sizes.size()); |
106 | if (!sizes.empty()) { |
107 | strides.back() = builder.getIndexAttr(1); |
108 | AffineExpr s0, s1; |
109 | bindSymbols(ctx: builder.getContext(), exprs&: s0, exprs&: s1); |
110 | for (int index = sizes.size() - 1; index > 0; --index) { |
111 | strides[index - 1] = affine::makeComposedFoldedAffineApply( |
112 | b&: builder, loc, expr: s0 * s1, |
113 | operands: ArrayRef<OpFoldResult>{strides[index], sizes[index]}); |
114 | } |
115 | } |
116 | |
117 | LinearizedMemRefInfo linearizedMemRefInfo; |
118 | std::tie(args&: linearizedMemRefInfo, args: std::ignore) = |
119 | getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset, |
120 | sizes, strides); |
121 | return linearizedMemRefInfo; |
122 | } |
123 | |
124 | /// Returns true if all the uses of op are not read/load. |
125 | /// There can be SubviewOp users as long as all its users are also |
126 | /// StoreOp/transfer_write. If return true it also fills out the uses, if it |
127 | /// returns false uses is unchanged. |
128 | static bool resultIsNotRead(Operation *op, std::vector<Operation *> &uses) { |
129 | std::vector<Operation *> opUses; |
130 | for (OpOperand &use : op->getUses()) { |
131 | Operation *useOp = use.getOwner(); |
132 | if (isa<memref::DeallocOp>(useOp) || |
133 | (useOp->getNumResults() == 0 && useOp->getNumRegions() == 0 && |
134 | !mlir::hasEffect<MemoryEffects::Read>(useOp)) || |
135 | (isa<memref::SubViewOp>(useOp) && resultIsNotRead(useOp, opUses))) { |
136 | opUses.push_back(x: useOp); |
137 | continue; |
138 | } |
139 | return false; |
140 | } |
141 | uses.insert(position: uses.end(), first: opUses.begin(), last: opUses.end()); |
142 | return true; |
143 | } |
144 | |
145 | void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) { |
146 | std::vector<Operation *> opToErase; |
147 | parentOp->walk(callback: [&](memref::AllocOp op) { |
148 | std::vector<Operation *> candidates; |
149 | if (resultIsNotRead(op, candidates)) { |
150 | opToErase.insert(position: opToErase.end(), first: candidates.begin(), last: candidates.end()); |
151 | opToErase.push_back(op.getOperation()); |
152 | } |
153 | }); |
154 | for (Operation *op : opToErase) |
155 | rewriter.eraseOp(op); |
156 | } |
157 | |
158 | } // namespace memref |
159 | } // namespace mlir |
160 | |