1//===- LLVMMemorySlot.cpp - MemorySlot interfaces ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements MemorySlot-related interfaces for LLVM dialect
10// operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16#include "mlir/IR/Matchers.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/Interfaces/DataLayoutInterfaces.h"
19#include "mlir/Interfaces/MemorySlotInterfaces.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23#define DEBUG_TYPE "sroa"
24
25using namespace mlir;
26
27//===----------------------------------------------------------------------===//
28// Interfaces for AllocaOp
29//===----------------------------------------------------------------------===//
30
31llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
32 if (!getOperation()->getBlock()->isEntryBlock())
33 return {};
34
35 return {MemorySlot{getResult(), getElemType()}};
36}
37
38Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
39 RewriterBase &rewriter) {
40 return rewriter.create<LLVM::UndefOp>(getLoc(), slot.elemType);
41}
42
43void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
44 BlockArgument argument,
45 RewriterBase &rewriter) {
46 for (Operation *user : getOperation()->getUsers())
47 if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
48 rewriter.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
49 declareOp.getVarInfo(),
50 declareOp.getLocationExpr());
51}
52
53void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
54 Value defaultValue,
55 RewriterBase &rewriter) {
56 if (defaultValue && defaultValue.use_empty())
57 rewriter.eraseOp(defaultValue.getDefiningOp());
58 rewriter.eraseOp(*this);
59}
60
61SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
62 if (!mlir::matchPattern(getArraySize(), m_One()))
63 return {};
64
65 auto destructurable = dyn_cast<DestructurableTypeInterface>(getElemType());
66 if (!destructurable)
67 return {};
68
69 std::optional<DenseMap<Attribute, Type>> destructuredType =
70 destructurable.getSubelementIndexMap();
71 if (!destructuredType)
72 return {};
73
74 return {DestructurableMemorySlot{{getResult(), getElemType()},
75 *destructuredType}};
76}
77
78DenseMap<Attribute, MemorySlot>
79LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
80 const SmallPtrSetImpl<Attribute> &usedIndices,
81 RewriterBase &rewriter) {
82 assert(slot.ptr == getResult());
83 rewriter.setInsertionPointAfter(*this);
84
85 auto destructurableType = cast<DestructurableTypeInterface>(getElemType());
86 DenseMap<Attribute, MemorySlot> slotMap;
87 for (Attribute index : usedIndices) {
88 Type elemType = destructurableType.getTypeAtIndex(index);
89 assert(elemType && "used index must exist");
90 auto subAlloca = rewriter.create<LLVM::AllocaOp>(
91 getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
92 getArraySize());
93 slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
94 }
95
96 return slotMap;
97}
98
99void LLVM::AllocaOp::handleDestructuringComplete(
100 const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
101 assert(slot.ptr == getResult());
102 rewriter.eraseOp(*this);
103}
104
105//===----------------------------------------------------------------------===//
106// Interfaces for LoadOp/StoreOp
107//===----------------------------------------------------------------------===//
108
109bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
110 return getAddr() == slot.ptr;
111}
112
113bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
114
115Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
116 Value reachingDef, const DataLayout &dataLayout) {
117 llvm_unreachable("getStored should not be called on LoadOp");
118}
119
120bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
121
122bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
123 return getAddr() == slot.ptr;
124}
125
126/// Checks if `type` can be used in any kind of conversion sequences.
127static bool isSupportedTypeForConversion(Type type) {
128 // Aggregate types are not bitcastable.
129 if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
130 return false;
131
132 // LLVM vector types are only used for either pointers or target specific
133 // types. These types cannot be casted in the general case, thus the memory
134 // optimizations do not support them.
135 if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
136 return false;
137
138 // Scalable types are not supported.
139 if (auto vectorType = dyn_cast<VectorType>(type))
140 return !vectorType.isScalable();
141 return true;
142}
143
144/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
145/// truncations. Checks for narrowing or widening conversion compatibility
146/// depending on `narrowingConversion`.
147static bool areConversionCompatible(const DataLayout &layout, Type targetType,
148 Type srcType, bool narrowingConversion) {
149 if (targetType == srcType)
150 return true;
151
152 if (!isSupportedTypeForConversion(type: targetType) ||
153 !isSupportedTypeForConversion(type: srcType))
154 return false;
155
156 uint64_t targetSize = layout.getTypeSize(t: targetType);
157 uint64_t srcSize = layout.getTypeSize(t: srcType);
158
159 // Pointer casts will only be sane when the bitsize of both pointer types is
160 // the same.
161 if (isa<LLVM::LLVMPointerType>(targetType) &&
162 isa<LLVM::LLVMPointerType>(srcType))
163 return targetSize == srcSize;
164
165 if (narrowingConversion)
166 return targetSize <= srcSize;
167 return targetSize >= srcSize;
168}
169
170/// Checks if `dataLayout` describes a little endian layout.
171static bool isBigEndian(const DataLayout &dataLayout) {
172 auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
173 return endiannessStr && endiannessStr == "big";
174}
175
176/// Converts a value to an integer type of the same size.
177/// Assumes that the type can be converted.
178static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
179 const DataLayout &dataLayout) {
180 Type type = val.getType();
181 assert(isSupportedTypeForConversion(type) &&
182 "expected value to have a convertible type");
183
184 if (isa<IntegerType>(Val: type))
185 return val;
186
187 uint64_t typeBitSize = dataLayout.getTypeSizeInBits(t: type);
188 IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);
189
190 if (isa<LLVM::LLVMPointerType>(type))
191 return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
192 return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
193}
194
195/// Converts a value with an integer type to `targetType`.
196static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
197 Value val, Type targetType) {
198 assert(isa<IntegerType>(val.getType()) &&
199 "expected value to have an integer type");
200 assert(isSupportedTypeForConversion(targetType) &&
201 "expected the target type to be supported for conversions");
202 if (val.getType() == targetType)
203 return val;
204 if (isa<LLVM::LLVMPointerType>(targetType))
205 return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
206 return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
207}
208
209/// Constructs operations that convert `srcValue` into a new value of type
210/// `targetType`. Assumes the types have the same bitsize.
211static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
212 Value srcValue, Type targetType,
213 const DataLayout &dataLayout) {
214 Type srcType = srcValue.getType();
215 assert(areConversionCompatible(dataLayout, targetType, srcType,
216 /*narrowingConversion=*/true) &&
217 "expected that the compatibility was checked before");
218
219 // Nothing has to be done if the types are already the same.
220 if (srcType == targetType)
221 return srcValue;
222
223 // In the special case of casting one pointer to another, we want to generate
224 // an address space cast. Bitcasts of pointers are not allowed and using
225 // pointer to integer conversions are not equivalent due to the loss of
226 // provenance.
227 if (isa<LLVM::LLVMPointerType>(targetType) &&
228 isa<LLVM::LLVMPointerType>(srcType))
229 return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
230 srcValue);
231
232 // For all other castable types, casting through integers is necessary.
233 Value replacement = castToSameSizedInt(rewriter, loc, val: srcValue, dataLayout);
234 return castIntValueToSameSizedType(rewriter, loc, val: replacement, targetType);
235}
236
237/// Constructs operations that convert `srcValue` into a new value of type
238/// `targetType`. Performs bit-level extraction if the source type is larger
239/// than the target type. Assumes that this conversion is possible.
240static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
241 Value srcValue, Type targetType,
242 const DataLayout &dataLayout) {
243 // Get the types of the source and target values.
244 Type srcType = srcValue.getType();
245 assert(areConversionCompatible(dataLayout, targetType, srcType,
246 /*narrowingConversion=*/true) &&
247 "expected that the compatibility was checked before");
248
249 uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(t: srcType);
250 uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(t: targetType);
251 if (srcTypeSize == targetTypeSize)
252 return castSameSizedTypes(rewriter, loc, srcValue, targetType, dataLayout);
253
254 // First, cast the value to a same-sized integer type.
255 Value replacement = castToSameSizedInt(rewriter, loc, val: srcValue, dataLayout);
256
257 // Truncate the integer if the size of the target is less than the value.
258 if (isBigEndian(dataLayout)) {
259 uint64_t shiftAmount = srcTypeSize - targetTypeSize;
260 auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
261 loc, rewriter.getIntegerAttr(srcType, shiftAmount));
262 replacement =
263 rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
264 }
265
266 replacement = rewriter.create<LLVM::TruncOp>(
267 loc, rewriter.getIntegerType(targetTypeSize), replacement);
268
269 // Now cast the integer to the actual target type if required.
270 return castIntValueToSameSizedType(rewriter, loc, val: replacement, targetType);
271}
272
273/// Constructs operations that insert the bits of `srcValue` into the
274/// "beginning" of `reachingDef` (beginning is endianness dependent).
275/// Assumes that this conversion is possible.
276static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
277 Value srcValue, Value reachingDef,
278 const DataLayout &dataLayout) {
279
280 assert(areConversionCompatible(dataLayout, reachingDef.getType(),
281 srcValue.getType(),
282 /*narrowingConversion=*/false) &&
283 "expected that the compatibility was checked before");
284 uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(t: srcValue.getType());
285 uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(t: reachingDef.getType());
286 if (slotTypeSize == valueTypeSize)
287 return castSameSizedTypes(rewriter, loc, srcValue, targetType: reachingDef.getType(),
288 dataLayout);
289
290 // In the case where the store only overwrites parts of the memory,
291 // bit fiddling is required to construct the new value.
292
293 // First convert both values to integers of the same size.
294 Value defAsInt = castToSameSizedInt(rewriter, loc, val: reachingDef, dataLayout);
295 Value valueAsInt = castToSameSizedInt(rewriter, loc, val: srcValue, dataLayout);
296 // Extend the value to the size of the reaching definition.
297 valueAsInt =
298 rewriter.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
299 uint64_t sizeDifference = slotTypeSize - valueTypeSize;
300 if (isBigEndian(dataLayout)) {
301 // On big endian systems, a store to the base pointer overwrites the most
302 // significant bits. To accomodate for this, the stored value needs to be
303 // shifted into the according position.
304 Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
305 loc, rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
306 valueAsInt =
307 rewriter.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
308 }
309
310 // Construct the mask that is used to erase the bits that are overwritten by
311 // the store.
312 APInt maskValue;
313 if (isBigEndian(dataLayout)) {
314 // Build a mask that has the most significant bits set to zero.
315 // Note: This is the same as 2^sizeDifference - 1
316 maskValue = APInt::getAllOnes(numBits: sizeDifference).zext(width: slotTypeSize);
317 } else {
318 // Build a mask that has the least significant bits set to zero.
319 // Note: This is the same as -(2^valueTypeSize)
320 maskValue = APInt::getAllOnes(numBits: valueTypeSize).zext(width: slotTypeSize);
321 maskValue.flipAllBits();
322 }
323
324 // Mask out the affected bits ...
325 Value mask = rewriter.create<LLVM::ConstantOp>(
326 loc, rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
327 Value masked = rewriter.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
328
329 // ... and combine the result with the new value.
330 Value combined = rewriter.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
331
332 return castIntValueToSameSizedType(rewriter, loc, val: combined,
333 targetType: reachingDef.getType());
334}
335
336Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
337 Value reachingDef,
338 const DataLayout &dataLayout) {
339 assert(reachingDef && reachingDef.getType() == slot.elemType &&
340 "expected the reaching definition's type to match the slot's type");
341 return createInsertAndCast(rewriter, getLoc(), getValue(), reachingDef,
342 dataLayout);
343}
344
345bool LLVM::LoadOp::canUsesBeRemoved(
346 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
347 SmallVectorImpl<OpOperand *> &newBlockingUses,
348 const DataLayout &dataLayout) {
349 if (blockingUses.size() != 1)
350 return false;
351 Value blockingUse = (*blockingUses.begin())->get();
352 // If the blocking use is the slot ptr itself, there will be enough
353 // context to reconstruct the result of the load at removal time, so it can
354 // be removed (provided it is not volatile).
355 return blockingUse == slot.ptr && getAddr() == slot.ptr &&
356 areConversionCompatible(dataLayout, getResult().getType(),
357 slot.elemType, /*narrowingConversion=*/true) &&
358 !getVolatile_();
359}
360
361DeletionKind LLVM::LoadOp::removeBlockingUses(
362 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
363 RewriterBase &rewriter, Value reachingDefinition,
364 const DataLayout &dataLayout) {
365 // `canUsesBeRemoved` checked this blocking use must be the loaded slot
366 // pointer.
367 Value newResult = createExtractAndCast(rewriter, getLoc(), reachingDefinition,
368 getResult().getType(), dataLayout);
369 rewriter.replaceAllUsesWith(getResult(), newResult);
370 return DeletionKind::Delete;
371}
372
373bool LLVM::StoreOp::canUsesBeRemoved(
374 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
375 SmallVectorImpl<OpOperand *> &newBlockingUses,
376 const DataLayout &dataLayout) {
377 if (blockingUses.size() != 1)
378 return false;
379 Value blockingUse = (*blockingUses.begin())->get();
380 // If the blocking use is the slot ptr itself, dropping the store is
381 // fine, provided we are currently promoting its target value. Don't allow a
382 // store OF the slot pointer, only INTO the slot pointer.
383 return blockingUse == slot.ptr && getAddr() == slot.ptr &&
384 getValue() != slot.ptr &&
385 areConversionCompatible(dataLayout, slot.elemType,
386 getValue().getType(),
387 /*narrowingConversion=*/false) &&
388 !getVolatile_();
389}
390
391DeletionKind LLVM::StoreOp::removeBlockingUses(
392 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
393 RewriterBase &rewriter, Value reachingDefinition,
394 const DataLayout &dataLayout) {
395 return DeletionKind::Delete;
396}
397
398/// Checks if `slot` can be accessed through the provided access type.
399static bool isValidAccessType(const MemorySlot &slot, Type accessType,
400 const DataLayout &dataLayout) {
401 return dataLayout.getTypeSize(t: accessType) <=
402 dataLayout.getTypeSize(t: slot.elemType);
403}
404
405LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
406 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
407 const DataLayout &dataLayout) {
408 return success(getAddr() != slot.ptr ||
409 isValidAccessType(slot, getType(), dataLayout));
410}
411
412LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
413 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
414 const DataLayout &dataLayout) {
415 return success(getAddr() != slot.ptr ||
416 isValidAccessType(slot, getValue().getType(), dataLayout));
417}
418
419/// Returns the subslot's type at the requested index.
420static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
421 Attribute index) {
422 auto subelementIndexMap =
423 cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
424 if (!subelementIndexMap)
425 return {};
426 assert(!subelementIndexMap->empty());
427
428 // Note: Returns a null-type when no entry was found.
429 return subelementIndexMap->lookup(index);
430}
431
432bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
433 SmallPtrSetImpl<Attribute> &usedIndices,
434 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
435 const DataLayout &dataLayout) {
436 if (getVolatile_())
437 return false;
438
439 // A load always accesses the first element of the destructured slot.
440 auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
441 Type subslotType = getTypeAtIndex(slot, index);
442 if (!subslotType)
443 return false;
444
445 // The access can only be replaced when the subslot is read within its bounds.
446 if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
447 return false;
448
449 usedIndices.insert(index);
450 return true;
451}
452
453DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
454 DenseMap<Attribute, MemorySlot> &subslots,
455 RewriterBase &rewriter,
456 const DataLayout &dataLayout) {
457 auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
458 auto it = subslots.find(index);
459 assert(it != subslots.end());
460
461 rewriter.modifyOpInPlace(
462 *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
463 return DeletionKind::Keep;
464}
465
466bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
467 SmallPtrSetImpl<Attribute> &usedIndices,
468 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
469 const DataLayout &dataLayout) {
470 if (getVolatile_())
471 return false;
472
473 // Storing the pointer to memory cannot be dealt with.
474 if (getValue() == slot.ptr)
475 return false;
476
477 // A store always accesses the first element of the destructured slot.
478 auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
479 Type subslotType = getTypeAtIndex(slot, index);
480 if (!subslotType)
481 return false;
482
483 // The access can only be replaced when the subslot is read within its bounds.
484 if (dataLayout.getTypeSize(getValue().getType()) >
485 dataLayout.getTypeSize(subslotType))
486 return false;
487
488 usedIndices.insert(index);
489 return true;
490}
491
492DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
493 DenseMap<Attribute, MemorySlot> &subslots,
494 RewriterBase &rewriter,
495 const DataLayout &dataLayout) {
496 auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
497 auto it = subslots.find(index);
498 assert(it != subslots.end());
499
500 rewriter.modifyOpInPlace(
501 *this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
502 return DeletionKind::Keep;
503}
504
505//===----------------------------------------------------------------------===//
506// Interfaces for discardable OPs
507//===----------------------------------------------------------------------===//
508
509/// Conditions the deletion of the operation to the removal of all its uses.
510static bool forwardToUsers(Operation *op,
511 SmallVectorImpl<OpOperand *> &newBlockingUses) {
512 for (Value result : op->getResults())
513 for (OpOperand &use : result.getUses())
514 newBlockingUses.push_back(Elt: &use);
515 return true;
516}
517
518bool LLVM::BitcastOp::canUsesBeRemoved(
519 const SmallPtrSetImpl<OpOperand *> &blockingUses,
520 SmallVectorImpl<OpOperand *> &newBlockingUses,
521 const DataLayout &dataLayout) {
522 return forwardToUsers(*this, newBlockingUses);
523}
524
525DeletionKind LLVM::BitcastOp::removeBlockingUses(
526 const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
527 return DeletionKind::Delete;
528}
529
530bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
531 const SmallPtrSetImpl<OpOperand *> &blockingUses,
532 SmallVectorImpl<OpOperand *> &newBlockingUses,
533 const DataLayout &dataLayout) {
534 return forwardToUsers(*this, newBlockingUses);
535}
536
537DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
538 const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
539 return DeletionKind::Delete;
540}
541
542bool LLVM::LifetimeStartOp::canUsesBeRemoved(
543 const SmallPtrSetImpl<OpOperand *> &blockingUses,
544 SmallVectorImpl<OpOperand *> &newBlockingUses,
545 const DataLayout &dataLayout) {
546 return true;
547}
548
549DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
550 const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
551 return DeletionKind::Delete;
552}
553
554bool LLVM::LifetimeEndOp::canUsesBeRemoved(
555 const SmallPtrSetImpl<OpOperand *> &blockingUses,
556 SmallVectorImpl<OpOperand *> &newBlockingUses,
557 const DataLayout &dataLayout) {
558 return true;
559}
560
561DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
562 const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
563 return DeletionKind::Delete;
564}
565
566bool LLVM::InvariantStartOp::canUsesBeRemoved(
567 const SmallPtrSetImpl<OpOperand *> &blockingUses,
568 SmallVectorImpl<OpOperand *> &newBlockingUses,
569 const DataLayout &dataLayout) {
570 return true;
571}
572
573DeletionKind LLVM::InvariantStartOp::removeBlockingUses(
574 const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
575 return DeletionKind::Delete;
576}
577
578bool LLVM::InvariantEndOp::canUsesBeRemoved(
579 const SmallPtrSetImpl<OpOperand *> &blockingUses,
580 SmallVectorImpl<OpOperand *> &newBlockingUses,
581 const DataLayout &dataLayout) {
582 return true;
583}
584
585DeletionKind LLVM::InvariantEndOp::removeBlockingUses(
586 const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
587 return DeletionKind::Delete;
588}
589
590bool LLVM::DbgDeclareOp::canUsesBeRemoved(
591 const SmallPtrSetImpl<OpOperand *> &blockingUses,
592 SmallVectorImpl<OpOperand *> &newBlockingUses,
593 const DataLayout &dataLayout) {
594 return true;
595}
596
597DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
598 const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
599 return DeletionKind::Delete;
600}
601
602bool LLVM::DbgValueOp::canUsesBeRemoved(
603 const SmallPtrSetImpl<OpOperand *> &blockingUses,
604 SmallVectorImpl<OpOperand *> &newBlockingUses,
605 const DataLayout &dataLayout) {
606 // There is only one operand that we can remove the use of.
607 if (blockingUses.size() != 1)
608 return false;
609
610 return (*blockingUses.begin())->get() == getValue();
611}
612
613DeletionKind LLVM::DbgValueOp::removeBlockingUses(
614 const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
615 // Rewriter by default is after '*this', but we need it before '*this'.
616 rewriter.setInsertionPoint(*this);
617
618 // Rather than dropping the debug value, replace it with undef to preserve the
619 // debug local variable info. This allows the debugger to inform the user that
620 // the variable has been optimized out.
621 auto undef =
622 rewriter.create<UndefOp>(getValue().getLoc(), getValue().getType());
623 rewriter.modifyOpInPlace(*this, [&] { getValueMutable().assign(undef); });
624 return DeletionKind::Keep;
625}
626
627bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; }
628
629void LLVM::DbgDeclareOp::visitReplacedValues(
630 ArrayRef<std::pair<Operation *, Value>> definitions,
631 RewriterBase &rewriter) {
632 for (auto [op, value] : definitions) {
633 rewriter.setInsertionPointAfter(op);
634 rewriter.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
635 getLocationExpr());
636 }
637}
638
639//===----------------------------------------------------------------------===//
640// Interfaces for GEPOp
641//===----------------------------------------------------------------------===//
642
643static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
644 return llvm::all_of(gepOp.getIndices(), [](auto index) {
645 auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
646 return indexAttr && indexAttr.getValue() == 0;
647 });
648}
649
650bool LLVM::GEPOp::canUsesBeRemoved(
651 const SmallPtrSetImpl<OpOperand *> &blockingUses,
652 SmallVectorImpl<OpOperand *> &newBlockingUses,
653 const DataLayout &dataLayout) {
654 // GEP can be removed as long as it is a no-op and its users can be removed.
655 if (!hasAllZeroIndices(*this))
656 return false;
657 return forwardToUsers(*this, newBlockingUses);
658}
659
660DeletionKind LLVM::GEPOp::removeBlockingUses(
661 const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) {
662 return DeletionKind::Delete;
663}
664
665/// Returns the amount of bytes the provided GEP elements will offset the
666/// pointer by. Returns nullopt if no constant offset could be computed.
667static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
668 LLVM::GEPOp gep) {
669 // Collects all indices.
670 SmallVector<uint64_t> indices;
671 for (auto index : gep.getIndices()) {
672 auto constIndex = dyn_cast<IntegerAttr>(index);
673 if (!constIndex)
674 return {};
675 int64_t gepIndex = constIndex.getInt();
676 // Negative indices are not supported.
677 if (gepIndex < 0)
678 return {};
679 indices.push_back(gepIndex);
680 }
681
682 Type currentType = gep.getElemType();
683 uint64_t offset = indices[0] * dataLayout.getTypeSize(t: currentType);
684
685 for (uint64_t index : llvm::drop_begin(RangeOrContainer&: indices)) {
686 bool shouldCancel =
687 TypeSwitch<Type, bool>(currentType)
688 .Case(caseFn: [&](LLVM::LLVMArrayType arrayType) {
689 offset +=
690 index * dataLayout.getTypeSize(t: arrayType.getElementType());
691 currentType = arrayType.getElementType();
692 return false;
693 })
694 .Case(caseFn: [&](LLVM::LLVMStructType structType) {
695 ArrayRef<Type> body = structType.getBody();
696 assert(index < body.size() && "expected valid struct indexing");
697 for (uint32_t i : llvm::seq(Size: index)) {
698 if (!structType.isPacked())
699 offset = llvm::alignTo(
700 Value: offset, Align: dataLayout.getTypeABIAlignment(t: body[i]));
701 offset += dataLayout.getTypeSize(t: body[i]);
702 }
703
704 // Align for the current type as well.
705 if (!structType.isPacked())
706 offset = llvm::alignTo(
707 Value: offset, Align: dataLayout.getTypeABIAlignment(t: body[index]));
708 currentType = body[index];
709 return false;
710 })
711 .Default(defaultFn: [&](Type type) {
712 LLVM_DEBUG(llvm::dbgs()
713 << "[sroa] Unsupported type for offset computations"
714 << type << "\n");
715 return true;
716 });
717
718 if (shouldCancel)
719 return std::nullopt;
720 }
721
722 return offset;
723}
724
725namespace {
726/// A struct that stores both the index into the aggregate type of the slot as
727/// well as the corresponding byte offset in memory.
728struct SubslotAccessInfo {
729 /// The parent slot's index that the access falls into.
730 uint32_t index;
731 /// The offset into the subslot of the access.
732 uint64_t subslotOffset;
733};
734} // namespace
735
736/// Computes subslot access information for an access into `slot` with the given
737/// offset.
738/// Returns nullopt when the offset is out-of-bounds or when the access is into
739/// the padding of `slot`.
740static std::optional<SubslotAccessInfo>
741getSubslotAccessInfo(const DestructurableMemorySlot &slot,
742 const DataLayout &dataLayout, LLVM::GEPOp gep) {
743 std::optional<uint64_t> offset = gepToByteOffset(dataLayout, gep);
744 if (!offset)
745 return {};
746
747 // Helper to check that a constant index is in the bounds of the GEP index
748 // representation. LLVM dialects's GEP arguments have a limited bitwidth, thus
749 // this additional check is necessary.
750 auto isOutOfBoundsGEPIndex = [](uint64_t index) {
751 return index >= (1 << LLVM::kGEPConstantBitWidth);
752 };
753
754 Type type = slot.elemType;
755 if (*offset >= dataLayout.getTypeSize(t: type))
756 return {};
757 return TypeSwitch<Type, std::optional<SubslotAccessInfo>>(type)
758 .Case(caseFn: [&](LLVM::LLVMArrayType arrayType)
759 -> std::optional<SubslotAccessInfo> {
760 // Find which element of the array contains the offset.
761 uint64_t elemSize = dataLayout.getTypeSize(t: arrayType.getElementType());
762 uint64_t index = *offset / elemSize;
763 if (isOutOfBoundsGEPIndex(index))
764 return {};
765 return SubslotAccessInfo{.index: static_cast<uint32_t>(index),
766 .subslotOffset: *offset - (index * elemSize)};
767 })
768 .Case(caseFn: [&](LLVM::LLVMStructType structType)
769 -> std::optional<SubslotAccessInfo> {
770 uint64_t distanceToStart = 0;
771 // Walk over the elements of the struct to find in which of
772 // them the offset is.
773 for (auto [index, elem] : llvm::enumerate(First: structType.getBody())) {
774 uint64_t elemSize = dataLayout.getTypeSize(t: elem);
775 if (!structType.isPacked()) {
776 distanceToStart = llvm::alignTo(
777 Value: distanceToStart, Align: dataLayout.getTypeABIAlignment(t: elem));
778 // If the offset is in padding, cancel the rewrite.
779 if (offset < distanceToStart)
780 return {};
781 }
782
783 if (offset < distanceToStart + elemSize) {
784 if (isOutOfBoundsGEPIndex(index))
785 return {};
786 // The offset is within this element, stop iterating the
787 // struct and return the index.
788 return SubslotAccessInfo{.index: static_cast<uint32_t>(index),
789 .subslotOffset: *offset - distanceToStart};
790 }
791
792 // The offset is not within this element, continue walking
793 // over the struct.
794 distanceToStart += elemSize;
795 }
796
797 return {};
798 });
799}
800
801/// Constructs a byte array type of the given size.
802static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context,
803 unsigned size) {
804 auto byteType = IntegerType::get(context, 8);
805 return LLVM::LLVMArrayType::get(context, byteType, size);
806}
807
808LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
809 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
810 const DataLayout &dataLayout) {
811 if (getBase() != slot.ptr)
812 return success();
813 std::optional<uint64_t> gepOffset = gepToByteOffset(dataLayout, *this);
814 if (!gepOffset)
815 return failure();
816 uint64_t slotSize = dataLayout.getTypeSize(slot.elemType);
817 // Check that the access is strictly inside the slot.
818 if (*gepOffset >= slotSize)
819 return failure();
820 // Every access that remains in bounds of the remaining slot is considered
821 // legal.
822 mustBeSafelyUsed.emplace_back<MemorySlot>(
823 {getRes(), getByteArrayType(getContext(), slotSize - *gepOffset)});
824 return success();
825}
826
827bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
828 SmallPtrSetImpl<Attribute> &usedIndices,
829 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
830 const DataLayout &dataLayout) {
831 if (!isa<LLVM::LLVMPointerType>(getBase().getType()))
832 return false;
833
834 if (getBase() != slot.ptr)
835 return false;
836 std::optional<SubslotAccessInfo> accessInfo =
837 getSubslotAccessInfo(slot, dataLayout, *this);
838 if (!accessInfo)
839 return false;
840 auto indexAttr =
841 IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
842 assert(slot.elementPtrs.contains(indexAttr));
843 usedIndices.insert(indexAttr);
844
845 // The remainder of the subslot should be accesses in-bounds. Thus, we create
846 // a dummy slot with the size of the remainder.
847 Type subslotType = slot.elementPtrs.lookup(indexAttr);
848 uint64_t slotSize = dataLayout.getTypeSize(subslotType);
849 LLVM::LLVMArrayType remainingSlotType =
850 getByteArrayType(getContext(), slotSize - accessInfo->subslotOffset);
851 mustBeSafelyUsed.emplace_back<MemorySlot>({getRes(), remainingSlotType});
852
853 return true;
854}
855
856DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
857 DenseMap<Attribute, MemorySlot> &subslots,
858 RewriterBase &rewriter,
859 const DataLayout &dataLayout) {
860 std::optional<SubslotAccessInfo> accessInfo =
861 getSubslotAccessInfo(slot, dataLayout, *this);
862 assert(accessInfo && "expected access info to be checked before");
863 auto indexAttr =
864 IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
865 const MemorySlot &newSlot = subslots.at(indexAttr);
866
867 auto byteType = IntegerType::get(rewriter.getContext(), 8);
868 auto newPtr = rewriter.createOrFold<LLVM::GEPOp>(
869 getLoc(), getResult().getType(), byteType, newSlot.ptr,
870 ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds());
871 rewriter.replaceAllUsesWith(getResult(), newPtr);
872 return DeletionKind::Delete;
873}
874
875//===----------------------------------------------------------------------===//
876// Utilities for memory intrinsics
877//===----------------------------------------------------------------------===//
878
879namespace {
880
881/// Returns the length of the given memory intrinsic in bytes if it can be known
882/// at compile-time on a best-effort basis, nothing otherwise.
883template <class MemIntr>
884std::optional<uint64_t> getStaticMemIntrLen(MemIntr op) {
885 APInt memIntrLen;
886 if (!matchPattern(op.getLen(), m_ConstantInt(&memIntrLen)))
887 return {};
888 if (memIntrLen.getBitWidth() > 64)
889 return {};
890 return memIntrLen.getZExtValue();
891}
892
893/// Returns the length of the given memory intrinsic in bytes if it can be known
894/// at compile-time on a best-effort basis, nothing otherwise.
895/// Because MemcpyInlineOp has its length encoded as an attribute, this requires
896/// specialized handling.
897template <>
898std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
899 APInt memIntrLen = op.getLen();
900 if (memIntrLen.getBitWidth() > 64)
901 return {};
902 return memIntrLen.getZExtValue();
903}
904
905} // namespace
906
907/// Returns whether one can be sure the memory intrinsic does not write outside
908/// of the bounds of the given slot, on a best-effort basis.
909template <class MemIntr>
910static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot,
911 const DataLayout &dataLayout) {
912 if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
913 op.getDst() != slot.ptr)
914 return false;
915
916 std::optional<uint64_t> memIntrLen = getStaticMemIntrLen(op);
917 return memIntrLen && *memIntrLen <= dataLayout.getTypeSize(t: slot.elemType);
918}
919
920/// Checks whether all indices are i32. This is used to check GEPs can index
921/// into them.
922static bool areAllIndicesI32(const DestructurableMemorySlot &slot) {
923 Type i32 = IntegerType::get(slot.ptr.getContext(), 32);
924 return llvm::all_of(Range: llvm::make_first_range(c: slot.elementPtrs),
925 P: [&](Attribute index) {
926 auto intIndex = dyn_cast<IntegerAttr>(index);
927 return intIndex && intIndex.getType() == i32;
928 });
929}
930
931//===----------------------------------------------------------------------===//
932// Interfaces for memset
933//===----------------------------------------------------------------------===//
934
935bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
936
937bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
938 return getDst() == slot.ptr;
939}
940
941Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
942 Value reachingDef,
943 const DataLayout &dataLayout) {
944 // TODO: Support non-integer types.
945 return TypeSwitch<Type, Value>(slot.elemType)
946 .Case([&](IntegerType intType) -> Value {
947 if (intType.getWidth() == 8)
948 return getVal();
949
950 assert(intType.getWidth() % 8 == 0);
951
952 // Build the memset integer by repeatedly shifting the value and
953 // or-ing it with the previous value.
954 uint64_t coveredBits = 8;
955 Value currentValue =
956 rewriter.create<LLVM::ZExtOp>(getLoc(), intType, getVal());
957 while (coveredBits < intType.getWidth()) {
958 Value shiftBy =
959 rewriter.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits);
960 Value shifted =
961 rewriter.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy);
962 currentValue =
963 rewriter.create<LLVM::OrOp>(getLoc(), currentValue, shifted);
964 coveredBits *= 2;
965 }
966
967 return currentValue;
968 })
969 .Default([](Type) -> Value {
970 llvm_unreachable(
971 "getStored should not be called on memset to unsupported type");
972 });
973}
974
975bool LLVM::MemsetOp::canUsesBeRemoved(
976 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
977 SmallVectorImpl<OpOperand *> &newBlockingUses,
978 const DataLayout &dataLayout) {
979 // TODO: Support non-integer types.
980 bool canConvertType =
981 TypeSwitch<Type, bool>(slot.elemType)
982 .Case([](IntegerType intType) {
983 return intType.getWidth() % 8 == 0 && intType.getWidth() > 0;
984 })
985 .Default([](Type) { return false; });
986 if (!canConvertType)
987 return false;
988
989 if (getIsVolatile())
990 return false;
991
992 return getStaticMemIntrLen(*this) == dataLayout.getTypeSize(slot.elemType);
993}
994
995DeletionKind LLVM::MemsetOp::removeBlockingUses(
996 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
997 RewriterBase &rewriter, Value reachingDefinition,
998 const DataLayout &dataLayout) {
999 return DeletionKind::Delete;
1000}
1001
1002LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
1003 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1004 const DataLayout &dataLayout) {
1005 return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
1006}
1007
1008bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
1009 SmallPtrSetImpl<Attribute> &usedIndices,
1010 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1011 const DataLayout &dataLayout) {
1012 if (&slot.elemType.getDialect() != getOperation()->getDialect())
1013 return false;
1014
1015 if (getIsVolatile())
1016 return false;
1017
1018 if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
1019 return false;
1020
1021 if (!areAllIndicesI32(slot))
1022 return false;
1023
1024 return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout);
1025}
1026
1027DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
1028 DenseMap<Attribute, MemorySlot> &subslots,
1029 RewriterBase &rewriter,
1030 const DataLayout &dataLayout) {
1031 std::optional<DenseMap<Attribute, Type>> types =
1032 cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
1033
1034 IntegerAttr memsetLenAttr;
1035 bool successfulMatch =
1036 matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
1037 (void)successfulMatch;
1038 assert(successfulMatch);
1039
1040 bool packed = false;
1041 if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
1042 packed = structType.isPacked();
1043
1044 Type i32 = IntegerType::get(getContext(), 32);
1045 uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
1046 uint64_t covered = 0;
1047 for (size_t i = 0; i < types->size(); i++) {
1048 // Create indices on the fly to get elements in the right order.
1049 Attribute index = IntegerAttr::get(i32, i);
1050 Type elemType = types->at(index);
1051 uint64_t typeSize = dataLayout.getTypeSize(elemType);
1052
1053 if (!packed)
1054 covered =
1055 llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType));
1056
1057 if (covered >= memsetLen)
1058 break;
1059
1060 // If this subslot is used, apply a new memset to it.
1061 // Otherwise, only compute its offset within the original memset.
1062 if (subslots.contains(index)) {
1063 uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize);
1064
1065 Value newMemsetSizeValue =
1066 rewriter
1067 .create<LLVM::ConstantOp>(
1068 getLen().getLoc(),
1069 IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
1070 .getResult();
1071
1072 rewriter.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr,
1073 getVal(), newMemsetSizeValue,
1074 getIsVolatile());
1075 }
1076
1077 covered += typeSize;
1078 }
1079
1080 return DeletionKind::Delete;
1081}
1082
1083//===----------------------------------------------------------------------===//
1084// Interfaces for memcpy/memmove
1085//===----------------------------------------------------------------------===//
1086
1087template <class MemcpyLike>
1088static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot) {
1089 return op.getSrc() == slot.ptr;
1090}
1091
1092template <class MemcpyLike>
1093static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) {
1094 return op.getDst() == slot.ptr;
1095}
1096
1097template <class MemcpyLike>
1098static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot,
1099 RewriterBase &rewriter) {
1100 return rewriter.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc());
1101}
1102
1103template <class MemcpyLike>
1104static bool
1105memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot,
1106 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1107 SmallVectorImpl<OpOperand *> &newBlockingUses,
1108 const DataLayout &dataLayout) {
1109 // If source and destination are the same, memcpy behavior is undefined and
1110 // memmove is a no-op. Because there is no memory change happening here,
1111 // simplifying such operations is left to canonicalization.
1112 if (op.getDst() == op.getSrc())
1113 return false;
1114
1115 if (op.getIsVolatile())
1116 return false;
1117
1118 return getStaticMemIntrLen(op) == dataLayout.getTypeSize(t: slot.elemType);
1119}
1120
1121template <class MemcpyLike>
1122static DeletionKind
1123memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot,
1124 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1125 RewriterBase &rewriter, Value reachingDefinition) {
1126 if (op.loadsFrom(slot))
1127 rewriter.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition,
1128 op.getDst());
1129 return DeletionKind::Delete;
1130}
1131
1132template <class MemcpyLike>
1133static LogicalResult
1134memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot,
1135 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
1136 DataLayout dataLayout = DataLayout::closest(op);
1137 // While rewiring memcpy-like intrinsics only supports full copies, partial
1138 // copies are still safe accesses so it is enough to only check for writes
1139 // within bounds.
1140 return success(definitelyWritesOnlyWithinSlot(op, slot, dataLayout));
1141}
1142
1143template <class MemcpyLike>
1144static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
1145 SmallPtrSetImpl<Attribute> &usedIndices,
1146 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1147 const DataLayout &dataLayout) {
1148 if (op.getIsVolatile())
1149 return false;
1150
1151 if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
1152 return false;
1153
1154 if (!areAllIndicesI32(slot))
1155 return false;
1156
1157 // Only full copies are supported.
1158 if (getStaticMemIntrLen(op) != dataLayout.getTypeSize(t: slot.elemType))
1159 return false;
1160
1161 if (op.getSrc() == slot.ptr)
1162 for (Attribute index : llvm::make_first_range(c: slot.elementPtrs))
1163 usedIndices.insert(Ptr: index);
1164
1165 return true;
1166}
1167
1168namespace {
1169
1170template <class MemcpyLike>
1171void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
1172 MemcpyLike toReplace, Value dst, Value src,
1173 Type toCpy, bool isVolatile) {
1174 Value memcpySize = rewriter.create<LLVM::ConstantOp>(
1175 toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(),
1176 layout.getTypeSize(toCpy)));
1177 rewriter.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize,
1178 isVolatile);
1179}
1180
1181template <>
1182void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout,
1183 LLVM::MemcpyInlineOp toReplace, Value dst,
1184 Value src, Type toCpy, bool isVolatile) {
1185 Type lenType = IntegerType::get(toReplace->getContext(),
1186 toReplace.getLen().getBitWidth());
1187 rewriter.create<LLVM::MemcpyInlineOp>(
1188 toReplace.getLoc(), dst, src,
1189 IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile);
1190}
1191
1192} // namespace
1193
1194/// Rewires a memcpy-like operation. Only copies to or from the full slot are
1195/// supported.
1196template <class MemcpyLike>
1197static DeletionKind
1198memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
1199 DenseMap<Attribute, MemorySlot> &subslots, RewriterBase &rewriter,
1200 const DataLayout &dataLayout) {
1201 if (subslots.empty())
1202 return DeletionKind::Delete;
1203
1204 assert((slot.ptr == op.getDst()) != (slot.ptr == op.getSrc()));
1205 bool isDst = slot.ptr == op.getDst();
1206
1207#ifndef NDEBUG
1208 size_t slotsTreated = 0;
1209#endif
1210
1211 // It was previously checked that index types are consistent, so this type can
1212 // be fetched now.
1213 Type indexType = cast<IntegerAttr>(subslots.begin()->first).getType();
1214 for (size_t i = 0, e = slot.elementPtrs.size(); i != e; i++) {
1215 Attribute index = IntegerAttr::get(indexType, i);
1216 if (!subslots.contains(Val: index))
1217 continue;
1218 const MemorySlot &subslot = subslots.at(Val: index);
1219
1220#ifndef NDEBUG
1221 slotsTreated++;
1222#endif
1223
1224 // First get a pointer to the equivalent of this subslot from the source
1225 // pointer.
1226 SmallVector<LLVM::GEPArg> gepIndices{
1227 0, static_cast<int32_t>(
1228 cast<IntegerAttr>(index).getValue().getZExtValue())};
1229 Value subslotPtrInOther = rewriter.create<LLVM::GEPOp>(
1230 op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), slot.elemType,
1231 isDst ? op.getSrc() : op.getDst(), gepIndices);
1232
1233 // Then create a new memcpy out of this source pointer.
1234 createMemcpyLikeToReplace(rewriter, dataLayout, op,
1235 isDst ? subslot.ptr : subslotPtrInOther,
1236 isDst ? subslotPtrInOther : subslot.ptr,
1237 subslot.elemType, op.getIsVolatile());
1238 }
1239
1240 assert(subslots.size() == slotsTreated);
1241
1242 return DeletionKind::Delete;
1243}
1244
1245bool LLVM::MemcpyOp::loadsFrom(const MemorySlot &slot) {
1246 return memcpyLoadsFrom(*this, slot);
1247}
1248
1249bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
1250 return memcpyStoresTo(*this, slot);
1251}
1252
1253Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
1254 Value reachingDef,
1255 const DataLayout &dataLayout) {
1256 return memcpyGetStored(*this, slot, rewriter);
1257}
1258
1259bool LLVM::MemcpyOp::canUsesBeRemoved(
1260 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1261 SmallVectorImpl<OpOperand *> &newBlockingUses,
1262 const DataLayout &dataLayout) {
1263 return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1264 dataLayout);
1265}
1266
1267DeletionKind LLVM::MemcpyOp::removeBlockingUses(
1268 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1269 RewriterBase &rewriter, Value reachingDefinition,
1270 const DataLayout &dataLayout) {
1271 return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
1272 reachingDefinition);
1273}
1274
1275LogicalResult LLVM::MemcpyOp::ensureOnlySafeAccesses(
1276 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1277 const DataLayout &dataLayout) {
1278 return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1279}
1280
1281bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot,
1282 SmallPtrSetImpl<Attribute> &usedIndices,
1283 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1284 const DataLayout &dataLayout) {
1285 return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1286 dataLayout);
1287}
1288
1289DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot,
1290 DenseMap<Attribute, MemorySlot> &subslots,
1291 RewriterBase &rewriter,
1292 const DataLayout &dataLayout) {
1293 return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
1294}
1295
1296bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) {
1297 return memcpyLoadsFrom(*this, slot);
1298}
1299
1300bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
1301 return memcpyStoresTo(*this, slot);
1302}
1303
1304Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
1305 RewriterBase &rewriter, Value reachingDef,
1306 const DataLayout &dataLayout) {
1307 return memcpyGetStored(*this, slot, rewriter);
1308}
1309
1310bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
1311 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1312 SmallVectorImpl<OpOperand *> &newBlockingUses,
1313 const DataLayout &dataLayout) {
1314 return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1315 dataLayout);
1316}
1317
1318DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
1319 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1320 RewriterBase &rewriter, Value reachingDefinition,
1321 const DataLayout &dataLayout) {
1322 return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
1323 reachingDefinition);
1324}
1325
1326LogicalResult LLVM::MemcpyInlineOp::ensureOnlySafeAccesses(
1327 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1328 const DataLayout &dataLayout) {
1329 return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1330}
1331
1332bool LLVM::MemcpyInlineOp::canRewire(
1333 const DestructurableMemorySlot &slot,
1334 SmallPtrSetImpl<Attribute> &usedIndices,
1335 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1336 const DataLayout &dataLayout) {
1337 return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1338 dataLayout);
1339}
1340
1341DeletionKind
1342LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot,
1343 DenseMap<Attribute, MemorySlot> &subslots,
1344 RewriterBase &rewriter,
1345 const DataLayout &dataLayout) {
1346 return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
1347}
1348
1349bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) {
1350 return memcpyLoadsFrom(*this, slot);
1351}
1352
1353bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
1354 return memcpyStoresTo(*this, slot);
1355}
1356
1357Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
1358 Value reachingDef,
1359 const DataLayout &dataLayout) {
1360 return memcpyGetStored(*this, slot, rewriter);
1361}
1362
1363bool LLVM::MemmoveOp::canUsesBeRemoved(
1364 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1365 SmallVectorImpl<OpOperand *> &newBlockingUses,
1366 const DataLayout &dataLayout) {
1367 return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1368 dataLayout);
1369}
1370
1371DeletionKind LLVM::MemmoveOp::removeBlockingUses(
1372 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1373 RewriterBase &rewriter, Value reachingDefinition,
1374 const DataLayout &dataLayout) {
1375 return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
1376 reachingDefinition);
1377}
1378
1379LogicalResult LLVM::MemmoveOp::ensureOnlySafeAccesses(
1380 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1381 const DataLayout &dataLayout) {
1382 return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1383}
1384
1385bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot,
1386 SmallPtrSetImpl<Attribute> &usedIndices,
1387 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1388 const DataLayout &dataLayout) {
1389 return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1390 dataLayout);
1391}
1392
1393DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot,
1394 DenseMap<Attribute, MemorySlot> &subslots,
1395 RewriterBase &rewriter,
1396 const DataLayout &dataLayout) {
1397 return memcpyRewire(*this, slot, subslots, rewriter, dataLayout);
1398}
1399
1400//===----------------------------------------------------------------------===//
1401// Interfaces for destructurable types
1402//===----------------------------------------------------------------------===//
1403
1404std::optional<DenseMap<Attribute, Type>>
1405LLVM::LLVMStructType::getSubelementIndexMap() {
1406 Type i32 = IntegerType::get(getContext(), 32);
1407 DenseMap<Attribute, Type> destructured;
1408 for (const auto &[index, elemType] : llvm::enumerate(getBody()))
1409 destructured.insert({IntegerAttr::get(i32, index), elemType});
1410 return destructured;
1411}
1412
1413Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) {
1414 auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
1415 if (!indexAttr || !indexAttr.getType().isInteger(32))
1416 return {};
1417 int32_t indexInt = indexAttr.getInt();
1418 ArrayRef<Type> body = getBody();
1419 if (indexInt < 0 || body.size() <= static_cast<uint32_t>(indexInt))
1420 return {};
1421 return body[indexInt];
1422}
1423
1424std::optional<DenseMap<Attribute, Type>>
1425LLVM::LLVMArrayType::getSubelementIndexMap() const {
1426 constexpr size_t maxArraySizeForDestructuring = 16;
1427 if (getNumElements() > maxArraySizeForDestructuring)
1428 return {};
1429 int32_t numElements = getNumElements();
1430
1431 Type i32 = IntegerType::get(getContext(), 32);
1432 DenseMap<Attribute, Type> destructured;
1433 for (int32_t index = 0; index < numElements; ++index)
1434 destructured.insert({IntegerAttr::get(i32, index), getElementType()});
1435 return destructured;
1436}
1437
1438Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
1439 auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
1440 if (!indexAttr || !indexAttr.getType().isInteger(32))
1441 return {};
1442 int32_t indexInt = indexAttr.getInt();
1443 if (indexInt < 0 || getNumElements() <= static_cast<uint32_t>(indexInt))
1444 return {};
1445 return getElementType();
1446}
1447

source code of mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp