1//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
10
11#include "mlir/Analysis/DataLayoutAnalysis.h"
12#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/MemRef/Transforms/Passes.h"
17#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
18#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
20#include "mlir/Dialect/SCF/IR/SCF.h"
21#include "mlir/Dialect/Transform/IR/TransformDialect.h"
22#include "mlir/Dialect/Transform/IR/TransformTypes.h"
23#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
24#include "mlir/Dialect/Vector/IR/VectorOps.h"
25#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
26#include "mlir/Interfaces/LoopLikeInterface.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/Support/Debug.h"
29
30using namespace mlir;
31
32#define DEBUG_TYPE "memref-transforms"
33#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
34
35//===----------------------------------------------------------------------===//
36// Apply...ConversionPatternsOp
37//===----------------------------------------------------------------------===//
38
39std::unique_ptr<TypeConverter>
40transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
41 LowerToLLVMOptions options(getContext());
42 options.allocLowering =
43 (getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
44 : LowerToLLVMOptions::AllocLowering::Malloc);
45 options.useGenericFunctions = getUseGenericFunctions();
46
47 if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout)
48 options.overrideIndexBitwidth(getIndexBitwidth());
49
50 // TODO: the following two options don't really make sense for
51 // memref_to_llvm_type_converter specifically but we should have a single
52 // to_llvm_type_converter.
53 if (getDataLayout().has_value())
54 options.dataLayout = llvm::DataLayout(getDataLayout().value());
55 options.useBarePtrCallConv = getUseBarePtrCallConv();
56
57 return std::make_unique<LLVMTypeConverter>(getContext(), options);
58}
59
60StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
61 return "LLVMTypeConverter";
62}
63
64//===----------------------------------------------------------------------===//
65// Apply...PatternsOp
66//===----------------------------------------------------------------------===//
67
68namespace {
69class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
70public:
71 explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
72 : OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
73 dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
74
75 LogicalResult matchAndRewrite(memref::AllocOp op,
76 PatternRewriter &rewriter) const override {
77 return success(memref::allocToAlloca(
78 rewriter, alloc: op, filter: [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
79 MemRefType type = alloc.getMemref().getType();
80 if (!type.hasStaticShape())
81 return false;
82
83 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(operation: alloc);
84 int64_t elementSize = dataLayout.getTypeSize(t: type.getElementType());
85 return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
86 }));
87 }
88
89private:
90 DataLayoutAnalysis dataLayoutAnalysis;
91 int64_t maxSize;
92};
93} // namespace
94
95void transform::ApplyAllocToAllocaOp::populatePatterns(
96 RewritePatternSet &patterns) {}
97
98void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
99 RewritePatternSet &patterns, transform::TransformState &state) {
100 patterns.insert<AllocToAllocaPattern>(
101 state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
102}
103
104void transform::ApplyExpandOpsPatternsOp::populatePatterns(
105 RewritePatternSet &patterns) {
106 memref::populateExpandOpsPatterns(patterns);
107}
108
109void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
110 RewritePatternSet &patterns) {
111 memref::populateExpandStridedMetadataPatterns(patterns);
112}
113
114void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
115 RewritePatternSet &patterns) {
116 memref::populateExtractAddressComputationsPatterns(patterns);
117}
118
119void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
120 RewritePatternSet &patterns) {
121 memref::populateFoldMemRefAliasOpPatterns(patterns);
122}
123
124void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
125 populatePatterns(RewritePatternSet &patterns) {
126 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
127}
128
129//===----------------------------------------------------------------------===//
130// AllocaToGlobalOp
131//===----------------------------------------------------------------------===//
132
133DiagnosedSilenceableFailure
134transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
135 transform::TransformResults &results,
136 transform::TransformState &state) {
137 auto allocaOps = state.getPayloadOps(getAlloca());
138
139 SmallVector<memref::GlobalOp> globalOps;
140 SmallVector<memref::GetGlobalOp> getGlobalOps;
141
142 // Transform `memref.alloca`s.
143 for (auto *op : allocaOps) {
144 auto alloca = cast<memref::AllocaOp>(op);
145 MLIRContext *ctx = rewriter.getContext();
146 Location loc = alloca->getLoc();
147
148 memref::GlobalOp globalOp;
149 {
150 // Find nearest symbol table.
151 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
152 assert(symbolTableOp && "expected alloca payload to be in symbol table");
153 SymbolTable symbolTable(symbolTableOp);
154
155 // Insert a `memref.global` into the symbol table.
156 Type resultType = alloca.getResult().getType();
157 OpBuilder builder(rewriter.getContext());
158 // TODO: Add a better builder for this.
159 globalOp = builder.create<memref::GlobalOp>(
160 loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
161 TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
162 symbolTable.insert(globalOp);
163 }
164
165 // Replace the `memref.alloca` with a `memref.get_global` accessing the
166 // global symbol inserted above.
167 rewriter.setInsertionPoint(alloca);
168 auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
169 alloca, globalOp.getType(), globalOp.getName());
170
171 globalOps.push_back(globalOp);
172 getGlobalOps.push_back(getGlobalOp);
173 }
174
175 // Assemble results.
176 results.set(cast<OpResult>(getGlobal()), globalOps);
177 results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
178
179 return DiagnosedSilenceableFailure::success();
180}
181
182void transform::MemRefAllocaToGlobalOp::getEffects(
183 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
184 producesHandle(getOperation()->getOpResults(), effects);
185 consumesHandle(getAllocaMutable(), effects);
186 modifiesPayload(effects);
187}
188
189//===----------------------------------------------------------------------===//
190// MemRefMultiBufferOp
191//===----------------------------------------------------------------------===//
192
193DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
194 transform::TransformRewriter &rewriter,
195 transform::TransformResults &transformResults,
196 transform::TransformState &state) {
197 SmallVector<Operation *> results;
198 for (Operation *op : state.getPayloadOps(getTarget())) {
199 bool canApplyMultiBuffer = true;
200 auto target = cast<memref::AllocOp>(op);
201 LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
202 // Skip allocations not used in a loop.
203 for (Operation *user : target->getUsers()) {
204 if (isa<memref::DeallocOp>(user))
205 continue;
206 auto loop = user->getParentOfType<LoopLikeOpInterface>();
207 if (!loop) {
208 LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
209 DBGS() << "----due to user: " << *user;);
210 canApplyMultiBuffer = false;
211 break;
212 }
213 }
214 if (!canApplyMultiBuffer) {
215 LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
216 continue;
217 }
218
219 auto newBuffer =
220 memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis());
221
222 if (failed(newBuffer)) {
223 LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
224 return emitSilenceableFailure(target->getLoc())
225 << "op failed to multibuffer";
226 }
227
228 results.push_back(*newBuffer);
229 }
230 transformResults.set(cast<OpResult>(getResult()), results);
231 return DiagnosedSilenceableFailure::success();
232}
233
234//===----------------------------------------------------------------------===//
235// MemRefEraseDeadAllocAndStoresOp
236//===----------------------------------------------------------------------===//
237
238DiagnosedSilenceableFailure
239transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
240 transform::TransformRewriter &rewriter, Operation *target,
241 transform::ApplyToEachResultList &results,
242 transform::TransformState &state) {
243 // Apply store to load forwarding and dead store elimination.
244 vector::transferOpflowOpt(rewriter, target);
245 memref::eraseDeadAllocAndStores(rewriter, target);
246 return DiagnosedSilenceableFailure::success();
247}
248
249void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
250 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
251 transform::onlyReadsHandle(getTargetMutable(), effects);
252 transform::modifiesPayload(effects);
253}
254void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
255 OperationState &result,
256 Value target) {
257 result.addOperands(target);
258}
259
260//===----------------------------------------------------------------------===//
261// MemRefMakeLoopIndependentOp
262//===----------------------------------------------------------------------===//
263
264DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne(
265 transform::TransformRewriter &rewriter, Operation *target,
266 transform::ApplyToEachResultList &results,
267 transform::TransformState &state) {
268 // Gather IVs.
269 SmallVector<Value> ivs;
270 Operation *nextOp = target;
271 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
272 nextOp = nextOp->getParentOfType<scf::ForOp>();
273 if (!nextOp) {
274 DiagnosedSilenceableFailure diag = emitSilenceableError()
275 << "could not find " << i
276 << "-th enclosing loop";
277 diag.attachNote(target->getLoc()) << "target op";
278 return diag;
279 }
280 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
281 }
282
283 // Rewrite IR.
284 FailureOr<Value> replacement = failure();
285 if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
286 replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs);
287 } else {
288 DiagnosedSilenceableFailure diag = emitSilenceableError()
289 << "unsupported target op";
290 diag.attachNote(target->getLoc()) << "target op";
291 return diag;
292 }
293 if (failed(replacement)) {
294 DiagnosedSilenceableFailure diag =
295 emitSilenceableError() << "could not make target op loop-independent";
296 diag.attachNote(target->getLoc()) << "target op";
297 return diag;
298 }
299 results.push_back(replacement->getDefiningOp());
300 return DiagnosedSilenceableFailure::success();
301}
302
303//===----------------------------------------------------------------------===//
304// Transform op registration
305//===----------------------------------------------------------------------===//
306
307namespace {
308class MemRefTransformDialectExtension
309 : public transform::TransformDialectExtension<
310 MemRefTransformDialectExtension> {
311public:
312 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemRefTransformDialectExtension)
313
314 using Base::Base;
315
316 void init() {
317 declareGeneratedDialect<affine::AffineDialect>();
318 declareGeneratedDialect<arith::ArithDialect>();
319 declareGeneratedDialect<memref::MemRefDialect>();
320 declareGeneratedDialect<nvgpu::NVGPUDialect>();
321 declareGeneratedDialect<vector::VectorDialect>();
322
323 registerTransformOps<
324#define GET_OP_LIST
325#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
326 >();
327 }
328};
329} // namespace
330
331#define GET_OP_CLASSES
332#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
333
334void mlir::memref::registerTransformDialectExtension(
335 DialectRegistry &registry) {
336 registry.addExtensions<MemRefTransformDialectExtension>();
337}
338

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp