1//===- ExpandRealloc.cpp - Expand memref.realloc ops into it's components -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MemRef/Transforms/Passes.h"
10#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
11
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/MemRef/IR/MemRef.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Transforms/DialectConversion.h"
16
17namespace mlir {
18namespace memref {
19#define GEN_PASS_DEF_EXPANDREALLOC
20#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
21} // namespace memref
22} // namespace mlir
23
24using namespace mlir;
25
26namespace {
27
28/// The `realloc` operation performs a conditional allocation and copy to
29/// increase the size of a buffer if necessary. This pattern converts the
30/// `realloc` operation into this sequence of simpler operations.
31
32/// Example of an expansion:
33/// ```mlir
34/// %realloc = memref.realloc %alloc (%size) : memref<?xf32> to memref<?xf32>
35/// ```
36/// is expanded to
37/// ```mlir
38/// %c0 = arith.constant 0 : index
39/// %dim = memref.dim %alloc, %c0 : memref<?xf32>
40/// %is_old_smaller = arith.cmpi ult, %dim, %arg1
41/// %realloc = scf.if %is_old_smaller -> (memref<?xf32>) {
42/// %new_alloc = memref.alloc(%size) : memref<?xf32>
43/// %subview = memref.subview %new_alloc[0] [%dim] [1]
44/// memref.copy %alloc, %subview
45/// memref.dealloc %alloc
46/// scf.yield %alloc_0 : memref<?xf32>
47/// } else {
48/// %reinterpret_cast = memref.reinterpret_cast %alloc to
49/// offset: [0], sizes: [%size], strides: [1]
50/// scf.yield %reinterpret_cast : memref<?xf32>
51/// }
52/// ```
53struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
54 ExpandReallocOpPattern(MLIRContext *ctx, bool emitDeallocs)
55 : OpRewritePattern(ctx), emitDeallocs(emitDeallocs) {}
56
57 LogicalResult matchAndRewrite(memref::ReallocOp op,
58 PatternRewriter &rewriter) const final {
59 Location loc = op.getLoc();
60 assert(op.getType().getRank() == 1 &&
61 "result MemRef must have exactly one rank");
62 assert(op.getSource().getType().getRank() == 1 &&
63 "source MemRef must have exactly one rank");
64 assert(op.getType().getLayout().isIdentity() &&
65 "result MemRef must have identity layout (or none)");
66 assert(op.getSource().getType().getLayout().isIdentity() &&
67 "source MemRef must have identity layout (or none)");
68
69 // Get the size of the original buffer.
70 int64_t inputSize =
71 cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
72 OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
73 if (ShapedType::isDynamic(inputSize)) {
74 Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
75 rewriter.getIndexAttr(0));
76 currSize = rewriter.create<memref::DimOp>(loc, op.getSource(), dimZero)
77 .getResult();
78 }
79
80 // Get the requested size that the new buffer should have.
81 int64_t outputSize =
82 cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0);
83 OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
84 ? OpFoldResult{op.getDynamicResultSize()}
85 : rewriter.getIndexAttr(outputSize);
86
87 // Only allocate a new buffer and copy over the values in the old buffer if
88 // the old buffer is smaller than the requested size.
89 Value lhs = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: currSize);
90 Value rhs = getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: targetSize);
91 Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
92 lhs, rhs);
93 auto ifOp = rewriter.create<scf::IfOp>(
94 loc, cond,
95 [&](OpBuilder &builder, Location loc) {
96 // Allocate the new buffer. If it is a dynamic memref we need to pass
97 // an additional operand for the size at runtime, otherwise the static
98 // size is encoded in the result type.
99 SmallVector<Value> dynamicSizeOperands;
100 if (op.getDynamicResultSize())
101 dynamicSizeOperands.push_back(op.getDynamicResultSize());
102
103 Value newAlloc = builder.create<memref::AllocOp>(
104 loc, op.getResult().getType(), dynamicSizeOperands,
105 op.getAlignmentAttr());
106
107 // Take a subview of the new (bigger) buffer such that we can copy the
108 // old values over (the copy operation requires both operands to have
109 // the same shape).
110 Value subview = builder.create<memref::SubViewOp>(
111 loc, newAlloc, ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
112 ArrayRef<OpFoldResult>{currSize},
113 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
114 builder.create<memref::CopyOp>(loc, op.getSource(), subview);
115
116 // Insert the deallocation of the old buffer only if requested
117 // (enabled by default).
118 if (emitDeallocs)
119 builder.create<memref::DeallocOp>(loc, op.getSource());
120
121 builder.create<scf::YieldOp>(loc, newAlloc);
122 },
123 [&](OpBuilder &builder, Location loc) {
124 // We need to reinterpret-cast here because either the input or output
125 // type might be static, which means we need to cast from static to
126 // dynamic or vice-versa. If both are static and the original buffer
127 // is already bigger than the requested size, the cast represents a
128 // subview operation.
129 Value casted = builder.create<memref::ReinterpretCastOp>(
130 loc, cast<MemRefType>(op.getResult().getType()), op.getSource(),
131 rewriter.getIndexAttr(0), ArrayRef<OpFoldResult>{targetSize},
132 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
133 builder.create<scf::YieldOp>(loc, casted);
134 });
135
136 rewriter.replaceOp(op, ifOp.getResult(0));
137 return success();
138 }
139
140private:
141 const bool emitDeallocs;
142};
143
144struct ExpandReallocPass
145 : public memref::impl::ExpandReallocBase<ExpandReallocPass> {
146 ExpandReallocPass(bool emitDeallocs)
147 : memref::impl::ExpandReallocBase<ExpandReallocPass>() {
148 this->emitDeallocs.setValue(emitDeallocs);
149 }
150 void runOnOperation() override {
151 MLIRContext &ctx = getContext();
152
153 RewritePatternSet patterns(&ctx);
154 memref::populateExpandReallocPatterns(patterns, emitDeallocs.getValue());
155 ConversionTarget target(ctx);
156
157 target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
158 memref::MemRefDialect>();
159 target.addIllegalOp<memref::ReallocOp>();
160 if (failed(applyPartialConversion(getOperation(), target,
161 std::move(patterns))))
162 signalPassFailure();
163 }
164};
165
166} // namespace
167
168void mlir::memref::populateExpandReallocPatterns(RewritePatternSet &patterns,
169 bool emitDeallocs) {
170 patterns.add<ExpandReallocOpPattern>(arg: patterns.getContext(), args&: emitDeallocs);
171}
172
173std::unique_ptr<Pass> mlir::memref::createExpandReallocPass(bool emitDeallocs) {
174 return std::make_unique<ExpandReallocPass>(args&: emitDeallocs);
175}
176

source code of mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp