1//===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10
11#include "mlir/Analysis/DataLayoutAnalysis.h"
12#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14#include "mlir/Conversion/LLVMCommon/Pattern.h"
15#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/BuiltinTypes.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/Pass/Pass.h"
27#include "llvm/ADT/SmallBitVector.h"
28#include "llvm/Support/MathExtras.h"
29#include <optional>
30
31namespace mlir {
32#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
33#include "mlir/Conversion/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37
38static constexpr LLVM::GEPNoWrapFlags kNoWrapFlags =
39 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw;
40
41namespace {
42
43static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
44 return !ShapedType::isDynamic(strideOrOffset);
45}
46
47static FailureOr<LLVM::LLVMFuncOp>
48getFreeFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
49 ModuleOp module) {
50 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
51
52 if (useGenericFn)
53 return LLVM::lookupOrCreateGenericFreeFn(b, moduleOp: module);
54
55 return LLVM::lookupOrCreateFreeFn(b, moduleOp: module);
56}
57
58static FailureOr<LLVM::LLVMFuncOp>
59getNotalignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
60 Operation *module, Type indexType) {
61 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
62 if (useGenericFn)
63 return LLVM::lookupOrCreateGenericAllocFn(b, moduleOp: module, indexType);
64
65 return LLVM::lookupOrCreateMallocFn(b, moduleOp: module, indexType);
66}
67
68static FailureOr<LLVM::LLVMFuncOp>
69getAlignedAllocFn(OpBuilder &b, const LLVMTypeConverter *typeConverter,
70 Operation *module, Type indexType) {
71 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
72
73 if (useGenericFn)
74 return LLVM::lookupOrCreateGenericAlignedAllocFn(b, moduleOp: module, indexType);
75
76 return LLVM::lookupOrCreateAlignedAllocFn(b, moduleOp: module, indexType);
77}
78
79/// Computes the aligned value for 'input' as follows:
80/// bumped = input + alignement - 1
81/// aligned = bumped - bumped % alignment
82static Value createAligned(ConversionPatternRewriter &rewriter, Location loc,
83 Value input, Value alignment) {
84 Value one = rewriter.create<LLVM::ConstantOp>(loc, alignment.getType(),
85 rewriter.getIndexAttr(1));
86 Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one);
87 Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump);
88 Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment);
89 return rewriter.create<LLVM::SubOp>(loc, bumped, mod);
90}
91
92/// Computes the byte size for the MemRef element type.
93static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
94 MemRefType memRefType, Operation *op,
95 const DataLayout *defaultLayout) {
96 const DataLayout *layout = defaultLayout;
97 if (const DataLayoutAnalysis *analysis =
98 typeConverter->getDataLayoutAnalysis()) {
99 layout = &analysis->getAbove(operation: op);
100 }
101 Type elementType = memRefType.getElementType();
102 if (auto memRefElementType = dyn_cast<MemRefType>(elementType))
103 return typeConverter->getMemRefDescriptorSize(type: memRefElementType, layout: *layout);
104 if (auto memRefElementType = dyn_cast<UnrankedMemRefType>(elementType))
105 return typeConverter->getUnrankedMemRefDescriptorSize(type: memRefElementType,
106 layout: *layout);
107 return layout->getTypeSize(t: elementType);
108}
109
110static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
111 Location loc, Value allocatedPtr,
112 MemRefType memRefType, Type elementPtrType,
113 const LLVMTypeConverter &typeConverter) {
114 auto allocatedPtrTy = cast<LLVM::LLVMPointerType>(allocatedPtr.getType());
115 FailureOr<unsigned> maybeMemrefAddrSpace =
116 typeConverter.getMemRefAddressSpace(type: memRefType);
117 assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
118 unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
119 if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
120 allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
121 loc, LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
122 allocatedPtr);
123 return allocatedPtr;
124}
125
126struct AllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
127 using ConvertOpToLLVMPattern<memref::AllocOp>::ConvertOpToLLVMPattern;
128
129 LogicalResult
130 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter) const override {
132 auto loc = op.getLoc();
133 MemRefType memRefType = op.getType();
134 if (!isConvertibleAndHasIdentityMaps(memRefType))
135 return rewriter.notifyMatchFailure(op, "incompatible memref type");
136
137 // Get or insert alloc function into the module.
138 FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
139 rewriter, getTypeConverter(),
140 op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
141 if (failed(Result: allocFuncOp))
142 return failure();
143
144 // Get actual sizes of the memref as values: static sizes are constant
145 // values and dynamic sizes are passed to 'alloc' as operands. In case of
146 // zero-dimensional memref, assume a scalar (size 1).
147 SmallVector<Value, 4> sizes;
148 SmallVector<Value, 4> strides;
149 Value sizeBytes;
150
151 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
152 rewriter, sizes, strides, sizeBytes, true);
153
154 Value alignment = getAlignment(rewriter, loc, op);
155 if (alignment) {
156 // Adjust the allocation size to consider alignment.
157 sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment);
158 }
159
160 // Allocate the underlying buffer.
161 Type elementPtrType = this->getElementPtrType(memRefType);
162 assert(elementPtrType && "could not compute element ptr type");
163 auto results =
164 rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
165
166 Value allocatedPtr =
167 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
168 elementPtrType, *getTypeConverter());
169 Value alignedPtr = allocatedPtr;
170 if (alignment) {
171 // Compute the aligned pointer.
172 Value allocatedInt =
173 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr);
174 Value alignmentInt =
175 createAligned(rewriter, loc, allocatedInt, alignment);
176 alignedPtr =
177 rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, alignmentInt);
178 }
179
180 // Create the MemRef descriptor.
181 auto memRefDescriptor = this->createMemRefDescriptor(
182 loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter);
183
184 // Return the final value of the descriptor.
185 rewriter.replaceOp(op, {memRefDescriptor});
186 return success();
187 }
188
189 /// Computes the alignment for the given memory allocation op.
190 template <typename OpType>
191 Value getAlignment(ConversionPatternRewriter &rewriter, Location loc,
192 OpType op) const {
193 MemRefType memRefType = op.getType();
194 Value alignment;
195 if (auto alignmentAttr = op.getAlignment()) {
196 Type indexType = getIndexType();
197 alignment =
198 createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
199 } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
200 // In the case where no alignment is specified, we may want to override
201 // `malloc's` behavior. `malloc` typically aligns at the size of the
202 // biggest scalar on a target HW. For non-scalars, use the natural
203 // alignment of the LLVM type given by the LLVM DataLayout.
204 alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter);
205 }
206 return alignment;
207 }
208};
209
210struct AlignedAllocOpLowering : public ConvertOpToLLVMPattern<memref::AllocOp> {
211 using ConvertOpToLLVMPattern<memref::AllocOp>::ConvertOpToLLVMPattern;
212
213 LogicalResult
214 matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
215 ConversionPatternRewriter &rewriter) const override {
216 auto loc = op.getLoc();
217 MemRefType memRefType = op.getType();
218 if (!isConvertibleAndHasIdentityMaps(memRefType))
219 return rewriter.notifyMatchFailure(op, "incompatible memref type");
220
221 // Get or insert alloc function into module.
222 FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
223 rewriter, getTypeConverter(),
224 op->getParentWithTrait<OpTrait::SymbolTable>(), getIndexType());
225 if (failed(Result: allocFuncOp))
226 return failure();
227
228 // Get actual sizes of the memref as values: static sizes are constant
229 // values and dynamic sizes are passed to 'alloc' as operands. In case of
230 // zero-dimensional memref, assume a scalar (size 1).
231 SmallVector<Value, 4> sizes;
232 SmallVector<Value, 4> strides;
233 Value sizeBytes;
234
235 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
236 rewriter, sizes, strides, sizeBytes, !false);
237
238 int64_t alignment = alignedAllocationGetAlignment(op, &defaultLayout);
239
240 Value allocAlignment =
241 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
242
243 // Function aligned_alloc requires size to be a multiple of alignment; we
244 // pad the size to the next multiple if necessary.
245 if (!isMemRefSizeMultipleOf(memRefType, alignment, op, &defaultLayout))
246 sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
247
248 Type elementPtrType = this->getElementPtrType(memRefType);
249 auto results = rewriter.create<LLVM::CallOp>(
250 loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
251
252 Value ptr =
253 castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
254 elementPtrType, *getTypeConverter());
255
256 // Create the MemRef descriptor.
257 auto memRefDescriptor = this->createMemRefDescriptor(
258 loc, memRefType, ptr, ptr, sizes, strides, rewriter);
259
260 // Return the final value of the descriptor.
261 rewriter.replaceOp(op, {memRefDescriptor});
262 return success();
263 }
264
265 /// The minimum alignment to use with aligned_alloc (has to be a power of 2).
266 static constexpr uint64_t kMinAlignedAllocAlignment = 16UL;
267
268 /// Computes the alignment for aligned_alloc used to allocate the buffer for
269 /// the memory allocation op.
270 ///
271 /// Aligned_alloc requires the allocation size to be a power of two, and the
272 /// allocation size to be a multiple of the alignment.
273 int64_t alignedAllocationGetAlignment(memref::AllocOp op,
274 const DataLayout *defaultLayout) const {
275 if (std::optional<uint64_t> alignment = op.getAlignment())
276 return *alignment;
277
278 // Whenever we don't have alignment set, we will use an alignment
279 // consistent with the element type; since the allocation size has to be a
280 // power of two, we will bump to the next power of two if it isn't.
281 unsigned eltSizeBytes = getMemRefEltSizeInBytes(
282 getTypeConverter(), op.getType(), op, defaultLayout);
283 return std::max(a: kMinAlignedAllocAlignment,
284 b: llvm::PowerOf2Ceil(A: eltSizeBytes));
285 }
286
287 /// Returns true if the memref size in bytes is known to be a multiple of
288 /// factor.
289 bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor, Operation *op,
290 const DataLayout *defaultLayout) const {
291 uint64_t sizeDivisor =
292 getMemRefEltSizeInBytes(getTypeConverter(), type, op, defaultLayout);
293 for (unsigned i = 0, e = type.getRank(); i < e; i++) {
294 if (type.isDynamicDim(i))
295 continue;
296 sizeDivisor = sizeDivisor * type.getDimSize(i);
297 }
298 return sizeDivisor % factor == 0;
299 }
300
301private:
302 /// Default layout to use in absence of the corresponding analysis.
303 DataLayout defaultLayout;
304};
305
306struct AllocaOpLowering : public ConvertOpToLLVMPattern<memref::AllocaOp> {
307 using ConvertOpToLLVMPattern<memref::AllocaOp>::ConvertOpToLLVMPattern;
308
309 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
310 /// is set to null for stack allocations. `accessAlignment` is set if
311 /// alignment is needed post allocation (for eg. in conjunction with malloc).
312 LogicalResult
313 matchAndRewrite(memref::AllocaOp op, OpAdaptor adaptor,
314 ConversionPatternRewriter &rewriter) const override {
315 auto loc = op.getLoc();
316 MemRefType memRefType = op.getType();
317 if (!isConvertibleAndHasIdentityMaps(memRefType))
318 return rewriter.notifyMatchFailure(op, "incompatible memref type");
319
320 // Get actual sizes of the memref as values: static sizes are constant
321 // values and dynamic sizes are passed to 'alloc' as operands. In case of
322 // zero-dimensional memref, assume a scalar (size 1).
323 SmallVector<Value, 4> sizes;
324 SmallVector<Value, 4> strides;
325 Value size;
326
327 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
328 rewriter, sizes, strides, size, !true);
329
330 // With alloca, one gets a pointer to the element type right away.
331 // For stack allocations.
332 auto elementType =
333 typeConverter->convertType(op.getType().getElementType());
334 FailureOr<unsigned> maybeAddressSpace =
335 getTypeConverter()->getMemRefAddressSpace(op.getType());
336 assert(succeeded(maybeAddressSpace) && "unsupported address space");
337 unsigned addrSpace = *maybeAddressSpace;
338 auto elementPtrType =
339 LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
340
341 auto allocatedElementPtr = rewriter.create<LLVM::AllocaOp>(
342 loc, elementPtrType, elementType, size, op.getAlignment().value_or(0));
343
344 // Create the MemRef descriptor.
345 auto memRefDescriptor = this->createMemRefDescriptor(
346 loc, memRefType, allocatedElementPtr, allocatedElementPtr, sizes,
347 strides, rewriter);
348
349 // Return the final value of the descriptor.
350 rewriter.replaceOp(op, {memRefDescriptor});
351 return success();
352 }
353};
354
355struct AllocaScopeOpLowering
356 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
357 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
358
359 LogicalResult
360 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
361 ConversionPatternRewriter &rewriter) const override {
362 OpBuilder::InsertionGuard guard(rewriter);
363 Location loc = allocaScopeOp.getLoc();
364
365 // Split the current block before the AllocaScopeOp to create the inlining
366 // point.
367 auto *currentBlock = rewriter.getInsertionBlock();
368 auto *remainingOpsBlock =
369 rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint());
370 Block *continueBlock;
371 if (allocaScopeOp.getNumResults() == 0) {
372 continueBlock = remainingOpsBlock;
373 } else {
374 continueBlock = rewriter.createBlock(
375 remainingOpsBlock, allocaScopeOp.getResultTypes(),
376 SmallVector<Location>(allocaScopeOp->getNumResults(),
377 allocaScopeOp.getLoc()));
378 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
379 }
380
381 // Inline body region.
382 Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
383 Block *afterBody = &allocaScopeOp.getBodyRegion().back();
384 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
385
386 // Save stack and then branch into the body of the region.
387 rewriter.setInsertionPointToEnd(currentBlock);
388 auto stackSaveOp =
389 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
390 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
391
392 // Replace the alloca_scope return with a branch that jumps out of the body.
393 // Stack restore before leaving the body region.
394 rewriter.setInsertionPointToEnd(afterBody);
395 auto returnOp =
396 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
397 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
398 returnOp, returnOp.getResults(), continueBlock);
399
400 // Insert stack restore before jumping out the body of the region.
401 rewriter.setInsertionPoint(branchOp);
402 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
403
404 // Replace the op with values return from the body region.
405 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
406
407 return success();
408 }
409};
410
411struct AssumeAlignmentOpLowering
412 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
413 using ConvertOpToLLVMPattern<
414 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
415 explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
416 : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
417
418 LogicalResult
419 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
420 ConversionPatternRewriter &rewriter) const override {
421 Value memref = adaptor.getMemref();
422 unsigned alignment = op.getAlignment();
423 auto loc = op.getLoc();
424
425 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
426 Value ptr = getStridedElementPtr(rewriter, loc, srcMemRefType, memref,
427 /*indices=*/{});
428
429 // Emit llvm.assume(true) ["align"(memref, alignment)].
430 // This is more direct than ptrtoint-based checks, is explicitly supported,
431 // and works with non-integral address spaces.
432 Value trueCond =
433 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getBoolAttr(true));
434 Value alignmentConst =
435 createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
436 rewriter.create<LLVM::AssumeOp>(loc, trueCond, LLVM::AssumeAlignTag(), ptr,
437 alignmentConst);
438 rewriter.replaceOp(op, memref);
439 return success();
440 }
441};
442
443// A `dealloc` is converted into a call to `free` on the underlying data buffer.
444// The memref descriptor being an SSA value, there is no need to clean it up
445// in any way.
446struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
447 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
448
449 explicit DeallocOpLowering(const LLVMTypeConverter &converter)
450 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
451
452 LogicalResult
453 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
454 ConversionPatternRewriter &rewriter) const override {
455 // Insert the `free` declaration if it is not already present.
456 FailureOr<LLVM::LLVMFuncOp> freeFunc = getFreeFn(
457 rewriter, getTypeConverter(), op->getParentOfType<ModuleOp>());
458 if (failed(Result: freeFunc))
459 return failure();
460 Value allocatedPtr;
461 if (auto unrankedTy =
462 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
463 auto elementPtrTy = LLVM::LLVMPointerType::get(
464 rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
465 allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
466 builder&: rewriter, loc: op.getLoc(),
467 memRefDescPtr: UnrankedMemRefDescriptor(adaptor.getMemref())
468 .memRefDescPtr(builder&: rewriter, loc: op.getLoc()),
469 elemPtrType: elementPtrTy);
470 } else {
471 allocatedPtr = MemRefDescriptor(adaptor.getMemref())
472 .allocatedPtr(builder&: rewriter, loc: op.getLoc());
473 }
474 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
475 allocatedPtr);
476 return success();
477 }
478};
479
480// A `dim` is converted to a constant for static sizes and to an access to the
481// size stored in the memref descriptor for dynamic sizes.
482struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
483 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
484
485 LogicalResult
486 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
487 ConversionPatternRewriter &rewriter) const override {
488 Type operandType = dimOp.getSource().getType();
489 if (isa<UnrankedMemRefType>(Val: operandType)) {
490 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
491 operandType, dimOp, adaptor.getOperands(), rewriter);
492 if (failed(Result: extractedSize))
493 return failure();
494 rewriter.replaceOp(dimOp, {*extractedSize});
495 return success();
496 }
497 if (isa<MemRefType>(Val: operandType)) {
498 rewriter.replaceOp(
499 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
500 adaptor.getOperands(), rewriter)});
501 return success();
502 }
503 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
504 }
505
506private:
507 FailureOr<Value>
508 extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
509 OpAdaptor adaptor,
510 ConversionPatternRewriter &rewriter) const {
511 Location loc = dimOp.getLoc();
512
513 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
514 auto scalarMemRefType =
515 MemRefType::get({}, unrankedMemRefType.getElementType());
516 FailureOr<unsigned> maybeAddressSpace =
517 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
518 if (failed(Result: maybeAddressSpace)) {
519 dimOp.emitOpError("memref memory space must be convertible to an integer "
520 "address space");
521 return failure();
522 }
523 unsigned addressSpace = *maybeAddressSpace;
524
525 // Extract pointer to the underlying ranked descriptor and bitcast it to a
526 // memref<element_type> descriptor pointer to minimize the number of GEP
527 // operations.
528 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
529 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(builder&: rewriter, loc);
530
531 Type elementType = typeConverter->convertType(scalarMemRefType);
532
533 // Get pointer to offset field of memref<element_type> descriptor.
534 auto indexPtrTy =
535 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
536 Value offsetPtr = rewriter.create<LLVM::GEPOp>(
537 loc, indexPtrTy, elementType, underlyingRankedDesc,
538 ArrayRef<LLVM::GEPArg>{0, 2});
539
540 // The size value that we have to extract can be obtained using GEPop with
541 // `dimOp.index() + 1` index argument.
542 Value idxPlusOne = rewriter.create<LLVM::AddOp>(
543 loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
544 adaptor.getIndex());
545 Value sizePtr = rewriter.create<LLVM::GEPOp>(
546 loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
547 idxPlusOne);
548 return rewriter
549 .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
550 .getResult();
551 }
552
553 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
554 if (auto idx = dimOp.getConstantIndex())
555 return idx;
556
557 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
558 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
559
560 return std::nullopt;
561 }
562
563 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
564 OpAdaptor adaptor,
565 ConversionPatternRewriter &rewriter) const {
566 Location loc = dimOp.getLoc();
567
568 // Take advantage if index is constant.
569 MemRefType memRefType = cast<MemRefType>(operandType);
570 Type indexType = getIndexType();
571 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
572 int64_t i = *index;
573 if (i >= 0 && i < memRefType.getRank()) {
574 if (memRefType.isDynamicDim(i)) {
575 // extract dynamic size from the memref descriptor.
576 MemRefDescriptor descriptor(adaptor.getSource());
577 return descriptor.size(builder&: rewriter, loc, pos: i);
578 }
579 // Use constant for static size.
580 int64_t dimSize = memRefType.getDimSize(i);
581 return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
582 }
583 }
584 Value index = adaptor.getIndex();
585 int64_t rank = memRefType.getRank();
586 MemRefDescriptor memrefDescriptor(adaptor.getSource());
587 return memrefDescriptor.size(builder&: rewriter, loc, pos: index, rank);
588 }
589};
590
591/// Common base for load and store operations on MemRefs. Restricts the match
592/// to supported MemRef types. Provides functionality to emit code accessing a
593/// specific element of the underlying data buffer.
594template <typename Derived>
595struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
596 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
597 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
598 using Base = LoadStoreOpLowering<Derived>;
599};
600
601/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
602/// retried until it succeeds in atomically storing a new value into memory.
603///
604/// +---------------------------------+
605/// | <code before the AtomicRMWOp> |
606/// | <compute initial %loaded> |
607/// | cf.br loop(%loaded) |
608/// +---------------------------------+
609/// |
610/// -------| |
611/// | v v
612/// | +--------------------------------+
613/// | | loop(%loaded): |
614/// | | <body contents> |
615/// | | %pair = cmpxchg |
616/// | | %ok = %pair[0] |
617/// | | %new = %pair[1] |
618/// | | cf.cond_br %ok, end, loop(%new) |
619/// | +--------------------------------+
620/// | | |
621/// |----------- |
622/// v
623/// +--------------------------------+
624/// | end: |
625/// | <code after the AtomicRMWOp> |
626/// +--------------------------------+
627///
628struct GenericAtomicRMWOpLowering
629 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
630 using Base::Base;
631
632 LogicalResult
633 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
634 ConversionPatternRewriter &rewriter) const override {
635 auto loc = atomicOp.getLoc();
636 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
637
638 // Split the block into initial, loop, and ending parts.
639 auto *initBlock = rewriter.getInsertionBlock();
640 auto *loopBlock = rewriter.splitBlock(block: initBlock, before: Block::iterator(atomicOp));
641 loopBlock->addArgument(valueType, loc);
642
643 auto *endBlock =
644 rewriter.splitBlock(block: loopBlock, before: Block::iterator(atomicOp)++);
645
646 // Compute the loaded value and branch to the loop block.
647 rewriter.setInsertionPointToEnd(initBlock);
648 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
649 auto dataPtr = getStridedElementPtr(
650 rewriter, loc, memRefType, adaptor.getMemref(), adaptor.getIndices());
651 Value init = rewriter.create<LLVM::LoadOp>(
652 loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
653 rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
654
655 // Prepare the body of the loop block.
656 rewriter.setInsertionPointToStart(loopBlock);
657
658 // Clone the GenericAtomicRMWOp region and extract the result.
659 auto loopArgument = loopBlock->getArgument(0);
660 IRMapping mapping;
661 mapping.map(atomicOp.getCurrentValue(), loopArgument);
662 Block &entryBlock = atomicOp.body().front();
663 for (auto &nestedOp : entryBlock.without_terminator()) {
664 Operation *clone = rewriter.clone(nestedOp, mapping);
665 mapping.map(nestedOp.getResults(), clone->getResults());
666 }
667 Value result = mapping.lookup(from: entryBlock.getTerminator()->getOperand(idx: 0));
668
669 // Prepare the epilog of the loop block.
670 // Append the cmpxchg op to the end of the loop block.
671 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
672 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
673 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
674 loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
675 // Extract the %new_loaded and %ok values from the pair.
676 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
677 Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
678
679 // Conditionally branch to the end or back to the loop depending on %ok.
680 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
681 loopBlock, newLoaded);
682
683 rewriter.setInsertionPointToEnd(endBlock);
684
685 // The 'result' of the atomic_rmw op is the newly loaded value.
686 rewriter.replaceOp(atomicOp, {newLoaded});
687
688 return success();
689 }
690};
691
692/// Returns the LLVM type of the global variable given the memref type `type`.
693static Type
694convertGlobalMemrefTypeToLLVM(MemRefType type,
695 const LLVMTypeConverter &typeConverter) {
696 // LLVM type for a global memref will be a multi-dimension array. For
697 // declarations or uninitialized global memrefs, we can potentially flatten
698 // this to a 1D array. However, for memref.global's with an initial value,
699 // we do not intend to flatten the ElementsAttribute when going from std ->
700 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
701 Type elementType = typeConverter.convertType(type.getElementType());
702 Type arrayTy = elementType;
703 // Shape has the outermost dim at index 0, so need to walk it backwards
704 for (int64_t dim : llvm::reverse(type.getShape()))
705 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
706 return arrayTy;
707}
708
709/// GlobalMemrefOp is lowered to a LLVM Global Variable.
710struct GlobalMemrefOpLowering
711 : public ConvertOpToLLVMPattern<memref::GlobalOp> {
712 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
713
714 LogicalResult
715 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
716 ConversionPatternRewriter &rewriter) const override {
717 MemRefType type = global.getType();
718 if (!isConvertibleAndHasIdentityMaps(type))
719 return failure();
720
721 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
722
723 LLVM::Linkage linkage =
724 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
725
726 Attribute initialValue = nullptr;
727 if (!global.isExternal() && !global.isUninitialized()) {
728 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
729 initialValue = elementsAttr;
730
731 // For scalar memrefs, the global variable created is of the element type,
732 // so unpack the elements attribute to extract the value.
733 if (type.getRank() == 0)
734 initialValue = elementsAttr.getSplatValue<Attribute>();
735 }
736
737 uint64_t alignment = global.getAlignment().value_or(0);
738 FailureOr<unsigned> addressSpace =
739 getTypeConverter()->getMemRefAddressSpace(type);
740 if (failed(Result: addressSpace))
741 return global.emitOpError(
742 "memory space cannot be converted to an integer address space");
743 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
744 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
745 initialValue, alignment, *addressSpace);
746 if (!global.isExternal() && global.isUninitialized()) {
747 rewriter.createBlock(&newGlobal.getInitializerRegion());
748 Value undef[] = {
749 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
750 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
751 }
752 return success();
753 }
754};
755
756/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
757/// the first element stashed into the descriptor. This reuses
758/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
759struct GetGlobalMemrefOpLowering
760 : public ConvertOpToLLVMPattern<memref::GetGlobalOp> {
761 using ConvertOpToLLVMPattern<memref::GetGlobalOp>::ConvertOpToLLVMPattern;
762
763 /// Buffer "allocation" for memref.get_global op is getting the address of
764 /// the global variable referenced.
765 LogicalResult
766 matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
767 ConversionPatternRewriter &rewriter) const override {
768 auto loc = op.getLoc();
769 MemRefType memRefType = op.getType();
770 if (!isConvertibleAndHasIdentityMaps(memRefType))
771 return rewriter.notifyMatchFailure(op, "incompatible memref type");
772
773 // Get actual sizes of the memref as values: static sizes are constant
774 // values and dynamic sizes are passed to 'alloc' as operands. In case of
775 // zero-dimensional memref, assume a scalar (size 1).
776 SmallVector<Value, 4> sizes;
777 SmallVector<Value, 4> strides;
778 Value sizeBytes;
779
780 this->getMemRefDescriptorSizes(loc, memRefType, adaptor.getOperands(),
781 rewriter, sizes, strides, sizeBytes, !false);
782
783 MemRefType type = cast<MemRefType>(op.getResult().getType());
784
785 // This is called after a type conversion, which would have failed if this
786 // call fails.
787 FailureOr<unsigned> maybeAddressSpace =
788 getTypeConverter()->getMemRefAddressSpace(type);
789 assert(succeeded(maybeAddressSpace) && "unsupported address space");
790 unsigned memSpace = *maybeAddressSpace;
791
792 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
793 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
794 auto addressOf =
795 rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, op.getName());
796
797 // Get the address of the first element in the array by creating a GEP with
798 // the address of the GV as the base, and (rank + 1) number of 0 indices.
799 auto gep = rewriter.create<LLVM::GEPOp>(
800 loc, ptrTy, arrayTy, addressOf,
801 SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
802
803 // We do not expect the memref obtained using `memref.get_global` to be
804 // ever deallocated. Set the allocated pointer to be known bad value to
805 // help debug if that ever happens.
806 auto intPtrType = getIntPtrType(memSpace);
807 Value deadBeefConst =
808 createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
809 auto deadBeefPtr =
810 rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
811
812 // Both allocated and aligned pointers are same. We could potentially stash
813 // a nullptr for the allocated pointer since we do not expect any dealloc.
814 // Create the MemRef descriptor.
815 auto memRefDescriptor = this->createMemRefDescriptor(
816 loc, memRefType, deadBeefPtr, gep, sizes, strides, rewriter);
817
818 // Return the final value of the descriptor.
819 rewriter.replaceOp(op, {memRefDescriptor});
820 return success();
821 }
822};
823
824// Load operation is lowered to obtaining a pointer to the indexed element
825// and loading it.
826struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
827 using Base::Base;
828
829 LogicalResult
830 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
831 ConversionPatternRewriter &rewriter) const override {
832 auto type = loadOp.getMemRefType();
833
834 // Per memref.load spec, the indices must be in-bounds:
835 // 0 <= idx < dim_size, and additionally all offsets are non-negative,
836 // hence inbounds and nuw are used when lowering to llvm.getelementptr.
837 Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
838 adaptor.getMemref(),
839 adaptor.getIndices(), kNoWrapFlags);
840 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
841 loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
842 false, loadOp.getNontemporal());
843 return success();
844 }
845};
846
847// Store operation is lowered to obtaining a pointer to the indexed element,
848// and storing the given value to it.
849struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
850 using Base::Base;
851
852 LogicalResult
853 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
854 ConversionPatternRewriter &rewriter) const override {
855 auto type = op.getMemRefType();
856
857 // Per memref.store spec, the indices must be in-bounds:
858 // 0 <= idx < dim_size, and additionally all offsets are non-negative,
859 // hence inbounds and nuw are used when lowering to llvm.getelementptr.
860 Value dataPtr =
861 getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
862 adaptor.getIndices(), kNoWrapFlags);
863 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
864 0, false, op.getNontemporal());
865 return success();
866 }
867};
868
869// The prefetch operation is lowered in a way similar to the load operation
870// except that the llvm.prefetch operation is used for replacement.
871struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
872 using Base::Base;
873
874 LogicalResult
875 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
876 ConversionPatternRewriter &rewriter) const override {
877 auto type = prefetchOp.getMemRefType();
878 auto loc = prefetchOp.getLoc();
879
880 Value dataPtr = getStridedElementPtr(
881 rewriter, loc, type, adaptor.getMemref(), adaptor.getIndices());
882
883 // Replace with llvm.prefetch.
884 IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
885 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
886 IntegerAttr isData =
887 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
888 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
889 localityHint, isData);
890 return success();
891 }
892};
893
894struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
895 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
896
897 LogicalResult
898 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
899 ConversionPatternRewriter &rewriter) const override {
900 Location loc = op.getLoc();
901 Type operandType = op.getMemref().getType();
902 if (isa<UnrankedMemRefType>(Val: operandType)) {
903 UnrankedMemRefDescriptor desc(adaptor.getMemref());
904 rewriter.replaceOp(op, {desc.rank(builder&: rewriter, loc)});
905 return success();
906 }
907 if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
908 Type indexType = getIndexType();
909 rewriter.replaceOp(op,
910 {createIndexAttrConstant(rewriter, loc, indexType,
911 rankedMemRefType.getRank())});
912 return success();
913 }
914 return failure();
915 }
916};
917
918struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
919 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
920
921 LogicalResult
922 matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
923 ConversionPatternRewriter &rewriter) const override {
924 Type srcType = memRefCastOp.getOperand().getType();
925 Type dstType = memRefCastOp.getType();
926
927 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
928 // used for type erasure. For now they must preserve underlying element type
929 // and require source and result type to have the same rank. Therefore,
930 // perform a sanity check that the underlying structs are the same. Once op
931 // semantics are relaxed we can revisit.
932 if (isa<MemRefType>(Val: srcType) && isa<MemRefType>(Val: dstType))
933 if (typeConverter->convertType(srcType) !=
934 typeConverter->convertType(dstType))
935 return failure();
936
937 // Unranked to unranked cast is disallowed
938 if (isa<UnrankedMemRefType>(Val: srcType) && isa<UnrankedMemRefType>(Val: dstType))
939 return failure();
940
941 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
942 auto loc = memRefCastOp.getLoc();
943
944 // For ranked/ranked case, just keep the original descriptor.
945 if (isa<MemRefType>(Val: srcType) && isa<MemRefType>(Val: dstType)) {
946 rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
947 return success();
948 }
949
950 if (isa<MemRefType>(Val: srcType) && isa<UnrankedMemRefType>(Val: dstType)) {
951 // Casting ranked to unranked memref type
952 // Set the rank in the destination from the memref type
953 // Allocate space on the stack and copy the src memref descriptor
954 // Set the ptr in the destination to the stack space
955 auto srcMemRefType = cast<MemRefType>(srcType);
956 int64_t rank = srcMemRefType.getRank();
957 // ptr = AllocaOp sizeof(MemRefDescriptor)
958 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
959 loc, adaptor.getSource(), rewriter);
960
961 // rank = ConstantOp srcRank
962 auto rankVal = rewriter.create<LLVM::ConstantOp>(
963 loc, getIndexType(), rewriter.getIndexAttr(rank));
964 // poison = PoisonOp
965 UnrankedMemRefDescriptor memRefDesc =
966 UnrankedMemRefDescriptor::poison(builder&: rewriter, loc: loc, descriptorType: targetStructType);
967 // d1 = InsertValueOp poison, rank, 0
968 memRefDesc.setRank(builder&: rewriter, loc: loc, value: rankVal);
969 // d2 = InsertValueOp d1, ptr, 1
970 memRefDesc.setMemRefDescPtr(builder&: rewriter, loc: loc, value: ptr);
971 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
972
973 } else if (isa<UnrankedMemRefType>(Val: srcType) && isa<MemRefType>(Val: dstType)) {
974 // Casting from unranked type to ranked.
975 // The operation is assumed to be doing a correct cast. If the destination
976 // type mismatches the unranked the type, it is undefined behavior.
977 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
978 // ptr = ExtractValueOp src, 1
979 auto ptr = memRefDesc.memRefDescPtr(builder&: rewriter, loc: loc);
980
981 // struct = LoadOp ptr
982 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
983 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
984 } else {
985 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
986 }
987
988 return success();
989 }
990};
991
992/// Pattern to lower a `memref.copy` to llvm.
993///
994/// For memrefs with identity layouts, the copy is lowered to the llvm
995/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
996/// to the generic `MemrefCopyFn`.
997struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
998 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
999
1000 LogicalResult
1001 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
1002 ConversionPatternRewriter &rewriter) const {
1003 auto loc = op.getLoc();
1004 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
1005
1006 MemRefDescriptor srcDesc(adaptor.getSource());
1007
1008 // Compute number of elements.
1009 Value numElements = rewriter.create<LLVM::ConstantOp>(
1010 loc, getIndexType(), rewriter.getIndexAttr(1));
1011 for (int pos = 0; pos < srcType.getRank(); ++pos) {
1012 auto size = srcDesc.size(rewriter, loc, pos);
1013 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
1014 }
1015
1016 // Get element size.
1017 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1018 // Compute total.
1019 Value totalSize =
1020 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
1021
1022 Type elementType = typeConverter->convertType(srcType.getElementType());
1023
1024 Value srcBasePtr = srcDesc.alignedPtr(builder&: rewriter, loc: loc);
1025 Value srcOffset = srcDesc.offset(builder&: rewriter, loc: loc);
1026 Value srcPtr = rewriter.create<LLVM::GEPOp>(
1027 loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
1028 MemRefDescriptor targetDesc(adaptor.getTarget());
1029 Value targetBasePtr = targetDesc.alignedPtr(builder&: rewriter, loc: loc);
1030 Value targetOffset = targetDesc.offset(builder&: rewriter, loc: loc);
1031 Value targetPtr = rewriter.create<LLVM::GEPOp>(
1032 loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
1033 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
1034 /*isVolatile=*/false);
1035 rewriter.eraseOp(op: op);
1036
1037 return success();
1038 }
1039
1040 LogicalResult
1041 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
1042 ConversionPatternRewriter &rewriter) const {
1043 auto loc = op.getLoc();
1044 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1045 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1046
1047 // First make sure we have an unranked memref descriptor representation.
1048 auto makeUnranked = [&, this](Value ranked, MemRefType type) {
1049 auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1050 type.getRank());
1051 auto *typeConverter = getTypeConverter();
1052 auto ptr =
1053 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
1054
1055 auto unrankedType =
1056 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
1057 return UnrankedMemRefDescriptor::pack(
1058 rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
1059 };
1060
1061 // Save stack position before promoting descriptors
1062 auto stackSaveOp =
1063 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
1064
1065 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
1066 Value unrankedSource =
1067 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
1068 : adaptor.getSource();
1069 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
1070 Value unrankedTarget =
1071 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
1072 : adaptor.getTarget();
1073
1074 // Now promote the unranked descriptors to the stack.
1075 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
1076 rewriter.getIndexAttr(1));
1077 auto promote = [&](Value desc) {
1078 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1079 auto allocated =
1080 rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
1081 rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
1082 return allocated;
1083 };
1084
1085 auto sourcePtr = promote(unrankedSource);
1086 auto targetPtr = promote(unrankedTarget);
1087
1088 // Derive size from llvm.getelementptr which will account for any
1089 // potential alignment
1090 auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
1091 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
1092 b&: rewriter, moduleOp: op->getParentOfType<ModuleOp>(), indexType: getIndexType(),
1093 unrankedDescriptorType: sourcePtr.getType());
1094 if (failed(copyFn))
1095 return failure();
1096 rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
1097 ValueRange{elemSize, sourcePtr, targetPtr});
1098
1099 // Restore stack used for descriptors
1100 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
1101
1102 rewriter.eraseOp(op: op);
1103
1104 return success();
1105 }
1106
1107 LogicalResult
1108 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
1109 ConversionPatternRewriter &rewriter) const override {
1110 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
1111 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
1112
1113 auto isContiguousMemrefType = [&](BaseMemRefType type) {
1114 auto memrefType = dyn_cast<mlir::MemRefType>(type);
1115 // We can use memcpy for memrefs if they have an identity layout or are
1116 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
1117 // special case handled by memrefCopy.
1118 return memrefType &&
1119 (memrefType.getLayout().isIdentity() ||
1120 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
1121 memref::isStaticShapeAndContiguousRowMajor(memrefType)));
1122 };
1123
1124 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
1125 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
1126
1127 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
1128 }
1129};
1130
1131struct MemorySpaceCastOpLowering
1132 : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
1133 using ConvertOpToLLVMPattern<
1134 memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
1135
1136 LogicalResult
1137 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
1138 ConversionPatternRewriter &rewriter) const override {
1139 Location loc = op.getLoc();
1140
1141 Type resultType = op.getDest().getType();
1142 if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
1143 auto resultDescType =
1144 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
1145 Type newPtrType = resultDescType.getBody()[0];
1146
1147 SmallVector<Value> descVals;
1148 MemRefDescriptor::unpack(builder&: rewriter, loc, packed: adaptor.getSource(), type: resultTypeR,
1149 results&: descVals);
1150 descVals[0] =
1151 rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
1152 descVals[1] =
1153 rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
1154 Value result = MemRefDescriptor::pack(builder&: rewriter, loc, converter: *getTypeConverter(),
1155 type: resultTypeR, values: descVals);
1156 rewriter.replaceOp(op, result);
1157 return success();
1158 }
1159 if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
1160 // Since the type converter won't be doing this for us, get the address
1161 // space.
1162 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
1163 FailureOr<unsigned> maybeSourceAddrSpace =
1164 getTypeConverter()->getMemRefAddressSpace(sourceType);
1165 if (failed(Result: maybeSourceAddrSpace))
1166 return rewriter.notifyMatchFailure(arg&: loc,
1167 msg: "non-integer source address space");
1168 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
1169 FailureOr<unsigned> maybeResultAddrSpace =
1170 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
1171 if (failed(Result: maybeResultAddrSpace))
1172 return rewriter.notifyMatchFailure(arg&: loc,
1173 msg: "non-integer result address space");
1174 unsigned resultAddrSpace = *maybeResultAddrSpace;
1175
1176 UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
1177 Value rank = sourceDesc.rank(builder&: rewriter, loc);
1178 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(builder&: rewriter, loc);
1179
1180 // Create and allocate storage for new memref descriptor.
1181 auto result = UnrankedMemRefDescriptor::poison(
1182 rewriter, loc, typeConverter->convertType(resultTypeU));
1183 result.setRank(rewriter, loc, rank);
1184 SmallVector<Value, 1> sizes;
1185 UnrankedMemRefDescriptor::computeSizes(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1186 values: result, addressSpaces: resultAddrSpace, sizes);
1187 Value resultUnderlyingSize = sizes.front();
1188 Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
1189 loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
1190 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
1191
1192 // Copy pointers, performing address space casts.
1193 auto sourceElemPtrType =
1194 LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
1195 auto resultElemPtrType =
1196 LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
1197
1198 Value allocatedPtr = sourceDesc.allocatedPtr(
1199 builder&: rewriter, loc, memRefDescPtr: sourceUnderlyingDesc, elemPtrType: sourceElemPtrType);
1200 Value alignedPtr =
1201 sourceDesc.alignedPtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1202 memRefDescPtr: sourceUnderlyingDesc, elemPtrType: sourceElemPtrType);
1203 allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1204 loc, resultElemPtrType, allocatedPtr);
1205 alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
1206 loc, resultElemPtrType, alignedPtr);
1207
1208 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
1209 resultElemPtrType, allocatedPtr);
1210 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
1211 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
1212
1213 // Copy all the index-valued operands.
1214 Value sourceIndexVals =
1215 sourceDesc.offsetBasePtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1216 memRefDescPtr: sourceUnderlyingDesc, elemPtrType: sourceElemPtrType);
1217 Value resultIndexVals =
1218 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
1219 resultUnderlyingDesc, resultElemPtrType);
1220
1221 int64_t bytesToSkip =
1222 2 * llvm::divideCeil(
1223 getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
1224 Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
1225 loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
1226 Value copySize = rewriter.create<LLVM::SubOp>(
1227 loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
1228 rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
1229 copySize, /*isVolatile=*/false);
1230
1231 rewriter.replaceOp(op, ValueRange{result});
1232 return success();
1233 }
1234 return rewriter.notifyMatchFailure(arg&: loc, msg: "unexpected memref type");
1235 }
1236};
1237
1238/// Extracts allocated, aligned pointers and offset from a ranked or unranked
1239/// memref type. In unranked case, the fields are extracted from the underlying
1240/// ranked descriptor.
1241static void extractPointersAndOffset(Location loc,
1242 ConversionPatternRewriter &rewriter,
1243 const LLVMTypeConverter &typeConverter,
1244 Value originalOperand,
1245 Value convertedOperand,
1246 Value *allocatedPtr, Value *alignedPtr,
1247 Value *offset = nullptr) {
1248 Type operandType = originalOperand.getType();
1249 if (isa<MemRefType>(Val: operandType)) {
1250 MemRefDescriptor desc(convertedOperand);
1251 *allocatedPtr = desc.allocatedPtr(builder&: rewriter, loc);
1252 *alignedPtr = desc.alignedPtr(builder&: rewriter, loc);
1253 if (offset != nullptr)
1254 *offset = desc.offset(builder&: rewriter, loc);
1255 return;
1256 }
1257
1258 // These will all cause assert()s on unconvertible types.
1259 unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1260 type: cast<UnrankedMemRefType>(operandType));
1261 auto elementPtrType =
1262 LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1263
1264 // Extract pointer to the underlying ranked memref descriptor and cast it to
1265 // ElemType**.
1266 UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1267 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(builder&: rewriter, loc);
1268
1269 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
1270 builder&: rewriter, loc, memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType);
1271 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1272 builder&: rewriter, loc, typeConverter, memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType);
1273 if (offset != nullptr) {
1274 *offset = UnrankedMemRefDescriptor::offset(
1275 builder&: rewriter, loc, typeConverter, memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType);
1276 }
1277}
1278
1279struct MemRefReinterpretCastOpLowering
1280 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1281 using ConvertOpToLLVMPattern<
1282 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1283
1284 LogicalResult
1285 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1286 ConversionPatternRewriter &rewriter) const override {
1287 Type srcType = castOp.getSource().getType();
1288
1289 Value descriptor;
1290 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1291 adaptor, &descriptor)))
1292 return failure();
1293 rewriter.replaceOp(castOp, {descriptor});
1294 return success();
1295 }
1296
1297private:
1298 LogicalResult convertSourceMemRefToDescriptor(
1299 ConversionPatternRewriter &rewriter, Type srcType,
1300 memref::ReinterpretCastOp castOp,
1301 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1302 MemRefType targetMemRefType =
1303 cast<MemRefType>(castOp.getResult().getType());
1304 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1305 typeConverter->convertType(targetMemRefType));
1306 if (!llvmTargetDescriptorTy)
1307 return failure();
1308
1309 // Create descriptor.
1310 Location loc = castOp.getLoc();
1311 auto desc = MemRefDescriptor::poison(builder&: rewriter, loc, descriptorType: llvmTargetDescriptorTy);
1312
1313 // Set allocated and aligned pointers.
1314 Value allocatedPtr, alignedPtr;
1315 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1316 castOp.getSource(), adaptor.getSource(),
1317 &allocatedPtr, &alignedPtr);
1318 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1319 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1320
1321 // Set offset.
1322 if (castOp.isDynamicOffset(0))
1323 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1324 else
1325 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1326
1327 // Set sizes and strides.
1328 unsigned dynSizeId = 0;
1329 unsigned dynStrideId = 0;
1330 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1331 if (castOp.isDynamicSize(i))
1332 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1333 else
1334 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1335
1336 if (castOp.isDynamicStride(i))
1337 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1338 else
1339 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1340 }
1341 *descriptor = desc;
1342 return success();
1343 }
1344};
1345
1346struct MemRefReshapeOpLowering
1347 : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1348 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1349
1350 LogicalResult
1351 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1352 ConversionPatternRewriter &rewriter) const override {
1353 Type srcType = reshapeOp.getSource().getType();
1354
1355 Value descriptor;
1356 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1357 adaptor, &descriptor)))
1358 return failure();
1359 rewriter.replaceOp(reshapeOp, {descriptor});
1360 return success();
1361 }
1362
1363private:
1364 LogicalResult
1365 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1366 Type srcType, memref::ReshapeOp reshapeOp,
1367 memref::ReshapeOp::Adaptor adaptor,
1368 Value *descriptor) const {
1369 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1370 if (shapeMemRefType.hasStaticShape()) {
1371 MemRefType targetMemRefType =
1372 cast<MemRefType>(reshapeOp.getResult().getType());
1373 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1374 typeConverter->convertType(targetMemRefType));
1375 if (!llvmTargetDescriptorTy)
1376 return failure();
1377
1378 // Create descriptor.
1379 Location loc = reshapeOp.getLoc();
1380 auto desc =
1381 MemRefDescriptor::poison(builder&: rewriter, loc, descriptorType: llvmTargetDescriptorTy);
1382
1383 // Set allocated and aligned pointers.
1384 Value allocatedPtr, alignedPtr;
1385 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1386 reshapeOp.getSource(), adaptor.getSource(),
1387 &allocatedPtr, &alignedPtr);
1388 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1389 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1390
1391 // Extract the offset and strides from the type.
1392 int64_t offset;
1393 SmallVector<int64_t> strides;
1394 if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
1395 return rewriter.notifyMatchFailure(
1396 reshapeOp, "failed to get stride and offset exprs");
1397
1398 if (!isStaticStrideOrOffset(strideOrOffset: offset))
1399 return rewriter.notifyMatchFailure(reshapeOp,
1400 "dynamic offset is unsupported");
1401
1402 desc.setConstantOffset(rewriter, loc, offset);
1403
1404 assert(targetMemRefType.getLayout().isIdentity() &&
1405 "Identity layout map is a precondition of a valid reshape op");
1406
1407 Type indexType = getIndexType();
1408 Value stride = nullptr;
1409 int64_t targetRank = targetMemRefType.getRank();
1410 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1411 if (!ShapedType::isDynamic(strides[i])) {
1412 // If the stride for this dimension is dynamic, then use the product
1413 // of the sizes of the inner dimensions.
1414 stride =
1415 createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1416 } else if (!stride) {
1417 // `stride` is null only in the first iteration of the loop. However,
1418 // since the target memref has an identity layout, we can safely set
1419 // the innermost stride to 1.
1420 stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1421 }
1422
1423 Value dimSize;
1424 // If the size of this dimension is dynamic, then load it at runtime
1425 // from the shape operand.
1426 if (!targetMemRefType.isDynamicDim(i)) {
1427 dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1428 targetMemRefType.getDimSize(i));
1429 } else {
1430 Value shapeOp = reshapeOp.getShape();
1431 Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1432 dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1433 Type indexType = getIndexType();
1434 if (dimSize.getType() != indexType)
1435 dimSize = typeConverter->materializeTargetConversion(
1436 rewriter, loc, indexType, dimSize);
1437 assert(dimSize && "Invalid memref element type");
1438 }
1439
1440 desc.setSize(rewriter, loc, i, dimSize);
1441 desc.setStride(rewriter, loc, i, stride);
1442
1443 // Prepare the stride value for the next dimension.
1444 stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1445 }
1446
1447 *descriptor = desc;
1448 return success();
1449 }
1450
1451 // The shape is a rank-1 tensor with unknown length.
1452 Location loc = reshapeOp.getLoc();
1453 MemRefDescriptor shapeDesc(adaptor.getShape());
1454 Value resultRank = shapeDesc.size(builder&: rewriter, loc, pos: 0);
1455
1456 // Extract address space and element type.
1457 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1458 unsigned addressSpace =
1459 *getTypeConverter()->getMemRefAddressSpace(targetType);
1460
1461 // Create the unranked memref descriptor that holds the ranked one. The
1462 // inner descriptor is allocated on stack.
1463 auto targetDesc = UnrankedMemRefDescriptor::poison(
1464 rewriter, loc, typeConverter->convertType(targetType));
1465 targetDesc.setRank(rewriter, loc, resultRank);
1466 SmallVector<Value, 4> sizes;
1467 UnrankedMemRefDescriptor::computeSizes(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1468 values: targetDesc, addressSpaces: addressSpace, sizes);
1469 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1470 loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1471 sizes.front());
1472 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1473
1474 // Extract pointers and offset from the source memref.
1475 Value allocatedPtr, alignedPtr, offset;
1476 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1477 reshapeOp.getSource(), adaptor.getSource(),
1478 &allocatedPtr, &alignedPtr, &offset);
1479
1480 // Set pointers and offset.
1481 auto elementPtrType =
1482 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1483
1484 UnrankedMemRefDescriptor::setAllocatedPtr(builder&: rewriter, loc, memRefDescPtr: underlyingDescPtr,
1485 elemPtrType: elementPtrType, allocatedPtr);
1486 UnrankedMemRefDescriptor::setAlignedPtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1487 memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType,
1488 alignedPtr);
1489 UnrankedMemRefDescriptor::setOffset(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1490 memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType,
1491 offset);
1492
1493 // Use the offset pointer as base for further addressing. Copy over the new
1494 // shape and compute strides. For this, we create a loop from rank-1 to 0.
1495 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1496 builder&: rewriter, loc, typeConverter: *getTypeConverter(), memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType);
1497 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1498 builder&: rewriter, loc, typeConverter: *getTypeConverter(), sizeBasePtr: targetSizesBase, rank: resultRank);
1499 Value shapeOperandPtr = shapeDesc.alignedPtr(builder&: rewriter, loc);
1500 Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1501 Value resultRankMinusOne =
1502 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1503
1504 Block *initBlock = rewriter.getInsertionBlock();
1505 Type indexType = getTypeConverter()->getIndexType();
1506 Block::iterator remainingOpsIt = std::next(x: rewriter.getInsertionPoint());
1507
1508 Block *condBlock = rewriter.createBlock(parent: initBlock->getParent(), insertPt: {},
1509 argTypes: {indexType, indexType}, locs: {loc, loc});
1510
1511 // Move the remaining initBlock ops to condBlock.
1512 Block *remainingBlock = rewriter.splitBlock(block: initBlock, before: remainingOpsIt);
1513 rewriter.mergeBlocks(source: remainingBlock, dest: condBlock, argValues: ValueRange());
1514
1515 rewriter.setInsertionPointToEnd(initBlock);
1516 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1517 condBlock);
1518 rewriter.setInsertionPointToStart(condBlock);
1519 Value indexArg = condBlock->getArgument(i: 0);
1520 Value strideArg = condBlock->getArgument(i: 1);
1521
1522 Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1523 Value pred = rewriter.create<LLVM::ICmpOp>(
1524 loc, IntegerType::get(rewriter.getContext(), 1),
1525 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1526
1527 Block *bodyBlock =
1528 rewriter.splitBlock(block: condBlock, before: rewriter.getInsertionPoint());
1529 rewriter.setInsertionPointToStart(bodyBlock);
1530
1531 // Copy size from shape to descriptor.
1532 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1533 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1534 loc, llvmIndexPtrType,
1535 typeConverter->convertType(shapeMemRefType.getElementType()),
1536 shapeOperandPtr, indexArg);
1537 Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1538 UnrankedMemRefDescriptor::setSize(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1539 sizeBasePtr: targetSizesBase, index: indexArg, size);
1540
1541 // Write stride value and compute next one.
1542 UnrankedMemRefDescriptor::setStride(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1543 strideBasePtr: targetStridesBase, index: indexArg, stride: strideArg);
1544 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1545
1546 // Decrement loop counter and branch back.
1547 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1548 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1549 condBlock);
1550
1551 Block *remainder =
1552 rewriter.splitBlock(block: bodyBlock, before: rewriter.getInsertionPoint());
1553
1554 // Hook up the cond exit to the remainder.
1555 rewriter.setInsertionPointToEnd(condBlock);
1556 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1557 remainder, std::nullopt);
1558
1559 // Reset position to beginning of new remainder block.
1560 rewriter.setInsertionPointToStart(remainder);
1561
1562 *descriptor = targetDesc;
1563 return success();
1564 }
1565};
1566
1567/// RessociatingReshapeOp must be expanded before we reach this stage.
1568/// Report that information.
1569template <typename ReshapeOp>
1570class ReassociatingReshapeOpConversion
1571 : public ConvertOpToLLVMPattern<ReshapeOp> {
1572public:
1573 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1574 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1575
1576 LogicalResult
1577 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1578 ConversionPatternRewriter &rewriter) const override {
1579 return rewriter.notifyMatchFailure(
1580 reshapeOp,
1581 "reassociation operations should have been expanded beforehand");
1582 }
1583};
1584
1585/// Subviews must be expanded before we reach this stage.
1586/// Report that information.
1587struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1588 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1589
1590 LogicalResult
1591 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1592 ConversionPatternRewriter &rewriter) const override {
1593 return rewriter.notifyMatchFailure(
1594 subViewOp, "subview operations should have been expanded beforehand");
1595 }
1596};
1597
1598/// Conversion pattern that transforms a transpose op into:
1599/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1600/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1601/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1602/// and stride. Size and stride are permutations of the original values.
1603/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1604/// The transpose op is replaced by the alloca'ed pointer.
1605class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1606public:
1607 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1608
1609 LogicalResult
1610 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1611 ConversionPatternRewriter &rewriter) const override {
1612 auto loc = transposeOp.getLoc();
1613 MemRefDescriptor viewMemRef(adaptor.getIn());
1614
1615 // No permutation, early exit.
1616 if (transposeOp.getPermutation().isIdentity())
1617 return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1618
1619 auto targetMemRef = MemRefDescriptor::poison(
1620 rewriter, loc,
1621 typeConverter->convertType(transposeOp.getIn().getType()));
1622
1623 // Copy the base and aligned pointers from the old descriptor to the new
1624 // one.
1625 targetMemRef.setAllocatedPtr(rewriter, loc,
1626 viewMemRef.allocatedPtr(builder&: rewriter, loc: loc));
1627 targetMemRef.setAlignedPtr(rewriter, loc,
1628 viewMemRef.alignedPtr(builder&: rewriter, loc: loc));
1629
1630 // Copy the offset pointer from the old descriptor to the new one.
1631 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(builder&: rewriter, loc: loc));
1632
1633 // Iterate over the dimensions and apply size/stride permutation:
1634 // When enumerating the results of the permutation map, the enumeration
1635 // index is the index into the target dimensions and the DimExpr points to
1636 // the dimension of the source memref.
1637 for (const auto &en :
1638 llvm::enumerate(transposeOp.getPermutation().getResults())) {
1639 int targetPos = en.index();
1640 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1641 targetMemRef.setSize(rewriter, loc, targetPos,
1642 viewMemRef.size(rewriter, loc, sourcePos));
1643 targetMemRef.setStride(rewriter, loc, targetPos,
1644 viewMemRef.stride(rewriter, loc, sourcePos));
1645 }
1646
1647 rewriter.replaceOp(transposeOp, {targetMemRef});
1648 return success();
1649 }
1650};
1651
1652/// Conversion pattern that transforms an op into:
1653/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1654/// 2. Updates to the descriptor to introduce the data ptr, offset, size
1655/// and stride.
1656/// The view op is replaced by the descriptor.
1657struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1658 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1659
1660 // Build and return the value for the idx^th shape dimension, either by
1661 // returning the constant shape dimension or counting the proper dynamic size.
1662 Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1663 ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1664 Type indexType) const {
1665 assert(idx < shape.size());
1666 if (!ShapedType::isDynamic(shape[idx]))
1667 return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1668 // Count the number of dynamic dims in range [0, idx]
1669 unsigned nDynamic =
1670 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1671 return dynamicSizes[nDynamic];
1672 }
1673
1674 // Build and return the idx^th stride, either by returning the constant stride
1675 // or by computing the dynamic stride from the current `runningStride` and
1676 // `nextSize`. The caller should keep a running stride and update it with the
1677 // result returned by this function.
1678 Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1679 ArrayRef<int64_t> strides, Value nextSize,
1680 Value runningStride, unsigned idx, Type indexType) const {
1681 assert(idx < strides.size());
1682 if (!ShapedType::isDynamic(strides[idx]))
1683 return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1684 if (nextSize)
1685 return runningStride
1686 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1687 : nextSize;
1688 assert(!runningStride);
1689 return createIndexAttrConstant(rewriter, loc, indexType, 1);
1690 }
1691
1692 LogicalResult
1693 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1694 ConversionPatternRewriter &rewriter) const override {
1695 auto loc = viewOp.getLoc();
1696
1697 auto viewMemRefType = viewOp.getType();
1698 auto targetElementTy =
1699 typeConverter->convertType(viewMemRefType.getElementType());
1700 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1701 if (!targetDescTy || !targetElementTy ||
1702 !LLVM::isCompatibleType(type: targetElementTy) ||
1703 !LLVM::isCompatibleType(type: targetDescTy))
1704 return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1705 failure();
1706
1707 int64_t offset;
1708 SmallVector<int64_t, 4> strides;
1709 auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
1710 if (failed(successStrides))
1711 return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1712 assert(offset == 0 && "expected offset to be 0");
1713
1714 // Target memref must be contiguous in memory (innermost stride is 1), or
1715 // empty (special case when at least one of the memref dimensions is 0).
1716 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1717 return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1718 failure();
1719
1720 // Create the descriptor.
1721 MemRefDescriptor sourceMemRef(adaptor.getSource());
1722 auto targetMemRef = MemRefDescriptor::poison(builder&: rewriter, loc: loc, descriptorType: targetDescTy);
1723
1724 // Field 1: Copy the allocated pointer, used for malloc/free.
1725 Value allocatedPtr = sourceMemRef.allocatedPtr(builder&: rewriter, loc: loc);
1726 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1727 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1728
1729 // Field 2: Copy the actual aligned pointer to payload.
1730 Value alignedPtr = sourceMemRef.alignedPtr(builder&: rewriter, loc: loc);
1731 alignedPtr = rewriter.create<LLVM::GEPOp>(
1732 loc, alignedPtr.getType(),
1733 typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1734 adaptor.getByteShift());
1735
1736 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1737
1738 Type indexType = getIndexType();
1739 // Field 3: The offset in the resulting type must be 0. This is
1740 // because of the type change: an offset on srcType* may not be
1741 // expressible as an offset on dstType*.
1742 targetMemRef.setOffset(
1743 rewriter, loc,
1744 createIndexAttrConstant(rewriter, loc, indexType, offset));
1745
1746 // Early exit for 0-D corner case.
1747 if (viewMemRefType.getRank() == 0)
1748 return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1749
1750 // Fields 4 and 5: Update sizes and strides.
1751 Value stride = nullptr, nextSize = nullptr;
1752 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1753 // Update size.
1754 Value size = getSize(rewriter, loc: loc, shape: viewMemRefType.getShape(),
1755 dynamicSizes: adaptor.getSizes(), idx: i, indexType);
1756 targetMemRef.setSize(rewriter, loc, i, size);
1757 // Update stride.
1758 stride =
1759 getStride(rewriter, loc: loc, strides, nextSize, runningStride: stride, idx: i, indexType);
1760 targetMemRef.setStride(rewriter, loc, i, stride);
1761 nextSize = size;
1762 }
1763
1764 rewriter.replaceOp(viewOp, {targetMemRef});
1765 return success();
1766 }
1767};
1768
1769//===----------------------------------------------------------------------===//
1770// AtomicRMWOpLowering
1771//===----------------------------------------------------------------------===//
1772
1773/// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1774/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1775static std::optional<LLVM::AtomicBinOp>
1776matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1777 switch (atomicOp.getKind()) {
1778 case arith::AtomicRMWKind::addf:
1779 return LLVM::AtomicBinOp::fadd;
1780 case arith::AtomicRMWKind::addi:
1781 return LLVM::AtomicBinOp::add;
1782 case arith::AtomicRMWKind::assign:
1783 return LLVM::AtomicBinOp::xchg;
1784 case arith::AtomicRMWKind::maximumf:
1785 return LLVM::AtomicBinOp::fmax;
1786 case arith::AtomicRMWKind::maxs:
1787 return LLVM::AtomicBinOp::max;
1788 case arith::AtomicRMWKind::maxu:
1789 return LLVM::AtomicBinOp::umax;
1790 case arith::AtomicRMWKind::minimumf:
1791 return LLVM::AtomicBinOp::fmin;
1792 case arith::AtomicRMWKind::mins:
1793 return LLVM::AtomicBinOp::min;
1794 case arith::AtomicRMWKind::minu:
1795 return LLVM::AtomicBinOp::umin;
1796 case arith::AtomicRMWKind::ori:
1797 return LLVM::AtomicBinOp::_or;
1798 case arith::AtomicRMWKind::andi:
1799 return LLVM::AtomicBinOp::_and;
1800 default:
1801 return std::nullopt;
1802 }
1803 llvm_unreachable("Invalid AtomicRMWKind");
1804}
1805
1806struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1807 using Base::Base;
1808
1809 LogicalResult
1810 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1811 ConversionPatternRewriter &rewriter) const override {
1812 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1813 if (!maybeKind)
1814 return failure();
1815 auto memRefType = atomicOp.getMemRefType();
1816 SmallVector<int64_t> strides;
1817 int64_t offset;
1818 if (failed(memRefType.getStridesAndOffset(strides, offset)))
1819 return failure();
1820 auto dataPtr =
1821 getStridedElementPtr(rewriter, atomicOp.getLoc(), memRefType,
1822 adaptor.getMemref(), adaptor.getIndices());
1823 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1824 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1825 LLVM::AtomicOrdering::acq_rel);
1826 return success();
1827 }
1828};
1829
1830/// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1831class ConvertExtractAlignedPointerAsIndex
1832 : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1833public:
1834 using ConvertOpToLLVMPattern<
1835 memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1836
1837 LogicalResult
1838 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1839 OpAdaptor adaptor,
1840 ConversionPatternRewriter &rewriter) const override {
1841 BaseMemRefType sourceTy = extractOp.getSource().getType();
1842
1843 Value alignedPtr;
1844 if (sourceTy.hasRank()) {
1845 MemRefDescriptor desc(adaptor.getSource());
1846 alignedPtr = desc.alignedPtr(builder&: rewriter, loc: extractOp->getLoc());
1847 } else {
1848 auto elementPtrTy = LLVM::LLVMPointerType::get(
1849 rewriter.getContext(), sourceTy.getMemorySpaceAsInt());
1850
1851 UnrankedMemRefDescriptor desc(adaptor.getSource());
1852 Value descPtr = desc.memRefDescPtr(builder&: rewriter, loc: extractOp->getLoc());
1853
1854 alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1855 builder&: rewriter, loc: extractOp->getLoc(), typeConverter: *getTypeConverter(), memRefDescPtr: descPtr,
1856 elemPtrType: elementPtrTy);
1857 }
1858
1859 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1860 extractOp, getTypeConverter()->getIndexType(), alignedPtr);
1861 return success();
1862 }
1863};
1864
1865/// Materialize the MemRef descriptor represented by the results of
1866/// ExtractStridedMetadataOp.
1867class ExtractStridedMetadataOpLowering
1868 : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1869public:
1870 using ConvertOpToLLVMPattern<
1871 memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1872
1873 LogicalResult
1874 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1875 OpAdaptor adaptor,
1876 ConversionPatternRewriter &rewriter) const override {
1877
1878 if (!LLVM::isCompatibleType(type: adaptor.getOperands().front().getType()))
1879 return failure();
1880
1881 // Create the descriptor.
1882 MemRefDescriptor sourceMemRef(adaptor.getSource());
1883 Location loc = extractStridedMetadataOp.getLoc();
1884 Value source = extractStridedMetadataOp.getSource();
1885
1886 auto sourceMemRefType = cast<MemRefType>(source.getType());
1887 int64_t rank = sourceMemRefType.getRank();
1888 SmallVector<Value> results;
1889 results.reserve(N: 2 + rank * 2);
1890
1891 // Base buffer.
1892 Value baseBuffer = sourceMemRef.allocatedPtr(builder&: rewriter, loc);
1893 Value alignedBuffer = sourceMemRef.alignedPtr(builder&: rewriter, loc);
1894 MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
1895 rewriter, loc, *getTypeConverter(),
1896 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1897 baseBuffer, alignedBuffer);
1898 results.push_back(Elt: (Value)dstMemRef);
1899
1900 // Offset.
1901 results.push_back(Elt: sourceMemRef.offset(builder&: rewriter, loc));
1902
1903 // Sizes.
1904 for (unsigned i = 0; i < rank; ++i)
1905 results.push_back(Elt: sourceMemRef.size(builder&: rewriter, loc, pos: i));
1906 // Strides.
1907 for (unsigned i = 0; i < rank; ++i)
1908 results.push_back(Elt: sourceMemRef.stride(builder&: rewriter, loc, pos: i));
1909
1910 rewriter.replaceOp(extractStridedMetadataOp, results);
1911 return success();
1912 }
1913};
1914
1915} // namespace
1916
1917void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
1918 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1919 // clang-format off
1920 patterns.add<
1921 AllocaOpLowering,
1922 AllocaScopeOpLowering,
1923 AtomicRMWOpLowering,
1924 AssumeAlignmentOpLowering,
1925 ConvertExtractAlignedPointerAsIndex,
1926 DimOpLowering,
1927 ExtractStridedMetadataOpLowering,
1928 GenericAtomicRMWOpLowering,
1929 GlobalMemrefOpLowering,
1930 GetGlobalMemrefOpLowering,
1931 LoadOpLowering,
1932 MemRefCastOpLowering,
1933 MemRefCopyOpLowering,
1934 MemorySpaceCastOpLowering,
1935 MemRefReinterpretCastOpLowering,
1936 MemRefReshapeOpLowering,
1937 PrefetchOpLowering,
1938 RankOpLowering,
1939 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1940 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1941 StoreOpLowering,
1942 SubViewOpLowering,
1943 TransposeOpLowering,
1944 ViewOpLowering>(converter);
1945 // clang-format on
1946 auto allocLowering = converter.getOptions().allocLowering;
1947 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1948 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(arg: converter);
1949 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1950 patterns.add<AllocOpLowering, DeallocOpLowering>(arg: converter);
1951}
1952
1953namespace {
1954struct FinalizeMemRefToLLVMConversionPass
1955 : public impl::FinalizeMemRefToLLVMConversionPassBase<
1956 FinalizeMemRefToLLVMConversionPass> {
1957 using FinalizeMemRefToLLVMConversionPassBase::
1958 FinalizeMemRefToLLVMConversionPassBase;
1959
1960 void runOnOperation() override {
1961 Operation *op = getOperation();
1962 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1963 LowerToLLVMOptions options(&getContext(),
1964 dataLayoutAnalysis.getAtOrAbove(op));
1965 options.allocLowering =
1966 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1967 : LowerToLLVMOptions::AllocLowering::Malloc);
1968
1969 options.useGenericFunctions = useGenericFunctions;
1970
1971 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1972 options.overrideIndexBitwidth(indexBitwidth);
1973
1974 LLVMTypeConverter typeConverter(&getContext(), options,
1975 &dataLayoutAnalysis);
1976 RewritePatternSet patterns(&getContext());
1977 populateFinalizeMemRefToLLVMConversionPatterns(converter: typeConverter, patterns);
1978 LLVMConversionTarget target(getContext());
1979 target.addLegalOp<func::FuncOp>();
1980 if (failed(applyPartialConversion(op, target, std::move(patterns))))
1981 signalPassFailure();
1982 }
1983};
1984
1985/// Implement the interface to convert MemRef to LLVM.
1986struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1987 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
1988 void loadDependentDialects(MLIRContext *context) const final {
1989 context->loadDialect<LLVM::LLVMDialect>();
1990 }
1991
1992 /// Hook for derived dialect interface to provide conversion patterns
1993 /// and mark dialect legal for the conversion target.
1994 void populateConvertToLLVMConversionPatterns(
1995 ConversionTarget &target, LLVMTypeConverter &typeConverter,
1996 RewritePatternSet &patterns) const final {
1997 populateFinalizeMemRefToLLVMConversionPatterns(converter: typeConverter, patterns);
1998 }
1999};
2000
2001} // namespace
2002
2003void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry &registry) {
2004 registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
2005 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
2006 });
2007}
2008

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp