1//===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Bufferization dialect to MemRef
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
18#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/SCF/IR/SCF.h"
22#include "mlir/IR/BuiltinTypes.h"
23#include "mlir/Transforms/DialectConversion.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREFPASS
27#include "mlir/Conversion/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31
32namespace {
33/// The CloneOpConversion transforms all bufferization clone operations into
34/// memref alloc and memref copy operations. In the dynamic-shape case, it also
35/// emits additional dim and constant operations to determine the shape. This
36/// conversion does not resolve memory leaks if it is used alone.
37struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
38 using OpConversionPattern<bufferization::CloneOp>::OpConversionPattern;
39
40 LogicalResult
41 matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
42 ConversionPatternRewriter &rewriter) const override {
43 Location loc = op->getLoc();
44
45 Type type = op.getType();
46 Value alloc;
47
48 if (auto unrankedType = dyn_cast<UnrankedMemRefType>(Val&: type)) {
49 // Constants
50 Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
51 Value one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
52
53 // Dynamically evaluate the size and shape of the unranked memref
54 Value rank = rewriter.create<memref::RankOp>(location: loc, args: op.getInput());
55 MemRefType allocType =
56 MemRefType::get(shape: {ShapedType::kDynamic}, elementType: rewriter.getIndexType());
57 Value shape = rewriter.create<memref::AllocaOp>(location: loc, args&: allocType, args&: rank);
58
59 // Create a loop to query dimension sizes, store them as a shape, and
60 // compute the total size of the memref
61 auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
62 ValueRange args) {
63 auto acc = args.front();
64 auto dim = rewriter.create<memref::DimOp>(location: loc, args: op.getInput(), args&: i);
65
66 rewriter.create<memref::StoreOp>(location: loc, args&: dim, args&: shape, args&: i);
67 acc = rewriter.create<arith::MulIOp>(location: loc, args&: acc, args&: dim);
68
69 rewriter.create<scf::YieldOp>(location: loc, args&: acc);
70 };
71 auto size = rewriter
72 .create<scf::ForOp>(location: loc, args&: zero, args&: rank, args&: one, args: ValueRange(one),
73 args&: loopBody)
74 .getResult(i: 0);
75
76 MemRefType memrefType = MemRefType::get(shape: {ShapedType::kDynamic},
77 elementType: unrankedType.getElementType());
78
79 // Allocate new memref with 1D dynamic shape, then reshape into the
80 // shape of the original unranked memref
81 alloc = rewriter.create<memref::AllocOp>(location: loc, args&: memrefType, args&: size);
82 alloc =
83 rewriter.create<memref::ReshapeOp>(location: loc, args&: unrankedType, args&: alloc, args&: shape);
84 } else {
85 MemRefType memrefType = cast<MemRefType>(Val&: type);
86 MemRefLayoutAttrInterface layout;
87 auto allocType =
88 MemRefType::get(shape: memrefType.getShape(), elementType: memrefType.getElementType(),
89 layout, memorySpace: memrefType.getMemorySpace());
90 // Since this implementation always allocates, certain result types of
91 // the clone op cannot be lowered.
92 if (!memref::CastOp::areCastCompatible(inputs: {allocType}, outputs: {memrefType}))
93 return failure();
94
95 // Transform a clone operation into alloc + copy operation and pay
96 // attention to the shape dimensions.
97 SmallVector<Value, 4> dynamicOperands;
98 for (int i = 0; i < memrefType.getRank(); ++i) {
99 if (!memrefType.isDynamicDim(idx: i))
100 continue;
101 Value dim = rewriter.createOrFold<memref::DimOp>(location: loc, args: op.getInput(), args&: i);
102 dynamicOperands.push_back(Elt: dim);
103 }
104
105 // Allocate a memref with identity layout.
106 alloc = rewriter.create<memref::AllocOp>(location: loc, args&: allocType, args&: dynamicOperands);
107 // Cast the allocation to the specified type if needed.
108 if (memrefType != allocType)
109 alloc =
110 rewriter.create<memref::CastOp>(location: op->getLoc(), args&: memrefType, args&: alloc);
111 }
112
113 rewriter.create<memref::CopyOp>(location: loc, args: op.getInput(), args&: alloc);
114 rewriter.replaceOp(op, newValues: alloc);
115 return success();
116 }
117};
118
119} // namespace
120
121namespace {
122struct BufferizationToMemRefPass
123 : public impl::ConvertBufferizationToMemRefPassBase<
124 BufferizationToMemRefPass> {
125 BufferizationToMemRefPass() = default;
126
127 void runOnOperation() override {
128 if (!isa<ModuleOp, FunctionOpInterface>(Val: getOperation())) {
129 emitError(loc: getOperation()->getLoc(),
130 message: "root operation must be a builtin.module or a function");
131 signalPassFailure();
132 return;
133 }
134
135 bufferization::DeallocHelperMap deallocHelperFuncMap;
136 if (auto module = dyn_cast<ModuleOp>(Val: getOperation())) {
137 OpBuilder builder = OpBuilder::atBlockBegin(block: module.getBody());
138
139 // Build dealloc helper function if there are deallocs.
140 getOperation()->walk(callback: [&](bufferization::DeallocOp deallocOp) {
141 Operation *symtableOp =
142 deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
143 if (deallocOp.getMemrefs().size() > 1 &&
144 !deallocHelperFuncMap.contains(Val: symtableOp)) {
145 SymbolTable symbolTable(symtableOp);
146 func::FuncOp helperFuncOp =
147 bufferization::buildDeallocationLibraryFunction(
148 builder, loc: getOperation()->getLoc(), symbolTable);
149 deallocHelperFuncMap[symtableOp] = helperFuncOp;
150 }
151 });
152 }
153
154 RewritePatternSet patterns(&getContext());
155 patterns.add<CloneOpConversion>(arg: patterns.getContext());
156 bufferization::populateBufferizationDeallocLoweringPattern(
157 patterns, deallocHelperFuncMap);
158
159 ConversionTarget target(getContext());
160 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
161 scf::SCFDialect, func::FuncDialect>();
162 target.addIllegalDialect<bufferization::BufferizationDialect>();
163
164 if (failed(Result: applyPartialConversion(op: getOperation(), target,
165 patterns: std::move(patterns))))
166 signalPassFailure();
167 }
168};
169} // namespace
170

source code of mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp