1//===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10
11#include "mlir/Analysis/DataLayoutAnalysis.h"
12#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14#include "mlir/Conversion/LLVMCommon/Pattern.h"
15#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/BuiltinTypes.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/Pass/Pass.h"
27#include "llvm/Support/MathExtras.h"
28#include <optional>
29
30#define DEBUG_TYPE "memref-to-llvm"
31#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
32
33namespace mlir {
34#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
35#include "mlir/Conversion/Passes.h.inc"
36} // namespace mlir
37
38using namespace mlir;
39
40static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
41 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
42
43namespace {
44
45static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
46 return ShapedType::isStatic(dValue: strideOrOffset);
47}
48
49static FailureOr<LLVM::LLVMFuncOp>
50getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter, ModuleOp module,
51 SymbolTableCollection *symbolTables) {
52 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
53
54 if (useGenericFn)
55 return LLVM::lookupOrCreateGenericFreeFn(b, moduleOp: module, symbolTables);
56
57 return LLVM::lookupOrCreateFreeFn(b, moduleOp: module, symbolTables);
58}
59
60static FailureOr<LLVM::LLVMFuncOp>
61getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
62 Operation *module, Type indexType,
63 SymbolTableCollection *symbolTables) {
64 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
65 if (useGenericFn)
66 return LLVM::lookupOrCreateGenericAllocFn(b, moduleOp: module, indexType,
67 symbolTables);
68
69 return LLVM::lookupOrCreateMallocFn(b, moduleOp: module, indexType, symbolTables);
70}
71
72static FailureOr<LLVM::LLVMFuncOp>
73getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
74 Operation *module, Type indexType,
75 SymbolTableCollection *symbolTables) {
76 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
77
78 if (useGenericFn)
79 return LLVM::lookupOrCreateGenericAlignedAllocFn(b, moduleOp: module, indexType,
80 symbolTables);
81
82 return LLVM::lookupOrCreateAlignedAllocFn(b, moduleOp: module, indexType, symbolTables);
83}
84
85/// Computes the aligned value for 'input' as follows:
86/// bumped = input + alignement - 1
87/// aligned = bumped - bumped % alignment
88static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
89 Value input, Value alignment) {
90 Value one = rewriter.create<LLVM::ConstantOp>(location: loc, args: alignment.getType(),
91 args: rewriter.getIndexAttr(value: 1));
92 Value bump = rewriter.create<LLVM::SubOp>(location: loc, args&: alignment, args&: one);
93 Value bumped = rewriter.create<LLVM::AddOp>(location: loc, args&: input, args&: bump);
94 Value mod = rewriter.create<LLVM::URemOp>(location: loc, args&: bumped, args&: alignment);
95 return rewriter.create<LLVM::SubOp>(location: loc, args&: bumped, args&: mod);
96}
97
98/// Computes the byte size for the MemRef element type.
99static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
100 MemRefType memRefType, Operation *op,
101 const DataLayout *defaultLayout) {
102 const DataLayout *layout = defaultLayout;
103 if (const DataLayoutAnalysis *analysis =
104 typeConverter->getDataLayoutAnalysis()) {
105 layout = &analysis->getAbove(operation: op);
106 }
107 Type elementType = memRefType.getElementType();
108 if (auto memRefElementType = dyn_cast<MemRefType>(Val&: elementType))
109 return typeConverter->getMemRefDescriptorSize(type: memRefElementType, layout: *layout);
110 if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(Val&: elementType))
111 return typeConverter->getUnrankedMemRefDescriptorSize(type: memRefElementType,
112 layout: *layout);
113 return layout->getTypeSize(t: elementType);
114}
115
116static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
117 Location loc, Value allocatedPtr,
118 MemRefType memRefType, Type elementPtrType,
119 const LLVMTypeConverter &typeConverter) {
120 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(Val: allocatedPtr.getType());
121 FailureOr<unsigned> maybeMemrefAddrSpace =
122 typeConverter.getMemRefAddressSpace(type: memRefType);
123 assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
124 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
125 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
126 allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
127 location: loc, args: LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace: memrefAddrSpace),
128 args&: allocatedPtr);
129 return allocatedPtr;
130}
131
132class AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
133 SymbolTableCollection *symbolTables = nullptr;
134
135public:
136 explicit AllocOpLowering(const LLVMTypeConverter &typeConverter,
137 SymbolTableCollection *symbolTables = nullptr,
138 PatternBenefit benefit = 1)
139 : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
140 symbolTables(symbolTables) {}
141
142 LogicalResult
143 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
144 ConversionPatternRewriter &rewriter) const override {
145 auto loc = op.getLoc();
146 MemRefType memRefType = op.getType();
147 if (!isConvertibleAndHasIdentityMaps(type: memRefType))
148 return rewriter.notifyMatchFailure(arg&: op, msg: "incompatible memref type");
149
150 // Get or insert alloc function into the module.
151 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
152 getNotalignedAllocFn(b&: rewriter, typeConverter: getTypeConverter(),
153 module: op->getParentWithTrait<OpTrait::SymbolTable>(),
154 indexType: getIndexType(), symbolTables);
155 if (failed(Result: allocFuncOp))
156 return failure();
157
158 // Get actual sizes of the memref as values: static sizes are constant
159 // values and dynamic sizes are passed to 'alloc' as operands. In case of
160 // zero-dimensional memref, assume a scalar (size 1).
161 SmallVector<Value, 4> sizes;
162 SmallVector<Value, 4> strides;
163 Value sizeBytes;
164
165 this->getMemRefDescriptorSizes(loc, memRefType, dynamicSizes: adaptor.getOperands(),
166 rewriter, sizes, strides, size&: sizeBytes, sizeInBytes: true);
167
168 Value alignment = getAlignment(rewriter, loc, op);
169 if (alignment) {
170 // Adjust the allocation size to consider alignment.
171 sizeBytes = rewriter.create<LLVM::AddOp>(location: loc, args&: sizeBytes, args&: alignment);
172 }
173
174 // Allocate the underlying buffer.
175 Type elementPtrType = this->getElementPtrType(type: memRefType);
176 assert(elementPtrType && "could not compute element ptr type");
177 auto results =
178 rewriter.create<LLVM::CallOp>(location: loc, args&: allocFuncOp.value(), args&: sizeBytes);
179
180 Value allocatedPtr =
181 castAllocFuncResult(rewriter, loc, allocatedPtr: results.getResult(), memRefType,
182 elementPtrType, typeConverter: *getTypeConverter());
183 Value alignedPtr = allocatedPtr;
184 if (alignment) {
185 // Compute the aligned pointer.
186 Value allocatedInt =
187 rewriter.create<LLVM::PtrToIntOp>(location: loc, args: getIndexType(), args&: allocatedPtr);
188 Value alignmentInt =
189 createAligned(rewriter, loc, input: allocatedInt, alignment);
190 alignedPtr =
191 rewriter.create<LLVM::IntToPtrOp>(location: loc, args&: elementPtrType, args&: alignmentInt);
192 }
193
194 // Create the MemRef descriptor.
195 auto memRefDescriptor = this->createMemRefDescriptor(
196 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
197
198 // Return the final value of the descriptor.
199 rewriter.replaceOp(op, newValues: {memRefDescriptor});
200 return success();
201 }
202
203 /// Computes the alignment for the given memory allocation op.
204 template <typename OpType>
205 Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
206 OpType op) const {
207 MemRefType memRefType = op.getType();
208 Value alignment;
209 if (auto alignmentAttr = op.getAlignment()) {
210 Type indexType = getIndexType();
211 alignment =
212 createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: *alignmentAttr);
213 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
214 // In the case where no alignment is specified, we may want to override
215 // `malloc's` behavior. `malloc` typically aligns at the size of the
216 // biggest scalar on a target HW. For non-scalars, use the natural
217 // alignment of the LLVM type given by the LLVM DataLayout.
218 alignment = getSizeInBytes(loc, type: memRefType.getElementType(), rewriter);
219 }
220 return alignment;
221 }
222};
223
224class AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
225 SymbolTableCollection *symbolTables = nullptr;
226
227public:
228 explicit AlignedAllocOpLowering(const LLVMTypeConverter &typeConverter,
229 SymbolTableCollection *symbolTables = nullptr,
230 PatternBenefit benefit = 1)
231 : ConvertOpToLLVMPattern<memref::AllocOp>(typeConverter, benefit),
232 symbolTables(symbolTables) {}
233
234 LogicalResult
235 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
236 ConversionPatternRewriter &rewriter) const override {
237 auto loc = op.getLoc();
238 MemRefType memRefType = op.getType();
239 if (!isConvertibleAndHasIdentityMaps(type: memRefType))
240 return rewriter.notifyMatchFailure(arg&: op, msg: "incompatible memref type");
241
242 // Get or insert alloc function into module.
243 FailureOr<LLVM::LLVMFuncOp> allocFuncOp =
244 getAlignedAllocFn(b&: rewriter, typeConverter: getTypeConverter(),
245 module: op->getParentWithTrait<OpTrait::SymbolTable>(),
246 indexType: getIndexType(), symbolTables);
247 if (failed(Result: allocFuncOp))
248 return failure();
249
250 // Get actual sizes of the memref as values: static sizes are constant
251 // values and dynamic sizes are passed to 'alloc' as operands. In case of
252 // zero-dimensional memref, assume a scalar (size 1).
253 SmallVector<Value, 4> sizes;
254 SmallVector<Value, 4> strides;
255 Value sizeBytes;
256
257 this->getMemRefDescriptorSizes(loc, memRefType, dynamicSizes: adaptor.getOperands(),
258 rewriter, sizes, strides, size&: sizeBytes, sizeInBytes: !false);
259
260 int64_t alignment = alignedAllocationGetAlignment(op, defaultLayout: &defaultLayout);
261
262 Value allocAlignment =
263 createIndexAttrConstant(builder&: rewriter, loc, resultType: getIndexType(), value: alignment);
264
265 // Function aligned_alloc requires size to be a multiple of alignment; we
266 // pad the size to the next multiple if necessary.
267 if (!isMemRefSizeMultipleOf(type: memRefType, factor: alignment, op, defaultLayout: &defaultLayout))
268 sizeBytes = createAligned(rewriter, loc, input: sizeBytes, alignment: allocAlignment);
269
270 Type elementPtrType = this->getElementPtrType(type: memRefType);
271 auto results = rewriter.create<LLVM::CallOp>(
272 location: loc, args&: allocFuncOp.value(), args: ValueRange({allocAlignment, sizeBytes}));
273
274 Value ptr =
275 castAllocFuncResult(rewriter, loc, allocatedPtr: results.getResult(), memRefType,
276 elementPtrType, typeConverter: *getTypeConverter());
277
278 // Create the MemRef descriptor.
279 auto memRefDescriptor = this->createMemRefDescriptor(
280 loc, memRefType, allocatedPtr: ptr, alignedPtr: ptr, sizes, strides, rewriter);
281
282 // Return the final value of the descriptor.
283 rewriter.replaceOp(op, newValues: {memRefDescriptor});
284 return success();
285 }
286
287 /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
288 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
289
290 /// Computes the alignment for aligned_alloc used to allocate the buffer for
291 /// the memory allocation op.
292 ///
293 /// Aligned_alloc requires the allocation size to be a power of two, and the
294 /// allocation size to be a multiple of the alignment.
295 int64_t alignedAllocationGetAlignment(memref::AllocOp op,
296 const DataLayout *defaultLayout) const {
297 if (std::optional<uint64_t> alignment = op.getAlignment())
298 return *alignment;
299
300 // Whenever we don't have alignment set, we will use an alignment
301 // consistent with the element type; since the allocation size has to be a
302 // power of two, we will bump to the next power of two if it isn't.
303 unsigned eltSizeBytes = getMemRefEltSizeInBytes(
304 typeConverter: getTypeConverter(), memRefType: op.getType(), op, defaultLayout);
305 return std::max(a: kMinAlignedAllocAlignment,
306 b: llvm::PowerOf2Ceil(A: eltSizeBytes));
307 }
308
309 /// Returns true if the memref size in bytes is known to be a multiple of
310 /// factor.
311 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
312 const DataLayout *defaultLayout) const {
313 uint64_t sizeDivisor =
314 getMemRefEltSizeInBytes(typeConverter: getTypeConverter(), memRefType: type, op, defaultLayout);
315 for (unsigned i = 0, e = type.getRank(); i < e; i++) {
316 if (type.isDynamicDim(idx: i))
317 continue;
318 sizeDivisor = sizeDivisor * type.getDimSize(idx: i);
319 }
320 return sizeDivisor % factor == 0;
321 }
322
323private:
324 /// Default layout to use in absence of the corresponding analysis.
325 DataLayout defaultLayout;
326};
327
328struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
329 using ConvertOpToLLVMPattern<memref::AllocaOp>::ConvertOpToLLVMPattern;
330
331 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
332 /// is set to null for stack allocations. `accessAlignment` is set if
333 /// alignment is needed post allocation (for eg. in conjunction with malloc).
334 LogicalResult
335 matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
336 ConversionPatternRewriter &rewriter) const override {
337 auto loc = op.getLoc();
338 MemRefType memRefType = op.getType();
339 if (!isConvertibleAndHasIdentityMaps(type: memRefType))
340 return rewriter.notifyMatchFailure(arg&: op, msg: "incompatible memref type");
341
342 // Get actual sizes of the memref as values: static sizes are constant
343 // values and dynamic sizes are passed to 'alloc' as operands. In case of
344 // zero-dimensional memref, assume a scalar (size 1).
345 SmallVector<Value, 4> sizes;
346 SmallVector<Value, 4> strides;
347 Value size;
348
349 this->getMemRefDescriptorSizes(loc, memRefType, dynamicSizes: adaptor.getOperands(),
350 rewriter, sizes, strides, size, sizeInBytes: !true);
351
352 // With alloca, one gets a pointer to the element type right away.
353 // For stack allocations.
354 auto elementType =
355 typeConverter->convertType(t: op.getType().getElementType());
356 FailureOr<unsigned> maybeAddressSpace =
357 getTypeConverter()->getMemRefAddressSpace(type: op.getType());
358 assert(succeeded(maybeAddressSpace) && "unsupported address space");
359 unsigned addrSpace = *maybeAddressSpace;
360 auto elementPtrType =
361 LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace: addrSpace);
362
363 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
364 location: loc, args&: elementPtrType, args&: elementType, args&: size, args: op.getAlignment().value_or(u: 0));
365
366 // Create the MemRef descriptor.
367 auto memRefDescriptor = this->createMemRefDescriptor(
368 loc, memRefType, allocatedPtr: allocatedElementPtr, alignedPtr: allocatedElementPtr, sizes,
369 strides, rewriter);
370
371 // Return the final value of the descriptor.
372 rewriter.replaceOp(op, newValues: {memRefDescriptor});
373 return success();
374 }
375};
376
377struct AllocaScopeOpLowering
378 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
379 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
380
381 LogicalResult
382 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
383 ConversionPatternRewriter &rewriter) const override {
384 OpBuilder::InsertionGuard guard(rewriter);
385 Location loc = allocaScopeOp.getLoc();
386
387 // Split the current block before the AllocaScopeOp to create the inlining
388 // point.
389 auto *currentBlock = rewriter.getInsertionBlock();
390 auto *remainingOpsBlock =
391 rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint());
392 Block *continueBlock;
393 if (allocaScopeOp.getNumResults() == 0) {
394 continueBlock = remainingOpsBlock;
395 } else {
396 continueBlock = rewriter.createBlock(
397 insertBefore: remainingOpsBlock, argTypes: allocaScopeOp.getResultTypes(),
398 locs: SmallVector<Location>(allocaScopeOp->getNumResults(),
399 allocaScopeOp.getLoc()));
400 rewriter.create<LLVM::BrOp>(location: loc, args: ValueRange(), args&: remainingOpsBlock);
401 }
402
403 // Inline body region.
404 Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
405 Block *afterBody = &allocaScopeOp.getBodyRegion().back();
406 rewriter.inlineRegionBefore(region&: allocaScopeOp.getBodyRegion(), before: continueBlock);
407
408 // Save stack and then branch into the body of the region.
409 rewriter.setInsertionPointToEnd(currentBlock);
410 auto stackSaveOp = rewriter.create<LLVM::StackSaveOp>(location: loc, args: getPtrType());
411 rewriter.create<LLVM::BrOp>(location: loc, args: ValueRange(), args&: beforeBody);
412
413 // Replace the alloca_scope return with a branch that jumps out of the body.
414 // Stack restore before leaving the body region.
415 rewriter.setInsertionPointToEnd(afterBody);
416 auto returnOp =
417 cast<memref::AllocaScopeReturnOp>(Val: afterBody->getTerminator());
418 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
419 op: returnOp, args: returnOp.getResults(), args&: continueBlock);
420
421 // Insert stack restore before jumping out the body of the region.
422 rewriter.setInsertionPoint(branchOp);
423 rewriter.create<LLVM::StackRestoreOp>(location: loc, args&: stackSaveOp);
424
425 // Replace the op with values return from the body region.
426 rewriter.replaceOp(op: allocaScopeOp, newValues: continueBlock->getArguments());
427
428 return success();
429 }
430};
431
432struct AssumeAlignmentOpLowering
433 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
434 using ConvertOpToLLVMPattern<
435 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
436 explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
437 : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
438
439 LogicalResult
440 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
441 ConversionPatternRewriter &rewriter) const override {
442 Value memref = adaptor.getMemref();
443 unsigned alignment = op.getAlignment();
444 auto loc = op.getLoc();
445
446 auto srcMemRefType = cast<MemRefType>(Val: op.getMemref().getType());
447 Value ptr = getStridedElementPtr(rewriter, loc, type: srcMemRefType, memRefDesc: memref,
448 /*indices=*/{});
449
450 // Emit llvm.assume(true) ["align"(memref, alignment)].
451 // This is more direct than ptrtoint-based checks, is explicitly supported,
452 // and works with non-integral address spaces.
453 Value trueCond =
454 rewriter.create<LLVM::ConstantOp>(location: loc, args: rewriter.getBoolAttr(value: true));
455 Value alignmentConst =
456 createIndexAttrConstant(builder&: rewriter, loc, resultType: getIndexType(), value: alignment);
457 rewriter.create<LLVM::AssumeOp>(location: loc, args&: trueCond, args: LLVM::AssumeAlignTag(), args&: ptr,
458 args&: alignmentConst);
459 rewriter.replaceOp(op, newValues: memref);
460 return success();
461 }
462};
463
464// A `dealloc` is converted into a call to `free` on the underlying data buffer.
465// The memref descriptor being an SSA value, there is no need to clean it up
466// in any way.
467class DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
468 SymbolTableCollection *symbolTables = nullptr;
469
470public:
471 explicit DeallocOpLowering(const LLVMTypeConverter &typeConverter,
472 SymbolTableCollection *symbolTables = nullptr,
473 PatternBenefit benefit = 1)
474 : ConvertOpToLLVMPattern<memref::DeallocOp>(typeConverter, benefit),
475 symbolTables(symbolTables) {}
476
477 LogicalResult
478 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
479 ConversionPatternRewriter &rewriter) const override {
480 // Insert the `free` declaration if it is not already present.
481 FailureOr<LLVM::LLVMFuncOp> freeFunc =
482 getFreeFn(b&: rewriter, typeConverter: getTypeConverter(), module: op->getParentOfType<ModuleOp>(),
483 symbolTables);
484 if (failed(Result: freeFunc))
485 return failure();
486 Value allocatedPtr;
487 if (auto unrankedTy =
488 llvm::dyn_cast<UnrankedMemRefType>(Val: op.getMemref().getType())) {
489 auto elementPtrTy = LLVM::LLVMPointerType::get(
490 context: rewriter.getContext(), addressSpace: unrankedTy.getMemorySpaceAsInt());
491 allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
492 builder&: rewriter, loc: op.getLoc(),
493 memRefDescPtr: UnrankedMemRefDescriptor(adaptor.getMemref())
494 .memRefDescPtr(builder&: rewriter, loc: op.getLoc()),
495 elemPtrType: elementPtrTy);
496 } else {
497 allocatedPtr = MemRefDescriptor(adaptor.getMemref())
498 .allocatedPtr(builder&: rewriter, loc: op.getLoc());
499 }
500 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, args&: freeFunc.value(),
501 args&: allocatedPtr);
502 return success();
503 }
504};
505
506// A `dim` is converted to a constant for static sizes and to an access to the
507// size stored in the memref descriptor for dynamic sizes.
508struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
509 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
510
511 LogicalResult
512 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
513 ConversionPatternRewriter &rewriter) const override {
514 Type operandType = dimOp.getSource().getType();
515 if (isa<UnrankedMemRefType>(Val: operandType)) {
516 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
517 operandType, dimOp, adaptor: adaptor.getOperands(), rewriter);
518 if (failed(Result: extractedSize))
519 return failure();
520 rewriter.replaceOp(op: dimOp, newValues: {*extractedSize});
521 return success();
522 }
523 if (isa<MemRefType>(Val: operandType)) {
524 rewriter.replaceOp(
525 op: dimOp, newValues: {extractSizeOfRankedMemRef(operandType, dimOp,
526 adaptor: adaptor.getOperands(), rewriter)});
527 return success();
528 }
529 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
530 }
531
532private:
533 FailureOr<Value>
534 extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
535 OpAdaptor adaptor,
536 ConversionPatternRewriter &rewriter) const {
537 Location loc = dimOp.getLoc();
538
539 auto unrankedMemRefType = cast<UnrankedMemRefType>(Val&: operandType);
540 auto scalarMemRefType =
541 MemRefType::get(shape: {}, elementType: unrankedMemRefType.getElementType());
542 FailureOr<unsigned> maybeAddressSpace =
543 getTypeConverter()->getMemRefAddressSpace(type: unrankedMemRefType);
544 if (failed(Result: maybeAddressSpace)) {
545 dimOp.emitOpError(message: "memref memory space must be convertible to an integer "
546 "address space");
547 return failure();
548 }
549 unsigned addressSpace = *maybeAddressSpace;
550
551 // Extract pointer to the underlying ranked descriptor and bitcast it to a
552 // memref<element_type> descriptor pointer to minimize the number of GEP
553 // operations.
554 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
555 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(builder&: rewriter, loc);
556
557 Type elementType = typeConverter->convertType(t: scalarMemRefType);
558
559 // Get pointer to offset field of memref<element_type> descriptor.
560 auto indexPtrTy =
561 LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace);
562 Value offsetPtr = rewriter.create<LLVM::GEPOp>(
563 location: loc, args&: indexPtrTy, args&: elementType, args&: underlyingRankedDesc,
564 args: ArrayRef<LLVM::GEPArg>{0, 2});
565
566 // The size value that we have to extract can be obtained using GEPop with
567 // `dimOp.index() + 1` index argument.
568 Value idxPlusOne = rewriter.create<LLVM::AddOp>(
569 location: loc, args: createIndexAttrConstant(builder&: rewriter, loc, resultType: getIndexType(), value: 1),
570 args: adaptor.getIndex());
571 Value sizePtr = rewriter.create<LLVM::GEPOp>(
572 location: loc, args&: indexPtrTy, args: getTypeConverter()->getIndexType(), args&: offsetPtr,
573 args&: idxPlusOne);
574 return rewriter
575 .create<LLVM::LoadOp>(location: loc, args: getTypeConverter()->getIndexType(), args&: sizePtr)
576 .getResult();
577 }
578
579 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
580 if (auto idx = dimOp.getConstantIndex())
581 return idx;
582
583 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
584 return cast<IntegerAttr>(Val: constantOp.getValue()).getValue().getSExtValue();
585
586 return std::nullopt;
587 }
588
589 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
590 OpAdaptor adaptor,
591 ConversionPatternRewriter &rewriter) const {
592 Location loc = dimOp.getLoc();
593
594 // Take advantage if index is constant.
595 MemRefType memRefType = cast<MemRefType>(Val&: operandType);
596 Type indexType = getIndexType();
597 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
598 int64_t i = *index;
599 if (i >= 0 && i < memRefType.getRank()) {
600 if (memRefType.isDynamicDim(idx: i)) {
601 // extract dynamic size from the memref descriptor.
602 MemRefDescriptor descriptor(adaptor.getSource());
603 return descriptor.size(builder&: rewriter, loc, pos: i);
604 }
605 // Use constant for static size.
606 int64_t dimSize = memRefType.getDimSize(idx: i);
607 return createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: dimSize);
608 }
609 }
610 Value index = adaptor.getIndex();
611 int64_t rank = memRefType.getRank();
612 MemRefDescriptor memrefDescriptor(adaptor.getSource());
613 return memrefDescriptor.size(builder&: rewriter, loc, pos: index, rank);
614 }
615};
616
617/// Common base for load and store operations on MemRefs. Restricts the match
618/// to supported MemRef types. Provides functionality to emit code accessing a
619/// specific element of the underlying data buffer.
620template <typename Derived>
621struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
622 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
623 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
624 using Base = LoadStoreOpLowering<Derived>;
625};
626
627/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
628/// retried until it succeeds in atomically storing a new value into memory.
629///
630/// +---------------------------------+
631/// | <code before the AtomicRMWOp> |
632/// | <compute initial %loaded> |
633/// | cf.br loop(%loaded) |
634/// +---------------------------------+
635/// |
636/// -------| |
637/// | v v
638/// | +--------------------------------+
639/// | | loop(%loaded): |
640/// | | <body contents> |
641/// | | %pair = cmpxchg |
642/// | | %ok = %pair[0] |
643/// | | %new = %pair[1] |
644/// | | cf.cond_br %ok, end, loop(%new) |
645/// | +--------------------------------+
646/// | | |
647/// |----------- |
648/// v
649/// +--------------------------------+
650/// | end: |
651/// | <code after the AtomicRMWOp> |
652/// +--------------------------------+
653///
654struct GenericAtomicRMWOpLowering
655 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
656 using Base::Base;
657
658 LogicalResult
659 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
660 ConversionPatternRewriter &rewriter) const override {
661 auto loc = atomicOp.getLoc();
662 Type valueType = typeConverter->convertType(t: atomicOp.getResult().getType());
663
664 // Split the block into initial, loop, and ending parts.
665 auto *initBlock = rewriter.getInsertionBlock();
666 auto *loopBlock = rewriter.splitBlock(block: initBlock, before: Block::iterator(atomicOp));
667 loopBlock->addArgument(type: valueType, loc);
668
669 auto *endBlock =
670 rewriter.splitBlock(block: loopBlock, before: Block::iterator(atomicOp)++);
671
672 // Compute the loaded value and branch to the loop block.
673 rewriter.setInsertionPointToEnd(initBlock);
674 auto memRefType = cast<MemRefType>(Val: atomicOp.getMemref().getType());
675 auto dataPtr = getStridedElementPtr(
676 rewriter, loc, type: memRefType, memRefDesc: adaptor.getMemref(), indices: adaptor.getIndices());
677 Value init = rewriter.create<LLVM::LoadOp>(
678 location: loc, args: typeConverter->convertType(t: memRefType.getElementType()), args&: dataPtr);
679 rewriter.create<LLVM::BrOp>(location: loc, args&: init, args&: loopBlock);
680
681 // Prepare the body of the loop block.
682 rewriter.setInsertionPointToStart(loopBlock);
683
684 // Clone the GenericAtomicRMWOp region and extract the result.
685 auto loopArgument = loopBlock->getArgument(i: 0);
686 IRMapping mapping;
687 mapping.map(from: atomicOp.getCurrentValue(), to: loopArgument);
688 Block &entryBlock = atomicOp.body().front();
689 for (auto &nestedOp : entryBlock.without_terminator()) {
690 Operation *clone = rewriter.clone(op&: nestedOp, mapper&: mapping);
691 mapping.map(from: nestedOp.getResults(), to: clone->getResults());
692 }
693 Value result = mapping.lookup(from: entryBlock.getTerminator()->getOperand(idx: 0));
694
695 // Prepare the epilog of the loop block.
696 // Append the cmpxchg op to the end of the loop block.
697 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
698 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
699 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
700 location: loc, args&: dataPtr, args&: loopArgument, args&: result, args&: successOrdering, args&: failureOrdering);
701 // Extract the %new_loaded and %ok values from the pair.
702 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(location: loc, args&: cmpxchg, args: 0);
703 Value ok = rewriter.create<LLVM::ExtractValueOp>(location: loc, args&: cmpxchg, args: 1);
704
705 // Conditionally branch to the end or back to the loop depending on %ok.
706 rewriter.create<LLVM::CondBrOp>(location: loc, args&: ok, args&: endBlock, args: ArrayRef<Value>(),
707 args&: loopBlock, args&: newLoaded);
708
709 rewriter.setInsertionPointToEnd(endBlock);
710
711 // The 'result' of the atomic_rmw op is the newly loaded value.
712 rewriter.replaceOp(op: atomicOp, newValues: {newLoaded});
713
714 return success();
715 }
716};
717
718/// Returns the LLVM type of the global variable given the memref type `type`.
719static Type
720convertGlobalMemrefTypeToLLVM(MemRefType type,
721 const LLVMTypeConverter &typeConverter) {
722 // LLVM type for a global memref will be a multi-dimension array. For
723 // declarations or uninitialized global memrefs, we can potentially flatten
724 // this to a 1D array. However, for memref.global's with an initial value,
725 // we do not intend to flatten the ElementsAttribute when going from std ->
726 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
727 Type elementType = typeConverter.convertType(t: type.getElementType());
728 Type arrayTy = elementType;
729 // Shape has the outermost dim at index 0, so need to walk it backwards
730 for (int64_t dim : llvm::reverse(C: type.getShape()))
731 arrayTy = LLVM::LLVMArrayType::get(elementType: arrayTy, numElements: dim);
732 return arrayTy;
733}
734
735/// GlobalMemrefOp is lowered to a LLVM Global Variable.
736class GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<memref::GlobalOp> {
737 SymbolTableCollection *symbolTables = nullptr;
738
739public:
740 explicit GlobalMemrefOpLowering(const LLVMTypeConverter &typeConverter,
741 SymbolTableCollection *symbolTables = nullptr,
742 PatternBenefit benefit = 1)
743 : ConvertOpToLLVMPattern<memref::GlobalOp>(typeConverter, benefit),
744 symbolTables(symbolTables) {}
745
746 LogicalResult
747 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
748 ConversionPatternRewriter &rewriter) const override {
749 MemRefType type = global.getType();
750 if (!isConvertibleAndHasIdentityMaps(type))
751 return failure();
752
753 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter: *getTypeConverter());
754
755 LLVM::Linkage linkage =
756 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
757 bool isExternal = global.isExternal();
758 bool isUninitialized = global.isUninitialized();
759
760 Attribute initialValue = nullptr;
761 if (!isExternal && !isUninitialized) {
762 auto elementsAttr = llvm::cast<ElementsAttr>(Val: *global.getInitialValue());
763 initialValue = elementsAttr;
764
765 // For scalar memrefs, the global variable created is of the element type,
766 // so unpack the elements attribute to extract the value.
767 if (type.getRank() == 0)
768 initialValue = elementsAttr.getSplatValue<Attribute>();
769 }
770
771 uint64_t alignment = global.getAlignment().value_or(u: 0);
772 FailureOr<unsigned> addressSpace =
773 getTypeConverter()->getMemRefAddressSpace(type);
774 if (failed(Result: addressSpace))
775 return global.emitOpError(
776 message: "memory space cannot be converted to an integer address space");
777
778 // Remove old operation from symbol table.
779 SymbolTable *symbolTable = nullptr;
780 if (symbolTables) {
781 Operation *symbolTableOp =
782 global->getParentWithTrait<OpTrait::SymbolTable>();
783 symbolTable = &symbolTables->getSymbolTable(op: symbolTableOp);
784 symbolTable->remove(op: global);
785 }
786
787 // Create new operation.
788 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
789 op: global, args&: arrayTy, args: global.getConstant(), args&: linkage, args: global.getSymName(),
790 args&: initialValue, args&: alignment, args&: *addressSpace);
791
792 // Insert new operation into symbol table.
793 if (symbolTable)
794 symbolTable->insert(symbol: newGlobal, insertPt: rewriter.getInsertionPoint());
795
796 if (!isExternal && isUninitialized) {
797 rewriter.createBlock(parent: &newGlobal.getInitializerRegion());
798 Value undef[] = {
799 rewriter.create<LLVM::UndefOp>(location: newGlobal.getLoc(), args&: arrayTy)};
800 rewriter.create<LLVM::ReturnOp>(location: newGlobal.getLoc(), args&: undef);
801 }
802 return success();
803 }
804};
805
806/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
807/// the first element stashed into the descriptor. This reuses
808/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
809struct GetGlobalMemrefOpLowering
810 : public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
811 using ConvertOpToLLVMPattern<memref::GetGlobalOp>::ConvertOpToLLVMPattern;
812
813 /// Buffer "allocation" for memref.get_global op is getting the address of
814 /// the global variable referenced.
815 LogicalResult
816 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
817 ConversionPatternRewriter &rewriter) const override {
818 auto loc = op.getLoc();
819 MemRefType memRefType = op.getType();
820 if (!isConvertibleAndHasIdentityMaps(type: memRefType))
821 return rewriter.notifyMatchFailure(arg&: op, msg: "incompatible memref type");
822
823 // Get actual sizes of the memref as values: static sizes are constant
824 // values and dynamic sizes are passed to 'alloc' as operands. In case of
825 // zero-dimensional memref, assume a scalar (size 1).
826 SmallVector<Value, 4> sizes;
827 SmallVector<Value, 4> strides;
828 Value sizeBytes;
829
830 this->getMemRefDescriptorSizes(loc, memRefType, dynamicSizes: adaptor.getOperands(),
831 rewriter, sizes, strides, size&: sizeBytes, sizeInBytes: !false);
832
833 MemRefType type = cast<MemRefType>(Val: op.getResult().getType());
834
835 // This is called after a type conversion, which would have failed if this
836 // call fails.
837 FailureOr<unsigned> maybeAddressSpace =
838 getTypeConverter()->getMemRefAddressSpace(type);
839 assert(succeeded(maybeAddressSpace) && "unsupported address space");
840 unsigned memSpace = *maybeAddressSpace;
841
842 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter: *getTypeConverter());
843 auto ptrTy = LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace: memSpace);
844 auto addressOf =
845 rewriter.create<LLVM::AddressOfOp>(location: loc, args&: ptrTy, args: op.getName());
846
847 // Get the address of the first element in the array by creating a GEP with
848 // the address of the GV as the base, and (rank + 1) number of 0 indices.
849 auto gep = rewriter.create<LLVM::GEPOp>(
850 location: loc, args&: ptrTy, args&: arrayTy, args&: addressOf,
851 args: SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
852
853 // We do not expect the memref obtained using `memref.get_global` to be
854 // ever deallocated. Set the allocated pointer to be known bad value to
855 // help debug if that ever happens.
856 auto intPtrType = getIntPtrType(addressSpace: memSpace);
857 Value deadBeefConst =
858 createIndexAttrConstant(builder&: rewriter, loc: op->getLoc(), resultType: intPtrType, value: 0xdeadbeef);
859 auto deadBeefPtr =
860 rewriter.create<LLVM::IntToPtrOp>(location: loc, args&: ptrTy, args&: deadBeefConst);
861
862 // Both allocated and aligned pointers are same. We could potentially stash
863 // a nullptr for the allocated pointer since we do not expect any dealloc.
864 // Create the MemRef descriptor.
865 auto memRefDescriptor = this->createMemRefDescriptor(
866 loc, memRefType, allocatedPtr: deadBeefPtr, alignedPtr: gep, sizes, strides, rewriter);
867
868 // Return the final value of the descriptor.
869 rewriter.replaceOp(op, newValues: {memRefDescriptor});
870 return success();
871 }
872};
873
874// Load operation is lowered to obtaining a pointer to the indexed element
875// and loading it.
876struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
877 using Base::Base;
878
879 LogicalResult
880 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
881 ConversionPatternRewriter &rewriter) const override {
882 auto type = loadOp.getMemRefType();
883
884 // Per memref.load spec, the indices must be in-bounds:
885 // 0 <= idx < dim_size, and additionally all offsets are non-negative,
886 // hence inbounds and nuw are used when lowering to llvm.getelementptr.
887 Value dataPtr = getStridedElementPtr(rewriter, loc: loadOp.getLoc(), type,
888 memRefDesc: adaptor.getMemref(),
889 indices: adaptor.getIndices(), noWrapFlags: kNoWrapFlags);
890 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
891 op: loadOp, args: typeConverter->convertType(t: type.getElementType()), args&: dataPtr, args: 0,
892 args: false, args: loadOp.getNontemporal());
893 return success();
894 }
895};
896
897// Store operation is lowered to obtaining a pointer to the indexed element,
898// and storing the given value to it.
899struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
900 using Base::Base;
901
902 LogicalResult
903 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
904 ConversionPatternRewriter &rewriter) const override {
905 auto type = op.getMemRefType();
906
907 // Per memref.store spec, the indices must be in-bounds:
908 // 0 <= idx < dim_size, and additionally all offsets are non-negative,
909 // hence inbounds and nuw are used when lowering to llvm.getelementptr.
910 Value dataPtr =
911 getStridedElementPtr(rewriter, loc: op.getLoc(), type, memRefDesc: adaptor.getMemref(),
912 indices: adaptor.getIndices(), noWrapFlags: kNoWrapFlags);
913 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, args: adaptor.getValue(), args&: dataPtr,
914 args: 0, args: false, args: op.getNontemporal());
915 return success();
916 }
917};
918
919// The prefetch operation is lowered in a way similar to the load operation
920// except that the llvm.prefetch operation is used for replacement.
921struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
922 using Base::Base;
923
924 LogicalResult
925 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
926 ConversionPatternRewriter &rewriter) const override {
927 auto type = prefetchOp.getMemRefType();
928 auto loc = prefetchOp.getLoc();
929
930 Value dataPtr = getStridedElementPtr(
931 rewriter, loc, type, memRefDesc: adaptor.getMemref(), indices: adaptor.getIndices());
932
933 // Replace with llvm.prefetch.
934 IntegerAttr isWrite = rewriter.getI32IntegerAttr(value: prefetchOp.getIsWrite());
935 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
936 IntegerAttr isData =
937 rewriter.getI32IntegerAttr(value: prefetchOp.getIsDataCache());
938 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(op: prefetchOp, args&: dataPtr, args&: isWrite,
939 args&: localityHint, args&: isData);
940 return success();
941 }
942};
943
944struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
945 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
946
947 LogicalResult
948 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
949 ConversionPatternRewriter &rewriter) const override {
950 Location loc = op.getLoc();
951 Type operandType = op.getMemref().getType();
952 if (isa<UnrankedMemRefType>(Val: operandType)) {
953 UnrankedMemRefDescriptor desc(adaptor.getMemref());
954 rewriter.replaceOp(op, newValues: {desc.rank(builder&: rewriter, loc)});
955 return success();
956 }
957 if (auto rankedMemRefType = dyn_cast<MemRefType>(Val&: operandType)) {
958 Type indexType = getIndexType();
959 rewriter.replaceOp(op,
960 newValues: {createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType,
961 value: rankedMemRefType.getRank())});
962 return success();
963 }
964 return failure();
965 }
966};
967
968struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
969 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
970
971 LogicalResult
972 matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
973 ConversionPatternRewriter &rewriter) const override {
974 Type srcType = memRefCastOp.getOperand().getType();
975 Type dstType = memRefCastOp.getType();
976
977 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
978 // used for type erasure. For now they must preserve underlying element type
979 // and require source and result type to have the same rank. Therefore,
980 // perform a sanity check that the underlying structs are the same. Once op
981 // semantics are relaxed we can revisit.
982 if (isa<MemRefType>(Val: srcType) && isa<MemRefType>(Val: dstType))
983 if (typeConverter->convertType(t: srcType) !=
984 typeConverter->convertType(t: dstType))
985 return failure();
986
987 // Unranked to unranked cast is disallowed
988 if (isa<UnrankedMemRefType>(Val: srcType) && isa<UnrankedMemRefType>(Val: dstType))
989 return failure();
990
991 auto targetStructType = typeConverter->convertType(t: memRefCastOp.getType());
992 auto loc = memRefCastOp.getLoc();
993
994 // For ranked/ranked case, just keep the original descriptor.
995 if (isa<MemRefType>(Val: srcType) && isa<MemRefType>(Val: dstType)) {
996 rewriter.replaceOp(op: memRefCastOp, newValues: {adaptor.getSource()});
997 return success();
998 }
999
1000 if (isa<MemRefType>(Val: srcType) && isa<UnrankedMemRefType>(Val: dstType)) {
1001 // Casting ranked to unranked memref type
1002 // Set the rank in the destination from the memref type
1003 // Allocate space on the stack and copy the src memref descriptor
1004 // Set the ptr in the destination to the stack space
1005 auto srcMemRefType = cast<MemRefType>(Val&: srcType);
1006 int64_t rank = srcMemRefType.getRank();
1007 // ptr = AllocaOp sizeof(MemRefDescriptor)
1008 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
1009 loc, operand: adaptor.getSource(), builder&: rewriter);
1010
1011 // rank = ConstantOp srcRank
1012 auto rankVal = rewriter.create<LLVM::ConstantOp>(
1013 location: loc, args: getIndexType(), args: rewriter.getIndexAttr(value: rank));
1014 // poison = PoisonOp
1015 UnrankedMemRefDescriptor memRefDesc =
1016 UnrankedMemRefDescriptor::poison(builder&: rewriter, loc, descriptorType: targetStructType);
1017 // d1 = InsertValueOp poison, rank, 0
1018 memRefDesc.setRank(builder&: rewriter, loc, value: rankVal);
1019 // d2 = InsertValueOp d1, ptr, 1
1020 memRefDesc.setMemRefDescPtr(builder&: rewriter, loc, value: ptr);
1021 rewriter.replaceOp(op: memRefCastOp, newValues: (Value)memRefDesc);
1022
1023 } else if (isa<UnrankedMemRefType>(Val: srcType) && isa<MemRefType>(Val: dstType)) {
1024 // Casting from unranked type to ranked.
1025 // The operation is assumed to be doing a correct cast. If the destination
1026 // type mismatches the unranked the type, it is undefined behavior.
1027 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
1028 // ptr = ExtractValueOp src, 1
1029 auto ptr = memRefDesc.memRefDescPtr(builder&: rewriter, loc);
1030
1031 // struct = LoadOp ptr
1032 auto loadOp = rewriter.create<LLVM::LoadOp>(location: loc, args&: targetStructType, args&: ptr);
1033 rewriter.replaceOp(op: memRefCastOp, newValues: loadOp.getResult());
1034 } else {
1035 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
1036 }
1037
1038 return success();
1039 }
1040};
1041
1042/// Pattern to lower a `memref.copy` to llvm.
1043///
1044/// For memrefs with identity layouts, the copy is lowered to the llvm
1045/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
1046/// to the generic `MemrefCopyFn`.
1047class MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
1048 SymbolTableCollection *symbolTables = nullptr;
1049
1050public:
1051 explicit MemRefCopyOpLowering(const LLVMTypeConverter &typeConverter,
1052 SymbolTableCollection *symbolTables = nullptr,
1053 PatternBenefit benefit = 1)
1054 : ConvertOpToLLVMPattern<memref::CopyOp>(typeConverter, benefit),
1055 symbolTables(symbolTables) {}
1056
1057 LogicalResult
1058 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1059 ConversionPatternRewriter &rewriter) const {
1060 auto loc = op.getLoc();
1061 auto srcType = dyn_cast<MemRefType>(Val: op.getSource().getType());
1062
1063 MemRefDescriptor srcDesc(adaptor.getSource());
1064
1065 // Compute number of elements.
1066 Value numElements = rewriter.create<LLVM::ConstantOp>(
1067 location: loc, args: getIndexType(), args: rewriter.getIndexAttr(value: 1));
1068 for (int pos = 0; pos < srcType.getRank(); ++pos) {
1069 auto size = srcDesc.size(builder&: rewriter, loc, pos);
1070 numElements = rewriter.create<LLVM::MulOp>(location: loc, args&: numElements, args&: size);
1071 }
1072
1073 // Get element size.
1074 auto sizeInBytes = getSizeInBytes(loc, type: srcType.getElementType(), rewriter);
1075 // Compute total.
1076 Value totalSize =
1077 rewriter.create<LLVM::MulOp>(location: loc, args&: numElements, args&: sizeInBytes);
1078
1079 Type elementType = typeConverter->convertType(t: srcType.getElementType());
1080
1081 Value srcBasePtr = srcDesc.alignedPtr(builder&: rewriter, loc);
1082 Value srcOffset = srcDesc.offset(builder&: rewriter, loc);
1083 Value srcPtr = rewriter.create<LLVM::GEPOp>(
1084 location: loc, args: srcBasePtr.getType(), args&: elementType, args&: srcBasePtr, args&: srcOffset);
1085 MemRefDescriptor targetDesc(adaptor.getTarget());
1086 Value targetBasePtr = targetDesc.alignedPtr(builder&: rewriter, loc);
1087 Value targetOffset = targetDesc.offset(builder&: rewriter, loc);
1088 Value targetPtr = rewriter.create<LLVM::GEPOp>(
1089 location: loc, args: targetBasePtr.getType(), args&: elementType, args&: targetBasePtr, args&: targetOffset);
1090 rewriter.create<LLVM::MemcpyOp>(location: loc, args&: targetPtr, args&: srcPtr, args&: totalSize,
1091 /*isVolatile=*/args: false);
1092 rewriter.eraseOp(op);
1093
1094 return success();
1095 }
1096
1097 LogicalResult
1098 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1099 ConversionPatternRewriter &rewriter) const {
1100 auto loc = op.getLoc();
1101 auto srcType = cast<BaseMemRefType>(Val: op.getSource().getType());
1102 auto targetType = cast<BaseMemRefType>(Val: op.getTarget().getType());
1103
1104 // First make sure we have an unranked memref descriptor representation.
1105 auto makeUnranked = [&, this](Value ranked, MemRefType type) {
1106 auto rank = rewriter.create<LLVM::ConstantOp>(location: loc, args: getIndexType(),
1107 args: type.getRank());
1108 auto *typeConverter = getTypeConverter();
1109 auto ptr =
1110 typeConverter->promoteOneMemRefDescriptor(loc, operand: ranked, builder&: rewriter);
1111
1112 auto unrankedType =
1113 UnrankedMemRefType::get(elementType: type.getElementType(), memorySpace: type.getMemorySpace());
1114 return UnrankedMemRefDescriptor::pack(
1115 builder&: rewriter, loc, converter: *typeConverter, type: unrankedType, values: ValueRange{rank, ptr});
1116 };
1117
1118 // Save stack position before promoting descriptors
1119 auto stackSaveOp = rewriter.create<LLVM::StackSaveOp>(location: loc, args: getPtrType());
1120
1121 auto srcMemRefType = dyn_cast<MemRefType>(Val&: srcType);
1122 Value unrankedSource =
1123 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1124 : adaptor.getSource();
1125 auto targetMemRefType = dyn_cast<MemRefType>(Val&: targetType);
1126 Value unrankedTarget =
1127 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1128 : adaptor.getTarget();
1129
1130 // Now promote the unranked descriptors to the stack.
1131 auto one = rewriter.create<LLVM::ConstantOp>(location: loc, args: getIndexType(),
1132 args: rewriter.getIndexAttr(value: 1));
1133 auto promote = [&](Value desc) {
1134 auto ptrType = LLVM::LLVMPointerType::get(context: rewriter.getContext());
1135 auto allocated =
1136 rewriter.create<LLVM::AllocaOp>(location: loc, args&: ptrType, args: desc.getType(), args&: one);
1137 rewriter.create<LLVM::StoreOp>(location: loc, args&: desc, args&: allocated);
1138 return allocated;
1139 };
1140
1141 auto sourcePtr = promote(unrankedSource);
1142 auto targetPtr = promote(unrankedTarget);
1143
1144 // Derive size from llvm.getelementptr which will account for any
1145 // potential alignment
1146 auto elemSize = getSizeInBytes(loc, type: srcType.getElementType(), rewriter);
1147 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
1148 b&: rewriter, moduleOp: op->getParentOfType<ModuleOp>(), indexType: getIndexType(),
1149 unrankedDescriptorType: sourcePtr.getType(), symbolTables);
1150 if (failed(Result: copyFn))
1151 return failure();
1152 rewriter.create<LLVM::CallOp>(location: loc, args&: copyFn.value(),
1153 args: ValueRange{elemSize, sourcePtr, targetPtr});
1154
1155 // Restore stack used for descriptors
1156 rewriter.create<LLVM::StackRestoreOp>(location: loc, args&: stackSaveOp);
1157
1158 rewriter.eraseOp(op);
1159
1160 return success();
1161 }
1162
1163 LogicalResult
1164 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1165 ConversionPatternRewriter &rewriter) const override {
1166 auto srcType = cast<BaseMemRefType>(Val: op.getSource().getType());
1167 auto targetType = cast<BaseMemRefType>(Val: op.getTarget().getType());
1168
1169 auto isContiguousMemrefType = [&](BaseMemRefType type) {
1170 auto memrefType = dyn_cast<mlir::MemRefType>(Val&: type);
1171 // We can use memcpy for memrefs if they have an identity layout or are
1172 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
1173 // special case handled by memrefCopy.
1174 return memrefType &&
1175 (memrefType.getLayout().isIdentity() ||
1176 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1177 memref::isStaticShapeAndContiguousRowMajor(type: memrefType)));
1178 };
1179
1180 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1181 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1182
1183 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1184 }
1185};
1186
1187struct MemorySpaceCastOpLowering
1188 : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
1189 using ConvertOpToLLVMPattern<
1190 memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
1191
1192 LogicalResult
1193 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1194 ConversionPatternRewriter &rewriter) const override {
1195 Location loc = op.getLoc();
1196
1197 Type resultType = op.getDest().getType();
1198 if (auto resultTypeR = dyn_cast<MemRefType>(Val&: resultType)) {
1199 auto resultDescType =
1200 cast<LLVM::LLVMStructType>(Val: typeConverter->convertType(t: resultTypeR));
1201 Type newPtrType = resultDescType.getBody()[0];
1202
1203 SmallVector<Value> descVals;
1204 MemRefDescriptor::unpack(builder&: rewriter, loc, packed: adaptor.getSource(), type: resultTypeR,
1205 results&: descVals);
1206 descVals[0] =
1207 rewriter.create<LLVM::AddrSpaceCastOp>(location: loc, args&: newPtrType, args&: descVals[0]);
1208 descVals[1] =
1209 rewriter.create<LLVM::AddrSpaceCastOp>(location: loc, args&: newPtrType, args&: descVals[1]);
1210 Value result = MemRefDescriptor::pack(builder&: rewriter, loc, converter: *getTypeConverter(),
1211 type: resultTypeR, values: descVals);
1212 rewriter.replaceOp(op, newValues: result);
1213 return success();
1214 }
1215 if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(Val&: resultType)) {
1216 // Since the type converter won't be doing this for us, get the address
1217 // space.
1218 auto sourceType = cast<UnrankedMemRefType>(Val: op.getSource().getType());
1219 FailureOr<unsigned> maybeSourceAddrSpace =
1220 getTypeConverter()->getMemRefAddressSpace(type: sourceType);
1221 if (failed(Result: maybeSourceAddrSpace))
1222 return rewriter.notifyMatchFailure(arg&: loc,
1223 msg: "non-integer source address space");
1224 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1225 FailureOr<unsigned> maybeResultAddrSpace =
1226 getTypeConverter()->getMemRefAddressSpace(type: resultTypeU);
1227 if (failed(Result: maybeResultAddrSpace))
1228 return rewriter.notifyMatchFailure(arg&: loc,
1229 msg: "non-integer result address space");
1230 unsigned resultAddrSpace = *maybeResultAddrSpace;
1231
1232 UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
1233 Value rank = sourceDesc.rank(builder&: rewriter, loc);
1234 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(builder&: rewriter, loc);
1235
1236 // Create and allocate storage for new memref descriptor.
1237 auto result = UnrankedMemRefDescriptor::poison(
1238 builder&: rewriter, loc, descriptorType: typeConverter->convertType(t: resultTypeU));
1239 result.setRank(builder&: rewriter, loc, value: rank);
1240 SmallVector<Value, 1> sizes;
1241 UnrankedMemRefDescriptor::computeSizes(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1242 values: result, addressSpaces: resultAddrSpace, sizes);
1243 Value resultUnderlyingSize = sizes.front();
1244 Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
1245 location: loc, args: getPtrType(), args: rewriter.getI8Type(), args&: resultUnderlyingSize);
1246 result.setMemRefDescPtr(builder&: rewriter, loc, value: resultUnderlyingDesc);
1247
1248 // Copy pointers, performing address space casts.
1249 auto sourceElemPtrType =
1250 LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace: sourceAddrSpace);
1251 auto resultElemPtrType =
1252 LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace: resultAddrSpace);
1253
1254 Value allocatedPtr = sourceDesc.allocatedPtr(
1255 builder&: rewriter, loc, memRefDescPtr: sourceUnderlyingDesc, elemPtrType: sourceElemPtrType);
1256 Value alignedPtr =
1257 sourceDesc.alignedPtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1258 memRefDescPtr: sourceUnderlyingDesc, elemPtrType: sourceElemPtrType);
1259 allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1260 location: loc, args&: resultElemPtrType, args&: allocatedPtr);
1261 alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1262 location: loc, args&: resultElemPtrType, args&: alignedPtr);
1263
1264 result.setAllocatedPtr(builder&: rewriter, loc, memRefDescPtr: resultUnderlyingDesc,
1265 elemPtrType: resultElemPtrType, allocatedPtr);
1266 result.setAlignedPtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1267 memRefDescPtr: resultUnderlyingDesc, elemPtrType: resultElemPtrType, alignedPtr);
1268
1269 // Copy all the index-valued operands.
1270 Value sourceIndexVals =
1271 sourceDesc.offsetBasePtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1272 memRefDescPtr: sourceUnderlyingDesc, elemPtrType: sourceElemPtrType);
1273 Value resultIndexVals =
1274 result.offsetBasePtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1275 memRefDescPtr: resultUnderlyingDesc, elemPtrType: resultElemPtrType);
1276
1277 int64_t bytesToSkip =
1278 2 * llvm::divideCeil(
1279 Numerator: getTypeConverter()->getPointerBitwidth(addressSpace: resultAddrSpace), Denominator: 8);
1280 Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
1281 location: loc, args: getIndexType(), args: rewriter.getIndexAttr(value: bytesToSkip));
1282 Value copySize = rewriter.create<LLVM::SubOp>(
1283 location: loc, args: getIndexType(), args&: resultUnderlyingSize, args&: bytesToSkipConst);
1284 rewriter.create<LLVM::MemcpyOp>(location: loc, args&: resultIndexVals, args&: sourceIndexVals,
1285 args&: copySize, /*isVolatile=*/args: false);
1286
1287 rewriter.replaceOp(op, newValues: ValueRange{result});
1288 return success();
1289 }
1290 return rewriter.notifyMatchFailure(arg&: loc, msg: "unexpected memref type");
1291 }
1292};
1293
1294/// Extracts allocated, aligned pointers and offset from a ranked or unranked
1295/// memref type. In unranked case, the fields are extracted from the underlying
1296/// ranked descriptor.
1297static void extractPointersAndOffset(Location loc,
1298 ConversionPatternRewriter &rewriter,
1299 const LLVMTypeConverter &typeConverter,
1300 Value originalOperand,
1301 Value convertedOperand,
1302 Value *allocatedPtr, Value *alignedPtr,
1303 Value *offset = nullptr) {
1304 Type operandType = originalOperand.getType();
1305 if (isa<MemRefType>(Val: operandType)) {
1306 MemRefDescriptor desc(convertedOperand);
1307 *allocatedPtr = desc.allocatedPtr(builder&: rewriter, loc);
1308 *alignedPtr = desc.alignedPtr(builder&: rewriter, loc);
1309 if (offset != nullptr)
1310 *offset = desc.offset(builder&: rewriter, loc);
1311 return;
1312 }
1313
1314 // These will all cause assert()s on unconvertible types.
1315 unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1316 type: cast<UnrankedMemRefType>(Val&: operandType));
1317 auto elementPtrType =
1318 LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace: memorySpace);
1319
1320 // Extract pointer to the underlying ranked memref descriptor and cast it to
1321 // ElemType**.
1322 UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1323 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(builder&: rewriter, loc);
1324
1325 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
1326 builder&: rewriter, loc, memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType);
1327 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1328 builder&: rewriter, loc, typeConverter, memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType);
1329 if (offset != nullptr) {
1330 *offset = UnrankedMemRefDescriptor::offset(
1331 builder&: rewriter, loc, typeConverter, memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType);
1332 }
1333}
1334
1335struct MemRefReinterpretCastOpLowering
1336 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1337 using ConvertOpToLLVMPattern<
1338 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1339
1340 LogicalResult
1341 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1342 ConversionPatternRewriter &rewriter) const override {
1343 Type srcType = castOp.getSource().getType();
1344
1345 Value descriptor;
1346 if (failed(Result: convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1347 adaptor, descriptor: &descriptor)))
1348 return failure();
1349 rewriter.replaceOp(op: castOp, newValues: {descriptor});
1350 return success();
1351 }
1352
1353private:
1354 LogicalResult convertSourceMemRefToDescriptor(
1355 ConversionPatternRewriter &rewriter, Type srcType,
1356 memref::ReinterpretCastOp castOp,
1357 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1358 MemRefType targetMemRefType =
1359 cast<MemRefType>(Val: castOp.getResult().getType());
1360 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1361 Val: typeConverter->convertType(t: targetMemRefType));
1362 if (!llvmTargetDescriptorTy)
1363 return failure();
1364
1365 // Create descriptor.
1366 Location loc = castOp.getLoc();
1367 auto desc = MemRefDescriptor::poison(builder&: rewriter, loc, descriptorType: llvmTargetDescriptorTy);
1368
1369 // Set allocated and aligned pointers.
1370 Value allocatedPtr, alignedPtr;
1371 extractPointersAndOffset(loc, rewriter, typeConverter: *getTypeConverter(),
1372 originalOperand: castOp.getSource(), convertedOperand: adaptor.getSource(),
1373 allocatedPtr: &allocatedPtr, alignedPtr: &alignedPtr);
1374 desc.setAllocatedPtr(builder&: rewriter, loc, ptr: allocatedPtr);
1375 desc.setAlignedPtr(builder&: rewriter, loc, ptr: alignedPtr);
1376
1377 // Set offset.
1378 if (castOp.isDynamicOffset(idx: 0))
1379 desc.setOffset(builder&: rewriter, loc, offset: adaptor.getOffsets()[0]);
1380 else
1381 desc.setConstantOffset(builder&: rewriter, loc, offset: castOp.getStaticOffset(idx: 0));
1382
1383 // Set sizes and strides.
1384 unsigned dynSizeId = 0;
1385 unsigned dynStrideId = 0;
1386 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1387 if (castOp.isDynamicSize(idx: i))
1388 desc.setSize(builder&: rewriter, loc, pos: i, size: adaptor.getSizes()[dynSizeId++]);
1389 else
1390 desc.setConstantSize(builder&: rewriter, loc, pos: i, size: castOp.getStaticSize(idx: i));
1391
1392 if (castOp.isDynamicStride(idx: i))
1393 desc.setStride(builder&: rewriter, loc, pos: i, stride: adaptor.getStrides()[dynStrideId++]);
1394 else
1395 desc.setConstantStride(builder&: rewriter, loc, pos: i, stride: castOp.getStaticStride(idx: i));
1396 }
1397 *descriptor = desc;
1398 return success();
1399 }
1400};
1401
1402struct MemRefReshapeOpLowering
1403 : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1404 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1405
1406 LogicalResult
1407 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1408 ConversionPatternRewriter &rewriter) const override {
1409 Type srcType = reshapeOp.getSource().getType();
1410
1411 Value descriptor;
1412 if (failed(Result: convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1413 adaptor, descriptor: &descriptor)))
1414 return failure();
1415 rewriter.replaceOp(op: reshapeOp, newValues: {descriptor});
1416 return success();
1417 }
1418
1419private:
1420 LogicalResult
1421 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1422 Type srcType, memref::ReshapeOp reshapeOp,
1423 memref::ReshapeOp::Adaptor adaptor,
1424 Value *descriptor) const {
1425 auto shapeMemRefType = cast<MemRefType>(Val: reshapeOp.getShape().getType());
1426 if (shapeMemRefType.hasStaticShape()) {
1427 MemRefType targetMemRefType =
1428 cast<MemRefType>(Val: reshapeOp.getResult().getType());
1429 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1430 Val: typeConverter->convertType(t: targetMemRefType));
1431 if (!llvmTargetDescriptorTy)
1432 return failure();
1433
1434 // Create descriptor.
1435 Location loc = reshapeOp.getLoc();
1436 auto desc =
1437 MemRefDescriptor::poison(builder&: rewriter, loc, descriptorType: llvmTargetDescriptorTy);
1438
1439 // Set allocated and aligned pointers.
1440 Value allocatedPtr, alignedPtr;
1441 extractPointersAndOffset(loc, rewriter, typeConverter: *getTypeConverter(),
1442 originalOperand: reshapeOp.getSource(), convertedOperand: adaptor.getSource(),
1443 allocatedPtr: &allocatedPtr, alignedPtr: &alignedPtr);
1444 desc.setAllocatedPtr(builder&: rewriter, loc, ptr: allocatedPtr);
1445 desc.setAlignedPtr(builder&: rewriter, loc, ptr: alignedPtr);
1446
1447 // Extract the offset and strides from the type.
1448 int64_t offset;
1449 SmallVector<int64_t> strides;
1450 if (failed(Result: targetMemRefType.getStridesAndOffset(strides, offset)))
1451 return rewriter.notifyMatchFailure(
1452 arg&: reshapeOp, msg: "failed to get stride and offset exprs");
1453
1454 if (!isStaticStrideOrOffset(strideOrOffset: offset))
1455 return rewriter.notifyMatchFailure(arg&: reshapeOp,
1456 msg: "dynamic offset is unsupported");
1457
1458 desc.setConstantOffset(builder&: rewriter, loc, offset);
1459
1460 assert(targetMemRefType.getLayout().isIdentity() &&
1461 "Identity layout map is a precondition of a valid reshape op");
1462
1463 Type indexType = getIndexType();
1464 Value stride = nullptr;
1465 int64_t targetRank = targetMemRefType.getRank();
1466 for (auto i : llvm::reverse(C: llvm::seq<int64_t>(Begin: 0, End: targetRank))) {
1467 if (ShapedType::isStatic(dValue: strides[i])) {
1468 // If the stride for this dimension is dynamic, then use the product
1469 // of the sizes of the inner dimensions.
1470 stride =
1471 createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: strides[i]);
1472 } else if (!stride) {
1473 // `stride` is null only in the first iteration of the loop. However,
1474 // since the target memref has an identity layout, we can safely set
1475 // the innermost stride to 1.
1476 stride = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1);
1477 }
1478
1479 Value dimSize;
1480 // If the size of this dimension is dynamic, then load it at runtime
1481 // from the shape operand.
1482 if (!targetMemRefType.isDynamicDim(idx: i)) {
1483 dimSize = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType,
1484 value: targetMemRefType.getDimSize(idx: i));
1485 } else {
1486 Value shapeOp = reshapeOp.getShape();
1487 Value index = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: i);
1488 dimSize = rewriter.create<memref::LoadOp>(location: loc, args&: shapeOp, args&: index);
1489 Type indexType = getIndexType();
1490 if (dimSize.getType() != indexType)
1491 dimSize = typeConverter->materializeTargetConversion(
1492 builder&: rewriter, loc, resultType: indexType, inputs: dimSize);
1493 assert(dimSize && "Invalid memref element type");
1494 }
1495
1496 desc.setSize(builder&: rewriter, loc, pos: i, size: dimSize);
1497 desc.setStride(builder&: rewriter, loc, pos: i, stride);
1498
1499 // Prepare the stride value for the next dimension.
1500 stride = rewriter.create<LLVM::MulOp>(location: loc, args&: stride, args&: dimSize);
1501 }
1502
1503 *descriptor = desc;
1504 return success();
1505 }
1506
1507 // The shape is a rank-1 tensor with unknown length.
1508 Location loc = reshapeOp.getLoc();
1509 MemRefDescriptor shapeDesc(adaptor.getShape());
1510 Value resultRank = shapeDesc.size(builder&: rewriter, loc, pos: 0);
1511
1512 // Extract address space and element type.
1513 auto targetType = cast<UnrankedMemRefType>(Val: reshapeOp.getResult().getType());
1514 unsigned addressSpace =
1515 *getTypeConverter()->getMemRefAddressSpace(type: targetType);
1516
1517 // Create the unranked memref descriptor that holds the ranked one. The
1518 // inner descriptor is allocated on stack.
1519 auto targetDesc = UnrankedMemRefDescriptor::poison(
1520 builder&: rewriter, loc, descriptorType: typeConverter->convertType(t: targetType));
1521 targetDesc.setRank(builder&: rewriter, loc, value: resultRank);
1522 SmallVector<Value, 4> sizes;
1523 UnrankedMemRefDescriptor::computeSizes(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1524 values: targetDesc, addressSpaces: addressSpace, sizes);
1525 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1526 location: loc, args: getPtrType(), args: IntegerType::get(context: getContext(), width: 8), args&: sizes.front());
1527 targetDesc.setMemRefDescPtr(builder&: rewriter, loc, value: underlyingDescPtr);
1528
1529 // Extract pointers and offset from the source memref.
1530 Value allocatedPtr, alignedPtr, offset;
1531 extractPointersAndOffset(loc, rewriter, typeConverter: *getTypeConverter(),
1532 originalOperand: reshapeOp.getSource(), convertedOperand: adaptor.getSource(),
1533 allocatedPtr: &allocatedPtr, alignedPtr: &alignedPtr, offset: &offset);
1534
1535 // Set pointers and offset.
1536 auto elementPtrType =
1537 LLVM::LLVMPointerType::get(context: rewriter.getContext(), addressSpace);
1538
1539 UnrankedMemRefDescriptor::setAllocatedPtr(builder&: rewriter, loc, memRefDescPtr: underlyingDescPtr,
1540 elemPtrType: elementPtrType, allocatedPtr);
1541 UnrankedMemRefDescriptor::setAlignedPtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1542 memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType,
1543 alignedPtr);
1544 UnrankedMemRefDescriptor::setOffset(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1545 memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType,
1546 offset);
1547
1548 // Use the offset pointer as base for further addressing. Copy over the new
1549 // shape and compute strides. For this, we create a loop from rank-1 to 0.
1550 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1551 builder&: rewriter, loc, typeConverter: *getTypeConverter(), memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType);
1552 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1553 builder&: rewriter, loc, typeConverter: *getTypeConverter(), sizeBasePtr: targetSizesBase, rank: resultRank);
1554 Value shapeOperandPtr = shapeDesc.alignedPtr(builder&: rewriter, loc);
1555 Value oneIndex = createIndexAttrConstant(builder&: rewriter, loc, resultType: getIndexType(), value: 1);
1556 Value resultRankMinusOne =
1557 rewriter.create<LLVM::SubOp>(location: loc, args&: resultRank, args&: oneIndex);
1558
1559 Block *initBlock = rewriter.getInsertionBlock();
1560 Type indexType = getTypeConverter()->getIndexType();
1561 Block::iterator remainingOpsIt = std::next(x: rewriter.getInsertionPoint());
1562
1563 Block *condBlock = rewriter.createBlock(parent: initBlock->getParent(), insertPt: {},
1564 argTypes: {indexType, indexType}, locs: {loc, loc});
1565
1566 // Move the remaining initBlock ops to condBlock.
1567 Block *remainingBlock = rewriter.splitBlock(block: initBlock, before: remainingOpsIt);
1568 rewriter.mergeBlocks(source: remainingBlock, dest: condBlock, argValues: ValueRange());
1569
1570 rewriter.setInsertionPointToEnd(initBlock);
1571 rewriter.create<LLVM::BrOp>(location: loc, args: ValueRange({resultRankMinusOne, oneIndex}),
1572 args&: condBlock);
1573 rewriter.setInsertionPointToStart(condBlock);
1574 Value indexArg = condBlock->getArgument(i: 0);
1575 Value strideArg = condBlock->getArgument(i: 1);
1576
1577 Value zeroIndex = createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 0);
1578 Value pred = rewriter.create<LLVM::ICmpOp>(
1579 location: loc, args: IntegerType::get(context: rewriter.getContext(), width: 1),
1580 args: LLVM::ICmpPredicate::sge, args&: indexArg, args&: zeroIndex);
1581
1582 Block *bodyBlock =
1583 rewriter.splitBlock(block: condBlock, before: rewriter.getInsertionPoint());
1584 rewriter.setInsertionPointToStart(bodyBlock);
1585
1586 // Copy size from shape to descriptor.
1587 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(context: rewriter.getContext());
1588 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1589 location: loc, args&: llvmIndexPtrType,
1590 args: typeConverter->convertType(t: shapeMemRefType.getElementType()),
1591 args&: shapeOperandPtr, args&: indexArg);
1592 Value size = rewriter.create<LLVM::LoadOp>(location: loc, args&: indexType, args&: sizeLoadGep);
1593 UnrankedMemRefDescriptor::setSize(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1594 sizeBasePtr: targetSizesBase, index: indexArg, size);
1595
1596 // Write stride value and compute next one.
1597 UnrankedMemRefDescriptor::setStride(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1598 strideBasePtr: targetStridesBase, index: indexArg, stride: strideArg);
1599 Value nextStride = rewriter.create<LLVM::MulOp>(location: loc, args&: strideArg, args&: size);
1600
1601 // Decrement loop counter and branch back.
1602 Value decrement = rewriter.create<LLVM::SubOp>(location: loc, args&: indexArg, args&: oneIndex);
1603 rewriter.create<LLVM::BrOp>(location: loc, args: ValueRange({decrement, nextStride}),
1604 args&: condBlock);
1605
1606 Block *remainder =
1607 rewriter.splitBlock(block: bodyBlock, before: rewriter.getInsertionPoint());
1608
1609 // Hook up the cond exit to the remainder.
1610 rewriter.setInsertionPointToEnd(condBlock);
1611 rewriter.create<LLVM::CondBrOp>(location: loc, args&: pred, args&: bodyBlock, args: ValueRange(),
1612 args&: remainder, args: ValueRange());
1613
1614 // Reset position to beginning of new remainder block.
1615 rewriter.setInsertionPointToStart(remainder);
1616
1617 *descriptor = targetDesc;
1618 return success();
1619 }
1620};
1621
1622/// RessociatingReshapeOp must be expanded before we reach this stage.
1623/// Report that information.
1624template <typename ReshapeOp>
1625class ReassociatingReshapeOpConversion
1626 : public ConvertOpToLLVMPattern<ReshapeOp> {
1627public:
1628 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1629 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1630
1631 LogicalResult
1632 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1633 ConversionPatternRewriter &rewriter) const override {
1634 return rewriter.notifyMatchFailure(
1635 reshapeOp,
1636 "reassociation operations should have been expanded beforehand");
1637 }
1638};
1639
1640/// Subviews must be expanded before we reach this stage.
1641/// Report that information.
1642struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1643 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1644
1645 LogicalResult
1646 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1647 ConversionPatternRewriter &rewriter) const override {
1648 return rewriter.notifyMatchFailure(
1649 arg&: subViewOp, msg: "subview operations should have been expanded beforehand");
1650 }
1651};
1652
1653/// Conversion pattern that transforms a transpose op into:
1654/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1655/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1656/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1657/// and stride. Size and stride are permutations of the original values.
1658/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1659/// The transpose op is replaced by the alloca'ed pointer.
1660class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1661public:
1662 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1663
1664 LogicalResult
1665 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1666 ConversionPatternRewriter &rewriter) const override {
1667 auto loc = transposeOp.getLoc();
1668 MemRefDescriptor viewMemRef(adaptor.getIn());
1669
1670 // No permutation, early exit.
1671 if (transposeOp.getPermutation().isIdentity())
1672 return rewriter.replaceOp(op: transposeOp, newValues: {viewMemRef}), success();
1673
1674 auto targetMemRef = MemRefDescriptor::poison(
1675 builder&: rewriter, loc,
1676 descriptorType: typeConverter->convertType(t: transposeOp.getIn().getType()));
1677
1678 // Copy the base and aligned pointers from the old descriptor to the new
1679 // one.
1680 targetMemRef.setAllocatedPtr(builder&: rewriter, loc,
1681 ptr: viewMemRef.allocatedPtr(builder&: rewriter, loc));
1682 targetMemRef.setAlignedPtr(builder&: rewriter, loc,
1683 ptr: viewMemRef.alignedPtr(builder&: rewriter, loc));
1684
1685 // Copy the offset pointer from the old descriptor to the new one.
1686 targetMemRef.setOffset(builder&: rewriter, loc, offset: viewMemRef.offset(builder&: rewriter, loc));
1687
1688 // Iterate over the dimensions and apply size/stride permutation:
1689 // When enumerating the results of the permutation map, the enumeration
1690 // index is the index into the target dimensions and the DimExpr points to
1691 // the dimension of the source memref.
1692 for (const auto &en :
1693 llvm::enumerate(First: transposeOp.getPermutation().getResults())) {
1694 int targetPos = en.index();
1695 int sourcePos = cast<AffineDimExpr>(Val: en.value()).getPosition();
1696 targetMemRef.setSize(builder&: rewriter, loc, pos: targetPos,
1697 size: viewMemRef.size(builder&: rewriter, loc, pos: sourcePos));
1698 targetMemRef.setStride(builder&: rewriter, loc, pos: targetPos,
1699 stride: viewMemRef.stride(builder&: rewriter, loc, pos: sourcePos));
1700 }
1701
1702 rewriter.replaceOp(op: transposeOp, newValues: {targetMemRef});
1703 return success();
1704 }
1705};
1706
1707/// Conversion pattern that transforms an op into:
1708/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1709/// 2. Updates to the descriptor to introduce the data ptr, offset, size
1710/// and stride.
1711/// The view op is replaced by the descriptor.
1712struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1713 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1714
1715 // Build and return the value for the idx^th shape dimension, either by
1716 // returning the constant shape dimension or counting the proper dynamic size.
1717 Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1718 ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1719 Type indexType) const {
1720 assert(idx < shape.size());
1721 if (ShapedType::isStatic(dValue: shape[idx]))
1722 return createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: shape[idx]);
1723 // Count the number of dynamic dims in range [0, idx]
1724 unsigned nDynamic =
1725 llvm::count_if(Range: shape.take_front(N: idx), P: ShapedType::isDynamic);
1726 return dynamicSizes[nDynamic];
1727 }
1728
1729 // Build and return the idx^th stride, either by returning the constant stride
1730 // or by computing the dynamic stride from the current `runningStride` and
1731 // `nextSize`. The caller should keep a running stride and update it with the
1732 // result returned by this function.
1733 Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1734 ArrayRef<int64_t> strides, Value nextSize,
1735 Value runningStride, unsigned idx, Type indexType) const {
1736 assert(idx < strides.size());
1737 if (ShapedType::isStatic(dValue: strides[idx]))
1738 return createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: strides[idx]);
1739 if (nextSize)
1740 return runningStride
1741 ? rewriter.create<LLVM::MulOp>(location: loc, args&: runningStride, args&: nextSize)
1742 : nextSize;
1743 assert(!runningStride);
1744 return createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: 1);
1745 }
1746
1747 LogicalResult
1748 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1749 ConversionPatternRewriter &rewriter) const override {
1750 auto loc = viewOp.getLoc();
1751
1752 auto viewMemRefType = viewOp.getType();
1753 auto targetElementTy =
1754 typeConverter->convertType(t: viewMemRefType.getElementType());
1755 auto targetDescTy = typeConverter->convertType(t: viewMemRefType);
1756 if (!targetDescTy || !targetElementTy ||
1757 !LLVM::isCompatibleType(type: targetElementTy) ||
1758 !LLVM::isCompatibleType(type: targetDescTy))
1759 return viewOp.emitWarning(message: "Target descriptor type not converted to LLVM"),
1760 failure();
1761
1762 int64_t offset;
1763 SmallVector<int64_t, 4> strides;
1764 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1765 if (failed(Result: successStrides))
1766 return viewOp.emitWarning(message: "cannot cast to non-strided shape"), failure();
1767 assert(offset == 0 && "expected offset to be 0");
1768
1769 // Target memref must be contiguous in memory (innermost stride is 1), or
1770 // empty (special case when at least one of the memref dimensions is 0).
1771 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1772 return viewOp.emitWarning(message: "cannot cast to non-contiguous shape"),
1773 failure();
1774
1775 // Create the descriptor.
1776 MemRefDescriptor sourceMemRef(adaptor.getSource());
1777 auto targetMemRef = MemRefDescriptor::poison(builder&: rewriter, loc, descriptorType: targetDescTy);
1778
1779 // Field 1: Copy the allocated pointer, used for malloc/free.
1780 Value allocatedPtr = sourceMemRef.allocatedPtr(builder&: rewriter, loc);
1781 auto srcMemRefType = cast<MemRefType>(Val: viewOp.getSource().getType());
1782 targetMemRef.setAllocatedPtr(builder&: rewriter, loc, ptr: allocatedPtr);
1783
1784 // Field 2: Copy the actual aligned pointer to payload.
1785 Value alignedPtr = sourceMemRef.alignedPtr(builder&: rewriter, loc);
1786 alignedPtr = rewriter.create<LLVM::GEPOp>(
1787 location: loc, args: alignedPtr.getType(),
1788 args: typeConverter->convertType(t: srcMemRefType.getElementType()), args&: alignedPtr,
1789 args: adaptor.getByteShift());
1790
1791 targetMemRef.setAlignedPtr(builder&: rewriter, loc, ptr: alignedPtr);
1792
1793 Type indexType = getIndexType();
1794 // Field 3: The offset in the resulting type must be 0. This is
1795 // because of the type change: an offset on srcType* may not be
1796 // expressible as an offset on dstType*.
1797 targetMemRef.setOffset(
1798 builder&: rewriter, loc,
1799 offset: createIndexAttrConstant(builder&: rewriter, loc, resultType: indexType, value: offset));
1800
1801 // Early exit for 0-D corner case.
1802 if (viewMemRefType.getRank() == 0)
1803 return rewriter.replaceOp(op: viewOp, newValues: {targetMemRef}), success();
1804
1805 // Fields 4 and 5: Update sizes and strides.
1806 Value stride = nullptr, nextSize = nullptr;
1807 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1808 // Update size.
1809 Value size = getSize(rewriter, loc, shape: viewMemRefType.getShape(),
1810 dynamicSizes: adaptor.getSizes(), idx: i, indexType);
1811 targetMemRef.setSize(builder&: rewriter, loc, pos: i, size);
1812 // Update stride.
1813 stride =
1814 getStride(rewriter, loc, strides, nextSize, runningStride: stride, idx: i, indexType);
1815 targetMemRef.setStride(builder&: rewriter, loc, pos: i, stride);
1816 nextSize = size;
1817 }
1818
1819 rewriter.replaceOp(op: viewOp, newValues: {targetMemRef});
1820 return success();
1821 }
1822};
1823
1824//===----------------------------------------------------------------------===//
1825// AtomicRMWOpLowering
1826//===----------------------------------------------------------------------===//
1827
1828/// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1829/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1830static std::optional<LLVM::AtomicBinOp>
1831matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1832 switch (atomicOp.getKind()) {
1833 case arith::AtomicRMWKind::addf:
1834 return LLVM::AtomicBinOp::fadd;
1835 case arith::AtomicRMWKind::addi:
1836 return LLVM::AtomicBinOp::add;
1837 case arith::AtomicRMWKind::assign:
1838 return LLVM::AtomicBinOp::xchg;
1839 case arith::AtomicRMWKind::maximumf:
1840 // TODO: remove this by end of 2025.
1841 LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
1842 "from fmax to fmaximum, expect more NaNs");
1843 return LLVM::AtomicBinOp::fmaximum;
1844 case arith::AtomicRMWKind::maxnumf:
1845 return LLVM::AtomicBinOp::fmax;
1846 case arith::AtomicRMWKind::maxs:
1847 return LLVM::AtomicBinOp::max;
1848 case arith::AtomicRMWKind::maxu:
1849 return LLVM::AtomicBinOp::umax;
1850 case arith::AtomicRMWKind::minimumf:
1851 // TODO: remove this by end of 2025.
1852 LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
1853 "from fmin to fminimum, expect more NaNs");
1854 return LLVM::AtomicBinOp::fminimum;
1855 case arith::AtomicRMWKind::minnumf:
1856 return LLVM::AtomicBinOp::fmin;
1857 case arith::AtomicRMWKind::mins:
1858 return LLVM::AtomicBinOp::min;
1859 case arith::AtomicRMWKind::minu:
1860 return LLVM::AtomicBinOp::umin;
1861 case arith::AtomicRMWKind::ori:
1862 return LLVM::AtomicBinOp::_or;
1863 case arith::AtomicRMWKind::andi:
1864 return LLVM::AtomicBinOp::_and;
1865 default:
1866 return std::nullopt;
1867 }
1868 llvm_unreachable("Invalid AtomicRMWKind");
1869}
1870
1871struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1872 using Base::Base;
1873
1874 LogicalResult
1875 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1876 ConversionPatternRewriter &rewriter) const override {
1877 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1878 if (!maybeKind)
1879 return failure();
1880 auto memRefType = atomicOp.getMemRefType();
1881 SmallVector<int64_t> strides;
1882 int64_t offset;
1883 if (failed(Result: memRefType.getStridesAndOffset(strides, offset)))
1884 return failure();
1885 auto dataPtr =
1886 getStridedElementPtr(rewriter, loc: atomicOp.getLoc(), type: memRefType,
1887 memRefDesc: adaptor.getMemref(), indices: adaptor.getIndices());
1888 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1889 op: atomicOp, args&: *maybeKind, args&: dataPtr, args: adaptor.getValue(),
1890 args: LLVM::AtomicOrdering::acq_rel);
1891 return success();
1892 }
1893};
1894
1895/// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1896class ConvertExtractAlignedPointerAsIndex
1897 : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1898public:
1899 using ConvertOpToLLVMPattern<
1900 memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1901
1902 LogicalResult
1903 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1904 OpAdaptor adaptor,
1905 ConversionPatternRewriter &rewriter) const override {
1906 BaseMemRefType sourceTy = extractOp.getSource().getType();
1907
1908 Value alignedPtr;
1909 if (sourceTy.hasRank()) {
1910 MemRefDescriptor desc(adaptor.getSource());
1911 alignedPtr = desc.alignedPtr(builder&: rewriter, loc: extractOp->getLoc());
1912 } else {
1913 auto elementPtrTy = LLVM::LLVMPointerType::get(
1914 context: rewriter.getContext(), addressSpace: sourceTy.getMemorySpaceAsInt());
1915
1916 UnrankedMemRefDescriptor desc(adaptor.getSource());
1917 Value descPtr = desc.memRefDescPtr(builder&: rewriter, loc: extractOp->getLoc());
1918
1919 alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1920 builder&: rewriter, loc: extractOp->getLoc(), typeConverter: *getTypeConverter(), memRefDescPtr: descPtr,
1921 elemPtrType: elementPtrTy);
1922 }
1923
1924 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1925 op: extractOp, args: getTypeConverter()->getIndexType(), args&: alignedPtr);
1926 return success();
1927 }
1928};
1929
1930/// Materialize the MemRef descriptor represented by the results of
1931/// ExtractStridedMetadataOp.
1932class ExtractStridedMetadataOpLowering
1933 : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1934public:
1935 using ConvertOpToLLVMPattern<
1936 memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1937
1938 LogicalResult
1939 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1940 OpAdaptor adaptor,
1941 ConversionPatternRewriter &rewriter) const override {
1942
1943 if (!LLVM::isCompatibleType(type: adaptor.getOperands().front().getType()))
1944 return failure();
1945
1946 // Create the descriptor.
1947 MemRefDescriptor sourceMemRef(adaptor.getSource());
1948 Location loc = extractStridedMetadataOp.getLoc();
1949 Value source = extractStridedMetadataOp.getSource();
1950
1951 auto sourceMemRefType = cast<MemRefType>(Val: source.getType());
1952 int64_t rank = sourceMemRefType.getRank();
1953 SmallVector<Value> results;
1954 results.reserve(N: 2 + rank * 2);
1955
1956 // Base buffer.
1957 Value baseBuffer = sourceMemRef.allocatedPtr(builder&: rewriter, loc);
1958 Value alignedBuffer = sourceMemRef.alignedPtr(builder&: rewriter, loc);
1959 MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
1960 builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1961 type: cast<MemRefType>(Val: extractStridedMetadataOp.getBaseBuffer().getType()),
1962 memory: baseBuffer, alignedMemory: alignedBuffer);
1963 results.push_back(Elt: (Value)dstMemRef);
1964
1965 // Offset.
1966 results.push_back(Elt: sourceMemRef.offset(builder&: rewriter, loc));
1967
1968 // Sizes.
1969 for (unsigned i = 0; i < rank; ++i)
1970 results.push_back(Elt: sourceMemRef.size(builder&: rewriter, loc, pos: i));
1971 // Strides.
1972 for (unsigned i = 0; i < rank; ++i)
1973 results.push_back(Elt: sourceMemRef.stride(builder&: rewriter, loc, pos: i));
1974
1975 rewriter.replaceOp(op: extractStridedMetadataOp, newValues: results);
1976 return success();
1977 }
1978};
1979
1980} // namespace
1981
1982void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
1983 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1984 SymbolTableCollection *symbolTables) {
1985 // clang-format off
1986 patterns.add<
1987 AllocaOpLowering,
1988 AllocaScopeOpLowering,
1989 AtomicRMWOpLowering,
1990 AssumeAlignmentOpLowering,
1991 ConvertExtractAlignedPointerAsIndex,
1992 DimOpLowering,
1993 ExtractStridedMetadataOpLowering,
1994 GenericAtomicRMWOpLowering,
1995 GetGlobalMemrefOpLowering,
1996 LoadOpLowering,
1997 MemRefCastOpLowering,
1998 MemorySpaceCastOpLowering,
1999 MemRefReinterpretCastOpLowering,
2000 MemRefReshapeOpLowering,
2001 PrefetchOpLowering,
2002 RankOpLowering,
2003 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
2004 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
2005 StoreOpLowering,
2006 SubViewOpLowering,
2007 TransposeOpLowering,
2008 ViewOpLowering>(arg: converter);
2009 // clang-format on
2010 patterns.add<GlobalMemrefOpLowering, MemRefCopyOpLowering>(arg: converter,
2011 args&: symbolTables);
2012 auto allocLowering = converter.getOptions().allocLowering;
2013 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
2014 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(arg: converter,
2015 args&: symbolTables);
2016 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
2017 patterns.add<AllocOpLowering, DeallocOpLowering>(arg: converter, args&: symbolTables);
2018}
2019
2020namespace {
2021struct FinalizeMemRefToLLVMConversionPass
2022 : public impl::FinalizeMemRefToLLVMConversionPassBase<
2023 FinalizeMemRefToLLVMConversionPass> {
2024 using FinalizeMemRefToLLVMConversionPassBase::
2025 FinalizeMemRefToLLVMConversionPassBase;
2026
2027 void runOnOperation() override {
2028 Operation *op = getOperation();
2029 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
2030 LowerToLLVMOptions options(&getContext(),
2031 dataLayoutAnalysis.getAtOrAbove(operation: op));
2032 options.allocLowering =
2033 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
2034 : LowerToLLVMOptions::AllocLowering::Malloc);
2035
2036 options.useGenericFunctions = useGenericFunctions;
2037
2038 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
2039 options.overrideIndexBitwidth(bitwidth: indexBitwidth);
2040
2041 LLVMTypeConverter typeConverter(&getContext(), options,
2042 &dataLayoutAnalysis);
2043 RewritePatternSet patterns(&getContext());
2044 SymbolTableCollection symbolTables;
2045 populateFinalizeMemRefToLLVMConversionPatterns(converter: typeConverter, patterns,
2046 symbolTables: &symbolTables);
2047 LLVMConversionTarget target(getContext());
2048 target.addLegalOp<func::FuncOp>();
2049 if (failed(Result: applyPartialConversion(op, target, patterns: std::move(patterns))))
2050 signalPassFailure();
2051 }
2052};
2053
2054/// Implement the interface to convert MemRef to LLVM.
2055struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
2056 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
2057 void loadDependentDialects(MLIRContext *context) const final {
2058 context->loadDialect<LLVM::LLVMDialect>();
2059 }
2060
2061 /// Hook for derived dialect interface to provide conversion patterns
2062 /// and mark dialect legal for the conversion target.
2063 void populateConvertToLLVMConversionPatterns(
2064 ConversionTarget &target, LLVMTypeConverter &typeConverter,
2065 RewritePatternSet &patterns) const final {
2066 populateFinalizeMemRefToLLVMConversionPatterns(converter: typeConverter, patterns);
2067 }
2068};
2069
2070} // namespace
2071
2072void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry &registry) {
2073 registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
2074 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
2075 });
2076}
2077

source code of mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp