1//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
10
11#include "mlir/Analysis/DataLayoutAnalysis.h"
12#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/MemRef/Transforms/Passes.h"
17#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
18#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
20#include "mlir/Dialect/SCF/IR/SCF.h"
21#include "mlir/Dialect/Transform/IR/TransformDialect.h"
22#include "mlir/Dialect/Transform/IR/TransformTypes.h"
23#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
24#include "mlir/Dialect/Vector/IR/VectorOps.h"
25#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
26#include "mlir/Interfaces/LoopLikeInterface.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/Support/Debug.h"
29
30using namespace mlir;
31
32#define DEBUG_TYPE "memref-transforms"
33#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
34
35//===----------------------------------------------------------------------===//
36// Apply...ConversionPatternsOp
37//===----------------------------------------------------------------------===//
38
39std::unique_ptr<TypeConverter>
40transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
41 LowerToLLVMOptions options(getContext());
42 options.allocLowering =
43 (getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
44 : LowerToLLVMOptions::AllocLowering::Malloc);
45 options.useGenericFunctions = getUseGenericFunctions();
46
47 if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout)
48 options.overrideIndexBitwidth(getIndexBitwidth());
49
50 // TODO: the following two options don't really make sense for
51 // memref_to_llvm_type_converter specifically but we should have a single
52 // to_llvm_type_converter.
53 if (getDataLayout().has_value())
54 options.dataLayout = llvm::DataLayout(getDataLayout().value());
55 options.useBarePtrCallConv = getUseBarePtrCallConv();
56
57 return std::make_unique<LLVMTypeConverter>(getContext(), options);
58}
59
60StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
61 return "LLVMTypeConverter";
62}
63
64//===----------------------------------------------------------------------===//
65// Apply...PatternsOp
66//===----------------------------------------------------------------------===//
67
68namespace {
69class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
70public:
71 explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
72 : OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
73 dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
74
75 LogicalResult matchAndRewrite(memref::AllocOp op,
76 PatternRewriter &rewriter) const override {
77 return success(memref::allocToAlloca(
78 rewriter, alloc: op, filter: [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
79 MemRefType type = alloc.getMemref().getType();
80 if (!type.hasStaticShape())
81 return false;
82
83 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(operation: alloc);
84 int64_t elementSize = dataLayout.getTypeSize(t: type.getElementType());
85 return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
86 }));
87 }
88
89private:
90 DataLayoutAnalysis dataLayoutAnalysis;
91 int64_t maxSize;
92};
93} // namespace
94
95void transform::ApplyAllocToAllocaOp::populatePatterns(
96 RewritePatternSet &patterns) {}
97
98void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
99 RewritePatternSet &patterns, transform::TransformState &state) {
100 patterns.insert<AllocToAllocaPattern>(
101 state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
102}
103
104void transform::ApplyExpandOpsPatternsOp::populatePatterns(
105 RewritePatternSet &patterns) {
106 memref::populateExpandOpsPatterns(patterns);
107}
108
109void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
110 RewritePatternSet &patterns) {
111 memref::populateExpandStridedMetadataPatterns(patterns);
112}
113
114void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
115 RewritePatternSet &patterns) {
116 memref::populateExtractAddressComputationsPatterns(patterns);
117}
118
119void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
120 RewritePatternSet &patterns) {
121 memref::populateFoldMemRefAliasOpPatterns(patterns);
122}
123
124void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
125 populatePatterns(RewritePatternSet &patterns) {
126 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
127}
128
129//===----------------------------------------------------------------------===//
130// AllocaToGlobalOp
131//===----------------------------------------------------------------------===//
132
133DiagnosedSilenceableFailure
134transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
135 transform::TransformResults &results,
136 transform::TransformState &state) {
137 auto allocaOps = state.getPayloadOps(getAlloca());
138
139 SmallVector<memref::GlobalOp> globalOps;
140 SmallVector<memref::GetGlobalOp> getGlobalOps;
141
142 // Transform `memref.alloca`s.
143 for (auto *op : allocaOps) {
144 auto alloca = cast<memref::AllocaOp>(op);
145 MLIRContext *ctx = rewriter.getContext();
146 Location loc = alloca->getLoc();
147
148 memref::GlobalOp globalOp;
149 {
150 // Find nearest symbol table.
151 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
152 assert(symbolTableOp && "expected alloca payload to be in symbol table");
153 SymbolTable symbolTable(symbolTableOp);
154
155 // Insert a `memref.global` into the symbol table.
156 Type resultType = alloca.getResult().getType();
157 OpBuilder builder(rewriter.getContext());
158 // TODO: Add a better builder for this.
159 globalOp = builder.create<memref::GlobalOp>(
160 loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
161 TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
162 symbolTable.insert(globalOp);
163 }
164
165 // Replace the `memref.alloca` with a `memref.get_global` accessing the
166 // global symbol inserted above.
167 rewriter.setInsertionPoint(alloca);
168 auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
169 alloca, globalOp.getType(), globalOp.getName());
170
171 globalOps.push_back(globalOp);
172 getGlobalOps.push_back(getGlobalOp);
173 }
174
175 // Assemble results.
176 results.set(cast<OpResult>(getGlobal()), globalOps);
177 results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
178
179 return DiagnosedSilenceableFailure::success();
180}
181
182void transform::MemRefAllocaToGlobalOp::getEffects(
183 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
184 producesHandle(getGlobal(), effects);
185 producesHandle(getGetGlobal(), effects);
186 consumesHandle(getAlloca(), effects);
187 modifiesPayload(effects);
188}
189
190//===----------------------------------------------------------------------===//
191// MemRefMultiBufferOp
192//===----------------------------------------------------------------------===//
193
194DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
195 transform::TransformRewriter &rewriter,
196 transform::TransformResults &transformResults,
197 transform::TransformState &state) {
198 SmallVector<Operation *> results;
199 for (Operation *op : state.getPayloadOps(getTarget())) {
200 bool canApplyMultiBuffer = true;
201 auto target = cast<memref::AllocOp>(op);
202 LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
203 // Skip allocations not used in a loop.
204 for (Operation *user : target->getUsers()) {
205 if (isa<memref::DeallocOp>(user))
206 continue;
207 auto loop = user->getParentOfType<LoopLikeOpInterface>();
208 if (!loop) {
209 LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
210 DBGS() << "----due to user: " << *user;);
211 canApplyMultiBuffer = false;
212 break;
213 }
214 }
215 if (!canApplyMultiBuffer) {
216 LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
217 continue;
218 }
219
220 auto newBuffer =
221 memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis());
222
223 if (failed(newBuffer)) {
224 LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
225 return emitSilenceableFailure(target->getLoc())
226 << "op failed to multibuffer";
227 }
228
229 results.push_back(*newBuffer);
230 }
231 transformResults.set(cast<OpResult>(getResult()), results);
232 return DiagnosedSilenceableFailure::success();
233}
234
235//===----------------------------------------------------------------------===//
236// MemRefEraseDeadAllocAndStoresOp
237//===----------------------------------------------------------------------===//
238
239DiagnosedSilenceableFailure
240transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
241 transform::TransformRewriter &rewriter, Operation *target,
242 transform::ApplyToEachResultList &results,
243 transform::TransformState &state) {
244 // Apply store to load forwarding and dead store elimination.
245 vector::transferOpflowOpt(rewriter, target);
246 memref::eraseDeadAllocAndStores(rewriter, target);
247 return DiagnosedSilenceableFailure::success();
248}
249
250void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
251 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
252 transform::onlyReadsHandle(getTarget(), effects);
253 transform::modifiesPayload(effects);
254}
255void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
256 OperationState &result,
257 Value target) {
258 result.addOperands(target);
259}
260
261//===----------------------------------------------------------------------===//
262// MemRefMakeLoopIndependentOp
263//===----------------------------------------------------------------------===//
264
265DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne(
266 transform::TransformRewriter &rewriter, Operation *target,
267 transform::ApplyToEachResultList &results,
268 transform::TransformState &state) {
269 // Gather IVs.
270 SmallVector<Value> ivs;
271 Operation *nextOp = target;
272 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
273 nextOp = nextOp->getParentOfType<scf::ForOp>();
274 if (!nextOp) {
275 DiagnosedSilenceableFailure diag = emitSilenceableError()
276 << "could not find " << i
277 << "-th enclosing loop";
278 diag.attachNote(target->getLoc()) << "target op";
279 return diag;
280 }
281 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
282 }
283
284 // Rewrite IR.
285 FailureOr<Value> replacement = failure();
286 if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
287 replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs);
288 } else {
289 DiagnosedSilenceableFailure diag = emitSilenceableError()
290 << "unsupported target op";
291 diag.attachNote(target->getLoc()) << "target op";
292 return diag;
293 }
294 if (failed(replacement)) {
295 DiagnosedSilenceableFailure diag =
296 emitSilenceableError() << "could not make target op loop-independent";
297 diag.attachNote(target->getLoc()) << "target op";
298 return diag;
299 }
300 results.push_back(replacement->getDefiningOp());
301 return DiagnosedSilenceableFailure::success();
302}
303
304//===----------------------------------------------------------------------===//
305// Transform op registration
306//===----------------------------------------------------------------------===//
307
308namespace {
309class MemRefTransformDialectExtension
310 : public transform::TransformDialectExtension<
311 MemRefTransformDialectExtension> {
312public:
313 using Base::Base;
314
315 void init() {
316 declareGeneratedDialect<affine::AffineDialect>();
317 declareGeneratedDialect<arith::ArithDialect>();
318 declareGeneratedDialect<memref::MemRefDialect>();
319 declareGeneratedDialect<nvgpu::NVGPUDialect>();
320 declareGeneratedDialect<vector::VectorDialect>();
321
322 registerTransformOps<
323#define GET_OP_LIST
324#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
325 >();
326 }
327};
328} // namespace
329
330#define GET_OP_CLASSES
331#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
332
333void mlir::memref::registerTransformDialectExtension(
334 DialectRegistry &registry) {
335 registry.addExtensions<MemRefTransformDialectExtension>();
336}
337

source code of mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp