1 | //===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
10 | |
11 | #include "mlir/Analysis/DataLayoutAnalysis.h" |
12 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
13 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
14 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
15 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
16 | #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" |
17 | #include "mlir/Dialect/Arith/IR/Arith.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
19 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
20 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
21 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
22 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
23 | #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" |
24 | #include "mlir/IR/AffineMap.h" |
25 | #include "mlir/IR/BuiltinTypes.h" |
26 | #include "mlir/IR/IRMapping.h" |
27 | #include "mlir/Pass/Pass.h" |
28 | #include "mlir/Support/MathExtras.h" |
29 | #include "llvm/ADT/SmallBitVector.h" |
30 | #include <optional> |
31 | |
32 | namespace mlir { |
33 | #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS |
34 | #include "mlir/Conversion/Passes.h.inc" |
35 | } // namespace mlir |
36 | |
37 | using namespace mlir; |
38 | |
39 | namespace { |
40 | |
41 | bool isStaticStrideOrOffset(int64_t strideOrOffset) { |
42 | return !ShapedType::isDynamic(strideOrOffset); |
43 | } |
44 | |
45 | LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter, |
46 | ModuleOp module) { |
47 | bool useGenericFn = typeConverter->getOptions().useGenericFunctions; |
48 | |
49 | if (useGenericFn) |
50 | return LLVM::lookupOrCreateGenericFreeFn(moduleOp: module); |
51 | |
52 | return LLVM::lookupOrCreateFreeFn(moduleOp: module); |
53 | } |
54 | |
55 | struct AllocOpLowering : public AllocLikeOpLLVMLowering { |
56 | AllocOpLowering(const LLVMTypeConverter &converter) |
57 | : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), |
58 | converter) {} |
59 | std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, |
60 | Location loc, Value sizeBytes, |
61 | Operation *op) const override { |
62 | return allocateBufferManuallyAlign( |
63 | rewriter, loc, sizeBytes, op, |
64 | getAlignment(rewriter, loc, cast<memref::AllocOp>(op))); |
65 | } |
66 | }; |
67 | |
68 | struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering { |
69 | AlignedAllocOpLowering(const LLVMTypeConverter &converter) |
70 | : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), |
71 | converter) {} |
72 | std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, |
73 | Location loc, Value sizeBytes, |
74 | Operation *op) const override { |
75 | Value ptr = allocateBufferAutoAlign( |
76 | rewriter, loc, sizeBytes, op, &defaultLayout, |
77 | alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op), |
78 | &defaultLayout)); |
79 | if (!ptr) |
80 | return std::make_tuple(Value(), Value()); |
81 | return std::make_tuple(ptr, ptr); |
82 | } |
83 | |
84 | private: |
85 | /// Default layout to use in absence of the corresponding analysis. |
86 | DataLayout defaultLayout; |
87 | }; |
88 | |
89 | struct AllocaOpLowering : public AllocLikeOpLLVMLowering { |
90 | AllocaOpLowering(const LLVMTypeConverter &converter) |
91 | : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(), |
92 | converter) { |
93 | setRequiresNumElements(); |
94 | } |
95 | |
96 | /// Allocates the underlying buffer using the right call. `allocatedBytePtr` |
97 | /// is set to null for stack allocations. `accessAlignment` is set if |
98 | /// alignment is needed post allocation (for eg. in conjunction with malloc). |
99 | std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, |
100 | Location loc, Value size, |
101 | Operation *op) const override { |
102 | |
103 | // With alloca, one gets a pointer to the element type right away. |
104 | // For stack allocations. |
105 | auto allocaOp = cast<memref::AllocaOp>(op); |
106 | auto elementType = |
107 | typeConverter->convertType(allocaOp.getType().getElementType()); |
108 | unsigned addrSpace = |
109 | *getTypeConverter()->getMemRefAddressSpace(type: allocaOp.getType()); |
110 | auto elementPtrType = |
111 | LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace); |
112 | |
113 | auto allocatedElementPtr = |
114 | rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size, |
115 | allocaOp.getAlignment().value_or(0)); |
116 | |
117 | return std::make_tuple(allocatedElementPtr, allocatedElementPtr); |
118 | } |
119 | }; |
120 | |
121 | struct AllocaScopeOpLowering |
122 | : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> { |
123 | using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern; |
124 | |
125 | LogicalResult |
126 | matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor, |
127 | ConversionPatternRewriter &rewriter) const override { |
128 | OpBuilder::InsertionGuard guard(rewriter); |
129 | Location loc = allocaScopeOp.getLoc(); |
130 | |
131 | // Split the current block before the AllocaScopeOp to create the inlining |
132 | // point. |
133 | auto *currentBlock = rewriter.getInsertionBlock(); |
134 | auto *remainingOpsBlock = |
135 | rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint()); |
136 | Block *continueBlock; |
137 | if (allocaScopeOp.getNumResults() == 0) { |
138 | continueBlock = remainingOpsBlock; |
139 | } else { |
140 | continueBlock = rewriter.createBlock( |
141 | remainingOpsBlock, allocaScopeOp.getResultTypes(), |
142 | SmallVector<Location>(allocaScopeOp->getNumResults(), |
143 | allocaScopeOp.getLoc())); |
144 | rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock); |
145 | } |
146 | |
147 | // Inline body region. |
148 | Block *beforeBody = &allocaScopeOp.getBodyRegion().front(); |
149 | Block *afterBody = &allocaScopeOp.getBodyRegion().back(); |
150 | rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock); |
151 | |
152 | // Save stack and then branch into the body of the region. |
153 | rewriter.setInsertionPointToEnd(currentBlock); |
154 | auto stackSaveOp = |
155 | rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); |
156 | rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody); |
157 | |
158 | // Replace the alloca_scope return with a branch that jumps out of the body. |
159 | // Stack restore before leaving the body region. |
160 | rewriter.setInsertionPointToEnd(afterBody); |
161 | auto returnOp = |
162 | cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator()); |
163 | auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( |
164 | returnOp, returnOp.getResults(), continueBlock); |
165 | |
166 | // Insert stack restore before jumping out the body of the region. |
167 | rewriter.setInsertionPoint(branchOp); |
168 | rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); |
169 | |
170 | // Replace the op with values return from the body region. |
171 | rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments()); |
172 | |
173 | return success(); |
174 | } |
175 | }; |
176 | |
177 | struct AssumeAlignmentOpLowering |
178 | : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> { |
179 | using ConvertOpToLLVMPattern< |
180 | memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern; |
181 | explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter) |
182 | : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {} |
183 | |
184 | LogicalResult |
185 | matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, |
186 | ConversionPatternRewriter &rewriter) const override { |
187 | Value memref = adaptor.getMemref(); |
188 | unsigned alignment = op.getAlignment(); |
189 | auto loc = op.getLoc(); |
190 | |
191 | auto srcMemRefType = cast<MemRefType>(op.getMemref().getType()); |
192 | Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{}, |
193 | rewriter); |
194 | |
195 | // Emit llvm.assume(memref & (alignment - 1) == 0). |
196 | // |
197 | // This relies on LLVM's CSE optimization (potentially after SROA), since |
198 | // after CSE all memref instances should get de-duplicated into the same |
199 | // pointer SSA value. |
200 | MemRefDescriptor memRefDescriptor(memref); |
201 | auto intPtrType = |
202 | getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); |
203 | Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0); |
204 | Value mask = |
205 | createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1); |
206 | Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr); |
207 | rewriter.create<LLVM::AssumeOp>( |
208 | loc, rewriter.create<LLVM::ICmpOp>( |
209 | loc, LLVM::ICmpPredicate::eq, |
210 | rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero)); |
211 | |
212 | rewriter.eraseOp(op: op); |
213 | return success(); |
214 | } |
215 | }; |
216 | |
217 | // A `dealloc` is converted into a call to `free` on the underlying data buffer. |
218 | // The memref descriptor being an SSA value, there is no need to clean it up |
219 | // in any way. |
220 | struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> { |
221 | using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern; |
222 | |
223 | explicit DeallocOpLowering(const LLVMTypeConverter &converter) |
224 | : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {} |
225 | |
226 | LogicalResult |
227 | matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, |
228 | ConversionPatternRewriter &rewriter) const override { |
229 | // Insert the `free` declaration if it is not already present. |
230 | LLVM::LLVMFuncOp freeFunc = |
231 | getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>()); |
232 | Value allocatedPtr; |
233 | if (auto unrankedTy = |
234 | llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) { |
235 | auto elementPtrTy = LLVM::LLVMPointerType::get( |
236 | rewriter.getContext(), unrankedTy.getMemorySpaceAsInt()); |
237 | allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( |
238 | builder&: rewriter, loc: op.getLoc(), |
239 | memRefDescPtr: UnrankedMemRefDescriptor(adaptor.getMemref()) |
240 | .memRefDescPtr(builder&: rewriter, loc: op.getLoc()), |
241 | elemPtrType: elementPtrTy); |
242 | } else { |
243 | allocatedPtr = MemRefDescriptor(adaptor.getMemref()) |
244 | .allocatedPtr(builder&: rewriter, loc: op.getLoc()); |
245 | } |
246 | rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr); |
247 | return success(); |
248 | } |
249 | }; |
250 | |
251 | // A `dim` is converted to a constant for static sizes and to an access to the |
252 | // size stored in the memref descriptor for dynamic sizes. |
253 | struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> { |
254 | using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern; |
255 | |
256 | LogicalResult |
257 | matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor, |
258 | ConversionPatternRewriter &rewriter) const override { |
259 | Type operandType = dimOp.getSource().getType(); |
260 | if (isa<UnrankedMemRefType>(operandType)) { |
261 | FailureOr<Value> = extractSizeOfUnrankedMemRef( |
262 | operandType, dimOp, adaptor.getOperands(), rewriter); |
263 | if (failed(extractedSize)) |
264 | return failure(); |
265 | rewriter.replaceOp(dimOp, {*extractedSize}); |
266 | return success(); |
267 | } |
268 | if (isa<MemRefType>(operandType)) { |
269 | rewriter.replaceOp( |
270 | dimOp, {extractSizeOfRankedMemRef(operandType, dimOp, |
271 | adaptor.getOperands(), rewriter)}); |
272 | return success(); |
273 | } |
274 | llvm_unreachable("expected MemRefType or UnrankedMemRefType" ); |
275 | } |
276 | |
277 | private: |
278 | FailureOr<Value> |
279 | (Type operandType, memref::DimOp dimOp, |
280 | OpAdaptor adaptor, |
281 | ConversionPatternRewriter &rewriter) const { |
282 | Location loc = dimOp.getLoc(); |
283 | |
284 | auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType); |
285 | auto scalarMemRefType = |
286 | MemRefType::get({}, unrankedMemRefType.getElementType()); |
287 | FailureOr<unsigned> maybeAddressSpace = |
288 | getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType); |
289 | if (failed(maybeAddressSpace)) { |
290 | dimOp.emitOpError("memref memory space must be convertible to an integer " |
291 | "address space" ); |
292 | return failure(); |
293 | } |
294 | unsigned addressSpace = *maybeAddressSpace; |
295 | |
296 | // Extract pointer to the underlying ranked descriptor and bitcast it to a |
297 | // memref<element_type> descriptor pointer to minimize the number of GEP |
298 | // operations. |
299 | UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource()); |
300 | Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(builder&: rewriter, loc); |
301 | |
302 | Type elementType = typeConverter->convertType(scalarMemRefType); |
303 | |
304 | // Get pointer to offset field of memref<element_type> descriptor. |
305 | auto indexPtrTy = |
306 | LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); |
307 | Value offsetPtr = rewriter.create<LLVM::GEPOp>( |
308 | loc, indexPtrTy, elementType, underlyingRankedDesc, |
309 | ArrayRef<LLVM::GEPArg>{0, 2}); |
310 | |
311 | // The size value that we have to extract can be obtained using GEPop with |
312 | // `dimOp.index() + 1` index argument. |
313 | Value idxPlusOne = rewriter.create<LLVM::AddOp>( |
314 | loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1), |
315 | adaptor.getIndex()); |
316 | Value sizePtr = rewriter.create<LLVM::GEPOp>( |
317 | loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, |
318 | idxPlusOne); |
319 | return rewriter |
320 | .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr) |
321 | .getResult(); |
322 | } |
323 | |
324 | std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const { |
325 | if (auto idx = dimOp.getConstantIndex()) |
326 | return idx; |
327 | |
328 | if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>()) |
329 | return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue(); |
330 | |
331 | return std::nullopt; |
332 | } |
333 | |
334 | Value (Type operandType, memref::DimOp dimOp, |
335 | OpAdaptor adaptor, |
336 | ConversionPatternRewriter &rewriter) const { |
337 | Location loc = dimOp.getLoc(); |
338 | |
339 | // Take advantage if index is constant. |
340 | MemRefType memRefType = cast<MemRefType>(operandType); |
341 | Type indexType = getIndexType(); |
342 | if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) { |
343 | int64_t i = *index; |
344 | if (i >= 0 && i < memRefType.getRank()) { |
345 | if (memRefType.isDynamicDim(i)) { |
346 | // extract dynamic size from the memref descriptor. |
347 | MemRefDescriptor descriptor(adaptor.getSource()); |
348 | return descriptor.size(builder&: rewriter, loc, pos: i); |
349 | } |
350 | // Use constant for static size. |
351 | int64_t dimSize = memRefType.getDimSize(i); |
352 | return createIndexAttrConstant(rewriter, loc, indexType, dimSize); |
353 | } |
354 | } |
355 | Value index = adaptor.getIndex(); |
356 | int64_t rank = memRefType.getRank(); |
357 | MemRefDescriptor memrefDescriptor(adaptor.getSource()); |
358 | return memrefDescriptor.size(builder&: rewriter, loc, pos: index, rank); |
359 | } |
360 | }; |
361 | |
362 | /// Common base for load and store operations on MemRefs. Restricts the match |
363 | /// to supported MemRef types. Provides functionality to emit code accessing a |
364 | /// specific element of the underlying data buffer. |
365 | template <typename Derived> |
366 | struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> { |
367 | using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern; |
368 | using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps; |
369 | using Base = LoadStoreOpLowering<Derived>; |
370 | |
371 | LogicalResult match(Derived op) const override { |
372 | MemRefType type = op.getMemRefType(); |
373 | return isConvertibleAndHasIdentityMaps(type) ? success() : failure(); |
374 | } |
375 | }; |
376 | |
377 | /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be |
378 | /// retried until it succeeds in atomically storing a new value into memory. |
379 | /// |
380 | /// +---------------------------------+ |
381 | /// | <code before the AtomicRMWOp> | |
382 | /// | <compute initial %loaded> | |
383 | /// | cf.br loop(%loaded) | |
384 | /// +---------------------------------+ |
385 | /// | |
386 | /// -------| | |
387 | /// | v v |
388 | /// | +--------------------------------+ |
389 | /// | | loop(%loaded): | |
390 | /// | | <body contents> | |
391 | /// | | %pair = cmpxchg | |
392 | /// | | %ok = %pair[0] | |
393 | /// | | %new = %pair[1] | |
394 | /// | | cf.cond_br %ok, end, loop(%new) | |
395 | /// | +--------------------------------+ |
396 | /// | | | |
397 | /// |----------- | |
398 | /// v |
399 | /// +--------------------------------+ |
400 | /// | end: | |
401 | /// | <code after the AtomicRMWOp> | |
402 | /// +--------------------------------+ |
403 | /// |
404 | struct GenericAtomicRMWOpLowering |
405 | : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> { |
406 | using Base::Base; |
407 | |
408 | LogicalResult |
409 | matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, |
410 | ConversionPatternRewriter &rewriter) const override { |
411 | auto loc = atomicOp.getLoc(); |
412 | Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); |
413 | |
414 | // Split the block into initial, loop, and ending parts. |
415 | auto *initBlock = rewriter.getInsertionBlock(); |
416 | auto *loopBlock = rewriter.splitBlock(block: initBlock, before: Block::iterator(atomicOp)); |
417 | loopBlock->addArgument(valueType, loc); |
418 | |
419 | auto *endBlock = |
420 | rewriter.splitBlock(block: loopBlock, before: Block::iterator(atomicOp)++); |
421 | |
422 | // Compute the loaded value and branch to the loop block. |
423 | rewriter.setInsertionPointToEnd(initBlock); |
424 | auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType()); |
425 | auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(), |
426 | adaptor.getIndices(), rewriter); |
427 | Value init = rewriter.create<LLVM::LoadOp>( |
428 | loc, typeConverter->convertType(memRefType.getElementType()), dataPtr); |
429 | rewriter.create<LLVM::BrOp>(loc, init, loopBlock); |
430 | |
431 | // Prepare the body of the loop block. |
432 | rewriter.setInsertionPointToStart(loopBlock); |
433 | |
434 | // Clone the GenericAtomicRMWOp region and extract the result. |
435 | auto loopArgument = loopBlock->getArgument(0); |
436 | IRMapping mapping; |
437 | mapping.map(atomicOp.getCurrentValue(), loopArgument); |
438 | Block &entryBlock = atomicOp.body().front(); |
439 | for (auto &nestedOp : entryBlock.without_terminator()) { |
440 | Operation *clone = rewriter.clone(nestedOp, mapping); |
441 | mapping.map(nestedOp.getResults(), clone->getResults()); |
442 | } |
443 | Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(idx: 0)); |
444 | |
445 | // Prepare the epilog of the loop block. |
446 | // Append the cmpxchg op to the end of the loop block. |
447 | auto successOrdering = LLVM::AtomicOrdering::acq_rel; |
448 | auto failureOrdering = LLVM::AtomicOrdering::monotonic; |
449 | auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>( |
450 | loc, dataPtr, loopArgument, result, successOrdering, failureOrdering); |
451 | // Extract the %new_loaded and %ok values from the pair. |
452 | Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0); |
453 | Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1); |
454 | |
455 | // Conditionally branch to the end or back to the loop depending on %ok. |
456 | rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(), |
457 | loopBlock, newLoaded); |
458 | |
459 | rewriter.setInsertionPointToEnd(endBlock); |
460 | |
461 | // The 'result' of the atomic_rmw op is the newly loaded value. |
462 | rewriter.replaceOp(atomicOp, {newLoaded}); |
463 | |
464 | return success(); |
465 | } |
466 | }; |
467 | |
468 | /// Returns the LLVM type of the global variable given the memref type `type`. |
469 | static Type |
470 | convertGlobalMemrefTypeToLLVM(MemRefType type, |
471 | const LLVMTypeConverter &typeConverter) { |
472 | // LLVM type for a global memref will be a multi-dimension array. For |
473 | // declarations or uninitialized global memrefs, we can potentially flatten |
474 | // this to a 1D array. However, for memref.global's with an initial value, |
475 | // we do not intend to flatten the ElementsAttribute when going from std -> |
476 | // LLVM dialect, so the LLVM type needs to me a multi-dimension array. |
477 | Type elementType = typeConverter.convertType(type.getElementType()); |
478 | Type arrayTy = elementType; |
479 | // Shape has the outermost dim at index 0, so need to walk it backwards |
480 | for (int64_t dim : llvm::reverse(type.getShape())) |
481 | arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); |
482 | return arrayTy; |
483 | } |
484 | |
485 | /// GlobalMemrefOp is lowered to a LLVM Global Variable. |
486 | struct GlobalMemrefOpLowering |
487 | : public ConvertOpToLLVMPattern<memref::GlobalOp> { |
488 | using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern; |
489 | |
490 | LogicalResult |
491 | matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor, |
492 | ConversionPatternRewriter &rewriter) const override { |
493 | MemRefType type = global.getType(); |
494 | if (!isConvertibleAndHasIdentityMaps(type)) |
495 | return failure(); |
496 | |
497 | Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); |
498 | |
499 | LLVM::Linkage linkage = |
500 | global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; |
501 | |
502 | Attribute initialValue = nullptr; |
503 | if (!global.isExternal() && !global.isUninitialized()) { |
504 | auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue()); |
505 | initialValue = elementsAttr; |
506 | |
507 | // For scalar memrefs, the global variable created is of the element type, |
508 | // so unpack the elements attribute to extract the value. |
509 | if (type.getRank() == 0) |
510 | initialValue = elementsAttr.getSplatValue<Attribute>(); |
511 | } |
512 | |
513 | uint64_t alignment = global.getAlignment().value_or(0); |
514 | FailureOr<unsigned> addressSpace = |
515 | getTypeConverter()->getMemRefAddressSpace(type); |
516 | if (failed(addressSpace)) |
517 | return global.emitOpError( |
518 | "memory space cannot be converted to an integer address space" ); |
519 | auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( |
520 | global, arrayTy, global.getConstant(), linkage, global.getSymName(), |
521 | initialValue, alignment, *addressSpace); |
522 | if (!global.isExternal() && global.isUninitialized()) { |
523 | rewriter.createBlock(&newGlobal.getInitializerRegion()); |
524 | Value undef[] = { |
525 | rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)}; |
526 | rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef); |
527 | } |
528 | return success(); |
529 | } |
530 | }; |
531 | |
532 | /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to |
533 | /// the first element stashed into the descriptor. This reuses |
534 | /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. |
535 | struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering { |
536 | GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter) |
537 | : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(), |
538 | converter) {} |
539 | |
540 | /// Buffer "allocation" for memref.get_global op is getting the address of |
541 | /// the global variable referenced. |
542 | std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter, |
543 | Location loc, Value sizeBytes, |
544 | Operation *op) const override { |
545 | auto getGlobalOp = cast<memref::GetGlobalOp>(op); |
546 | MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType()); |
547 | |
548 | // This is called after a type conversion, which would have failed if this |
549 | // call fails. |
550 | FailureOr<unsigned> maybeAddressSpace = |
551 | getTypeConverter()->getMemRefAddressSpace(type); |
552 | if (failed(maybeAddressSpace)) |
553 | return std::make_tuple(Value(), Value()); |
554 | unsigned memSpace = *maybeAddressSpace; |
555 | |
556 | Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); |
557 | auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace); |
558 | auto addressOf = |
559 | rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName()); |
560 | |
561 | // Get the address of the first element in the array by creating a GEP with |
562 | // the address of the GV as the base, and (rank + 1) number of 0 indices. |
563 | auto gep = rewriter.create<LLVM::GEPOp>( |
564 | loc, ptrTy, arrayTy, addressOf, |
565 | SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0)); |
566 | |
567 | // We do not expect the memref obtained using `memref.get_global` to be |
568 | // ever deallocated. Set the allocated pointer to be known bad value to |
569 | // help debug if that ever happens. |
570 | auto intPtrType = getIntPtrType(addressSpace: memSpace); |
571 | Value deadBeefConst = |
572 | createIndexAttrConstant(builder&: rewriter, loc: op->getLoc(), resultType: intPtrType, value: 0xdeadbeef); |
573 | auto deadBeefPtr = |
574 | rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst); |
575 | |
576 | // Both allocated and aligned pointers are same. We could potentially stash |
577 | // a nullptr for the allocated pointer since we do not expect any dealloc. |
578 | return std::make_tuple(deadBeefPtr, gep); |
579 | } |
580 | }; |
581 | |
582 | // Load operation is lowered to obtaining a pointer to the indexed element |
583 | // and loading it. |
584 | struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> { |
585 | using Base::Base; |
586 | |
587 | LogicalResult |
588 | matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, |
589 | ConversionPatternRewriter &rewriter) const override { |
590 | auto type = loadOp.getMemRefType(); |
591 | |
592 | Value dataPtr = |
593 | getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(), |
594 | adaptor.getIndices(), rewriter); |
595 | rewriter.replaceOpWithNewOp<LLVM::LoadOp>( |
596 | loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0, |
597 | false, loadOp.getNontemporal()); |
598 | return success(); |
599 | } |
600 | }; |
601 | |
602 | // Store operation is lowered to obtaining a pointer to the indexed element, |
603 | // and storing the given value to it. |
604 | struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> { |
605 | using Base::Base; |
606 | |
607 | LogicalResult |
608 | matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, |
609 | ConversionPatternRewriter &rewriter) const override { |
610 | auto type = op.getMemRefType(); |
611 | |
612 | Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(), |
613 | adaptor.getIndices(), rewriter); |
614 | rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr, |
615 | 0, false, op.getNontemporal()); |
616 | return success(); |
617 | } |
618 | }; |
619 | |
620 | // The prefetch operation is lowered in a way similar to the load operation |
621 | // except that the llvm.prefetch operation is used for replacement. |
622 | struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> { |
623 | using Base::Base; |
624 | |
625 | LogicalResult |
626 | matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor, |
627 | ConversionPatternRewriter &rewriter) const override { |
628 | auto type = prefetchOp.getMemRefType(); |
629 | auto loc = prefetchOp.getLoc(); |
630 | |
631 | Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(), |
632 | adaptor.getIndices(), rewriter); |
633 | |
634 | // Replace with llvm.prefetch. |
635 | IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite()); |
636 | IntegerAttr localityHint = prefetchOp.getLocalityHintAttr(); |
637 | IntegerAttr isData = |
638 | rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache()); |
639 | rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite, |
640 | localityHint, isData); |
641 | return success(); |
642 | } |
643 | }; |
644 | |
645 | struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> { |
646 | using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern; |
647 | |
648 | LogicalResult |
649 | matchAndRewrite(memref::RankOp op, OpAdaptor adaptor, |
650 | ConversionPatternRewriter &rewriter) const override { |
651 | Location loc = op.getLoc(); |
652 | Type operandType = op.getMemref().getType(); |
653 | if (dyn_cast<UnrankedMemRefType>(operandType)) { |
654 | UnrankedMemRefDescriptor desc(adaptor.getMemref()); |
655 | rewriter.replaceOp(op, {desc.rank(builder&: rewriter, loc)}); |
656 | return success(); |
657 | } |
658 | if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) { |
659 | Type indexType = getIndexType(); |
660 | rewriter.replaceOp(op, |
661 | {createIndexAttrConstant(rewriter, loc, indexType, |
662 | rankedMemRefType.getRank())}); |
663 | return success(); |
664 | } |
665 | return failure(); |
666 | } |
667 | }; |
668 | |
669 | struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> { |
670 | using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern; |
671 | |
672 | LogicalResult match(memref::CastOp memRefCastOp) const override { |
673 | Type srcType = memRefCastOp.getOperand().getType(); |
674 | Type dstType = memRefCastOp.getType(); |
675 | |
676 | // memref::CastOp reduce to bitcast in the ranked MemRef case and can be |
677 | // used for type erasure. For now they must preserve underlying element type |
678 | // and require source and result type to have the same rank. Therefore, |
679 | // perform a sanity check that the underlying structs are the same. Once op |
680 | // semantics are relaxed we can revisit. |
681 | if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) |
682 | return success(typeConverter->convertType(srcType) == |
683 | typeConverter->convertType(dstType)); |
684 | |
685 | // At least one of the operands is unranked type |
686 | assert(isa<UnrankedMemRefType>(srcType) || |
687 | isa<UnrankedMemRefType>(dstType)); |
688 | |
689 | // Unranked to unranked cast is disallowed |
690 | return !(isa<UnrankedMemRefType>(srcType) && |
691 | isa<UnrankedMemRefType>(dstType)) |
692 | ? success() |
693 | : failure(); |
694 | } |
695 | |
696 | void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor, |
697 | ConversionPatternRewriter &rewriter) const override { |
698 | auto srcType = memRefCastOp.getOperand().getType(); |
699 | auto dstType = memRefCastOp.getType(); |
700 | auto targetStructType = typeConverter->convertType(memRefCastOp.getType()); |
701 | auto loc = memRefCastOp.getLoc(); |
702 | |
703 | // For ranked/ranked case, just keep the original descriptor. |
704 | if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) |
705 | return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()}); |
706 | |
707 | if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) { |
708 | // Casting ranked to unranked memref type |
709 | // Set the rank in the destination from the memref type |
710 | // Allocate space on the stack and copy the src memref descriptor |
711 | // Set the ptr in the destination to the stack space |
712 | auto srcMemRefType = cast<MemRefType>(srcType); |
713 | int64_t rank = srcMemRefType.getRank(); |
714 | // ptr = AllocaOp sizeof(MemRefDescriptor) |
715 | auto ptr = getTypeConverter()->promoteOneMemRefDescriptor( |
716 | loc, adaptor.getSource(), rewriter); |
717 | |
718 | // rank = ConstantOp srcRank |
719 | auto rankVal = rewriter.create<LLVM::ConstantOp>( |
720 | loc, getIndexType(), rewriter.getIndexAttr(rank)); |
721 | // undef = UndefOp |
722 | UnrankedMemRefDescriptor memRefDesc = |
723 | UnrankedMemRefDescriptor::undef(builder&: rewriter, loc: loc, descriptorType: targetStructType); |
724 | // d1 = InsertValueOp undef, rank, 0 |
725 | memRefDesc.setRank(builder&: rewriter, loc: loc, value: rankVal); |
726 | // d2 = InsertValueOp d1, ptr, 1 |
727 | memRefDesc.setMemRefDescPtr(builder&: rewriter, loc: loc, value: ptr); |
728 | rewriter.replaceOp(memRefCastOp, (Value)memRefDesc); |
729 | |
730 | } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) { |
731 | // Casting from unranked type to ranked. |
732 | // The operation is assumed to be doing a correct cast. If the destination |
733 | // type mismatches the unranked the type, it is undefined behavior. |
734 | UnrankedMemRefDescriptor memRefDesc(adaptor.getSource()); |
735 | // ptr = ExtractValueOp src, 1 |
736 | auto ptr = memRefDesc.memRefDescPtr(builder&: rewriter, loc: loc); |
737 | |
738 | // struct = LoadOp ptr |
739 | auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr); |
740 | rewriter.replaceOp(memRefCastOp, loadOp.getResult()); |
741 | } else { |
742 | llvm_unreachable("Unsupported unranked memref to unranked memref cast" ); |
743 | } |
744 | } |
745 | }; |
746 | |
747 | /// Pattern to lower a `memref.copy` to llvm. |
748 | /// |
749 | /// For memrefs with identity layouts, the copy is lowered to the llvm |
750 | /// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call |
751 | /// to the generic `MemrefCopyFn`. |
752 | struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> { |
753 | using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern; |
754 | |
755 | LogicalResult |
756 | lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor, |
757 | ConversionPatternRewriter &rewriter) const { |
758 | auto loc = op.getLoc(); |
759 | auto srcType = dyn_cast<MemRefType>(op.getSource().getType()); |
760 | |
761 | MemRefDescriptor srcDesc(adaptor.getSource()); |
762 | |
763 | // Compute number of elements. |
764 | Value numElements = rewriter.create<LLVM::ConstantOp>( |
765 | loc, getIndexType(), rewriter.getIndexAttr(1)); |
766 | for (int pos = 0; pos < srcType.getRank(); ++pos) { |
767 | auto size = srcDesc.size(rewriter, loc, pos); |
768 | numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size); |
769 | } |
770 | |
771 | // Get element size. |
772 | auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter); |
773 | // Compute total. |
774 | Value totalSize = |
775 | rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes); |
776 | |
777 | Type elementType = typeConverter->convertType(srcType.getElementType()); |
778 | |
779 | Value srcBasePtr = srcDesc.alignedPtr(builder&: rewriter, loc: loc); |
780 | Value srcOffset = srcDesc.offset(builder&: rewriter, loc: loc); |
781 | Value srcPtr = rewriter.create<LLVM::GEPOp>( |
782 | loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset); |
783 | MemRefDescriptor targetDesc(adaptor.getTarget()); |
784 | Value targetBasePtr = targetDesc.alignedPtr(builder&: rewriter, loc: loc); |
785 | Value targetOffset = targetDesc.offset(builder&: rewriter, loc: loc); |
786 | Value targetPtr = rewriter.create<LLVM::GEPOp>( |
787 | loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset); |
788 | rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize, |
789 | /*isVolatile=*/false); |
790 | rewriter.eraseOp(op: op); |
791 | |
792 | return success(); |
793 | } |
794 | |
795 | LogicalResult |
796 | lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor, |
797 | ConversionPatternRewriter &rewriter) const { |
798 | auto loc = op.getLoc(); |
799 | auto srcType = cast<BaseMemRefType>(op.getSource().getType()); |
800 | auto targetType = cast<BaseMemRefType>(op.getTarget().getType()); |
801 | |
802 | // First make sure we have an unranked memref descriptor representation. |
803 | auto makeUnranked = [&, this](Value ranked, MemRefType type) { |
804 | auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), |
805 | type.getRank()); |
806 | auto *typeConverter = getTypeConverter(); |
807 | auto ptr = |
808 | typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter); |
809 | |
810 | auto unrankedType = |
811 | UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace()); |
812 | return UnrankedMemRefDescriptor::pack( |
813 | rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr}); |
814 | }; |
815 | |
816 | // Save stack position before promoting descriptors |
817 | auto stackSaveOp = |
818 | rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType()); |
819 | |
820 | auto srcMemRefType = dyn_cast<MemRefType>(srcType); |
821 | Value unrankedSource = |
822 | srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType) |
823 | : adaptor.getSource(); |
824 | auto targetMemRefType = dyn_cast<MemRefType>(targetType); |
825 | Value unrankedTarget = |
826 | targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType) |
827 | : adaptor.getTarget(); |
828 | |
829 | // Now promote the unranked descriptors to the stack. |
830 | auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), |
831 | rewriter.getIndexAttr(1)); |
832 | auto promote = [&](Value desc) { |
833 | auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
834 | auto allocated = |
835 | rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one); |
836 | rewriter.create<LLVM::StoreOp>(loc, desc, allocated); |
837 | return allocated; |
838 | }; |
839 | |
840 | auto sourcePtr = promote(unrankedSource); |
841 | auto targetPtr = promote(unrankedTarget); |
842 | |
843 | // Derive size from llvm.getelementptr which will account for any |
844 | // potential alignment |
845 | auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter); |
846 | auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( |
847 | moduleOp: op->getParentOfType<ModuleOp>(), indexType: getIndexType(), unrankedDescriptorType: sourcePtr.getType()); |
848 | rewriter.create<LLVM::CallOp>(loc, copyFn, |
849 | ValueRange{elemSize, sourcePtr, targetPtr}); |
850 | |
851 | // Restore stack used for descriptors |
852 | rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp); |
853 | |
854 | rewriter.eraseOp(op: op); |
855 | |
856 | return success(); |
857 | } |
858 | |
859 | LogicalResult |
860 | matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, |
861 | ConversionPatternRewriter &rewriter) const override { |
862 | auto srcType = cast<BaseMemRefType>(op.getSource().getType()); |
863 | auto targetType = cast<BaseMemRefType>(op.getTarget().getType()); |
864 | |
865 | auto isContiguousMemrefType = [&](BaseMemRefType type) { |
866 | auto memrefType = dyn_cast<mlir::MemRefType>(type); |
867 | // We can use memcpy for memrefs if they have an identity layout or are |
868 | // contiguous with an arbitrary offset. Ignore empty memrefs, which is a |
869 | // special case handled by memrefCopy. |
870 | return memrefType && |
871 | (memrefType.getLayout().isIdentity() || |
872 | (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 && |
873 | memref::isStaticShapeAndContiguousRowMajor(memrefType))); |
874 | }; |
875 | |
876 | if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType)) |
877 | return lowerToMemCopyIntrinsic(op, adaptor, rewriter); |
878 | |
879 | return lowerToMemCopyFunctionCall(op, adaptor, rewriter); |
880 | } |
881 | }; |
882 | |
883 | struct MemorySpaceCastOpLowering |
884 | : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> { |
885 | using ConvertOpToLLVMPattern< |
886 | memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern; |
887 | |
888 | LogicalResult |
889 | matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor, |
890 | ConversionPatternRewriter &rewriter) const override { |
891 | Location loc = op.getLoc(); |
892 | |
893 | Type resultType = op.getDest().getType(); |
894 | if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) { |
895 | auto resultDescType = |
896 | cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR)); |
897 | Type newPtrType = resultDescType.getBody()[0]; |
898 | |
899 | SmallVector<Value> descVals; |
900 | MemRefDescriptor::unpack(builder&: rewriter, loc, packed: adaptor.getSource(), type: resultTypeR, |
901 | results&: descVals); |
902 | descVals[0] = |
903 | rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]); |
904 | descVals[1] = |
905 | rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]); |
906 | Value result = MemRefDescriptor::pack(builder&: rewriter, loc, converter: *getTypeConverter(), |
907 | type: resultTypeR, values: descVals); |
908 | rewriter.replaceOp(op, result); |
909 | return success(); |
910 | } |
911 | if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) { |
912 | // Since the type converter won't be doing this for us, get the address |
913 | // space. |
914 | auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType()); |
915 | FailureOr<unsigned> maybeSourceAddrSpace = |
916 | getTypeConverter()->getMemRefAddressSpace(sourceType); |
917 | if (failed(maybeSourceAddrSpace)) |
918 | return rewriter.notifyMatchFailure(loc, |
919 | "non-integer source address space" ); |
920 | unsigned sourceAddrSpace = *maybeSourceAddrSpace; |
921 | FailureOr<unsigned> maybeResultAddrSpace = |
922 | getTypeConverter()->getMemRefAddressSpace(resultTypeU); |
923 | if (failed(maybeResultAddrSpace)) |
924 | return rewriter.notifyMatchFailure(loc, |
925 | "non-integer result address space" ); |
926 | unsigned resultAddrSpace = *maybeResultAddrSpace; |
927 | |
928 | UnrankedMemRefDescriptor sourceDesc(adaptor.getSource()); |
929 | Value rank = sourceDesc.rank(builder&: rewriter, loc); |
930 | Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(builder&: rewriter, loc); |
931 | |
932 | // Create and allocate storage for new memref descriptor. |
933 | auto result = UnrankedMemRefDescriptor::undef( |
934 | rewriter, loc, typeConverter->convertType(resultTypeU)); |
935 | result.setRank(rewriter, loc, rank); |
936 | SmallVector<Value, 1> sizes; |
937 | UnrankedMemRefDescriptor::computeSizes(builder&: rewriter, loc, typeConverter: *getTypeConverter(), |
938 | values: result, addressSpaces: resultAddrSpace, sizes&: sizes); |
939 | Value resultUnderlyingSize = sizes.front(); |
940 | Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>( |
941 | loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize); |
942 | result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc); |
943 | |
944 | // Copy pointers, performing address space casts. |
945 | auto sourceElemPtrType = |
946 | LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace); |
947 | auto resultElemPtrType = |
948 | LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace); |
949 | |
950 | Value allocatedPtr = sourceDesc.allocatedPtr( |
951 | builder&: rewriter, loc, memRefDescPtr: sourceUnderlyingDesc, elemPtrType: sourceElemPtrType); |
952 | Value alignedPtr = |
953 | sourceDesc.alignedPtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(), |
954 | memRefDescPtr: sourceUnderlyingDesc, elemPtrType: sourceElemPtrType); |
955 | allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>( |
956 | loc, resultElemPtrType, allocatedPtr); |
957 | alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>( |
958 | loc, resultElemPtrType, alignedPtr); |
959 | |
960 | result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc, |
961 | resultElemPtrType, allocatedPtr); |
962 | result.setAlignedPtr(rewriter, loc, *getTypeConverter(), |
963 | resultUnderlyingDesc, resultElemPtrType, alignedPtr); |
964 | |
965 | // Copy all the index-valued operands. |
966 | Value sourceIndexVals = |
967 | sourceDesc.offsetBasePtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(), |
968 | memRefDescPtr: sourceUnderlyingDesc, elemPtrType: sourceElemPtrType); |
969 | Value resultIndexVals = |
970 | result.offsetBasePtr(rewriter, loc, *getTypeConverter(), |
971 | resultUnderlyingDesc, resultElemPtrType); |
972 | |
973 | int64_t bytesToSkip = |
974 | 2 * |
975 | ceilDiv(getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8); |
976 | Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>( |
977 | loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip)); |
978 | Value copySize = rewriter.create<LLVM::SubOp>( |
979 | loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst); |
980 | rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals, |
981 | copySize, /*isVolatile=*/false); |
982 | |
983 | rewriter.replaceOp(op, ValueRange{result}); |
984 | return success(); |
985 | } |
986 | return rewriter.notifyMatchFailure(loc, "unexpected memref type" ); |
987 | } |
988 | }; |
989 | |
990 | /// Extracts allocated, aligned pointers and offset from a ranked or unranked |
991 | /// memref type. In unranked case, the fields are extracted from the underlying |
992 | /// ranked descriptor. |
993 | static void extractPointersAndOffset(Location loc, |
994 | ConversionPatternRewriter &rewriter, |
995 | const LLVMTypeConverter &typeConverter, |
996 | Value originalOperand, |
997 | Value convertedOperand, |
998 | Value *allocatedPtr, Value *alignedPtr, |
999 | Value *offset = nullptr) { |
1000 | Type operandType = originalOperand.getType(); |
1001 | if (isa<MemRefType>(operandType)) { |
1002 | MemRefDescriptor desc(convertedOperand); |
1003 | *allocatedPtr = desc.allocatedPtr(builder&: rewriter, loc); |
1004 | *alignedPtr = desc.alignedPtr(builder&: rewriter, loc); |
1005 | if (offset != nullptr) |
1006 | *offset = desc.offset(builder&: rewriter, loc); |
1007 | return; |
1008 | } |
1009 | |
1010 | // These will all cause assert()s on unconvertible types. |
1011 | unsigned memorySpace = *typeConverter.getMemRefAddressSpace( |
1012 | type: cast<UnrankedMemRefType>(operandType)); |
1013 | auto elementPtrType = |
1014 | LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace); |
1015 | |
1016 | // Extract pointer to the underlying ranked memref descriptor and cast it to |
1017 | // ElemType**. |
1018 | UnrankedMemRefDescriptor unrankedDesc(convertedOperand); |
1019 | Value underlyingDescPtr = unrankedDesc.memRefDescPtr(builder&: rewriter, loc); |
1020 | |
1021 | *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( |
1022 | builder&: rewriter, loc, memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType); |
1023 | *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( |
1024 | builder&: rewriter, loc, typeConverter, memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType); |
1025 | if (offset != nullptr) { |
1026 | *offset = UnrankedMemRefDescriptor::offset( |
1027 | builder&: rewriter, loc, typeConverter, memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType); |
1028 | } |
1029 | } |
1030 | |
1031 | struct MemRefReinterpretCastOpLowering |
1032 | : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> { |
1033 | using ConvertOpToLLVMPattern< |
1034 | memref::ReinterpretCastOp>::ConvertOpToLLVMPattern; |
1035 | |
1036 | LogicalResult |
1037 | matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor, |
1038 | ConversionPatternRewriter &rewriter) const override { |
1039 | Type srcType = castOp.getSource().getType(); |
1040 | |
1041 | Value descriptor; |
1042 | if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, |
1043 | adaptor, &descriptor))) |
1044 | return failure(); |
1045 | rewriter.replaceOp(castOp, {descriptor}); |
1046 | return success(); |
1047 | } |
1048 | |
1049 | private: |
1050 | LogicalResult convertSourceMemRefToDescriptor( |
1051 | ConversionPatternRewriter &rewriter, Type srcType, |
1052 | memref::ReinterpretCastOp castOp, |
1053 | memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { |
1054 | MemRefType targetMemRefType = |
1055 | cast<MemRefType>(castOp.getResult().getType()); |
1056 | auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( |
1057 | typeConverter->convertType(targetMemRefType)); |
1058 | if (!llvmTargetDescriptorTy) |
1059 | return failure(); |
1060 | |
1061 | // Create descriptor. |
1062 | Location loc = castOp.getLoc(); |
1063 | auto desc = MemRefDescriptor::undef(builder&: rewriter, loc, descriptorType: llvmTargetDescriptorTy); |
1064 | |
1065 | // Set allocated and aligned pointers. |
1066 | Value allocatedPtr, alignedPtr; |
1067 | extractPointersAndOffset(loc, rewriter, *getTypeConverter(), |
1068 | castOp.getSource(), adaptor.getSource(), |
1069 | &allocatedPtr, &alignedPtr); |
1070 | desc.setAllocatedPtr(rewriter, loc, allocatedPtr); |
1071 | desc.setAlignedPtr(rewriter, loc, alignedPtr); |
1072 | |
1073 | // Set offset. |
1074 | if (castOp.isDynamicOffset(0)) |
1075 | desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]); |
1076 | else |
1077 | desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); |
1078 | |
1079 | // Set sizes and strides. |
1080 | unsigned dynSizeId = 0; |
1081 | unsigned dynStrideId = 0; |
1082 | for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { |
1083 | if (castOp.isDynamicSize(i)) |
1084 | desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]); |
1085 | else |
1086 | desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); |
1087 | |
1088 | if (castOp.isDynamicStride(i)) |
1089 | desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]); |
1090 | else |
1091 | desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); |
1092 | } |
1093 | *descriptor = desc; |
1094 | return success(); |
1095 | } |
1096 | }; |
1097 | |
1098 | struct MemRefReshapeOpLowering |
1099 | : public ConvertOpToLLVMPattern<memref::ReshapeOp> { |
1100 | using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern; |
1101 | |
1102 | LogicalResult |
1103 | matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor, |
1104 | ConversionPatternRewriter &rewriter) const override { |
1105 | Type srcType = reshapeOp.getSource().getType(); |
1106 | |
1107 | Value descriptor; |
1108 | if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, |
1109 | adaptor, &descriptor))) |
1110 | return failure(); |
1111 | rewriter.replaceOp(reshapeOp, {descriptor}); |
1112 | return success(); |
1113 | } |
1114 | |
1115 | private: |
1116 | LogicalResult |
1117 | convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, |
1118 | Type srcType, memref::ReshapeOp reshapeOp, |
1119 | memref::ReshapeOp::Adaptor adaptor, |
1120 | Value *descriptor) const { |
1121 | auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType()); |
1122 | if (shapeMemRefType.hasStaticShape()) { |
1123 | MemRefType targetMemRefType = |
1124 | cast<MemRefType>(reshapeOp.getResult().getType()); |
1125 | auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>( |
1126 | typeConverter->convertType(targetMemRefType)); |
1127 | if (!llvmTargetDescriptorTy) |
1128 | return failure(); |
1129 | |
1130 | // Create descriptor. |
1131 | Location loc = reshapeOp.getLoc(); |
1132 | auto desc = |
1133 | MemRefDescriptor::undef(builder&: rewriter, loc, descriptorType: llvmTargetDescriptorTy); |
1134 | |
1135 | // Set allocated and aligned pointers. |
1136 | Value allocatedPtr, alignedPtr; |
1137 | extractPointersAndOffset(loc, rewriter, *getTypeConverter(), |
1138 | reshapeOp.getSource(), adaptor.getSource(), |
1139 | &allocatedPtr, &alignedPtr); |
1140 | desc.setAllocatedPtr(rewriter, loc, allocatedPtr); |
1141 | desc.setAlignedPtr(rewriter, loc, alignedPtr); |
1142 | |
1143 | // Extract the offset and strides from the type. |
1144 | int64_t offset; |
1145 | SmallVector<int64_t> strides; |
1146 | if (failed(getStridesAndOffset(targetMemRefType, strides, offset))) |
1147 | return rewriter.notifyMatchFailure( |
1148 | reshapeOp, "failed to get stride and offset exprs" ); |
1149 | |
1150 | if (!isStaticStrideOrOffset(strideOrOffset: offset)) |
1151 | return rewriter.notifyMatchFailure(reshapeOp, |
1152 | "dynamic offset is unsupported" ); |
1153 | |
1154 | desc.setConstantOffset(rewriter, loc, offset); |
1155 | |
1156 | assert(targetMemRefType.getLayout().isIdentity() && |
1157 | "Identity layout map is a precondition of a valid reshape op" ); |
1158 | |
1159 | Type indexType = getIndexType(); |
1160 | Value stride = nullptr; |
1161 | int64_t targetRank = targetMemRefType.getRank(); |
1162 | for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) { |
1163 | if (!ShapedType::isDynamic(strides[i])) { |
1164 | // If the stride for this dimension is dynamic, then use the product |
1165 | // of the sizes of the inner dimensions. |
1166 | stride = |
1167 | createIndexAttrConstant(rewriter, loc, indexType, strides[i]); |
1168 | } else if (!stride) { |
1169 | // `stride` is null only in the first iteration of the loop. However, |
1170 | // since the target memref has an identity layout, we can safely set |
1171 | // the innermost stride to 1. |
1172 | stride = createIndexAttrConstant(rewriter, loc, indexType, 1); |
1173 | } |
1174 | |
1175 | Value dimSize; |
1176 | // If the size of this dimension is dynamic, then load it at runtime |
1177 | // from the shape operand. |
1178 | if (!targetMemRefType.isDynamicDim(i)) { |
1179 | dimSize = createIndexAttrConstant(rewriter, loc, indexType, |
1180 | targetMemRefType.getDimSize(i)); |
1181 | } else { |
1182 | Value shapeOp = reshapeOp.getShape(); |
1183 | Value index = createIndexAttrConstant(rewriter, loc, indexType, i); |
1184 | dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index); |
1185 | Type indexType = getIndexType(); |
1186 | if (dimSize.getType() != indexType) |
1187 | dimSize = typeConverter->materializeTargetConversion( |
1188 | rewriter, loc, indexType, dimSize); |
1189 | assert(dimSize && "Invalid memref element type" ); |
1190 | } |
1191 | |
1192 | desc.setSize(rewriter, loc, i, dimSize); |
1193 | desc.setStride(rewriter, loc, i, stride); |
1194 | |
1195 | // Prepare the stride value for the next dimension. |
1196 | stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize); |
1197 | } |
1198 | |
1199 | *descriptor = desc; |
1200 | return success(); |
1201 | } |
1202 | |
1203 | // The shape is a rank-1 tensor with unknown length. |
1204 | Location loc = reshapeOp.getLoc(); |
1205 | MemRefDescriptor shapeDesc(adaptor.getShape()); |
1206 | Value resultRank = shapeDesc.size(builder&: rewriter, loc, pos: 0); |
1207 | |
1208 | // Extract address space and element type. |
1209 | auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType()); |
1210 | unsigned addressSpace = |
1211 | *getTypeConverter()->getMemRefAddressSpace(targetType); |
1212 | |
1213 | // Create the unranked memref descriptor that holds the ranked one. The |
1214 | // inner descriptor is allocated on stack. |
1215 | auto targetDesc = UnrankedMemRefDescriptor::undef( |
1216 | rewriter, loc, typeConverter->convertType(targetType)); |
1217 | targetDesc.setRank(rewriter, loc, resultRank); |
1218 | SmallVector<Value, 4> sizes; |
1219 | UnrankedMemRefDescriptor::computeSizes(builder&: rewriter, loc, typeConverter: *getTypeConverter(), |
1220 | values: targetDesc, addressSpaces: addressSpace, sizes); |
1221 | Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( |
1222 | loc, getVoidPtrType(), IntegerType::get(getContext(), 8), |
1223 | sizes.front()); |
1224 | targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); |
1225 | |
1226 | // Extract pointers and offset from the source memref. |
1227 | Value allocatedPtr, alignedPtr, offset; |
1228 | extractPointersAndOffset(loc, rewriter, *getTypeConverter(), |
1229 | reshapeOp.getSource(), adaptor.getSource(), |
1230 | &allocatedPtr, &alignedPtr, &offset); |
1231 | |
1232 | // Set pointers and offset. |
1233 | auto elementPtrType = |
1234 | LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace); |
1235 | |
1236 | UnrankedMemRefDescriptor::setAllocatedPtr(builder&: rewriter, loc, memRefDescPtr: underlyingDescPtr, |
1237 | elemPtrType: elementPtrType, allocatedPtr); |
1238 | UnrankedMemRefDescriptor::setAlignedPtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(), |
1239 | memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType, |
1240 | alignedPtr); |
1241 | UnrankedMemRefDescriptor::setOffset(builder&: rewriter, loc, typeConverter: *getTypeConverter(), |
1242 | memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType, |
1243 | offset); |
1244 | |
1245 | // Use the offset pointer as base for further addressing. Copy over the new |
1246 | // shape and compute strides. For this, we create a loop from rank-1 to 0. |
1247 | Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( |
1248 | builder&: rewriter, loc, typeConverter: *getTypeConverter(), memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType); |
1249 | Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( |
1250 | builder&: rewriter, loc, typeConverter: *getTypeConverter(), sizeBasePtr: targetSizesBase, rank: resultRank); |
1251 | Value shapeOperandPtr = shapeDesc.alignedPtr(builder&: rewriter, loc); |
1252 | Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1); |
1253 | Value resultRankMinusOne = |
1254 | rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); |
1255 | |
1256 | Block *initBlock = rewriter.getInsertionBlock(); |
1257 | Type indexType = getTypeConverter()->getIndexType(); |
1258 | Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); |
1259 | |
1260 | Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, |
1261 | {indexType, indexType}, {loc, loc}); |
1262 | |
1263 | // Move the remaining initBlock ops to condBlock. |
1264 | Block *remainingBlock = rewriter.splitBlock(block: initBlock, before: remainingOpsIt); |
1265 | rewriter.mergeBlocks(source: remainingBlock, dest: condBlock, argValues: ValueRange()); |
1266 | |
1267 | rewriter.setInsertionPointToEnd(initBlock); |
1268 | rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), |
1269 | condBlock); |
1270 | rewriter.setInsertionPointToStart(condBlock); |
1271 | Value indexArg = condBlock->getArgument(i: 0); |
1272 | Value strideArg = condBlock->getArgument(i: 1); |
1273 | |
1274 | Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0); |
1275 | Value pred = rewriter.create<LLVM::ICmpOp>( |
1276 | loc, IntegerType::get(rewriter.getContext(), 1), |
1277 | LLVM::ICmpPredicate::sge, indexArg, zeroIndex); |
1278 | |
1279 | Block *bodyBlock = |
1280 | rewriter.splitBlock(block: condBlock, before: rewriter.getInsertionPoint()); |
1281 | rewriter.setInsertionPointToStart(bodyBlock); |
1282 | |
1283 | // Copy size from shape to descriptor. |
1284 | auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
1285 | Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( |
1286 | loc, llvmIndexPtrType, |
1287 | typeConverter->convertType(shapeMemRefType.getElementType()), |
1288 | shapeOperandPtr, indexArg); |
1289 | Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep); |
1290 | UnrankedMemRefDescriptor::setSize(builder&: rewriter, loc, typeConverter: *getTypeConverter(), |
1291 | sizeBasePtr: targetSizesBase, index: indexArg, size); |
1292 | |
1293 | // Write stride value and compute next one. |
1294 | UnrankedMemRefDescriptor::setStride(builder&: rewriter, loc, typeConverter: *getTypeConverter(), |
1295 | strideBasePtr: targetStridesBase, index: indexArg, stride: strideArg); |
1296 | Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); |
1297 | |
1298 | // Decrement loop counter and branch back. |
1299 | Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); |
1300 | rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), |
1301 | condBlock); |
1302 | |
1303 | Block *remainder = |
1304 | rewriter.splitBlock(block: bodyBlock, before: rewriter.getInsertionPoint()); |
1305 | |
1306 | // Hook up the cond exit to the remainder. |
1307 | rewriter.setInsertionPointToEnd(condBlock); |
1308 | rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt, |
1309 | remainder, std::nullopt); |
1310 | |
1311 | // Reset position to beginning of new remainder block. |
1312 | rewriter.setInsertionPointToStart(remainder); |
1313 | |
1314 | *descriptor = targetDesc; |
1315 | return success(); |
1316 | } |
1317 | }; |
1318 | |
1319 | /// RessociatingReshapeOp must be expanded before we reach this stage. |
1320 | /// Report that information. |
1321 | template <typename ReshapeOp> |
1322 | class ReassociatingReshapeOpConversion |
1323 | : public ConvertOpToLLVMPattern<ReshapeOp> { |
1324 | public: |
1325 | using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern; |
1326 | using ReshapeOpAdaptor = typename ReshapeOp::Adaptor; |
1327 | |
1328 | LogicalResult |
1329 | matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor, |
1330 | ConversionPatternRewriter &rewriter) const override { |
1331 | return rewriter.notifyMatchFailure( |
1332 | reshapeOp, |
1333 | "reassociation operations should have been expanded beforehand" ); |
1334 | } |
1335 | }; |
1336 | |
1337 | /// Subviews must be expanded before we reach this stage. |
1338 | /// Report that information. |
1339 | struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> { |
1340 | using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern; |
1341 | |
1342 | LogicalResult |
1343 | matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, |
1344 | ConversionPatternRewriter &rewriter) const override { |
1345 | return rewriter.notifyMatchFailure( |
1346 | subViewOp, "subview operations should have been expanded beforehand" ); |
1347 | } |
1348 | }; |
1349 | |
1350 | /// Conversion pattern that transforms a transpose op into: |
1351 | /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. |
1352 | /// 2. A load of the ViewDescriptor from the pointer allocated in 1. |
1353 | /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size |
1354 | /// and stride. Size and stride are permutations of the original values. |
1355 | /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. |
1356 | /// The transpose op is replaced by the alloca'ed pointer. |
1357 | class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> { |
1358 | public: |
1359 | using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern; |
1360 | |
1361 | LogicalResult |
1362 | matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor, |
1363 | ConversionPatternRewriter &rewriter) const override { |
1364 | auto loc = transposeOp.getLoc(); |
1365 | MemRefDescriptor viewMemRef(adaptor.getIn()); |
1366 | |
1367 | // No permutation, early exit. |
1368 | if (transposeOp.getPermutation().isIdentity()) |
1369 | return rewriter.replaceOp(transposeOp, {viewMemRef}), success(); |
1370 | |
1371 | auto targetMemRef = MemRefDescriptor::undef( |
1372 | rewriter, loc, |
1373 | typeConverter->convertType(transposeOp.getIn().getType())); |
1374 | |
1375 | // Copy the base and aligned pointers from the old descriptor to the new |
1376 | // one. |
1377 | targetMemRef.setAllocatedPtr(rewriter, loc, |
1378 | viewMemRef.allocatedPtr(builder&: rewriter, loc: loc)); |
1379 | targetMemRef.setAlignedPtr(rewriter, loc, |
1380 | viewMemRef.alignedPtr(builder&: rewriter, loc: loc)); |
1381 | |
1382 | // Copy the offset pointer from the old descriptor to the new one. |
1383 | targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(builder&: rewriter, loc: loc)); |
1384 | |
1385 | // Iterate over the dimensions and apply size/stride permutation: |
1386 | // When enumerating the results of the permutation map, the enumeration |
1387 | // index is the index into the target dimensions and the DimExpr points to |
1388 | // the dimension of the source memref. |
1389 | for (const auto &en : |
1390 | llvm::enumerate(transposeOp.getPermutation().getResults())) { |
1391 | int targetPos = en.index(); |
1392 | int sourcePos = cast<AffineDimExpr>(en.value()).getPosition(); |
1393 | targetMemRef.setSize(rewriter, loc, targetPos, |
1394 | viewMemRef.size(rewriter, loc, sourcePos)); |
1395 | targetMemRef.setStride(rewriter, loc, targetPos, |
1396 | viewMemRef.stride(rewriter, loc, sourcePos)); |
1397 | } |
1398 | |
1399 | rewriter.replaceOp(transposeOp, {targetMemRef}); |
1400 | return success(); |
1401 | } |
1402 | }; |
1403 | |
1404 | /// Conversion pattern that transforms an op into: |
1405 | /// 1. An `llvm.mlir.undef` operation to create a memref descriptor |
1406 | /// 2. Updates to the descriptor to introduce the data ptr, offset, size |
1407 | /// and stride. |
1408 | /// The view op is replaced by the descriptor. |
1409 | struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> { |
1410 | using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern; |
1411 | |
1412 | // Build and return the value for the idx^th shape dimension, either by |
1413 | // returning the constant shape dimension or counting the proper dynamic size. |
1414 | Value getSize(ConversionPatternRewriter &rewriter, Location loc, |
1415 | ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx, |
1416 | Type indexType) const { |
1417 | assert(idx < shape.size()); |
1418 | if (!ShapedType::isDynamic(shape[idx])) |
1419 | return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]); |
1420 | // Count the number of dynamic dims in range [0, idx] |
1421 | unsigned nDynamic = |
1422 | llvm::count_if(shape.take_front(idx), ShapedType::isDynamic); |
1423 | return dynamicSizes[nDynamic]; |
1424 | } |
1425 | |
1426 | // Build and return the idx^th stride, either by returning the constant stride |
1427 | // or by computing the dynamic stride from the current `runningStride` and |
1428 | // `nextSize`. The caller should keep a running stride and update it with the |
1429 | // result returned by this function. |
1430 | Value getStride(ConversionPatternRewriter &rewriter, Location loc, |
1431 | ArrayRef<int64_t> strides, Value nextSize, |
1432 | Value runningStride, unsigned idx, Type indexType) const { |
1433 | assert(idx < strides.size()); |
1434 | if (!ShapedType::isDynamic(strides[idx])) |
1435 | return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]); |
1436 | if (nextSize) |
1437 | return runningStride |
1438 | ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) |
1439 | : nextSize; |
1440 | assert(!runningStride); |
1441 | return createIndexAttrConstant(rewriter, loc, indexType, 1); |
1442 | } |
1443 | |
1444 | LogicalResult |
1445 | matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor, |
1446 | ConversionPatternRewriter &rewriter) const override { |
1447 | auto loc = viewOp.getLoc(); |
1448 | |
1449 | auto viewMemRefType = viewOp.getType(); |
1450 | auto targetElementTy = |
1451 | typeConverter->convertType(viewMemRefType.getElementType()); |
1452 | auto targetDescTy = typeConverter->convertType(viewMemRefType); |
1453 | if (!targetDescTy || !targetElementTy || |
1454 | !LLVM::isCompatibleType(type: targetElementTy) || |
1455 | !LLVM::isCompatibleType(type: targetDescTy)) |
1456 | return viewOp.emitWarning("Target descriptor type not converted to LLVM" ), |
1457 | failure(); |
1458 | |
1459 | int64_t offset; |
1460 | SmallVector<int64_t, 4> strides; |
1461 | auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); |
1462 | if (failed(successStrides)) |
1463 | return viewOp.emitWarning("cannot cast to non-strided shape" ), failure(); |
1464 | assert(offset == 0 && "expected offset to be 0" ); |
1465 | |
1466 | // Target memref must be contiguous in memory (innermost stride is 1), or |
1467 | // empty (special case when at least one of the memref dimensions is 0). |
1468 | if (!strides.empty() && (strides.back() != 1 && strides.back() != 0)) |
1469 | return viewOp.emitWarning("cannot cast to non-contiguous shape" ), |
1470 | failure(); |
1471 | |
1472 | // Create the descriptor. |
1473 | MemRefDescriptor sourceMemRef(adaptor.getSource()); |
1474 | auto targetMemRef = MemRefDescriptor::undef(builder&: rewriter, loc: loc, descriptorType: targetDescTy); |
1475 | |
1476 | // Field 1: Copy the allocated pointer, used for malloc/free. |
1477 | Value allocatedPtr = sourceMemRef.allocatedPtr(builder&: rewriter, loc: loc); |
1478 | auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType()); |
1479 | targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr); |
1480 | |
1481 | // Field 2: Copy the actual aligned pointer to payload. |
1482 | Value alignedPtr = sourceMemRef.alignedPtr(builder&: rewriter, loc: loc); |
1483 | alignedPtr = rewriter.create<LLVM::GEPOp>( |
1484 | loc, alignedPtr.getType(), |
1485 | typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr, |
1486 | adaptor.getByteShift()); |
1487 | |
1488 | targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr); |
1489 | |
1490 | Type indexType = getIndexType(); |
1491 | // Field 3: The offset in the resulting type must be 0. This is |
1492 | // because of the type change: an offset on srcType* may not be |
1493 | // expressible as an offset on dstType*. |
1494 | targetMemRef.setOffset( |
1495 | rewriter, loc, |
1496 | createIndexAttrConstant(rewriter, loc, indexType, offset)); |
1497 | |
1498 | // Early exit for 0-D corner case. |
1499 | if (viewMemRefType.getRank() == 0) |
1500 | return rewriter.replaceOp(viewOp, {targetMemRef}), success(); |
1501 | |
1502 | // Fields 4 and 5: Update sizes and strides. |
1503 | Value stride = nullptr, nextSize = nullptr; |
1504 | for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { |
1505 | // Update size. |
1506 | Value size = getSize(rewriter, loc: loc, shape: viewMemRefType.getShape(), |
1507 | dynamicSizes: adaptor.getSizes(), idx: i, indexType); |
1508 | targetMemRef.setSize(rewriter, loc, i, size); |
1509 | // Update stride. |
1510 | stride = |
1511 | getStride(rewriter, loc: loc, strides: strides, nextSize, runningStride: stride, idx: i, indexType); |
1512 | targetMemRef.setStride(rewriter, loc, i, stride); |
1513 | nextSize = size; |
1514 | } |
1515 | |
1516 | rewriter.replaceOp(viewOp, {targetMemRef}); |
1517 | return success(); |
1518 | } |
1519 | }; |
1520 | |
1521 | //===----------------------------------------------------------------------===// |
1522 | // AtomicRMWOpLowering |
1523 | //===----------------------------------------------------------------------===// |
1524 | |
1525 | /// Try to match the kind of a memref.atomic_rmw to determine whether to use a |
1526 | /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. |
1527 | static std::optional<LLVM::AtomicBinOp> |
1528 | matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { |
1529 | switch (atomicOp.getKind()) { |
1530 | case arith::AtomicRMWKind::addf: |
1531 | return LLVM::AtomicBinOp::fadd; |
1532 | case arith::AtomicRMWKind::addi: |
1533 | return LLVM::AtomicBinOp::add; |
1534 | case arith::AtomicRMWKind::assign: |
1535 | return LLVM::AtomicBinOp::xchg; |
1536 | case arith::AtomicRMWKind::maximumf: |
1537 | return LLVM::AtomicBinOp::fmax; |
1538 | case arith::AtomicRMWKind::maxs: |
1539 | return LLVM::AtomicBinOp::max; |
1540 | case arith::AtomicRMWKind::maxu: |
1541 | return LLVM::AtomicBinOp::umax; |
1542 | case arith::AtomicRMWKind::minimumf: |
1543 | return LLVM::AtomicBinOp::fmin; |
1544 | case arith::AtomicRMWKind::mins: |
1545 | return LLVM::AtomicBinOp::min; |
1546 | case arith::AtomicRMWKind::minu: |
1547 | return LLVM::AtomicBinOp::umin; |
1548 | case arith::AtomicRMWKind::ori: |
1549 | return LLVM::AtomicBinOp::_or; |
1550 | case arith::AtomicRMWKind::andi: |
1551 | return LLVM::AtomicBinOp::_and; |
1552 | default: |
1553 | return std::nullopt; |
1554 | } |
1555 | llvm_unreachable("Invalid AtomicRMWKind" ); |
1556 | } |
1557 | |
1558 | struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> { |
1559 | using Base::Base; |
1560 | |
1561 | LogicalResult |
1562 | matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, |
1563 | ConversionPatternRewriter &rewriter) const override { |
1564 | auto maybeKind = matchSimpleAtomicOp(atomicOp); |
1565 | if (!maybeKind) |
1566 | return failure(); |
1567 | auto memRefType = atomicOp.getMemRefType(); |
1568 | SmallVector<int64_t> strides; |
1569 | int64_t offset; |
1570 | if (failed(getStridesAndOffset(memRefType, strides, offset))) |
1571 | return failure(); |
1572 | auto dataPtr = |
1573 | getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), |
1574 | adaptor.getIndices(), rewriter); |
1575 | rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>( |
1576 | atomicOp, *maybeKind, dataPtr, adaptor.getValue(), |
1577 | LLVM::AtomicOrdering::acq_rel); |
1578 | return success(); |
1579 | } |
1580 | }; |
1581 | |
1582 | /// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index. |
1583 | class |
1584 | : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> { |
1585 | public: |
1586 | using ConvertOpToLLVMPattern< |
1587 | memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern; |
1588 | |
1589 | LogicalResult |
1590 | matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp , |
1591 | OpAdaptor adaptor, |
1592 | ConversionPatternRewriter &rewriter) const override { |
1593 | MemRefDescriptor desc(adaptor.getSource()); |
1594 | rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>( |
1595 | extractOp, getTypeConverter()->getIndexType(), |
1596 | desc.alignedPtr(rewriter, extractOp->getLoc())); |
1597 | return success(); |
1598 | } |
1599 | }; |
1600 | |
1601 | /// Materialize the MemRef descriptor represented by the results of |
1602 | /// ExtractStridedMetadataOp. |
1603 | class |
1604 | : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> { |
1605 | public: |
1606 | using ConvertOpToLLVMPattern< |
1607 | memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern; |
1608 | |
1609 | LogicalResult |
1610 | matchAndRewrite(memref::ExtractStridedMetadataOp , |
1611 | OpAdaptor adaptor, |
1612 | ConversionPatternRewriter &rewriter) const override { |
1613 | |
1614 | if (!LLVM::isCompatibleType(type: adaptor.getOperands().front().getType())) |
1615 | return failure(); |
1616 | |
1617 | // Create the descriptor. |
1618 | MemRefDescriptor sourceMemRef(adaptor.getSource()); |
1619 | Location loc = extractStridedMetadataOp.getLoc(); |
1620 | Value source = extractStridedMetadataOp.getSource(); |
1621 | |
1622 | auto sourceMemRefType = cast<MemRefType>(source.getType()); |
1623 | int64_t rank = sourceMemRefType.getRank(); |
1624 | SmallVector<Value> results; |
1625 | results.reserve(2 + rank * 2); |
1626 | |
1627 | // Base buffer. |
1628 | Value baseBuffer = sourceMemRef.allocatedPtr(builder&: rewriter, loc); |
1629 | Value alignedBuffer = sourceMemRef.alignedPtr(builder&: rewriter, loc); |
1630 | MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape( |
1631 | rewriter, loc, *getTypeConverter(), |
1632 | cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()), |
1633 | baseBuffer, alignedBuffer); |
1634 | results.push_back((Value)dstMemRef); |
1635 | |
1636 | // Offset. |
1637 | results.push_back(sourceMemRef.offset(builder&: rewriter, loc)); |
1638 | |
1639 | // Sizes. |
1640 | for (unsigned i = 0; i < rank; ++i) |
1641 | results.push_back(sourceMemRef.size(builder&: rewriter, loc, pos: i)); |
1642 | // Strides. |
1643 | for (unsigned i = 0; i < rank; ++i) |
1644 | results.push_back(sourceMemRef.stride(builder&: rewriter, loc, pos: i)); |
1645 | |
1646 | rewriter.replaceOp(extractStridedMetadataOp, results); |
1647 | return success(); |
1648 | } |
1649 | }; |
1650 | |
1651 | } // namespace |
1652 | |
1653 | void mlir::populateFinalizeMemRefToLLVMConversionPatterns( |
1654 | LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
1655 | // clang-format off |
1656 | patterns.add< |
1657 | AllocaOpLowering, |
1658 | AllocaScopeOpLowering, |
1659 | AtomicRMWOpLowering, |
1660 | AssumeAlignmentOpLowering, |
1661 | ConvertExtractAlignedPointerAsIndex, |
1662 | DimOpLowering, |
1663 | ExtractStridedMetadataOpLowering, |
1664 | GenericAtomicRMWOpLowering, |
1665 | GlobalMemrefOpLowering, |
1666 | GetGlobalMemrefOpLowering, |
1667 | LoadOpLowering, |
1668 | MemRefCastOpLowering, |
1669 | MemRefCopyOpLowering, |
1670 | MemorySpaceCastOpLowering, |
1671 | MemRefReinterpretCastOpLowering, |
1672 | MemRefReshapeOpLowering, |
1673 | PrefetchOpLowering, |
1674 | RankOpLowering, |
1675 | ReassociatingReshapeOpConversion<memref::ExpandShapeOp>, |
1676 | ReassociatingReshapeOpConversion<memref::CollapseShapeOp>, |
1677 | StoreOpLowering, |
1678 | SubViewOpLowering, |
1679 | TransposeOpLowering, |
1680 | ViewOpLowering>(converter); |
1681 | // clang-format on |
1682 | auto allocLowering = converter.getOptions().allocLowering; |
1683 | if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc) |
1684 | patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(arg&: converter); |
1685 | else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc) |
1686 | patterns.add<AllocOpLowering, DeallocOpLowering>(arg&: converter); |
1687 | } |
1688 | |
1689 | namespace { |
1690 | struct FinalizeMemRefToLLVMConversionPass |
1691 | : public impl::FinalizeMemRefToLLVMConversionPassBase< |
1692 | FinalizeMemRefToLLVMConversionPass> { |
1693 | using FinalizeMemRefToLLVMConversionPassBase:: |
1694 | FinalizeMemRefToLLVMConversionPassBase; |
1695 | |
1696 | void runOnOperation() override { |
1697 | Operation *op = getOperation(); |
1698 | const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>(); |
1699 | LowerToLLVMOptions options(&getContext(), |
1700 | dataLayoutAnalysis.getAtOrAbove(op)); |
1701 | options.allocLowering = |
1702 | (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc |
1703 | : LowerToLLVMOptions::AllocLowering::Malloc); |
1704 | |
1705 | options.useGenericFunctions = useGenericFunctions; |
1706 | |
1707 | if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) |
1708 | options.overrideIndexBitwidth(indexBitwidth); |
1709 | |
1710 | LLVMTypeConverter typeConverter(&getContext(), options, |
1711 | &dataLayoutAnalysis); |
1712 | RewritePatternSet patterns(&getContext()); |
1713 | populateFinalizeMemRefToLLVMConversionPatterns(converter&: typeConverter, patterns); |
1714 | LLVMConversionTarget target(getContext()); |
1715 | target.addLegalOp<func::FuncOp>(); |
1716 | if (failed(applyPartialConversion(op, target, std::move(patterns)))) |
1717 | signalPassFailure(); |
1718 | } |
1719 | }; |
1720 | |
1721 | /// Implement the interface to convert MemRef to LLVM. |
1722 | struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
1723 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
1724 | void loadDependentDialects(MLIRContext *context) const final { |
1725 | context->loadDialect<LLVM::LLVMDialect>(); |
1726 | } |
1727 | |
1728 | /// Hook for derived dialect interface to provide conversion patterns |
1729 | /// and mark dialect legal for the conversion target. |
1730 | void populateConvertToLLVMConversionPatterns( |
1731 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
1732 | RewritePatternSet &patterns) const final { |
1733 | populateFinalizeMemRefToLLVMConversionPatterns(converter&: typeConverter, patterns); |
1734 | } |
1735 | }; |
1736 | |
1737 | } // namespace |
1738 | |
1739 | void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry ®istry) { |
1740 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) { |
1741 | dialect->addInterfaces<MemRefToLLVMDialectInterface>(); |
1742 | }); |
1743 | } |
1744 | |