1//===- MemRefToLLVM.cpp - MemRef to LLVM dialect conversion ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
10
11#include "mlir/Analysis/DataLayoutAnalysis.h"
12#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
13#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14#include "mlir/Conversion/LLVMCommon/Pattern.h"
15#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16#include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
20#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
24#include "mlir/IR/AffineMap.h"
25#include "mlir/IR/BuiltinTypes.h"
26#include "mlir/IR/IRMapping.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/MathExtras.h"
29#include "llvm/ADT/SmallBitVector.h"
30#include <optional>
31
32namespace mlir {
33#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
34#include "mlir/Conversion/Passes.h.inc"
35} // namespace mlir
36
37using namespace mlir;
38
39namespace {
40
41bool isStaticStrideOrOffset(int64_t strideOrOffset) {
42 return !ShapedType::isDynamic(strideOrOffset);
43}
44
45LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
46 ModuleOp module) {
47 bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
48
49 if (useGenericFn)
50 return LLVM::lookupOrCreateGenericFreeFn(moduleOp: module);
51
52 return LLVM::lookupOrCreateFreeFn(moduleOp: module);
53}
54
55struct AllocOpLowering : public AllocLikeOpLLVMLowering {
56 AllocOpLowering(const LLVMTypeConverter &converter)
57 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
58 converter) {}
59 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
60 Location loc, Value sizeBytes,
61 Operation *op) const override {
62 return allocateBufferManuallyAlign(
63 rewriter, loc, sizeBytes, op,
64 getAlignment(rewriter, loc, cast<memref::AllocOp>(op)));
65 }
66};
67
68struct AlignedAllocOpLowering : public AllocLikeOpLLVMLowering {
69 AlignedAllocOpLowering(const LLVMTypeConverter &converter)
70 : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(),
71 converter) {}
72 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
73 Location loc, Value sizeBytes,
74 Operation *op) const override {
75 Value ptr = allocateBufferAutoAlign(
76 rewriter, loc, sizeBytes, op, &defaultLayout,
77 alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
78 &defaultLayout));
79 if (!ptr)
80 return std::make_tuple(Value(), Value());
81 return std::make_tuple(ptr, ptr);
82 }
83
84private:
85 /// Default layout to use in absence of the corresponding analysis.
86 DataLayout defaultLayout;
87};
88
89struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
90 AllocaOpLowering(const LLVMTypeConverter &converter)
91 : AllocLikeOpLLVMLowering(memref::AllocaOp::getOperationName(),
92 converter) {
93 setRequiresNumElements();
94 }
95
96 /// Allocates the underlying buffer using the right call. `allocatedBytePtr`
97 /// is set to null for stack allocations. `accessAlignment` is set if
98 /// alignment is needed post allocation (for eg. in conjunction with malloc).
99 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
100 Location loc, Value size,
101 Operation *op) const override {
102
103 // With alloca, one gets a pointer to the element type right away.
104 // For stack allocations.
105 auto allocaOp = cast<memref::AllocaOp>(op);
106 auto elementType =
107 typeConverter->convertType(allocaOp.getType().getElementType());
108 unsigned addrSpace =
109 *getTypeConverter()->getMemRefAddressSpace(type: allocaOp.getType());
110 auto elementPtrType =
111 LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);
112
113 auto allocatedElementPtr =
114 rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
115 allocaOp.getAlignment().value_or(0));
116
117 return std::make_tuple(allocatedElementPtr, allocatedElementPtr);
118 }
119};
120
121struct AllocaScopeOpLowering
122 : public ConvertOpToLLVMPattern<memref::AllocaScopeOp> {
123 using ConvertOpToLLVMPattern<memref::AllocaScopeOp>::ConvertOpToLLVMPattern;
124
125 LogicalResult
126 matchAndRewrite(memref::AllocaScopeOp allocaScopeOp, OpAdaptor adaptor,
127 ConversionPatternRewriter &rewriter) const override {
128 OpBuilder::InsertionGuard guard(rewriter);
129 Location loc = allocaScopeOp.getLoc();
130
131 // Split the current block before the AllocaScopeOp to create the inlining
132 // point.
133 auto *currentBlock = rewriter.getInsertionBlock();
134 auto *remainingOpsBlock =
135 rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint());
136 Block *continueBlock;
137 if (allocaScopeOp.getNumResults() == 0) {
138 continueBlock = remainingOpsBlock;
139 } else {
140 continueBlock = rewriter.createBlock(
141 remainingOpsBlock, allocaScopeOp.getResultTypes(),
142 SmallVector<Location>(allocaScopeOp->getNumResults(),
143 allocaScopeOp.getLoc()));
144 rewriter.create<LLVM::BrOp>(loc, ValueRange(), remainingOpsBlock);
145 }
146
147 // Inline body region.
148 Block *beforeBody = &allocaScopeOp.getBodyRegion().front();
149 Block *afterBody = &allocaScopeOp.getBodyRegion().back();
150 rewriter.inlineRegionBefore(allocaScopeOp.getBodyRegion(), continueBlock);
151
152 // Save stack and then branch into the body of the region.
153 rewriter.setInsertionPointToEnd(currentBlock);
154 auto stackSaveOp =
155 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
156 rewriter.create<LLVM::BrOp>(loc, ValueRange(), beforeBody);
157
158 // Replace the alloca_scope return with a branch that jumps out of the body.
159 // Stack restore before leaving the body region.
160 rewriter.setInsertionPointToEnd(afterBody);
161 auto returnOp =
162 cast<memref::AllocaScopeReturnOp>(afterBody->getTerminator());
163 auto branchOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
164 returnOp, returnOp.getResults(), continueBlock);
165
166 // Insert stack restore before jumping out the body of the region.
167 rewriter.setInsertionPoint(branchOp);
168 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
169
170 // Replace the op with values return from the body region.
171 rewriter.replaceOp(allocaScopeOp, continueBlock->getArguments());
172
173 return success();
174 }
175};
176
177struct AssumeAlignmentOpLowering
178 : public ConvertOpToLLVMPattern<memref::AssumeAlignmentOp> {
179 using ConvertOpToLLVMPattern<
180 memref::AssumeAlignmentOp>::ConvertOpToLLVMPattern;
181 explicit AssumeAlignmentOpLowering(const LLVMTypeConverter &converter)
182 : ConvertOpToLLVMPattern<memref::AssumeAlignmentOp>(converter) {}
183
184 LogicalResult
185 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter) const override {
187 Value memref = adaptor.getMemref();
188 unsigned alignment = op.getAlignment();
189 auto loc = op.getLoc();
190
191 auto srcMemRefType = cast<MemRefType>(op.getMemref().getType());
192 Value ptr = getStridedElementPtr(loc, srcMemRefType, memref, /*indices=*/{},
193 rewriter);
194
195 // Emit llvm.assume(memref & (alignment - 1) == 0).
196 //
197 // This relies on LLVM's CSE optimization (potentially after SROA), since
198 // after CSE all memref instances should get de-duplicated into the same
199 // pointer SSA value.
200 MemRefDescriptor memRefDescriptor(memref);
201 auto intPtrType =
202 getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace());
203 Value zero = createIndexAttrConstant(rewriter, loc, intPtrType, 0);
204 Value mask =
205 createIndexAttrConstant(rewriter, loc, intPtrType, alignment - 1);
206 Value ptrValue = rewriter.create<LLVM::PtrToIntOp>(loc, intPtrType, ptr);
207 rewriter.create<LLVM::AssumeOp>(
208 loc, rewriter.create<LLVM::ICmpOp>(
209 loc, LLVM::ICmpPredicate::eq,
210 rewriter.create<LLVM::AndOp>(loc, ptrValue, mask), zero));
211
212 rewriter.eraseOp(op: op);
213 return success();
214 }
215};
216
217// A `dealloc` is converted into a call to `free` on the underlying data buffer.
218// The memref descriptor being an SSA value, there is no need to clean it up
219// in any way.
220struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
221 using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
222
223 explicit DeallocOpLowering(const LLVMTypeConverter &converter)
224 : ConvertOpToLLVMPattern<memref::DeallocOp>(converter) {}
225
226 LogicalResult
227 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
228 ConversionPatternRewriter &rewriter) const override {
229 // Insert the `free` declaration if it is not already present.
230 LLVM::LLVMFuncOp freeFunc =
231 getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
232 Value allocatedPtr;
233 if (auto unrankedTy =
234 llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
235 auto elementPtrTy = LLVM::LLVMPointerType::get(
236 rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
237 allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
238 builder&: rewriter, loc: op.getLoc(),
239 memRefDescPtr: UnrankedMemRefDescriptor(adaptor.getMemref())
240 .memRefDescPtr(builder&: rewriter, loc: op.getLoc()),
241 elemPtrType: elementPtrTy);
242 } else {
243 allocatedPtr = MemRefDescriptor(adaptor.getMemref())
244 .allocatedPtr(builder&: rewriter, loc: op.getLoc());
245 }
246 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
247 return success();
248 }
249};
250
251// A `dim` is converted to a constant for static sizes and to an access to the
252// size stored in the memref descriptor for dynamic sizes.
253struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
254 using ConvertOpToLLVMPattern<memref::DimOp>::ConvertOpToLLVMPattern;
255
256 LogicalResult
257 matchAndRewrite(memref::DimOp dimOp, OpAdaptor adaptor,
258 ConversionPatternRewriter &rewriter) const override {
259 Type operandType = dimOp.getSource().getType();
260 if (isa<UnrankedMemRefType>(operandType)) {
261 FailureOr<Value> extractedSize = extractSizeOfUnrankedMemRef(
262 operandType, dimOp, adaptor.getOperands(), rewriter);
263 if (failed(extractedSize))
264 return failure();
265 rewriter.replaceOp(dimOp, {*extractedSize});
266 return success();
267 }
268 if (isa<MemRefType>(operandType)) {
269 rewriter.replaceOp(
270 dimOp, {extractSizeOfRankedMemRef(operandType, dimOp,
271 adaptor.getOperands(), rewriter)});
272 return success();
273 }
274 llvm_unreachable("expected MemRefType or UnrankedMemRefType");
275 }
276
277private:
278 FailureOr<Value>
279 extractSizeOfUnrankedMemRef(Type operandType, memref::DimOp dimOp,
280 OpAdaptor adaptor,
281 ConversionPatternRewriter &rewriter) const {
282 Location loc = dimOp.getLoc();
283
284 auto unrankedMemRefType = cast<UnrankedMemRefType>(operandType);
285 auto scalarMemRefType =
286 MemRefType::get({}, unrankedMemRefType.getElementType());
287 FailureOr<unsigned> maybeAddressSpace =
288 getTypeConverter()->getMemRefAddressSpace(unrankedMemRefType);
289 if (failed(maybeAddressSpace)) {
290 dimOp.emitOpError("memref memory space must be convertible to an integer "
291 "address space");
292 return failure();
293 }
294 unsigned addressSpace = *maybeAddressSpace;
295
296 // Extract pointer to the underlying ranked descriptor and bitcast it to a
297 // memref<element_type> descriptor pointer to minimize the number of GEP
298 // operations.
299 UnrankedMemRefDescriptor unrankedDesc(adaptor.getSource());
300 Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(builder&: rewriter, loc);
301
302 Type elementType = typeConverter->convertType(scalarMemRefType);
303
304 // Get pointer to offset field of memref<element_type> descriptor.
305 auto indexPtrTy =
306 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
307 Value offsetPtr = rewriter.create<LLVM::GEPOp>(
308 loc, indexPtrTy, elementType, underlyingRankedDesc,
309 ArrayRef<LLVM::GEPArg>{0, 2});
310
311 // The size value that we have to extract can be obtained using GEPop with
312 // `dimOp.index() + 1` index argument.
313 Value idxPlusOne = rewriter.create<LLVM::AddOp>(
314 loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1),
315 adaptor.getIndex());
316 Value sizePtr = rewriter.create<LLVM::GEPOp>(
317 loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr,
318 idxPlusOne);
319 return rewriter
320 .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr)
321 .getResult();
322 }
323
324 std::optional<int64_t> getConstantDimIndex(memref::DimOp dimOp) const {
325 if (auto idx = dimOp.getConstantIndex())
326 return idx;
327
328 if (auto constantOp = dimOp.getIndex().getDefiningOp<LLVM::ConstantOp>())
329 return cast<IntegerAttr>(constantOp.getValue()).getValue().getSExtValue();
330
331 return std::nullopt;
332 }
333
334 Value extractSizeOfRankedMemRef(Type operandType, memref::DimOp dimOp,
335 OpAdaptor adaptor,
336 ConversionPatternRewriter &rewriter) const {
337 Location loc = dimOp.getLoc();
338
339 // Take advantage if index is constant.
340 MemRefType memRefType = cast<MemRefType>(operandType);
341 Type indexType = getIndexType();
342 if (std::optional<int64_t> index = getConstantDimIndex(dimOp)) {
343 int64_t i = *index;
344 if (i >= 0 && i < memRefType.getRank()) {
345 if (memRefType.isDynamicDim(i)) {
346 // extract dynamic size from the memref descriptor.
347 MemRefDescriptor descriptor(adaptor.getSource());
348 return descriptor.size(builder&: rewriter, loc, pos: i);
349 }
350 // Use constant for static size.
351 int64_t dimSize = memRefType.getDimSize(i);
352 return createIndexAttrConstant(rewriter, loc, indexType, dimSize);
353 }
354 }
355 Value index = adaptor.getIndex();
356 int64_t rank = memRefType.getRank();
357 MemRefDescriptor memrefDescriptor(adaptor.getSource());
358 return memrefDescriptor.size(builder&: rewriter, loc, pos: index, rank);
359 }
360};
361
362/// Common base for load and store operations on MemRefs. Restricts the match
363/// to supported MemRef types. Provides functionality to emit code accessing a
364/// specific element of the underlying data buffer.
365template <typename Derived>
366struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
367 using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
368 using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
369 using Base = LoadStoreOpLowering<Derived>;
370
371 LogicalResult match(Derived op) const override {
372 MemRefType type = op.getMemRefType();
373 return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
374 }
375};
376
377/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
378/// retried until it succeeds in atomically storing a new value into memory.
379///
380/// +---------------------------------+
381/// | <code before the AtomicRMWOp> |
382/// | <compute initial %loaded> |
383/// | cf.br loop(%loaded) |
384/// +---------------------------------+
385/// |
386/// -------| |
387/// | v v
388/// | +--------------------------------+
389/// | | loop(%loaded): |
390/// | | <body contents> |
391/// | | %pair = cmpxchg |
392/// | | %ok = %pair[0] |
393/// | | %new = %pair[1] |
394/// | | cf.cond_br %ok, end, loop(%new) |
395/// | +--------------------------------+
396/// | | |
397/// |----------- |
398/// v
399/// +--------------------------------+
400/// | end: |
401/// | <code after the AtomicRMWOp> |
402/// +--------------------------------+
403///
404struct GenericAtomicRMWOpLowering
405 : public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
406 using Base::Base;
407
408 LogicalResult
409 matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
410 ConversionPatternRewriter &rewriter) const override {
411 auto loc = atomicOp.getLoc();
412 Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
413
414 // Split the block into initial, loop, and ending parts.
415 auto *initBlock = rewriter.getInsertionBlock();
416 auto *loopBlock = rewriter.splitBlock(block: initBlock, before: Block::iterator(atomicOp));
417 loopBlock->addArgument(valueType, loc);
418
419 auto *endBlock =
420 rewriter.splitBlock(block: loopBlock, before: Block::iterator(atomicOp)++);
421
422 // Compute the loaded value and branch to the loop block.
423 rewriter.setInsertionPointToEnd(initBlock);
424 auto memRefType = cast<MemRefType>(atomicOp.getMemref().getType());
425 auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
426 adaptor.getIndices(), rewriter);
427 Value init = rewriter.create<LLVM::LoadOp>(
428 loc, typeConverter->convertType(memRefType.getElementType()), dataPtr);
429 rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
430
431 // Prepare the body of the loop block.
432 rewriter.setInsertionPointToStart(loopBlock);
433
434 // Clone the GenericAtomicRMWOp region and extract the result.
435 auto loopArgument = loopBlock->getArgument(0);
436 IRMapping mapping;
437 mapping.map(atomicOp.getCurrentValue(), loopArgument);
438 Block &entryBlock = atomicOp.body().front();
439 for (auto &nestedOp : entryBlock.without_terminator()) {
440 Operation *clone = rewriter.clone(nestedOp, mapping);
441 mapping.map(nestedOp.getResults(), clone->getResults());
442 }
443 Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(idx: 0));
444
445 // Prepare the epilog of the loop block.
446 // Append the cmpxchg op to the end of the loop block.
447 auto successOrdering = LLVM::AtomicOrdering::acq_rel;
448 auto failureOrdering = LLVM::AtomicOrdering::monotonic;
449 auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
450 loc, dataPtr, loopArgument, result, successOrdering, failureOrdering);
451 // Extract the %new_loaded and %ok values from the pair.
452 Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 0);
453 Value ok = rewriter.create<LLVM::ExtractValueOp>(loc, cmpxchg, 1);
454
455 // Conditionally branch to the end or back to the loop depending on %ok.
456 rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
457 loopBlock, newLoaded);
458
459 rewriter.setInsertionPointToEnd(endBlock);
460
461 // The 'result' of the atomic_rmw op is the newly loaded value.
462 rewriter.replaceOp(atomicOp, {newLoaded});
463
464 return success();
465 }
466};
467
468/// Returns the LLVM type of the global variable given the memref type `type`.
469static Type
470convertGlobalMemrefTypeToLLVM(MemRefType type,
471 const LLVMTypeConverter &typeConverter) {
472 // LLVM type for a global memref will be a multi-dimension array. For
473 // declarations or uninitialized global memrefs, we can potentially flatten
474 // this to a 1D array. However, for memref.global's with an initial value,
475 // we do not intend to flatten the ElementsAttribute when going from std ->
476 // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
477 Type elementType = typeConverter.convertType(type.getElementType());
478 Type arrayTy = elementType;
479 // Shape has the outermost dim at index 0, so need to walk it backwards
480 for (int64_t dim : llvm::reverse(type.getShape()))
481 arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
482 return arrayTy;
483}
484
485/// GlobalMemrefOp is lowered to a LLVM Global Variable.
486struct GlobalMemrefOpLowering
487 : public ConvertOpToLLVMPattern<memref::GlobalOp> {
488 using ConvertOpToLLVMPattern<memref::GlobalOp>::ConvertOpToLLVMPattern;
489
490 LogicalResult
491 matchAndRewrite(memref::GlobalOp global, OpAdaptor adaptor,
492 ConversionPatternRewriter &rewriter) const override {
493 MemRefType type = global.getType();
494 if (!isConvertibleAndHasIdentityMaps(type))
495 return failure();
496
497 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
498
499 LLVM::Linkage linkage =
500 global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
501
502 Attribute initialValue = nullptr;
503 if (!global.isExternal() && !global.isUninitialized()) {
504 auto elementsAttr = llvm::cast<ElementsAttr>(*global.getInitialValue());
505 initialValue = elementsAttr;
506
507 // For scalar memrefs, the global variable created is of the element type,
508 // so unpack the elements attribute to extract the value.
509 if (type.getRank() == 0)
510 initialValue = elementsAttr.getSplatValue<Attribute>();
511 }
512
513 uint64_t alignment = global.getAlignment().value_or(0);
514 FailureOr<unsigned> addressSpace =
515 getTypeConverter()->getMemRefAddressSpace(type);
516 if (failed(addressSpace))
517 return global.emitOpError(
518 "memory space cannot be converted to an integer address space");
519 auto newGlobal = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
520 global, arrayTy, global.getConstant(), linkage, global.getSymName(),
521 initialValue, alignment, *addressSpace);
522 if (!global.isExternal() && global.isUninitialized()) {
523 rewriter.createBlock(&newGlobal.getInitializerRegion());
524 Value undef[] = {
525 rewriter.create<LLVM::UndefOp>(global.getLoc(), arrayTy)};
526 rewriter.create<LLVM::ReturnOp>(global.getLoc(), undef);
527 }
528 return success();
529 }
530};
531
532/// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to
533/// the first element stashed into the descriptor. This reuses
534/// `AllocLikeOpLowering` to reuse the Memref descriptor construction.
535struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
536 GetGlobalMemrefOpLowering(const LLVMTypeConverter &converter)
537 : AllocLikeOpLLVMLowering(memref::GetGlobalOp::getOperationName(),
538 converter) {}
539
540 /// Buffer "allocation" for memref.get_global op is getting the address of
541 /// the global variable referenced.
542 std::tuple<Value, Value> allocateBuffer(ConversionPatternRewriter &rewriter,
543 Location loc, Value sizeBytes,
544 Operation *op) const override {
545 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
546 MemRefType type = cast<MemRefType>(getGlobalOp.getResult().getType());
547
548 // This is called after a type conversion, which would have failed if this
549 // call fails.
550 FailureOr<unsigned> maybeAddressSpace =
551 getTypeConverter()->getMemRefAddressSpace(type);
552 if (failed(maybeAddressSpace))
553 return std::make_tuple(Value(), Value());
554 unsigned memSpace = *maybeAddressSpace;
555
556 Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
557 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
558 auto addressOf =
559 rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());
560
561 // Get the address of the first element in the array by creating a GEP with
562 // the address of the GV as the base, and (rank + 1) number of 0 indices.
563 auto gep = rewriter.create<LLVM::GEPOp>(
564 loc, ptrTy, arrayTy, addressOf,
565 SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));
566
567 // We do not expect the memref obtained using `memref.get_global` to be
568 // ever deallocated. Set the allocated pointer to be known bad value to
569 // help debug if that ever happens.
570 auto intPtrType = getIntPtrType(addressSpace: memSpace);
571 Value deadBeefConst =
572 createIndexAttrConstant(builder&: rewriter, loc: op->getLoc(), resultType: intPtrType, value: 0xdeadbeef);
573 auto deadBeefPtr =
574 rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);
575
576 // Both allocated and aligned pointers are same. We could potentially stash
577 // a nullptr for the allocated pointer since we do not expect any dealloc.
578 return std::make_tuple(deadBeefPtr, gep);
579 }
580};
581
582// Load operation is lowered to obtaining a pointer to the indexed element
583// and loading it.
584struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
585 using Base::Base;
586
587 LogicalResult
588 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
589 ConversionPatternRewriter &rewriter) const override {
590 auto type = loadOp.getMemRefType();
591
592 Value dataPtr =
593 getStridedElementPtr(loadOp.getLoc(), type, adaptor.getMemref(),
594 adaptor.getIndices(), rewriter);
595 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
596 loadOp, typeConverter->convertType(type.getElementType()), dataPtr, 0,
597 false, loadOp.getNontemporal());
598 return success();
599 }
600};
601
602// Store operation is lowered to obtaining a pointer to the indexed element,
603// and storing the given value to it.
604struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
605 using Base::Base;
606
607 LogicalResult
608 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
609 ConversionPatternRewriter &rewriter) const override {
610 auto type = op.getMemRefType();
611
612 Value dataPtr = getStridedElementPtr(op.getLoc(), type, adaptor.getMemref(),
613 adaptor.getIndices(), rewriter);
614 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
615 0, false, op.getNontemporal());
616 return success();
617 }
618};
619
620// The prefetch operation is lowered in a way similar to the load operation
621// except that the llvm.prefetch operation is used for replacement.
622struct PrefetchOpLowering : public LoadStoreOpLowering<memref::PrefetchOp> {
623 using Base::Base;
624
625 LogicalResult
626 matchAndRewrite(memref::PrefetchOp prefetchOp, OpAdaptor adaptor,
627 ConversionPatternRewriter &rewriter) const override {
628 auto type = prefetchOp.getMemRefType();
629 auto loc = prefetchOp.getLoc();
630
631 Value dataPtr = getStridedElementPtr(loc, type, adaptor.getMemref(),
632 adaptor.getIndices(), rewriter);
633
634 // Replace with llvm.prefetch.
635 IntegerAttr isWrite = rewriter.getI32IntegerAttr(prefetchOp.getIsWrite());
636 IntegerAttr localityHint = prefetchOp.getLocalityHintAttr();
637 IntegerAttr isData =
638 rewriter.getI32IntegerAttr(prefetchOp.getIsDataCache());
639 rewriter.replaceOpWithNewOp<LLVM::Prefetch>(prefetchOp, dataPtr, isWrite,
640 localityHint, isData);
641 return success();
642 }
643};
644
645struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
646 using ConvertOpToLLVMPattern<memref::RankOp>::ConvertOpToLLVMPattern;
647
648 LogicalResult
649 matchAndRewrite(memref::RankOp op, OpAdaptor adaptor,
650 ConversionPatternRewriter &rewriter) const override {
651 Location loc = op.getLoc();
652 Type operandType = op.getMemref().getType();
653 if (dyn_cast<UnrankedMemRefType>(operandType)) {
654 UnrankedMemRefDescriptor desc(adaptor.getMemref());
655 rewriter.replaceOp(op, {desc.rank(builder&: rewriter, loc)});
656 return success();
657 }
658 if (auto rankedMemRefType = dyn_cast<MemRefType>(operandType)) {
659 Type indexType = getIndexType();
660 rewriter.replaceOp(op,
661 {createIndexAttrConstant(rewriter, loc, indexType,
662 rankedMemRefType.getRank())});
663 return success();
664 }
665 return failure();
666 }
667};
668
669struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
670 using ConvertOpToLLVMPattern<memref::CastOp>::ConvertOpToLLVMPattern;
671
672 LogicalResult match(memref::CastOp memRefCastOp) const override {
673 Type srcType = memRefCastOp.getOperand().getType();
674 Type dstType = memRefCastOp.getType();
675
676 // memref::CastOp reduce to bitcast in the ranked MemRef case and can be
677 // used for type erasure. For now they must preserve underlying element type
678 // and require source and result type to have the same rank. Therefore,
679 // perform a sanity check that the underlying structs are the same. Once op
680 // semantics are relaxed we can revisit.
681 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
682 return success(typeConverter->convertType(srcType) ==
683 typeConverter->convertType(dstType));
684
685 // At least one of the operands is unranked type
686 assert(isa<UnrankedMemRefType>(srcType) ||
687 isa<UnrankedMemRefType>(dstType));
688
689 // Unranked to unranked cast is disallowed
690 return !(isa<UnrankedMemRefType>(srcType) &&
691 isa<UnrankedMemRefType>(dstType))
692 ? success()
693 : failure();
694 }
695
696 void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
697 ConversionPatternRewriter &rewriter) const override {
698 auto srcType = memRefCastOp.getOperand().getType();
699 auto dstType = memRefCastOp.getType();
700 auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
701 auto loc = memRefCastOp.getLoc();
702
703 // For ranked/ranked case, just keep the original descriptor.
704 if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
705 return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
706
707 if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
708 // Casting ranked to unranked memref type
709 // Set the rank in the destination from the memref type
710 // Allocate space on the stack and copy the src memref descriptor
711 // Set the ptr in the destination to the stack space
712 auto srcMemRefType = cast<MemRefType>(srcType);
713 int64_t rank = srcMemRefType.getRank();
714 // ptr = AllocaOp sizeof(MemRefDescriptor)
715 auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
716 loc, adaptor.getSource(), rewriter);
717
718 // rank = ConstantOp srcRank
719 auto rankVal = rewriter.create<LLVM::ConstantOp>(
720 loc, getIndexType(), rewriter.getIndexAttr(rank));
721 // undef = UndefOp
722 UnrankedMemRefDescriptor memRefDesc =
723 UnrankedMemRefDescriptor::undef(builder&: rewriter, loc: loc, descriptorType: targetStructType);
724 // d1 = InsertValueOp undef, rank, 0
725 memRefDesc.setRank(builder&: rewriter, loc: loc, value: rankVal);
726 // d2 = InsertValueOp d1, ptr, 1
727 memRefDesc.setMemRefDescPtr(builder&: rewriter, loc: loc, value: ptr);
728 rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);
729
730 } else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
731 // Casting from unranked type to ranked.
732 // The operation is assumed to be doing a correct cast. If the destination
733 // type mismatches the unranked the type, it is undefined behavior.
734 UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
735 // ptr = ExtractValueOp src, 1
736 auto ptr = memRefDesc.memRefDescPtr(builder&: rewriter, loc: loc);
737
738 // struct = LoadOp ptr
739 auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
740 rewriter.replaceOp(memRefCastOp, loadOp.getResult());
741 } else {
742 llvm_unreachable("Unsupported unranked memref to unranked memref cast");
743 }
744 }
745};
746
747/// Pattern to lower a `memref.copy` to llvm.
748///
749/// For memrefs with identity layouts, the copy is lowered to the llvm
750/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
751/// to the generic `MemrefCopyFn`.
752struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
753 using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
754
755 LogicalResult
756 lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
757 ConversionPatternRewriter &rewriter) const {
758 auto loc = op.getLoc();
759 auto srcType = dyn_cast<MemRefType>(op.getSource().getType());
760
761 MemRefDescriptor srcDesc(adaptor.getSource());
762
763 // Compute number of elements.
764 Value numElements = rewriter.create<LLVM::ConstantOp>(
765 loc, getIndexType(), rewriter.getIndexAttr(1));
766 for (int pos = 0; pos < srcType.getRank(); ++pos) {
767 auto size = srcDesc.size(rewriter, loc, pos);
768 numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
769 }
770
771 // Get element size.
772 auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
773 // Compute total.
774 Value totalSize =
775 rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
776
777 Type elementType = typeConverter->convertType(srcType.getElementType());
778
779 Value srcBasePtr = srcDesc.alignedPtr(builder&: rewriter, loc: loc);
780 Value srcOffset = srcDesc.offset(builder&: rewriter, loc: loc);
781 Value srcPtr = rewriter.create<LLVM::GEPOp>(
782 loc, srcBasePtr.getType(), elementType, srcBasePtr, srcOffset);
783 MemRefDescriptor targetDesc(adaptor.getTarget());
784 Value targetBasePtr = targetDesc.alignedPtr(builder&: rewriter, loc: loc);
785 Value targetOffset = targetDesc.offset(builder&: rewriter, loc: loc);
786 Value targetPtr = rewriter.create<LLVM::GEPOp>(
787 loc, targetBasePtr.getType(), elementType, targetBasePtr, targetOffset);
788 rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
789 /*isVolatile=*/false);
790 rewriter.eraseOp(op: op);
791
792 return success();
793 }
794
795 LogicalResult
796 lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
797 ConversionPatternRewriter &rewriter) const {
798 auto loc = op.getLoc();
799 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
800 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
801
802 // First make sure we have an unranked memref descriptor representation.
803 auto makeUnranked = [&, this](Value ranked, MemRefType type) {
804 auto rank = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
805 type.getRank());
806 auto *typeConverter = getTypeConverter();
807 auto ptr =
808 typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);
809
810 auto unrankedType =
811 UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
812 return UnrankedMemRefDescriptor::pack(
813 rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
814 };
815
816 // Save stack position before promoting descriptors
817 auto stackSaveOp =
818 rewriter.create<LLVM::StackSaveOp>(loc, getVoidPtrType());
819
820 auto srcMemRefType = dyn_cast<MemRefType>(srcType);
821 Value unrankedSource =
822 srcMemRefType ? makeUnranked(adaptor.getSource(), srcMemRefType)
823 : adaptor.getSource();
824 auto targetMemRefType = dyn_cast<MemRefType>(targetType);
825 Value unrankedTarget =
826 targetMemRefType ? makeUnranked(adaptor.getTarget(), targetMemRefType)
827 : adaptor.getTarget();
828
829 // Now promote the unranked descriptors to the stack.
830 auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
831 rewriter.getIndexAttr(1));
832 auto promote = [&](Value desc) {
833 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
834 auto allocated =
835 rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
836 rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
837 return allocated;
838 };
839
840 auto sourcePtr = promote(unrankedSource);
841 auto targetPtr = promote(unrankedTarget);
842
843 // Derive size from llvm.getelementptr which will account for any
844 // potential alignment
845 auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
846 auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
847 moduleOp: op->getParentOfType<ModuleOp>(), indexType: getIndexType(), unrankedDescriptorType: sourcePtr.getType());
848 rewriter.create<LLVM::CallOp>(loc, copyFn,
849 ValueRange{elemSize, sourcePtr, targetPtr});
850
851 // Restore stack used for descriptors
852 rewriter.create<LLVM::StackRestoreOp>(loc, stackSaveOp);
853
854 rewriter.eraseOp(op: op);
855
856 return success();
857 }
858
859 LogicalResult
860 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
861 ConversionPatternRewriter &rewriter) const override {
862 auto srcType = cast<BaseMemRefType>(op.getSource().getType());
863 auto targetType = cast<BaseMemRefType>(op.getTarget().getType());
864
865 auto isContiguousMemrefType = [&](BaseMemRefType type) {
866 auto memrefType = dyn_cast<mlir::MemRefType>(type);
867 // We can use memcpy for memrefs if they have an identity layout or are
868 // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
869 // special case handled by memrefCopy.
870 return memrefType &&
871 (memrefType.getLayout().isIdentity() ||
872 (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
873 memref::isStaticShapeAndContiguousRowMajor(memrefType)));
874 };
875
876 if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
877 return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
878
879 return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
880 }
881};
882
883struct MemorySpaceCastOpLowering
884 : public ConvertOpToLLVMPattern<memref::MemorySpaceCastOp> {
885 using ConvertOpToLLVMPattern<
886 memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern;
887
888 LogicalResult
889 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
890 ConversionPatternRewriter &rewriter) const override {
891 Location loc = op.getLoc();
892
893 Type resultType = op.getDest().getType();
894 if (auto resultTypeR = dyn_cast<MemRefType>(resultType)) {
895 auto resultDescType =
896 cast<LLVM::LLVMStructType>(typeConverter->convertType(resultTypeR));
897 Type newPtrType = resultDescType.getBody()[0];
898
899 SmallVector<Value> descVals;
900 MemRefDescriptor::unpack(builder&: rewriter, loc, packed: adaptor.getSource(), type: resultTypeR,
901 results&: descVals);
902 descVals[0] =
903 rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[0]);
904 descVals[1] =
905 rewriter.create<LLVM::AddrSpaceCastOp>(loc, newPtrType, descVals[1]);
906 Value result = MemRefDescriptor::pack(builder&: rewriter, loc, converter: *getTypeConverter(),
907 type: resultTypeR, values: descVals);
908 rewriter.replaceOp(op, result);
909 return success();
910 }
911 if (auto resultTypeU = dyn_cast<UnrankedMemRefType>(resultType)) {
912 // Since the type converter won't be doing this for us, get the address
913 // space.
914 auto sourceType = cast<UnrankedMemRefType>(op.getSource().getType());
915 FailureOr<unsigned> maybeSourceAddrSpace =
916 getTypeConverter()->getMemRefAddressSpace(sourceType);
917 if (failed(maybeSourceAddrSpace))
918 return rewriter.notifyMatchFailure(loc,
919 "non-integer source address space");
920 unsigned sourceAddrSpace = *maybeSourceAddrSpace;
921 FailureOr<unsigned> maybeResultAddrSpace =
922 getTypeConverter()->getMemRefAddressSpace(resultTypeU);
923 if (failed(maybeResultAddrSpace))
924 return rewriter.notifyMatchFailure(loc,
925 "non-integer result address space");
926 unsigned resultAddrSpace = *maybeResultAddrSpace;
927
928 UnrankedMemRefDescriptor sourceDesc(adaptor.getSource());
929 Value rank = sourceDesc.rank(builder&: rewriter, loc);
930 Value sourceUnderlyingDesc = sourceDesc.memRefDescPtr(builder&: rewriter, loc);
931
932 // Create and allocate storage for new memref descriptor.
933 auto result = UnrankedMemRefDescriptor::undef(
934 rewriter, loc, typeConverter->convertType(resultTypeU));
935 result.setRank(rewriter, loc, rank);
936 SmallVector<Value, 1> sizes;
937 UnrankedMemRefDescriptor::computeSizes(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
938 values: result, addressSpaces: resultAddrSpace, sizes&: sizes);
939 Value resultUnderlyingSize = sizes.front();
940 Value resultUnderlyingDesc = rewriter.create<LLVM::AllocaOp>(
941 loc, getVoidPtrType(), rewriter.getI8Type(), resultUnderlyingSize);
942 result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);
943
944 // Copy pointers, performing address space casts.
945 auto sourceElemPtrType =
946 LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
947 auto resultElemPtrType =
948 LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);
949
950 Value allocatedPtr = sourceDesc.allocatedPtr(
951 builder&: rewriter, loc, memRefDescPtr: sourceUnderlyingDesc, elemPtrType: sourceElemPtrType);
952 Value alignedPtr =
953 sourceDesc.alignedPtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
954 memRefDescPtr: sourceUnderlyingDesc, elemPtrType: sourceElemPtrType);
955 allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
956 loc, resultElemPtrType, allocatedPtr);
957 alignedPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
958 loc, resultElemPtrType, alignedPtr);
959
960 result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
961 resultElemPtrType, allocatedPtr);
962 result.setAlignedPtr(rewriter, loc, *getTypeConverter(),
963 resultUnderlyingDesc, resultElemPtrType, alignedPtr);
964
965 // Copy all the index-valued operands.
966 Value sourceIndexVals =
967 sourceDesc.offsetBasePtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
968 memRefDescPtr: sourceUnderlyingDesc, elemPtrType: sourceElemPtrType);
969 Value resultIndexVals =
970 result.offsetBasePtr(rewriter, loc, *getTypeConverter(),
971 resultUnderlyingDesc, resultElemPtrType);
972
973 int64_t bytesToSkip =
974 2 *
975 ceilDiv(getTypeConverter()->getPointerBitwidth(resultAddrSpace), 8);
976 Value bytesToSkipConst = rewriter.create<LLVM::ConstantOp>(
977 loc, getIndexType(), rewriter.getIndexAttr(bytesToSkip));
978 Value copySize = rewriter.create<LLVM::SubOp>(
979 loc, getIndexType(), resultUnderlyingSize, bytesToSkipConst);
980 rewriter.create<LLVM::MemcpyOp>(loc, resultIndexVals, sourceIndexVals,
981 copySize, /*isVolatile=*/false);
982
983 rewriter.replaceOp(op, ValueRange{result});
984 return success();
985 }
986 return rewriter.notifyMatchFailure(loc, "unexpected memref type");
987 }
988};
989
990/// Extracts allocated, aligned pointers and offset from a ranked or unranked
991/// memref type. In unranked case, the fields are extracted from the underlying
992/// ranked descriptor.
993static void extractPointersAndOffset(Location loc,
994 ConversionPatternRewriter &rewriter,
995 const LLVMTypeConverter &typeConverter,
996 Value originalOperand,
997 Value convertedOperand,
998 Value *allocatedPtr, Value *alignedPtr,
999 Value *offset = nullptr) {
1000 Type operandType = originalOperand.getType();
1001 if (isa<MemRefType>(operandType)) {
1002 MemRefDescriptor desc(convertedOperand);
1003 *allocatedPtr = desc.allocatedPtr(builder&: rewriter, loc);
1004 *alignedPtr = desc.alignedPtr(builder&: rewriter, loc);
1005 if (offset != nullptr)
1006 *offset = desc.offset(builder&: rewriter, loc);
1007 return;
1008 }
1009
1010 // These will all cause assert()s on unconvertible types.
1011 unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
1012 type: cast<UnrankedMemRefType>(operandType));
1013 auto elementPtrType =
1014 LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);
1015
1016 // Extract pointer to the underlying ranked memref descriptor and cast it to
1017 // ElemType**.
1018 UnrankedMemRefDescriptor unrankedDesc(convertedOperand);
1019 Value underlyingDescPtr = unrankedDesc.memRefDescPtr(builder&: rewriter, loc);
1020
1021 *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
1022 builder&: rewriter, loc, memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType);
1023 *alignedPtr = UnrankedMemRefDescriptor::alignedPtr(
1024 builder&: rewriter, loc, typeConverter, memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType);
1025 if (offset != nullptr) {
1026 *offset = UnrankedMemRefDescriptor::offset(
1027 builder&: rewriter, loc, typeConverter, memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType);
1028 }
1029}
1030
1031struct MemRefReinterpretCastOpLowering
1032 : public ConvertOpToLLVMPattern<memref::ReinterpretCastOp> {
1033 using ConvertOpToLLVMPattern<
1034 memref::ReinterpretCastOp>::ConvertOpToLLVMPattern;
1035
1036 LogicalResult
1037 matchAndRewrite(memref::ReinterpretCastOp castOp, OpAdaptor adaptor,
1038 ConversionPatternRewriter &rewriter) const override {
1039 Type srcType = castOp.getSource().getType();
1040
1041 Value descriptor;
1042 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp,
1043 adaptor, &descriptor)))
1044 return failure();
1045 rewriter.replaceOp(castOp, {descriptor});
1046 return success();
1047 }
1048
1049private:
1050 LogicalResult convertSourceMemRefToDescriptor(
1051 ConversionPatternRewriter &rewriter, Type srcType,
1052 memref::ReinterpretCastOp castOp,
1053 memref::ReinterpretCastOp::Adaptor adaptor, Value *descriptor) const {
1054 MemRefType targetMemRefType =
1055 cast<MemRefType>(castOp.getResult().getType());
1056 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1057 typeConverter->convertType(targetMemRefType));
1058 if (!llvmTargetDescriptorTy)
1059 return failure();
1060
1061 // Create descriptor.
1062 Location loc = castOp.getLoc();
1063 auto desc = MemRefDescriptor::undef(builder&: rewriter, loc, descriptorType: llvmTargetDescriptorTy);
1064
1065 // Set allocated and aligned pointers.
1066 Value allocatedPtr, alignedPtr;
1067 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1068 castOp.getSource(), adaptor.getSource(),
1069 &allocatedPtr, &alignedPtr);
1070 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1071 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1072
1073 // Set offset.
1074 if (castOp.isDynamicOffset(0))
1075 desc.setOffset(rewriter, loc, adaptor.getOffsets()[0]);
1076 else
1077 desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0));
1078
1079 // Set sizes and strides.
1080 unsigned dynSizeId = 0;
1081 unsigned dynStrideId = 0;
1082 for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
1083 if (castOp.isDynamicSize(i))
1084 desc.setSize(rewriter, loc, i, adaptor.getSizes()[dynSizeId++]);
1085 else
1086 desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i));
1087
1088 if (castOp.isDynamicStride(i))
1089 desc.setStride(rewriter, loc, i, adaptor.getStrides()[dynStrideId++]);
1090 else
1091 desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i));
1092 }
1093 *descriptor = desc;
1094 return success();
1095 }
1096};
1097
1098struct MemRefReshapeOpLowering
1099 : public ConvertOpToLLVMPattern<memref::ReshapeOp> {
1100 using ConvertOpToLLVMPattern<memref::ReshapeOp>::ConvertOpToLLVMPattern;
1101
1102 LogicalResult
1103 matchAndRewrite(memref::ReshapeOp reshapeOp, OpAdaptor adaptor,
1104 ConversionPatternRewriter &rewriter) const override {
1105 Type srcType = reshapeOp.getSource().getType();
1106
1107 Value descriptor;
1108 if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp,
1109 adaptor, &descriptor)))
1110 return failure();
1111 rewriter.replaceOp(reshapeOp, {descriptor});
1112 return success();
1113 }
1114
1115private:
1116 LogicalResult
1117 convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter,
1118 Type srcType, memref::ReshapeOp reshapeOp,
1119 memref::ReshapeOp::Adaptor adaptor,
1120 Value *descriptor) const {
1121 auto shapeMemRefType = cast<MemRefType>(reshapeOp.getShape().getType());
1122 if (shapeMemRefType.hasStaticShape()) {
1123 MemRefType targetMemRefType =
1124 cast<MemRefType>(reshapeOp.getResult().getType());
1125 auto llvmTargetDescriptorTy = dyn_cast_or_null<LLVM::LLVMStructType>(
1126 typeConverter->convertType(targetMemRefType));
1127 if (!llvmTargetDescriptorTy)
1128 return failure();
1129
1130 // Create descriptor.
1131 Location loc = reshapeOp.getLoc();
1132 auto desc =
1133 MemRefDescriptor::undef(builder&: rewriter, loc, descriptorType: llvmTargetDescriptorTy);
1134
1135 // Set allocated and aligned pointers.
1136 Value allocatedPtr, alignedPtr;
1137 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1138 reshapeOp.getSource(), adaptor.getSource(),
1139 &allocatedPtr, &alignedPtr);
1140 desc.setAllocatedPtr(rewriter, loc, allocatedPtr);
1141 desc.setAlignedPtr(rewriter, loc, alignedPtr);
1142
1143 // Extract the offset and strides from the type.
1144 int64_t offset;
1145 SmallVector<int64_t> strides;
1146 if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1147 return rewriter.notifyMatchFailure(
1148 reshapeOp, "failed to get stride and offset exprs");
1149
1150 if (!isStaticStrideOrOffset(strideOrOffset: offset))
1151 return rewriter.notifyMatchFailure(reshapeOp,
1152 "dynamic offset is unsupported");
1153
1154 desc.setConstantOffset(rewriter, loc, offset);
1155
1156 assert(targetMemRefType.getLayout().isIdentity() &&
1157 "Identity layout map is a precondition of a valid reshape op");
1158
1159 Type indexType = getIndexType();
1160 Value stride = nullptr;
1161 int64_t targetRank = targetMemRefType.getRank();
1162 for (auto i : llvm::reverse(llvm::seq<int64_t>(0, targetRank))) {
1163 if (!ShapedType::isDynamic(strides[i])) {
1164 // If the stride for this dimension is dynamic, then use the product
1165 // of the sizes of the inner dimensions.
1166 stride =
1167 createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
1168 } else if (!stride) {
1169 // `stride` is null only in the first iteration of the loop. However,
1170 // since the target memref has an identity layout, we can safely set
1171 // the innermost stride to 1.
1172 stride = createIndexAttrConstant(rewriter, loc, indexType, 1);
1173 }
1174
1175 Value dimSize;
1176 // If the size of this dimension is dynamic, then load it at runtime
1177 // from the shape operand.
1178 if (!targetMemRefType.isDynamicDim(i)) {
1179 dimSize = createIndexAttrConstant(rewriter, loc, indexType,
1180 targetMemRefType.getDimSize(i));
1181 } else {
1182 Value shapeOp = reshapeOp.getShape();
1183 Value index = createIndexAttrConstant(rewriter, loc, indexType, i);
1184 dimSize = rewriter.create<memref::LoadOp>(loc, shapeOp, index);
1185 Type indexType = getIndexType();
1186 if (dimSize.getType() != indexType)
1187 dimSize = typeConverter->materializeTargetConversion(
1188 rewriter, loc, indexType, dimSize);
1189 assert(dimSize && "Invalid memref element type");
1190 }
1191
1192 desc.setSize(rewriter, loc, i, dimSize);
1193 desc.setStride(rewriter, loc, i, stride);
1194
1195 // Prepare the stride value for the next dimension.
1196 stride = rewriter.create<LLVM::MulOp>(loc, stride, dimSize);
1197 }
1198
1199 *descriptor = desc;
1200 return success();
1201 }
1202
1203 // The shape is a rank-1 tensor with unknown length.
1204 Location loc = reshapeOp.getLoc();
1205 MemRefDescriptor shapeDesc(adaptor.getShape());
1206 Value resultRank = shapeDesc.size(builder&: rewriter, loc, pos: 0);
1207
1208 // Extract address space and element type.
1209 auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
1210 unsigned addressSpace =
1211 *getTypeConverter()->getMemRefAddressSpace(targetType);
1212
1213 // Create the unranked memref descriptor that holds the ranked one. The
1214 // inner descriptor is allocated on stack.
1215 auto targetDesc = UnrankedMemRefDescriptor::undef(
1216 rewriter, loc, typeConverter->convertType(targetType));
1217 targetDesc.setRank(rewriter, loc, resultRank);
1218 SmallVector<Value, 4> sizes;
1219 UnrankedMemRefDescriptor::computeSizes(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1220 values: targetDesc, addressSpaces: addressSpace, sizes);
1221 Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>(
1222 loc, getVoidPtrType(), IntegerType::get(getContext(), 8),
1223 sizes.front());
1224 targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
1225
1226 // Extract pointers and offset from the source memref.
1227 Value allocatedPtr, alignedPtr, offset;
1228 extractPointersAndOffset(loc, rewriter, *getTypeConverter(),
1229 reshapeOp.getSource(), adaptor.getSource(),
1230 &allocatedPtr, &alignedPtr, &offset);
1231
1232 // Set pointers and offset.
1233 auto elementPtrType =
1234 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
1235
1236 UnrankedMemRefDescriptor::setAllocatedPtr(builder&: rewriter, loc, memRefDescPtr: underlyingDescPtr,
1237 elemPtrType: elementPtrType, allocatedPtr);
1238 UnrankedMemRefDescriptor::setAlignedPtr(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1239 memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType,
1240 alignedPtr);
1241 UnrankedMemRefDescriptor::setOffset(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1242 memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType,
1243 offset);
1244
1245 // Use the offset pointer as base for further addressing. Copy over the new
1246 // shape and compute strides. For this, we create a loop from rank-1 to 0.
1247 Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr(
1248 builder&: rewriter, loc, typeConverter: *getTypeConverter(), memRefDescPtr: underlyingDescPtr, elemPtrType: elementPtrType);
1249 Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr(
1250 builder&: rewriter, loc, typeConverter: *getTypeConverter(), sizeBasePtr: targetSizesBase, rank: resultRank);
1251 Value shapeOperandPtr = shapeDesc.alignedPtr(builder&: rewriter, loc);
1252 Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1);
1253 Value resultRankMinusOne =
1254 rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
1255
1256 Block *initBlock = rewriter.getInsertionBlock();
1257 Type indexType = getTypeConverter()->getIndexType();
1258 Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
1259
1260 Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
1261 {indexType, indexType}, {loc, loc});
1262
1263 // Move the remaining initBlock ops to condBlock.
1264 Block *remainingBlock = rewriter.splitBlock(block: initBlock, before: remainingOpsIt);
1265 rewriter.mergeBlocks(source: remainingBlock, dest: condBlock, argValues: ValueRange());
1266
1267 rewriter.setInsertionPointToEnd(initBlock);
1268 rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}),
1269 condBlock);
1270 rewriter.setInsertionPointToStart(condBlock);
1271 Value indexArg = condBlock->getArgument(i: 0);
1272 Value strideArg = condBlock->getArgument(i: 1);
1273
1274 Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0);
1275 Value pred = rewriter.create<LLVM::ICmpOp>(
1276 loc, IntegerType::get(rewriter.getContext(), 1),
1277 LLVM::ICmpPredicate::sge, indexArg, zeroIndex);
1278
1279 Block *bodyBlock =
1280 rewriter.splitBlock(block: condBlock, before: rewriter.getInsertionPoint());
1281 rewriter.setInsertionPointToStart(bodyBlock);
1282
1283 // Copy size from shape to descriptor.
1284 auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
1285 Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
1286 loc, llvmIndexPtrType,
1287 typeConverter->convertType(shapeMemRefType.getElementType()),
1288 shapeOperandPtr, indexArg);
1289 Value size = rewriter.create<LLVM::LoadOp>(loc, indexType, sizeLoadGep);
1290 UnrankedMemRefDescriptor::setSize(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1291 sizeBasePtr: targetSizesBase, index: indexArg, size);
1292
1293 // Write stride value and compute next one.
1294 UnrankedMemRefDescriptor::setStride(builder&: rewriter, loc, typeConverter: *getTypeConverter(),
1295 strideBasePtr: targetStridesBase, index: indexArg, stride: strideArg);
1296 Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size);
1297
1298 // Decrement loop counter and branch back.
1299 Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex);
1300 rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}),
1301 condBlock);
1302
1303 Block *remainder =
1304 rewriter.splitBlock(block: bodyBlock, before: rewriter.getInsertionPoint());
1305
1306 // Hook up the cond exit to the remainder.
1307 rewriter.setInsertionPointToEnd(condBlock);
1308 rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, std::nullopt,
1309 remainder, std::nullopt);
1310
1311 // Reset position to beginning of new remainder block.
1312 rewriter.setInsertionPointToStart(remainder);
1313
1314 *descriptor = targetDesc;
1315 return success();
1316 }
1317};
1318
1319/// RessociatingReshapeOp must be expanded before we reach this stage.
1320/// Report that information.
1321template <typename ReshapeOp>
1322class ReassociatingReshapeOpConversion
1323 : public ConvertOpToLLVMPattern<ReshapeOp> {
1324public:
1325 using ConvertOpToLLVMPattern<ReshapeOp>::ConvertOpToLLVMPattern;
1326 using ReshapeOpAdaptor = typename ReshapeOp::Adaptor;
1327
1328 LogicalResult
1329 matchAndRewrite(ReshapeOp reshapeOp, typename ReshapeOp::Adaptor adaptor,
1330 ConversionPatternRewriter &rewriter) const override {
1331 return rewriter.notifyMatchFailure(
1332 reshapeOp,
1333 "reassociation operations should have been expanded beforehand");
1334 }
1335};
1336
1337/// Subviews must be expanded before we reach this stage.
1338/// Report that information.
1339struct SubViewOpLowering : public ConvertOpToLLVMPattern<memref::SubViewOp> {
1340 using ConvertOpToLLVMPattern<memref::SubViewOp>::ConvertOpToLLVMPattern;
1341
1342 LogicalResult
1343 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
1344 ConversionPatternRewriter &rewriter) const override {
1345 return rewriter.notifyMatchFailure(
1346 subViewOp, "subview operations should have been expanded beforehand");
1347 }
1348};
1349
1350/// Conversion pattern that transforms a transpose op into:
1351/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
1352/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
1353/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
1354/// and stride. Size and stride are permutations of the original values.
1355/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
1356/// The transpose op is replaced by the alloca'ed pointer.
1357class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
1358public:
1359 using ConvertOpToLLVMPattern<memref::TransposeOp>::ConvertOpToLLVMPattern;
1360
1361 LogicalResult
1362 matchAndRewrite(memref::TransposeOp transposeOp, OpAdaptor adaptor,
1363 ConversionPatternRewriter &rewriter) const override {
1364 auto loc = transposeOp.getLoc();
1365 MemRefDescriptor viewMemRef(adaptor.getIn());
1366
1367 // No permutation, early exit.
1368 if (transposeOp.getPermutation().isIdentity())
1369 return rewriter.replaceOp(transposeOp, {viewMemRef}), success();
1370
1371 auto targetMemRef = MemRefDescriptor::undef(
1372 rewriter, loc,
1373 typeConverter->convertType(transposeOp.getIn().getType()));
1374
1375 // Copy the base and aligned pointers from the old descriptor to the new
1376 // one.
1377 targetMemRef.setAllocatedPtr(rewriter, loc,
1378 viewMemRef.allocatedPtr(builder&: rewriter, loc: loc));
1379 targetMemRef.setAlignedPtr(rewriter, loc,
1380 viewMemRef.alignedPtr(builder&: rewriter, loc: loc));
1381
1382 // Copy the offset pointer from the old descriptor to the new one.
1383 targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(builder&: rewriter, loc: loc));
1384
1385 // Iterate over the dimensions and apply size/stride permutation:
1386 // When enumerating the results of the permutation map, the enumeration
1387 // index is the index into the target dimensions and the DimExpr points to
1388 // the dimension of the source memref.
1389 for (const auto &en :
1390 llvm::enumerate(transposeOp.getPermutation().getResults())) {
1391 int targetPos = en.index();
1392 int sourcePos = cast<AffineDimExpr>(en.value()).getPosition();
1393 targetMemRef.setSize(rewriter, loc, targetPos,
1394 viewMemRef.size(rewriter, loc, sourcePos));
1395 targetMemRef.setStride(rewriter, loc, targetPos,
1396 viewMemRef.stride(rewriter, loc, sourcePos));
1397 }
1398
1399 rewriter.replaceOp(transposeOp, {targetMemRef});
1400 return success();
1401 }
1402};
1403
1404/// Conversion pattern that transforms an op into:
1405/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
1406/// 2. Updates to the descriptor to introduce the data ptr, offset, size
1407/// and stride.
1408/// The view op is replaced by the descriptor.
1409struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
1410 using ConvertOpToLLVMPattern<memref::ViewOp>::ConvertOpToLLVMPattern;
1411
1412 // Build and return the value for the idx^th shape dimension, either by
1413 // returning the constant shape dimension or counting the proper dynamic size.
1414 Value getSize(ConversionPatternRewriter &rewriter, Location loc,
1415 ArrayRef<int64_t> shape, ValueRange dynamicSizes, unsigned idx,
1416 Type indexType) const {
1417 assert(idx < shape.size());
1418 if (!ShapedType::isDynamic(shape[idx]))
1419 return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]);
1420 // Count the number of dynamic dims in range [0, idx]
1421 unsigned nDynamic =
1422 llvm::count_if(shape.take_front(idx), ShapedType::isDynamic);
1423 return dynamicSizes[nDynamic];
1424 }
1425
1426 // Build and return the idx^th stride, either by returning the constant stride
1427 // or by computing the dynamic stride from the current `runningStride` and
1428 // `nextSize`. The caller should keep a running stride and update it with the
1429 // result returned by this function.
1430 Value getStride(ConversionPatternRewriter &rewriter, Location loc,
1431 ArrayRef<int64_t> strides, Value nextSize,
1432 Value runningStride, unsigned idx, Type indexType) const {
1433 assert(idx < strides.size());
1434 if (!ShapedType::isDynamic(strides[idx]))
1435 return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]);
1436 if (nextSize)
1437 return runningStride
1438 ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
1439 : nextSize;
1440 assert(!runningStride);
1441 return createIndexAttrConstant(rewriter, loc, indexType, 1);
1442 }
1443
1444 LogicalResult
1445 matchAndRewrite(memref::ViewOp viewOp, OpAdaptor adaptor,
1446 ConversionPatternRewriter &rewriter) const override {
1447 auto loc = viewOp.getLoc();
1448
1449 auto viewMemRefType = viewOp.getType();
1450 auto targetElementTy =
1451 typeConverter->convertType(viewMemRefType.getElementType());
1452 auto targetDescTy = typeConverter->convertType(viewMemRefType);
1453 if (!targetDescTy || !targetElementTy ||
1454 !LLVM::isCompatibleType(type: targetElementTy) ||
1455 !LLVM::isCompatibleType(type: targetDescTy))
1456 return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
1457 failure();
1458
1459 int64_t offset;
1460 SmallVector<int64_t, 4> strides;
1461 auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1462 if (failed(successStrides))
1463 return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
1464 assert(offset == 0 && "expected offset to be 0");
1465
1466 // Target memref must be contiguous in memory (innermost stride is 1), or
1467 // empty (special case when at least one of the memref dimensions is 0).
1468 if (!strides.empty() && (strides.back() != 1 && strides.back() != 0))
1469 return viewOp.emitWarning("cannot cast to non-contiguous shape"),
1470 failure();
1471
1472 // Create the descriptor.
1473 MemRefDescriptor sourceMemRef(adaptor.getSource());
1474 auto targetMemRef = MemRefDescriptor::undef(builder&: rewriter, loc: loc, descriptorType: targetDescTy);
1475
1476 // Field 1: Copy the allocated pointer, used for malloc/free.
1477 Value allocatedPtr = sourceMemRef.allocatedPtr(builder&: rewriter, loc: loc);
1478 auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
1479 targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);
1480
1481 // Field 2: Copy the actual aligned pointer to payload.
1482 Value alignedPtr = sourceMemRef.alignedPtr(builder&: rewriter, loc: loc);
1483 alignedPtr = rewriter.create<LLVM::GEPOp>(
1484 loc, alignedPtr.getType(),
1485 typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
1486 adaptor.getByteShift());
1487
1488 targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);
1489
1490 Type indexType = getIndexType();
1491 // Field 3: The offset in the resulting type must be 0. This is
1492 // because of the type change: an offset on srcType* may not be
1493 // expressible as an offset on dstType*.
1494 targetMemRef.setOffset(
1495 rewriter, loc,
1496 createIndexAttrConstant(rewriter, loc, indexType, offset));
1497
1498 // Early exit for 0-D corner case.
1499 if (viewMemRefType.getRank() == 0)
1500 return rewriter.replaceOp(viewOp, {targetMemRef}), success();
1501
1502 // Fields 4 and 5: Update sizes and strides.
1503 Value stride = nullptr, nextSize = nullptr;
1504 for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
1505 // Update size.
1506 Value size = getSize(rewriter, loc: loc, shape: viewMemRefType.getShape(),
1507 dynamicSizes: adaptor.getSizes(), idx: i, indexType);
1508 targetMemRef.setSize(rewriter, loc, i, size);
1509 // Update stride.
1510 stride =
1511 getStride(rewriter, loc: loc, strides: strides, nextSize, runningStride: stride, idx: i, indexType);
1512 targetMemRef.setStride(rewriter, loc, i, stride);
1513 nextSize = size;
1514 }
1515
1516 rewriter.replaceOp(viewOp, {targetMemRef});
1517 return success();
1518 }
1519};
1520
1521//===----------------------------------------------------------------------===//
1522// AtomicRMWOpLowering
1523//===----------------------------------------------------------------------===//
1524
1525/// Try to match the kind of a memref.atomic_rmw to determine whether to use a
1526/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg.
1527static std::optional<LLVM::AtomicBinOp>
1528matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
1529 switch (atomicOp.getKind()) {
1530 case arith::AtomicRMWKind::addf:
1531 return LLVM::AtomicBinOp::fadd;
1532 case arith::AtomicRMWKind::addi:
1533 return LLVM::AtomicBinOp::add;
1534 case arith::AtomicRMWKind::assign:
1535 return LLVM::AtomicBinOp::xchg;
1536 case arith::AtomicRMWKind::maximumf:
1537 return LLVM::AtomicBinOp::fmax;
1538 case arith::AtomicRMWKind::maxs:
1539 return LLVM::AtomicBinOp::max;
1540 case arith::AtomicRMWKind::maxu:
1541 return LLVM::AtomicBinOp::umax;
1542 case arith::AtomicRMWKind::minimumf:
1543 return LLVM::AtomicBinOp::fmin;
1544 case arith::AtomicRMWKind::mins:
1545 return LLVM::AtomicBinOp::min;
1546 case arith::AtomicRMWKind::minu:
1547 return LLVM::AtomicBinOp::umin;
1548 case arith::AtomicRMWKind::ori:
1549 return LLVM::AtomicBinOp::_or;
1550 case arith::AtomicRMWKind::andi:
1551 return LLVM::AtomicBinOp::_and;
1552 default:
1553 return std::nullopt;
1554 }
1555 llvm_unreachable("Invalid AtomicRMWKind");
1556}
1557
1558struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
1559 using Base::Base;
1560
1561 LogicalResult
1562 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
1563 ConversionPatternRewriter &rewriter) const override {
1564 auto maybeKind = matchSimpleAtomicOp(atomicOp);
1565 if (!maybeKind)
1566 return failure();
1567 auto memRefType = atomicOp.getMemRefType();
1568 SmallVector<int64_t> strides;
1569 int64_t offset;
1570 if (failed(getStridesAndOffset(memRefType, strides, offset)))
1571 return failure();
1572 auto dataPtr =
1573 getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
1574 adaptor.getIndices(), rewriter);
1575 rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
1576 atomicOp, *maybeKind, dataPtr, adaptor.getValue(),
1577 LLVM::AtomicOrdering::acq_rel);
1578 return success();
1579 }
1580};
1581
1582/// Unpack the pointer returned by a memref.extract_aligned_pointer_as_index.
1583class ConvertExtractAlignedPointerAsIndex
1584 : public ConvertOpToLLVMPattern<memref::ExtractAlignedPointerAsIndexOp> {
1585public:
1586 using ConvertOpToLLVMPattern<
1587 memref::ExtractAlignedPointerAsIndexOp>::ConvertOpToLLVMPattern;
1588
1589 LogicalResult
1590 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
1591 OpAdaptor adaptor,
1592 ConversionPatternRewriter &rewriter) const override {
1593 MemRefDescriptor desc(adaptor.getSource());
1594 rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(
1595 extractOp, getTypeConverter()->getIndexType(),
1596 desc.alignedPtr(rewriter, extractOp->getLoc()));
1597 return success();
1598 }
1599};
1600
1601/// Materialize the MemRef descriptor represented by the results of
1602/// ExtractStridedMetadataOp.
1603class ExtractStridedMetadataOpLowering
1604 : public ConvertOpToLLVMPattern<memref::ExtractStridedMetadataOp> {
1605public:
1606 using ConvertOpToLLVMPattern<
1607 memref::ExtractStridedMetadataOp>::ConvertOpToLLVMPattern;
1608
1609 LogicalResult
1610 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1611 OpAdaptor adaptor,
1612 ConversionPatternRewriter &rewriter) const override {
1613
1614 if (!LLVM::isCompatibleType(type: adaptor.getOperands().front().getType()))
1615 return failure();
1616
1617 // Create the descriptor.
1618 MemRefDescriptor sourceMemRef(adaptor.getSource());
1619 Location loc = extractStridedMetadataOp.getLoc();
1620 Value source = extractStridedMetadataOp.getSource();
1621
1622 auto sourceMemRefType = cast<MemRefType>(source.getType());
1623 int64_t rank = sourceMemRefType.getRank();
1624 SmallVector<Value> results;
1625 results.reserve(2 + rank * 2);
1626
1627 // Base buffer.
1628 Value baseBuffer = sourceMemRef.allocatedPtr(builder&: rewriter, loc);
1629 Value alignedBuffer = sourceMemRef.alignedPtr(builder&: rewriter, loc);
1630 MemRefDescriptor dstMemRef = MemRefDescriptor::fromStaticShape(
1631 rewriter, loc, *getTypeConverter(),
1632 cast<MemRefType>(extractStridedMetadataOp.getBaseBuffer().getType()),
1633 baseBuffer, alignedBuffer);
1634 results.push_back((Value)dstMemRef);
1635
1636 // Offset.
1637 results.push_back(sourceMemRef.offset(builder&: rewriter, loc));
1638
1639 // Sizes.
1640 for (unsigned i = 0; i < rank; ++i)
1641 results.push_back(sourceMemRef.size(builder&: rewriter, loc, pos: i));
1642 // Strides.
1643 for (unsigned i = 0; i < rank; ++i)
1644 results.push_back(sourceMemRef.stride(builder&: rewriter, loc, pos: i));
1645
1646 rewriter.replaceOp(extractStridedMetadataOp, results);
1647 return success();
1648 }
1649};
1650
1651} // namespace
1652
1653void mlir::populateFinalizeMemRefToLLVMConversionPatterns(
1654 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1655 // clang-format off
1656 patterns.add<
1657 AllocaOpLowering,
1658 AllocaScopeOpLowering,
1659 AtomicRMWOpLowering,
1660 AssumeAlignmentOpLowering,
1661 ConvertExtractAlignedPointerAsIndex,
1662 DimOpLowering,
1663 ExtractStridedMetadataOpLowering,
1664 GenericAtomicRMWOpLowering,
1665 GlobalMemrefOpLowering,
1666 GetGlobalMemrefOpLowering,
1667 LoadOpLowering,
1668 MemRefCastOpLowering,
1669 MemRefCopyOpLowering,
1670 MemorySpaceCastOpLowering,
1671 MemRefReinterpretCastOpLowering,
1672 MemRefReshapeOpLowering,
1673 PrefetchOpLowering,
1674 RankOpLowering,
1675 ReassociatingReshapeOpConversion<memref::ExpandShapeOp>,
1676 ReassociatingReshapeOpConversion<memref::CollapseShapeOp>,
1677 StoreOpLowering,
1678 SubViewOpLowering,
1679 TransposeOpLowering,
1680 ViewOpLowering>(converter);
1681 // clang-format on
1682 auto allocLowering = converter.getOptions().allocLowering;
1683 if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
1684 patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(arg&: converter);
1685 else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
1686 patterns.add<AllocOpLowering, DeallocOpLowering>(arg&: converter);
1687}
1688
1689namespace {
1690struct FinalizeMemRefToLLVMConversionPass
1691 : public impl::FinalizeMemRefToLLVMConversionPassBase<
1692 FinalizeMemRefToLLVMConversionPass> {
1693 using FinalizeMemRefToLLVMConversionPassBase::
1694 FinalizeMemRefToLLVMConversionPassBase;
1695
1696 void runOnOperation() override {
1697 Operation *op = getOperation();
1698 const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
1699 LowerToLLVMOptions options(&getContext(),
1700 dataLayoutAnalysis.getAtOrAbove(op));
1701 options.allocLowering =
1702 (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
1703 : LowerToLLVMOptions::AllocLowering::Malloc);
1704
1705 options.useGenericFunctions = useGenericFunctions;
1706
1707 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
1708 options.overrideIndexBitwidth(indexBitwidth);
1709
1710 LLVMTypeConverter typeConverter(&getContext(), options,
1711 &dataLayoutAnalysis);
1712 RewritePatternSet patterns(&getContext());
1713 populateFinalizeMemRefToLLVMConversionPatterns(converter&: typeConverter, patterns);
1714 LLVMConversionTarget target(getContext());
1715 target.addLegalOp<func::FuncOp>();
1716 if (failed(applyPartialConversion(op, target, std::move(patterns))))
1717 signalPassFailure();
1718 }
1719};
1720
1721/// Implement the interface to convert MemRef to LLVM.
1722struct MemRefToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
1723 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
1724 void loadDependentDialects(MLIRContext *context) const final {
1725 context->loadDialect<LLVM::LLVMDialect>();
1726 }
1727
1728 /// Hook for derived dialect interface to provide conversion patterns
1729 /// and mark dialect legal for the conversion target.
1730 void populateConvertToLLVMConversionPatterns(
1731 ConversionTarget &target, LLVMTypeConverter &typeConverter,
1732 RewritePatternSet &patterns) const final {
1733 populateFinalizeMemRefToLLVMConversionPatterns(converter&: typeConverter, patterns);
1734 }
1735};
1736
1737} // namespace
1738
1739void mlir::registerConvertMemRefToLLVMInterface(DialectRegistry &registry) {
1740 registry.addExtension(extensionFn: +[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
1741 dialect->addInterfaces<MemRefToLLVMDialectInterface>();
1742 });
1743}
1744

source code of mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp