1 | //===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" |
10 | |
11 | #include "mlir/Analysis/DataLayoutAnalysis.h" |
12 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
13 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
16 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
17 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
18 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
19 | #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" |
20 | #include "mlir/Dialect/SCF/IR/SCF.h" |
21 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
22 | #include "mlir/Dialect/Transform/IR/TransformTypes.h" |
23 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
24 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
25 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
26 | #include "mlir/Interfaces/LoopLikeInterface.h" |
27 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
28 | #include "llvm/Support/Debug.h" |
29 | |
30 | using namespace mlir; |
31 | |
32 | #define DEBUG_TYPE "memref-transforms" |
33 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
34 | |
35 | //===----------------------------------------------------------------------===// |
36 | // Apply...ConversionPatternsOp |
37 | //===----------------------------------------------------------------------===// |
38 | |
39 | std::unique_ptr<TypeConverter> |
40 | transform::MemrefToLLVMTypeConverterOp::getTypeConverter() { |
41 | LowerToLLVMOptions options(getContext()); |
42 | options.allocLowering = |
43 | (getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc |
44 | : LowerToLLVMOptions::AllocLowering::Malloc); |
45 | options.useGenericFunctions = getUseGenericFunctions(); |
46 | |
47 | if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout) |
48 | options.overrideIndexBitwidth(getIndexBitwidth()); |
49 | |
50 | // TODO: the following two options don't really make sense for |
51 | // memref_to_llvm_type_converter specifically but we should have a single |
52 | // to_llvm_type_converter. |
53 | if (getDataLayout().has_value()) |
54 | options.dataLayout = llvm::DataLayout(getDataLayout().value()); |
55 | options.useBarePtrCallConv = getUseBarePtrCallConv(); |
56 | |
57 | return std::make_unique<LLVMTypeConverter>(getContext(), options); |
58 | } |
59 | |
60 | StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() { |
61 | return "LLVMTypeConverter" ; |
62 | } |
63 | |
64 | //===----------------------------------------------------------------------===// |
65 | // Apply...PatternsOp |
66 | //===----------------------------------------------------------------------===// |
67 | |
68 | namespace { |
69 | class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> { |
70 | public: |
71 | explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0) |
72 | : OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()), |
73 | dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {} |
74 | |
75 | LogicalResult matchAndRewrite(memref::AllocOp op, |
76 | PatternRewriter &rewriter) const override { |
77 | return success(memref::allocToAlloca( |
78 | rewriter, alloc: op, filter: [this](memref::AllocOp alloc, memref::DeallocOp dealloc) { |
79 | MemRefType type = alloc.getMemref().getType(); |
80 | if (!type.hasStaticShape()) |
81 | return false; |
82 | |
83 | const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(operation: alloc); |
84 | int64_t elementSize = dataLayout.getTypeSize(t: type.getElementType()); |
85 | return maxSize == 0 || type.getNumElements() * elementSize < maxSize; |
86 | })); |
87 | } |
88 | |
89 | private: |
90 | DataLayoutAnalysis dataLayoutAnalysis; |
91 | int64_t maxSize; |
92 | }; |
93 | } // namespace |
94 | |
95 | void transform::ApplyAllocToAllocaOp::populatePatterns( |
96 | RewritePatternSet &patterns) {} |
97 | |
98 | void transform::ApplyAllocToAllocaOp::populatePatternsWithState( |
99 | RewritePatternSet &patterns, transform::TransformState &state) { |
100 | patterns.insert<AllocToAllocaPattern>( |
101 | state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0))); |
102 | } |
103 | |
104 | void transform::ApplyExpandOpsPatternsOp::populatePatterns( |
105 | RewritePatternSet &patterns) { |
106 | memref::populateExpandOpsPatterns(patterns); |
107 | } |
108 | |
109 | void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns( |
110 | RewritePatternSet &patterns) { |
111 | memref::populateExpandStridedMetadataPatterns(patterns); |
112 | } |
113 | |
114 | void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns( |
115 | RewritePatternSet &patterns) { |
116 | memref::populateExtractAddressComputationsPatterns(patterns); |
117 | } |
118 | |
119 | void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns( |
120 | RewritePatternSet &patterns) { |
121 | memref::populateFoldMemRefAliasOpPatterns(patterns); |
122 | } |
123 | |
124 | void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp:: |
125 | populatePatterns(RewritePatternSet &patterns) { |
126 | memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); |
127 | } |
128 | |
129 | //===----------------------------------------------------------------------===// |
130 | // AllocaToGlobalOp |
131 | //===----------------------------------------------------------------------===// |
132 | |
133 | DiagnosedSilenceableFailure |
134 | transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, |
135 | transform::TransformResults &results, |
136 | transform::TransformState &state) { |
137 | auto allocaOps = state.getPayloadOps(getAlloca()); |
138 | |
139 | SmallVector<memref::GlobalOp> globalOps; |
140 | SmallVector<memref::GetGlobalOp> getGlobalOps; |
141 | |
142 | // Transform `memref.alloca`s. |
143 | for (auto *op : allocaOps) { |
144 | auto alloca = cast<memref::AllocaOp>(op); |
145 | MLIRContext *ctx = rewriter.getContext(); |
146 | Location loc = alloca->getLoc(); |
147 | |
148 | memref::GlobalOp globalOp; |
149 | { |
150 | // Find nearest symbol table. |
151 | Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op); |
152 | assert(symbolTableOp && "expected alloca payload to be in symbol table" ); |
153 | SymbolTable symbolTable(symbolTableOp); |
154 | |
155 | // Insert a `memref.global` into the symbol table. |
156 | Type resultType = alloca.getResult().getType(); |
157 | OpBuilder builder(rewriter.getContext()); |
158 | // TODO: Add a better builder for this. |
159 | globalOp = builder.create<memref::GlobalOp>( |
160 | loc, StringAttr::get(ctx, "alloca" ), StringAttr::get(ctx, "private" ), |
161 | TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{}); |
162 | symbolTable.insert(globalOp); |
163 | } |
164 | |
165 | // Replace the `memref.alloca` with a `memref.get_global` accessing the |
166 | // global symbol inserted above. |
167 | rewriter.setInsertionPoint(alloca); |
168 | auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>( |
169 | alloca, globalOp.getType(), globalOp.getName()); |
170 | |
171 | globalOps.push_back(globalOp); |
172 | getGlobalOps.push_back(getGlobalOp); |
173 | } |
174 | |
175 | // Assemble results. |
176 | results.set(cast<OpResult>(getGlobal()), globalOps); |
177 | results.set(cast<OpResult>(getGetGlobal()), getGlobalOps); |
178 | |
179 | return DiagnosedSilenceableFailure::success(); |
180 | } |
181 | |
182 | void transform::MemRefAllocaToGlobalOp::getEffects( |
183 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
184 | producesHandle(getGlobal(), effects); |
185 | producesHandle(getGetGlobal(), effects); |
186 | consumesHandle(getAlloca(), effects); |
187 | modifiesPayload(effects); |
188 | } |
189 | |
190 | //===----------------------------------------------------------------------===// |
191 | // MemRefMultiBufferOp |
192 | //===----------------------------------------------------------------------===// |
193 | |
194 | DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( |
195 | transform::TransformRewriter &rewriter, |
196 | transform::TransformResults &transformResults, |
197 | transform::TransformState &state) { |
198 | SmallVector<Operation *> results; |
199 | for (Operation *op : state.getPayloadOps(getTarget())) { |
200 | bool canApplyMultiBuffer = true; |
201 | auto target = cast<memref::AllocOp>(op); |
202 | LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n" ;); |
203 | // Skip allocations not used in a loop. |
204 | for (Operation *user : target->getUsers()) { |
205 | if (isa<memref::DeallocOp>(user)) |
206 | continue; |
207 | auto loop = user->getParentOfType<LoopLikeOpInterface>(); |
208 | if (!loop) { |
209 | LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n" ; |
210 | DBGS() << "----due to user: " << *user;); |
211 | canApplyMultiBuffer = false; |
212 | break; |
213 | } |
214 | } |
215 | if (!canApplyMultiBuffer) { |
216 | LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n" ;); |
217 | continue; |
218 | } |
219 | |
220 | auto newBuffer = |
221 | memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis()); |
222 | |
223 | if (failed(newBuffer)) { |
224 | LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n" ;); |
225 | return emitSilenceableFailure(target->getLoc()) |
226 | << "op failed to multibuffer" ; |
227 | } |
228 | |
229 | results.push_back(*newBuffer); |
230 | } |
231 | transformResults.set(cast<OpResult>(getResult()), results); |
232 | return DiagnosedSilenceableFailure::success(); |
233 | } |
234 | |
235 | //===----------------------------------------------------------------------===// |
236 | // MemRefEraseDeadAllocAndStoresOp |
237 | //===----------------------------------------------------------------------===// |
238 | |
239 | DiagnosedSilenceableFailure |
240 | transform::MemRefEraseDeadAllocAndStoresOp::applyToOne( |
241 | transform::TransformRewriter &rewriter, Operation *target, |
242 | transform::ApplyToEachResultList &results, |
243 | transform::TransformState &state) { |
244 | // Apply store to load forwarding and dead store elimination. |
245 | vector::transferOpflowOpt(rewriter, target); |
246 | memref::eraseDeadAllocAndStores(rewriter, target); |
247 | return DiagnosedSilenceableFailure::success(); |
248 | } |
249 | |
250 | void transform::MemRefEraseDeadAllocAndStoresOp::getEffects( |
251 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
252 | transform::onlyReadsHandle(getTarget(), effects); |
253 | transform::modifiesPayload(effects); |
254 | } |
255 | void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder, |
256 | OperationState &result, |
257 | Value target) { |
258 | result.addOperands(target); |
259 | } |
260 | |
261 | //===----------------------------------------------------------------------===// |
262 | // MemRefMakeLoopIndependentOp |
263 | //===----------------------------------------------------------------------===// |
264 | |
265 | DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne( |
266 | transform::TransformRewriter &rewriter, Operation *target, |
267 | transform::ApplyToEachResultList &results, |
268 | transform::TransformState &state) { |
269 | // Gather IVs. |
270 | SmallVector<Value> ivs; |
271 | Operation *nextOp = target; |
272 | for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) { |
273 | nextOp = nextOp->getParentOfType<scf::ForOp>(); |
274 | if (!nextOp) { |
275 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
276 | << "could not find " << i |
277 | << "-th enclosing loop" ; |
278 | diag.attachNote(target->getLoc()) << "target op" ; |
279 | return diag; |
280 | } |
281 | ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar()); |
282 | } |
283 | |
284 | // Rewrite IR. |
285 | FailureOr<Value> replacement = failure(); |
286 | if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) { |
287 | replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs); |
288 | } else { |
289 | DiagnosedSilenceableFailure diag = emitSilenceableError() |
290 | << "unsupported target op" ; |
291 | diag.attachNote(target->getLoc()) << "target op" ; |
292 | return diag; |
293 | } |
294 | if (failed(replacement)) { |
295 | DiagnosedSilenceableFailure diag = |
296 | emitSilenceableError() << "could not make target op loop-independent" ; |
297 | diag.attachNote(target->getLoc()) << "target op" ; |
298 | return diag; |
299 | } |
300 | results.push_back(replacement->getDefiningOp()); |
301 | return DiagnosedSilenceableFailure::success(); |
302 | } |
303 | |
304 | //===----------------------------------------------------------------------===// |
305 | // Transform op registration |
306 | //===----------------------------------------------------------------------===// |
307 | |
308 | namespace { |
309 | class MemRefTransformDialectExtension |
310 | : public transform::TransformDialectExtension< |
311 | MemRefTransformDialectExtension> { |
312 | public: |
313 | using Base::Base; |
314 | |
315 | void init() { |
316 | declareGeneratedDialect<affine::AffineDialect>(); |
317 | declareGeneratedDialect<arith::ArithDialect>(); |
318 | declareGeneratedDialect<memref::MemRefDialect>(); |
319 | declareGeneratedDialect<nvgpu::NVGPUDialect>(); |
320 | declareGeneratedDialect<vector::VectorDialect>(); |
321 | |
322 | registerTransformOps< |
323 | #define GET_OP_LIST |
324 | #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" |
325 | >(); |
326 | } |
327 | }; |
328 | } // namespace |
329 | |
330 | #define GET_OP_CLASSES |
331 | #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" |
332 | |
333 | void mlir::memref::registerTransformDialectExtension( |
334 | DialectRegistry ®istry) { |
335 | registry.addExtensions<MemRefTransformDialectExtension>(); |
336 | } |
337 | |