1//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
10
11#include "mlir/Analysis/DataLayoutAnalysis.h"
12#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/MemRef/Transforms/Passes.h"
17#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
18#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
20#include "mlir/Dialect/SCF/IR/SCF.h"
21#include "mlir/Dialect/Transform/IR/TransformDialect.h"
22#include "mlir/Dialect/Transform/IR/TransformTypes.h"
23#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
24#include "mlir/Dialect/Vector/IR/VectorOps.h"
25#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
26#include "mlir/Interfaces/LoopLikeInterface.h"
27#include "llvm/Support/Debug.h"
28
29using namespace mlir;
30
31#define DEBUG_TYPE "memref-transforms"
32#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
33
34//===----------------------------------------------------------------------===//
35// Apply...ConversionPatternsOp
36//===----------------------------------------------------------------------===//
37
38std::unique_ptr<TypeConverter>
39transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
40 LowerToLLVMOptions options(getContext());
41 options.allocLowering =
42 (getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
43 : LowerToLLVMOptions::AllocLowering::Malloc);
44 options.useGenericFunctions = getUseGenericFunctions();
45
46 if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout)
47 options.overrideIndexBitwidth(bitwidth: getIndexBitwidth());
48
49 // TODO: the following two options don't really make sense for
50 // memref_to_llvm_type_converter specifically but we should have a single
51 // to_llvm_type_converter.
52 if (getDataLayout().has_value())
53 options.dataLayout = llvm::DataLayout(getDataLayout().value());
54 options.useBarePtrCallConv = getUseBarePtrCallConv();
55
56 return std::make_unique<LLVMTypeConverter>(args: getContext(), args&: options);
57}
58
59StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
60 return "LLVMTypeConverter";
61}
62
63//===----------------------------------------------------------------------===//
64// Apply...PatternsOp
65//===----------------------------------------------------------------------===//
66
67namespace {
68class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
69public:
70 explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
71 : OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
72 dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
73
74 LogicalResult matchAndRewrite(memref::AllocOp op,
75 PatternRewriter &rewriter) const override {
76 return success(IsSuccess: memref::allocToAlloca(
77 rewriter, alloc: op, filter: [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
78 MemRefType type = alloc.getMemref().getType();
79 if (!type.hasStaticShape())
80 return false;
81
82 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(operation: alloc);
83 int64_t elementSize = dataLayout.getTypeSize(t: type.getElementType());
84 return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
85 }));
86 }
87
88private:
89 DataLayoutAnalysis dataLayoutAnalysis;
90 int64_t maxSize;
91};
92} // namespace
93
94void transform::ApplyAllocToAllocaOp::populatePatterns(
95 RewritePatternSet &patterns) {}
96
97void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
98 RewritePatternSet &patterns, transform::TransformState &state) {
99 patterns.insert<AllocToAllocaPattern>(
100 arg: state.getTopLevel(), args: static_cast<int64_t>(getSizeLimit().value_or(u: 0)));
101}
102
103void transform::ApplyExpandOpsPatternsOp::populatePatterns(
104 RewritePatternSet &patterns) {
105 memref::populateExpandOpsPatterns(patterns);
106}
107
108void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
109 RewritePatternSet &patterns) {
110 memref::populateExpandStridedMetadataPatterns(patterns);
111}
112
113void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
114 RewritePatternSet &patterns) {
115 memref::populateExtractAddressComputationsPatterns(patterns);
116}
117
118void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
119 RewritePatternSet &patterns) {
120 memref::populateFoldMemRefAliasOpPatterns(patterns);
121}
122
123void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
124 populatePatterns(RewritePatternSet &patterns) {
125 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
126}
127
128//===----------------------------------------------------------------------===//
129// AllocaToGlobalOp
130//===----------------------------------------------------------------------===//
131
132DiagnosedSilenceableFailure
133transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
134 transform::TransformResults &results,
135 transform::TransformState &state) {
136 auto allocaOps = state.getPayloadOps(value: getAlloca());
137
138 SmallVector<memref::GlobalOp> globalOps;
139 SmallVector<memref::GetGlobalOp> getGlobalOps;
140
141 // Transform `memref.alloca`s.
142 for (auto *op : allocaOps) {
143 auto alloca = cast<memref::AllocaOp>(Val: op);
144 MLIRContext *ctx = rewriter.getContext();
145 Location loc = alloca->getLoc();
146
147 memref::GlobalOp globalOp;
148 {
149 // Find nearest symbol table.
150 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from: op);
151 assert(symbolTableOp && "expected alloca payload to be in symbol table");
152 SymbolTable symbolTable(symbolTableOp);
153
154 // Insert a `memref.global` into the symbol table.
155 Type resultType = alloca.getResult().getType();
156 OpBuilder builder(rewriter.getContext());
157 // TODO: Add a better builder for this.
158 globalOp = builder.create<memref::GlobalOp>(
159 location: loc, args: StringAttr::get(context: ctx, bytes: "alloca"), args: StringAttr::get(context: ctx, bytes: "private"),
160 args: TypeAttr::get(type: resultType), args: Attribute{}, args: UnitAttr{}, args: IntegerAttr{});
161 symbolTable.insert(symbol: globalOp);
162 }
163
164 // Replace the `memref.alloca` with a `memref.get_global` accessing the
165 // global symbol inserted above.
166 rewriter.setInsertionPoint(alloca);
167 auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
168 op: alloca, args: globalOp.getType(), args: globalOp.getName());
169
170 globalOps.push_back(Elt: globalOp);
171 getGlobalOps.push_back(Elt: getGlobalOp);
172 }
173
174 // Assemble results.
175 results.set(value: cast<OpResult>(Val: getGlobal()), ops&: globalOps);
176 results.set(value: cast<OpResult>(Val: getGetGlobal()), ops&: getGlobalOps);
177
178 return DiagnosedSilenceableFailure::success();
179}
180
181void transform::MemRefAllocaToGlobalOp::getEffects(
182 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
183 producesHandle(handles: getOperation()->getOpResults(), effects);
184 consumesHandle(handles: getAllocaMutable(), effects);
185 modifiesPayload(effects);
186}
187
188//===----------------------------------------------------------------------===//
189// MemRefMultiBufferOp
190//===----------------------------------------------------------------------===//
191
192DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
193 transform::TransformRewriter &rewriter,
194 transform::TransformResults &transformResults,
195 transform::TransformState &state) {
196 SmallVector<Operation *> results;
197 for (Operation *op : state.getPayloadOps(value: getTarget())) {
198 bool canApplyMultiBuffer = true;
199 auto target = cast<memref::AllocOp>(Val: op);
200 LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
201 // Skip allocations not used in a loop.
202 for (Operation *user : target->getUsers()) {
203 if (isa<memref::DeallocOp>(Val: user))
204 continue;
205 auto loop = user->getParentOfType<LoopLikeOpInterface>();
206 if (!loop) {
207 LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
208 DBGS() << "----due to user: " << *user;);
209 canApplyMultiBuffer = false;
210 break;
211 }
212 }
213 if (!canApplyMultiBuffer) {
214 LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
215 continue;
216 }
217
218 auto newBuffer =
219 memref::multiBuffer(rewriter, allocOp: target, multiplier: getFactor(), skipOverrideAnalysis: getSkipAnalysis());
220
221 if (failed(Result: newBuffer)) {
222 LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
223 return emitSilenceableFailure(loc: target->getLoc())
224 << "op failed to multibuffer";
225 }
226
227 results.push_back(Elt: *newBuffer);
228 }
229 transformResults.set(value: cast<OpResult>(Val: getResult()), ops&: results);
230 return DiagnosedSilenceableFailure::success();
231}
232
233//===----------------------------------------------------------------------===//
234// MemRefEraseDeadAllocAndStoresOp
235//===----------------------------------------------------------------------===//
236
237DiagnosedSilenceableFailure
238transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
239 transform::TransformRewriter &rewriter, Operation *target,
240 transform::ApplyToEachResultList &results,
241 transform::TransformState &state) {
242 // Apply store to load forwarding and dead store elimination.
243 vector::transferOpflowOpt(rewriter, rootOp: target);
244 memref::eraseDeadAllocAndStores(rewriter, parentOp: target);
245 return DiagnosedSilenceableFailure::success();
246}
247
248void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
249 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
250 transform::onlyReadsHandle(handles: getTargetMutable(), effects);
251 transform::modifiesPayload(effects);
252}
253void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
254 OperationState &result,
255 Value target) {
256 result.addOperands(newOperands: target);
257}
258
259//===----------------------------------------------------------------------===//
260// MemRefMakeLoopIndependentOp
261//===----------------------------------------------------------------------===//
262
263DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne(
264 transform::TransformRewriter &rewriter, Operation *target,
265 transform::ApplyToEachResultList &results,
266 transform::TransformState &state) {
267 // Gather IVs.
268 SmallVector<Value> ivs;
269 Operation *nextOp = target;
270 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
271 nextOp = nextOp->getParentOfType<scf::ForOp>();
272 if (!nextOp) {
273 DiagnosedSilenceableFailure diag = emitSilenceableError()
274 << "could not find " << i
275 << "-th enclosing loop";
276 diag.attachNote(loc: target->getLoc()) << "target op";
277 return diag;
278 }
279 ivs.push_back(Elt: cast<scf::ForOp>(Val: nextOp).getInductionVar());
280 }
281
282 // Rewrite IR.
283 FailureOr<Value> replacement = failure();
284 if (auto allocaOp = dyn_cast<memref::AllocaOp>(Val: target)) {
285 replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, independencies: ivs);
286 } else {
287 DiagnosedSilenceableFailure diag = emitSilenceableError()
288 << "unsupported target op";
289 diag.attachNote(loc: target->getLoc()) << "target op";
290 return diag;
291 }
292 if (failed(Result: replacement)) {
293 DiagnosedSilenceableFailure diag =
294 emitSilenceableError() << "could not make target op loop-independent";
295 diag.attachNote(loc: target->getLoc()) << "target op";
296 return diag;
297 }
298 results.push_back(op: replacement->getDefiningOp());
299 return DiagnosedSilenceableFailure::success();
300}
301
302//===----------------------------------------------------------------------===//
303// Transform op registration
304//===----------------------------------------------------------------------===//
305
306namespace {
307class MemRefTransformDialectExtension
308 : public transform::TransformDialectExtension<
309 MemRefTransformDialectExtension> {
310public:
311 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemRefTransformDialectExtension)
312
313 using Base::Base;
314
315 void init() {
316 declareGeneratedDialect<affine::AffineDialect>();
317 declareGeneratedDialect<arith::ArithDialect>();
318 declareGeneratedDialect<memref::MemRefDialect>();
319 declareGeneratedDialect<nvgpu::NVGPUDialect>();
320 declareGeneratedDialect<vector::VectorDialect>();
321
322 registerTransformOps<
323#define GET_OP_LIST
324#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
325 >();
326 }
327};
328} // namespace
329
330#define GET_OP_CLASSES
331#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
332
333void mlir::memref::registerTransformDialectExtension(
334 DialectRegistry &registry) {
335 registry.addExtensions<MemRefTransformDialectExtension>();
336}
337

source code of mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp