1//===- LLVMMemorySlot.cpp - MemorySlot interfaces ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements MemorySlot-related interfaces for LLVM dialect
10// operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16#include "mlir/IR/Matchers.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/Interfaces/DataLayoutInterfaces.h"
19#include "mlir/Interfaces/MemorySlotInterfaces.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23#define DEBUG_TYPE "sroa"
24
25using namespace mlir;
26
27//===----------------------------------------------------------------------===//
28// Interfaces for AllocaOp
29//===----------------------------------------------------------------------===//
30
31llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
32 if (!getOperation()->getBlock()->isEntryBlock())
33 return {};
34
35 return {MemorySlot{getResult(), getElemType()}};
36}
37
38Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
39 OpBuilder &builder) {
40 return builder.create<LLVM::UndefOp>(getLoc(), slot.elemType);
41}
42
43void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
44 BlockArgument argument,
45 OpBuilder &builder) {
46 for (Operation *user : getOperation()->getUsers())
47 if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
48 builder.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
49 declareOp.getVarInfo(),
50 declareOp.getLocationExpr());
51}
52
53std::optional<PromotableAllocationOpInterface>
54LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
55 Value defaultValue,
56 OpBuilder &builder) {
57 if (defaultValue && defaultValue.use_empty())
58 defaultValue.getDefiningOp()->erase();
59 this->erase();
60 return std::nullopt;
61}
62
63SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
64 if (!mlir::matchPattern(getArraySize(), m_One()))
65 return {};
66
67 auto destructurable = dyn_cast<DestructurableTypeInterface>(getElemType());
68 if (!destructurable)
69 return {};
70
71 std::optional<DenseMap<Attribute, Type>> destructuredType =
72 destructurable.getSubelementIndexMap();
73 if (!destructuredType)
74 return {};
75
76 return {DestructurableMemorySlot{{getResult(), getElemType()},
77 *destructuredType}};
78}
79
80DenseMap<Attribute, MemorySlot> LLVM::AllocaOp::destructure(
81 const DestructurableMemorySlot &slot,
82 const SmallPtrSetImpl<Attribute> &usedIndices, OpBuilder &builder,
83 SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators) {
84 assert(slot.ptr == getResult());
85 builder.setInsertionPointAfter(*this);
86
87 auto destructurableType = cast<DestructurableTypeInterface>(getElemType());
88 DenseMap<Attribute, MemorySlot> slotMap;
89 for (Attribute index : usedIndices) {
90 Type elemType = destructurableType.getTypeAtIndex(index);
91 assert(elemType && "used index must exist");
92 auto subAlloca = builder.create<LLVM::AllocaOp>(
93 getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
94 getArraySize());
95 newAllocators.push_back(subAlloca);
96 slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
97 }
98
99 return slotMap;
100}
101
102std::optional<DestructurableAllocationOpInterface>
103LLVM::AllocaOp::handleDestructuringComplete(
104 const DestructurableMemorySlot &slot, OpBuilder &builder) {
105 assert(slot.ptr == getResult());
106 this->erase();
107 return std::nullopt;
108}
109
110//===----------------------------------------------------------------------===//
111// Interfaces for LoadOp/StoreOp
112//===----------------------------------------------------------------------===//
113
114bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
115 return getAddr() == slot.ptr;
116}
117
118bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
119
120Value LLVM::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
121 Value reachingDef, const DataLayout &dataLayout) {
122 llvm_unreachable("getStored should not be called on LoadOp");
123}
124
125bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
126
127bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
128 return getAddr() == slot.ptr;
129}
130
131/// Checks if `type` can be used in any kind of conversion sequences.
132static bool isSupportedTypeForConversion(Type type) {
133 // Aggregate types are not bitcastable.
134 if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
135 return false;
136
137 if (auto vectorType = dyn_cast<VectorType>(type)) {
138 // Vectors of pointers cannot be casted.
139 if (isa<LLVM::LLVMPointerType>(vectorType.getElementType()))
140 return false;
141 // Scalable types are not supported.
142 return !vectorType.isScalable();
143 }
144 return true;
145}
146
147/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
148/// truncations. Checks for narrowing or widening conversion compatibility
149/// depending on `narrowingConversion`.
150static bool areConversionCompatible(const DataLayout &layout, Type targetType,
151 Type srcType, bool narrowingConversion) {
152 if (targetType == srcType)
153 return true;
154
155 if (!isSupportedTypeForConversion(type: targetType) ||
156 !isSupportedTypeForConversion(type: srcType))
157 return false;
158
159 uint64_t targetSize = layout.getTypeSize(t: targetType);
160 uint64_t srcSize = layout.getTypeSize(t: srcType);
161
162 // Pointer casts will only be sane when the bitsize of both pointer types is
163 // the same.
164 if (isa<LLVM::LLVMPointerType>(targetType) &&
165 isa<LLVM::LLVMPointerType>(srcType))
166 return targetSize == srcSize;
167
168 if (narrowingConversion)
169 return targetSize <= srcSize;
170 return targetSize >= srcSize;
171}
172
173/// Checks if `dataLayout` describes a little endian layout.
174static bool isBigEndian(const DataLayout &dataLayout) {
175 auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
176 return endiannessStr && endiannessStr == "big";
177}
178
179/// Converts a value to an integer type of the same size.
180/// Assumes that the type can be converted.
181static Value castToSameSizedInt(OpBuilder &builder, Location loc, Value val,
182 const DataLayout &dataLayout) {
183 Type type = val.getType();
184 assert(isSupportedTypeForConversion(type) &&
185 "expected value to have a convertible type");
186
187 if (isa<IntegerType>(Val: type))
188 return val;
189
190 uint64_t typeBitSize = dataLayout.getTypeSizeInBits(t: type);
191 IntegerType valueSizeInteger = builder.getIntegerType(typeBitSize);
192
193 if (isa<LLVM::LLVMPointerType>(type))
194 return builder.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
195 return builder.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
196}
197
198/// Converts a value with an integer type to `targetType`.
199static Value castIntValueToSameSizedType(OpBuilder &builder, Location loc,
200 Value val, Type targetType) {
201 assert(isa<IntegerType>(val.getType()) &&
202 "expected value to have an integer type");
203 assert(isSupportedTypeForConversion(targetType) &&
204 "expected the target type to be supported for conversions");
205 if (val.getType() == targetType)
206 return val;
207 if (isa<LLVM::LLVMPointerType>(targetType))
208 return builder.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
209 return builder.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
210}
211
212/// Constructs operations that convert `srcValue` into a new value of type
213/// `targetType`. Assumes the types have the same bitsize.
214static Value castSameSizedTypes(OpBuilder &builder, Location loc,
215 Value srcValue, Type targetType,
216 const DataLayout &dataLayout) {
217 Type srcType = srcValue.getType();
218 assert(areConversionCompatible(dataLayout, targetType, srcType,
219 /*narrowingConversion=*/true) &&
220 "expected that the compatibility was checked before");
221
222 // Nothing has to be done if the types are already the same.
223 if (srcType == targetType)
224 return srcValue;
225
226 // In the special case of casting one pointer to another, we want to generate
227 // an address space cast. Bitcasts of pointers are not allowed and using
228 // pointer to integer conversions are not equivalent due to the loss of
229 // provenance.
230 if (isa<LLVM::LLVMPointerType>(targetType) &&
231 isa<LLVM::LLVMPointerType>(srcType))
232 return builder.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
233 srcValue);
234
235 // For all other castable types, casting through integers is necessary.
236 Value replacement = castToSameSizedInt(builder, loc, val: srcValue, dataLayout);
237 return castIntValueToSameSizedType(builder, loc, val: replacement, targetType);
238}
239
240/// Constructs operations that convert `srcValue` into a new value of type
241/// `targetType`. Performs bit-level extraction if the source type is larger
242/// than the target type. Assumes that this conversion is possible.
243static Value createExtractAndCast(OpBuilder &builder, Location loc,
244 Value srcValue, Type targetType,
245 const DataLayout &dataLayout) {
246 // Get the types of the source and target values.
247 Type srcType = srcValue.getType();
248 assert(areConversionCompatible(dataLayout, targetType, srcType,
249 /*narrowingConversion=*/true) &&
250 "expected that the compatibility was checked before");
251
252 uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(t: srcType);
253 uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(t: targetType);
254 if (srcTypeSize == targetTypeSize)
255 return castSameSizedTypes(builder, loc, srcValue, targetType, dataLayout);
256
257 // First, cast the value to a same-sized integer type.
258 Value replacement = castToSameSizedInt(builder, loc, val: srcValue, dataLayout);
259
260 // Truncate the integer if the size of the target is less than the value.
261 if (isBigEndian(dataLayout)) {
262 uint64_t shiftAmount = srcTypeSize - targetTypeSize;
263 auto shiftConstant = builder.create<LLVM::ConstantOp>(
264 loc, builder.getIntegerAttr(srcType, shiftAmount));
265 replacement =
266 builder.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
267 }
268
269 replacement = builder.create<LLVM::TruncOp>(
270 loc, builder.getIntegerType(targetTypeSize), replacement);
271
272 // Now cast the integer to the actual target type if required.
273 return castIntValueToSameSizedType(builder, loc, val: replacement, targetType);
274}
275
276/// Constructs operations that insert the bits of `srcValue` into the
277/// "beginning" of `reachingDef` (beginning is endianness dependent).
278/// Assumes that this conversion is possible.
279static Value createInsertAndCast(OpBuilder &builder, Location loc,
280 Value srcValue, Value reachingDef,
281 const DataLayout &dataLayout) {
282
283 assert(areConversionCompatible(dataLayout, reachingDef.getType(),
284 srcValue.getType(),
285 /*narrowingConversion=*/false) &&
286 "expected that the compatibility was checked before");
287 uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(t: srcValue.getType());
288 uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(t: reachingDef.getType());
289 if (slotTypeSize == valueTypeSize)
290 return castSameSizedTypes(builder, loc, srcValue, targetType: reachingDef.getType(),
291 dataLayout);
292
293 // In the case where the store only overwrites parts of the memory,
294 // bit fiddling is required to construct the new value.
295
296 // First convert both values to integers of the same size.
297 Value defAsInt = castToSameSizedInt(builder, loc, val: reachingDef, dataLayout);
298 Value valueAsInt = castToSameSizedInt(builder, loc, val: srcValue, dataLayout);
299 // Extend the value to the size of the reaching definition.
300 valueAsInt =
301 builder.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
302 uint64_t sizeDifference = slotTypeSize - valueTypeSize;
303 if (isBigEndian(dataLayout)) {
304 // On big endian systems, a store to the base pointer overwrites the most
305 // significant bits. To accomodate for this, the stored value needs to be
306 // shifted into the according position.
307 Value bigEndianShift = builder.create<LLVM::ConstantOp>(
308 loc, builder.getIntegerAttr(defAsInt.getType(), sizeDifference));
309 valueAsInt =
310 builder.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
311 }
312
313 // Construct the mask that is used to erase the bits that are overwritten by
314 // the store.
315 APInt maskValue;
316 if (isBigEndian(dataLayout)) {
317 // Build a mask that has the most significant bits set to zero.
318 // Note: This is the same as 2^sizeDifference - 1
319 maskValue = APInt::getAllOnes(numBits: sizeDifference).zext(width: slotTypeSize);
320 } else {
321 // Build a mask that has the least significant bits set to zero.
322 // Note: This is the same as -(2^valueTypeSize)
323 maskValue = APInt::getAllOnes(numBits: valueTypeSize).zext(width: slotTypeSize);
324 maskValue.flipAllBits();
325 }
326
327 // Mask out the affected bits ...
328 Value mask = builder.create<LLVM::ConstantOp>(
329 loc, builder.getIntegerAttr(defAsInt.getType(), maskValue));
330 Value masked = builder.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
331
332 // ... and combine the result with the new value.
333 Value combined = builder.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
334
335 return castIntValueToSameSizedType(builder, loc, val: combined,
336 targetType: reachingDef.getType());
337}
338
339Value LLVM::StoreOp::getStored(const MemorySlot &slot, OpBuilder &builder,
340 Value reachingDef,
341 const DataLayout &dataLayout) {
342 assert(reachingDef && reachingDef.getType() == slot.elemType &&
343 "expected the reaching definition's type to match the slot's type");
344 return createInsertAndCast(builder, getLoc(), getValue(), reachingDef,
345 dataLayout);
346}
347
348bool LLVM::LoadOp::canUsesBeRemoved(
349 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
350 SmallVectorImpl<OpOperand *> &newBlockingUses,
351 const DataLayout &dataLayout) {
352 if (blockingUses.size() != 1)
353 return false;
354 Value blockingUse = (*blockingUses.begin())->get();
355 // If the blocking use is the slot ptr itself, there will be enough
356 // context to reconstruct the result of the load at removal time, so it can
357 // be removed (provided it is not volatile).
358 return blockingUse == slot.ptr && getAddr() == slot.ptr &&
359 areConversionCompatible(dataLayout, getResult().getType(),
360 slot.elemType, /*narrowingConversion=*/true) &&
361 !getVolatile_();
362}
363
364DeletionKind LLVM::LoadOp::removeBlockingUses(
365 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
366 OpBuilder &builder, Value reachingDefinition,
367 const DataLayout &dataLayout) {
368 // `canUsesBeRemoved` checked this blocking use must be the loaded slot
369 // pointer.
370 Value newResult = createExtractAndCast(builder, getLoc(), reachingDefinition,
371 getResult().getType(), dataLayout);
372 getResult().replaceAllUsesWith(newResult);
373 return DeletionKind::Delete;
374}
375
376bool LLVM::StoreOp::canUsesBeRemoved(
377 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
378 SmallVectorImpl<OpOperand *> &newBlockingUses,
379 const DataLayout &dataLayout) {
380 if (blockingUses.size() != 1)
381 return false;
382 Value blockingUse = (*blockingUses.begin())->get();
383 // If the blocking use is the slot ptr itself, dropping the store is
384 // fine, provided we are currently promoting its target value. Don't allow a
385 // store OF the slot pointer, only INTO the slot pointer.
386 return blockingUse == slot.ptr && getAddr() == slot.ptr &&
387 getValue() != slot.ptr &&
388 areConversionCompatible(dataLayout, slot.elemType,
389 getValue().getType(),
390 /*narrowingConversion=*/false) &&
391 !getVolatile_();
392}
393
394DeletionKind LLVM::StoreOp::removeBlockingUses(
395 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
396 OpBuilder &builder, Value reachingDefinition,
397 const DataLayout &dataLayout) {
398 return DeletionKind::Delete;
399}
400
401/// Checks if `slot` can be accessed through the provided access type.
402static bool isValidAccessType(const MemorySlot &slot, Type accessType,
403 const DataLayout &dataLayout) {
404 return dataLayout.getTypeSize(t: accessType) <=
405 dataLayout.getTypeSize(t: slot.elemType);
406}
407
408LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
409 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
410 const DataLayout &dataLayout) {
411 return success(getAddr() != slot.ptr ||
412 isValidAccessType(slot, getType(), dataLayout));
413}
414
415LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
416 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
417 const DataLayout &dataLayout) {
418 return success(getAddr() != slot.ptr ||
419 isValidAccessType(slot, getValue().getType(), dataLayout));
420}
421
422/// Returns the subslot's type at the requested index.
423static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
424 Attribute index) {
425 auto subelementIndexMap =
426 cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
427 if (!subelementIndexMap)
428 return {};
429 assert(!subelementIndexMap->empty());
430
431 // Note: Returns a null-type when no entry was found.
432 return subelementIndexMap->lookup(index);
433}
434
435bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
436 SmallPtrSetImpl<Attribute> &usedIndices,
437 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
438 const DataLayout &dataLayout) {
439 if (getVolatile_())
440 return false;
441
442 // A load always accesses the first element of the destructured slot.
443 auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
444 Type subslotType = getTypeAtIndex(slot, index);
445 if (!subslotType)
446 return false;
447
448 // The access can only be replaced when the subslot is read within its bounds.
449 if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
450 return false;
451
452 usedIndices.insert(index);
453 return true;
454}
455
456DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
457 DenseMap<Attribute, MemorySlot> &subslots,
458 OpBuilder &builder,
459 const DataLayout &dataLayout) {
460 auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
461 auto it = subslots.find(index);
462 assert(it != subslots.end());
463
464 getAddrMutable().set(it->getSecond().ptr);
465 return DeletionKind::Keep;
466}
467
468bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
469 SmallPtrSetImpl<Attribute> &usedIndices,
470 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
471 const DataLayout &dataLayout) {
472 if (getVolatile_())
473 return false;
474
475 // Storing the pointer to memory cannot be dealt with.
476 if (getValue() == slot.ptr)
477 return false;
478
479 // A store always accesses the first element of the destructured slot.
480 auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
481 Type subslotType = getTypeAtIndex(slot, index);
482 if (!subslotType)
483 return false;
484
485 // The access can only be replaced when the subslot is read within its bounds.
486 if (dataLayout.getTypeSize(getValue().getType()) >
487 dataLayout.getTypeSize(subslotType))
488 return false;
489
490 usedIndices.insert(index);
491 return true;
492}
493
494DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
495 DenseMap<Attribute, MemorySlot> &subslots,
496 OpBuilder &builder,
497 const DataLayout &dataLayout) {
498 auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
499 auto it = subslots.find(index);
500 assert(it != subslots.end());
501
502 getAddrMutable().set(it->getSecond().ptr);
503 return DeletionKind::Keep;
504}
505
506//===----------------------------------------------------------------------===//
507// Interfaces for discardable OPs
508//===----------------------------------------------------------------------===//
509
510/// Conditions the deletion of the operation to the removal of all its uses.
511static bool forwardToUsers(Operation *op,
512 SmallVectorImpl<OpOperand *> &newBlockingUses) {
513 for (Value result : op->getResults())
514 for (OpOperand &use : result.getUses())
515 newBlockingUses.push_back(Elt: &use);
516 return true;
517}
518
519bool LLVM::BitcastOp::canUsesBeRemoved(
520 const SmallPtrSetImpl<OpOperand *> &blockingUses,
521 SmallVectorImpl<OpOperand *> &newBlockingUses,
522 const DataLayout &dataLayout) {
523 return forwardToUsers(*this, newBlockingUses);
524}
525
526DeletionKind LLVM::BitcastOp::removeBlockingUses(
527 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
528 return DeletionKind::Delete;
529}
530
531bool LLVM::AddrSpaceCastOp::canUsesBeRemoved(
532 const SmallPtrSetImpl<OpOperand *> &blockingUses,
533 SmallVectorImpl<OpOperand *> &newBlockingUses,
534 const DataLayout &dataLayout) {
535 return forwardToUsers(*this, newBlockingUses);
536}
537
538DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses(
539 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
540 return DeletionKind::Delete;
541}
542
543bool LLVM::LifetimeStartOp::canUsesBeRemoved(
544 const SmallPtrSetImpl<OpOperand *> &blockingUses,
545 SmallVectorImpl<OpOperand *> &newBlockingUses,
546 const DataLayout &dataLayout) {
547 return true;
548}
549
550DeletionKind LLVM::LifetimeStartOp::removeBlockingUses(
551 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
552 return DeletionKind::Delete;
553}
554
555bool LLVM::LifetimeEndOp::canUsesBeRemoved(
556 const SmallPtrSetImpl<OpOperand *> &blockingUses,
557 SmallVectorImpl<OpOperand *> &newBlockingUses,
558 const DataLayout &dataLayout) {
559 return true;
560}
561
562DeletionKind LLVM::LifetimeEndOp::removeBlockingUses(
563 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
564 return DeletionKind::Delete;
565}
566
567bool LLVM::InvariantStartOp::canUsesBeRemoved(
568 const SmallPtrSetImpl<OpOperand *> &blockingUses,
569 SmallVectorImpl<OpOperand *> &newBlockingUses,
570 const DataLayout &dataLayout) {
571 return true;
572}
573
574DeletionKind LLVM::InvariantStartOp::removeBlockingUses(
575 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
576 return DeletionKind::Delete;
577}
578
579bool LLVM::InvariantEndOp::canUsesBeRemoved(
580 const SmallPtrSetImpl<OpOperand *> &blockingUses,
581 SmallVectorImpl<OpOperand *> &newBlockingUses,
582 const DataLayout &dataLayout) {
583 return true;
584}
585
586DeletionKind LLVM::InvariantEndOp::removeBlockingUses(
587 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
588 return DeletionKind::Delete;
589}
590
591bool LLVM::LaunderInvariantGroupOp::canUsesBeRemoved(
592 const SmallPtrSetImpl<OpOperand *> &blockingUses,
593 SmallVectorImpl<OpOperand *> &newBlockingUses,
594 const DataLayout &dataLayout) {
595 return forwardToUsers(*this, newBlockingUses);
596}
597
598DeletionKind LLVM::LaunderInvariantGroupOp::removeBlockingUses(
599 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
600 return DeletionKind::Delete;
601}
602
603bool LLVM::StripInvariantGroupOp::canUsesBeRemoved(
604 const SmallPtrSetImpl<OpOperand *> &blockingUses,
605 SmallVectorImpl<OpOperand *> &newBlockingUses,
606 const DataLayout &dataLayout) {
607 return forwardToUsers(*this, newBlockingUses);
608}
609
610DeletionKind LLVM::StripInvariantGroupOp::removeBlockingUses(
611 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
612 return DeletionKind::Delete;
613}
614
615bool LLVM::DbgDeclareOp::canUsesBeRemoved(
616 const SmallPtrSetImpl<OpOperand *> &blockingUses,
617 SmallVectorImpl<OpOperand *> &newBlockingUses,
618 const DataLayout &dataLayout) {
619 return true;
620}
621
622DeletionKind LLVM::DbgDeclareOp::removeBlockingUses(
623 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
624 return DeletionKind::Delete;
625}
626
627bool LLVM::DbgValueOp::canUsesBeRemoved(
628 const SmallPtrSetImpl<OpOperand *> &blockingUses,
629 SmallVectorImpl<OpOperand *> &newBlockingUses,
630 const DataLayout &dataLayout) {
631 // There is only one operand that we can remove the use of.
632 if (blockingUses.size() != 1)
633 return false;
634
635 return (*blockingUses.begin())->get() == getValue();
636}
637
638DeletionKind LLVM::DbgValueOp::removeBlockingUses(
639 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
640 // builder by default is after '*this', but we need it before '*this'.
641 builder.setInsertionPoint(*this);
642
643 // Rather than dropping the debug value, replace it with undef to preserve the
644 // debug local variable info. This allows the debugger to inform the user that
645 // the variable has been optimized out.
646 auto undef =
647 builder.create<UndefOp>(getValue().getLoc(), getValue().getType());
648 getValueMutable().assign(undef);
649 return DeletionKind::Keep;
650}
651
652bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; }
653
654void LLVM::DbgDeclareOp::visitReplacedValues(
655 ArrayRef<std::pair<Operation *, Value>> definitions, OpBuilder &builder) {
656 for (auto [op, value] : definitions) {
657 builder.setInsertionPointAfter(op);
658 builder.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
659 getLocationExpr());
660 }
661}
662
663//===----------------------------------------------------------------------===//
664// Interfaces for GEPOp
665//===----------------------------------------------------------------------===//
666
667static bool hasAllZeroIndices(LLVM::GEPOp gepOp) {
668 return llvm::all_of(gepOp.getIndices(), [](auto index) {
669 auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index);
670 return indexAttr && indexAttr.getValue() == 0;
671 });
672}
673
674bool LLVM::GEPOp::canUsesBeRemoved(
675 const SmallPtrSetImpl<OpOperand *> &blockingUses,
676 SmallVectorImpl<OpOperand *> &newBlockingUses,
677 const DataLayout &dataLayout) {
678 // GEP can be removed as long as it is a no-op and its users can be removed.
679 if (!hasAllZeroIndices(*this))
680 return false;
681 return forwardToUsers(*this, newBlockingUses);
682}
683
684DeletionKind LLVM::GEPOp::removeBlockingUses(
685 const SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder) {
686 return DeletionKind::Delete;
687}
688
689/// Returns the amount of bytes the provided GEP elements will offset the
690/// pointer by. Returns nullopt if no constant offset could be computed.
691static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
692 LLVM::GEPOp gep) {
693 // Collects all indices.
694 SmallVector<uint64_t> indices;
695 for (auto index : gep.getIndices()) {
696 auto constIndex = dyn_cast<IntegerAttr>(index);
697 if (!constIndex)
698 return {};
699 int64_t gepIndex = constIndex.getInt();
700 // Negative indices are not supported.
701 if (gepIndex < 0)
702 return {};
703 indices.push_back(gepIndex);
704 }
705
706 Type currentType = gep.getElemType();
707 uint64_t offset = indices[0] * dataLayout.getTypeSize(t: currentType);
708
709 for (uint64_t index : llvm::drop_begin(RangeOrContainer&: indices)) {
710 bool shouldCancel =
711 TypeSwitch<Type, bool>(currentType)
712 .Case(caseFn: [&](LLVM::LLVMArrayType arrayType) {
713 offset +=
714 index * dataLayout.getTypeSize(t: arrayType.getElementType());
715 currentType = arrayType.getElementType();
716 return false;
717 })
718 .Case(caseFn: [&](LLVM::LLVMStructType structType) {
719 ArrayRef<Type> body = structType.getBody();
720 assert(index < body.size() && "expected valid struct indexing");
721 for (uint32_t i : llvm::seq(Size: index)) {
722 if (!structType.isPacked())
723 offset = llvm::alignTo(
724 Value: offset, Align: dataLayout.getTypeABIAlignment(t: body[i]));
725 offset += dataLayout.getTypeSize(t: body[i]);
726 }
727
728 // Align for the current type as well.
729 if (!structType.isPacked())
730 offset = llvm::alignTo(
731 Value: offset, Align: dataLayout.getTypeABIAlignment(t: body[index]));
732 currentType = body[index];
733 return false;
734 })
735 .Default(defaultFn: [&](Type type) {
736 LLVM_DEBUG(llvm::dbgs()
737 << "[sroa] Unsupported type for offset computations"
738 << type << "\n");
739 return true;
740 });
741
742 if (shouldCancel)
743 return std::nullopt;
744 }
745
746 return offset;
747}
748
749namespace {
750/// A struct that stores both the index into the aggregate type of the slot as
751/// well as the corresponding byte offset in memory.
752struct SubslotAccessInfo {
753 /// The parent slot's index that the access falls into.
754 uint32_t index;
755 /// The offset into the subslot of the access.
756 uint64_t subslotOffset;
757};
758} // namespace
759
760/// Computes subslot access information for an access into `slot` with the given
761/// offset.
762/// Returns nullopt when the offset is out-of-bounds or when the access is into
763/// the padding of `slot`.
764static std::optional<SubslotAccessInfo>
765getSubslotAccessInfo(const DestructurableMemorySlot &slot,
766 const DataLayout &dataLayout, LLVM::GEPOp gep) {
767 std::optional<uint64_t> offset = gepToByteOffset(dataLayout, gep);
768 if (!offset)
769 return {};
770
771 // Helper to check that a constant index is in the bounds of the GEP index
772 // representation. LLVM dialects's GEP arguments have a limited bitwidth, thus
773 // this additional check is necessary.
774 auto isOutOfBoundsGEPIndex = [](uint64_t index) {
775 return index >= (1 << LLVM::kGEPConstantBitWidth);
776 };
777
778 Type type = slot.elemType;
779 if (*offset >= dataLayout.getTypeSize(t: type))
780 return {};
781 return TypeSwitch<Type, std::optional<SubslotAccessInfo>>(type)
782 .Case(caseFn: [&](LLVM::LLVMArrayType arrayType)
783 -> std::optional<SubslotAccessInfo> {
784 // Find which element of the array contains the offset.
785 uint64_t elemSize = dataLayout.getTypeSize(t: arrayType.getElementType());
786 uint64_t index = *offset / elemSize;
787 if (isOutOfBoundsGEPIndex(index))
788 return {};
789 return SubslotAccessInfo{.index: static_cast<uint32_t>(index),
790 .subslotOffset: *offset - (index * elemSize)};
791 })
792 .Case(caseFn: [&](LLVM::LLVMStructType structType)
793 -> std::optional<SubslotAccessInfo> {
794 uint64_t distanceToStart = 0;
795 // Walk over the elements of the struct to find in which of
796 // them the offset is.
797 for (auto [index, elem] : llvm::enumerate(structType.getBody())) {
798 uint64_t elemSize = dataLayout.getTypeSize(elem);
799 if (!structType.isPacked()) {
800 distanceToStart = llvm::alignTo(
801 distanceToStart, dataLayout.getTypeABIAlignment(elem));
802 // If the offset is in padding, cancel the rewrite.
803 if (offset < distanceToStart)
804 return {};
805 }
806
807 if (offset < distanceToStart + elemSize) {
808 if (isOutOfBoundsGEPIndex(index))
809 return {};
810 // The offset is within this element, stop iterating the
811 // struct and return the index.
812 return SubslotAccessInfo{static_cast<uint32_t>(index),
813 *offset - distanceToStart};
814 }
815
816 // The offset is not within this element, continue walking
817 // over the struct.
818 distanceToStart += elemSize;
819 }
820
821 return {};
822 });
823}
824
825/// Constructs a byte array type of the given size.
826static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context,
827 unsigned size) {
828 auto byteType = IntegerType::get(context, 8);
829 return LLVM::LLVMArrayType::get(context, byteType, size);
830}
831
832LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
833 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
834 const DataLayout &dataLayout) {
835 if (getBase() != slot.ptr)
836 return success();
837 std::optional<uint64_t> gepOffset = gepToByteOffset(dataLayout, *this);
838 if (!gepOffset)
839 return failure();
840 uint64_t slotSize = dataLayout.getTypeSize(slot.elemType);
841 // Check that the access is strictly inside the slot.
842 if (*gepOffset >= slotSize)
843 return failure();
844 // Every access that remains in bounds of the remaining slot is considered
845 // legal.
846 mustBeSafelyUsed.emplace_back<MemorySlot>(
847 {getRes(), getByteArrayType(getContext(), slotSize - *gepOffset)});
848 return success();
849}
850
851bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
852 SmallPtrSetImpl<Attribute> &usedIndices,
853 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
854 const DataLayout &dataLayout) {
855 if (!isa<LLVM::LLVMPointerType>(getBase().getType()))
856 return false;
857
858 if (getBase() != slot.ptr)
859 return false;
860 std::optional<SubslotAccessInfo> accessInfo =
861 getSubslotAccessInfo(slot, dataLayout, *this);
862 if (!accessInfo)
863 return false;
864 auto indexAttr =
865 IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
866 assert(slot.subelementTypes.contains(indexAttr));
867 usedIndices.insert(indexAttr);
868
869 // The remainder of the subslot should be accesses in-bounds. Thus, we create
870 // a dummy slot with the size of the remainder.
871 Type subslotType = slot.subelementTypes.lookup(indexAttr);
872 uint64_t slotSize = dataLayout.getTypeSize(subslotType);
873 LLVM::LLVMArrayType remainingSlotType =
874 getByteArrayType(getContext(), slotSize - accessInfo->subslotOffset);
875 mustBeSafelyUsed.emplace_back<MemorySlot>({getRes(), remainingSlotType});
876
877 return true;
878}
879
880DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
881 DenseMap<Attribute, MemorySlot> &subslots,
882 OpBuilder &builder,
883 const DataLayout &dataLayout) {
884 std::optional<SubslotAccessInfo> accessInfo =
885 getSubslotAccessInfo(slot, dataLayout, *this);
886 assert(accessInfo && "expected access info to be checked before");
887 auto indexAttr =
888 IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
889 const MemorySlot &newSlot = subslots.at(indexAttr);
890
891 auto byteType = IntegerType::get(builder.getContext(), 8);
892 auto newPtr = builder.createOrFold<LLVM::GEPOp>(
893 getLoc(), getResult().getType(), byteType, newSlot.ptr,
894 ArrayRef<GEPArg>(accessInfo->subslotOffset), getNoWrapFlags());
895 getResult().replaceAllUsesWith(newPtr);
896 return DeletionKind::Delete;
897}
898
899//===----------------------------------------------------------------------===//
900// Utilities for memory intrinsics
901//===----------------------------------------------------------------------===//
902
903namespace {
904
905/// Returns the length of the given memory intrinsic in bytes if it can be known
906/// at compile-time on a best-effort basis, nothing otherwise.
907template <class MemIntr>
908std::optional<uint64_t> getStaticMemIntrLen(MemIntr op) {
909 APInt memIntrLen;
910 if (!matchPattern(op.getLen(), m_ConstantInt(&memIntrLen)))
911 return {};
912 if (memIntrLen.getBitWidth() > 64)
913 return {};
914 return memIntrLen.getZExtValue();
915}
916
917/// Returns the length of the given memory intrinsic in bytes if it can be known
918/// at compile-time on a best-effort basis, nothing otherwise.
919/// Because MemcpyInlineOp has its length encoded as an attribute, this requires
920/// specialized handling.
921template <>
922std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) {
923 APInt memIntrLen = op.getLen();
924 if (memIntrLen.getBitWidth() > 64)
925 return {};
926 return memIntrLen.getZExtValue();
927}
928
929/// Returns the length of the given memory intrinsic in bytes if it can be known
930/// at compile-time on a best-effort basis, nothing otherwise.
931/// Because MemsetInlineOp has its length encoded as an attribute, this requires
932/// specialized handling.
933template <>
934std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemsetInlineOp op) {
935 APInt memIntrLen = op.getLen();
936 if (memIntrLen.getBitWidth() > 64)
937 return {};
938 return memIntrLen.getZExtValue();
939}
940
941/// Returns an integer attribute representing the length of a memset intrinsic
942template <class MemsetIntr>
943IntegerAttr createMemsetLenAttr(MemsetIntr op) {
944 IntegerAttr memsetLenAttr;
945 bool successfulMatch =
946 matchPattern(op.getLen(), m_Constant<IntegerAttr>(&memsetLenAttr));
947 (void)successfulMatch;
948 assert(successfulMatch);
949 return memsetLenAttr;
950}
951
952/// Returns an integer attribute representing the length of a memset intrinsic
953/// Because MemsetInlineOp has its length encoded as an attribute, this requires
954/// specialized handling.
955template <>
956IntegerAttr createMemsetLenAttr(LLVM::MemsetInlineOp op) {
957 return op.getLenAttr();
958}
959
960/// Creates a memset intrinsic of that matches the `toReplace` intrinsic
961/// using the provided parameters. There are template specializations for
962/// MemsetOp and MemsetInlineOp.
963template <class MemsetIntr>
964void createMemsetIntr(OpBuilder &builder, MemsetIntr toReplace,
965 IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
966 DenseMap<Attribute, MemorySlot> &subslots,
967 Attribute index);
968
969template <>
970void createMemsetIntr(OpBuilder &builder, LLVM::MemsetOp toReplace,
971 IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
972 DenseMap<Attribute, MemorySlot> &subslots,
973 Attribute index) {
974 Value newMemsetSizeValue =
975 builder
976 .create<LLVM::ConstantOp>(
977 toReplace.getLen().getLoc(),
978 IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize))
979 .getResult();
980
981 builder.create<LLVM::MemsetOp>(toReplace.getLoc(), subslots.at(index).ptr,
982 toReplace.getVal(), newMemsetSizeValue,
983 toReplace.getIsVolatile());
984}
985
986template <>
987void createMemsetIntr(OpBuilder &builder, LLVM::MemsetInlineOp toReplace,
988 IntegerAttr memsetLenAttr, uint64_t newMemsetSize,
989 DenseMap<Attribute, MemorySlot> &subslots,
990 Attribute index) {
991 auto newMemsetSizeValue =
992 IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize);
993
994 builder.create<LLVM::MemsetInlineOp>(
995 toReplace.getLoc(), subslots.at(index).ptr, toReplace.getVal(),
996 newMemsetSizeValue, toReplace.getIsVolatile());
997}
998
999} // namespace
1000
1001/// Returns whether one can be sure the memory intrinsic does not write outside
1002/// of the bounds of the given slot, on a best-effort basis.
1003template <class MemIntr>
1004static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot,
1005 const DataLayout &dataLayout) {
1006 if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) ||
1007 op.getDst() != slot.ptr)
1008 return false;
1009
1010 std::optional<uint64_t> memIntrLen = getStaticMemIntrLen(op);
1011 return memIntrLen && *memIntrLen <= dataLayout.getTypeSize(t: slot.elemType);
1012}
1013
1014/// Checks whether all indices are i32. This is used to check GEPs can index
1015/// into them.
1016static bool areAllIndicesI32(const DestructurableMemorySlot &slot) {
1017 Type i32 = IntegerType::get(slot.ptr.getContext(), 32);
1018 return llvm::all_of(Range: llvm::make_first_range(c: slot.subelementTypes),
1019 P: [&](Attribute index) {
1020 auto intIndex = dyn_cast<IntegerAttr>(index);
1021 return intIndex && intIndex.getType() == i32;
1022 });
1023}
1024
1025//===----------------------------------------------------------------------===//
1026// Interfaces for memset and memset.inline
1027//===----------------------------------------------------------------------===//
1028
1029template <class MemsetIntr>
1030static bool memsetCanRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
1031 SmallPtrSetImpl<Attribute> &usedIndices,
1032 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1033 const DataLayout &dataLayout) {
1034 if (&slot.elemType.getDialect() != op.getOperation()->getDialect())
1035 return false;
1036
1037 if (op.getIsVolatile())
1038 return false;
1039
1040 if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
1041 return false;
1042
1043 if (!areAllIndicesI32(slot))
1044 return false;
1045
1046 return definitelyWritesOnlyWithinSlot(op, slot, dataLayout);
1047}
1048
1049template <class MemsetIntr>
1050static Value memsetGetStored(MemsetIntr op, const MemorySlot &slot,
1051 OpBuilder &builder) {
1052 /// Returns an integer value that is `width` bits wide representing the value
1053 /// assigned to the slot by memset.
1054 auto buildMemsetValue = [&](unsigned width) -> Value {
1055 assert(width % 8 == 0);
1056 auto intType = IntegerType::get(op.getContext(), width);
1057
1058 // If we know the pattern at compile time, we can compute and assign a
1059 // constant directly.
1060 IntegerAttr constantPattern;
1061 if (matchPattern(op.getVal(), m_Constant(&constantPattern))) {
1062 assert(constantPattern.getValue().getBitWidth() == 8);
1063 APInt memsetVal(/*numBits=*/width, /*val=*/0);
1064 for (unsigned loBit = 0; loBit < width; loBit += 8)
1065 memsetVal.insertBits(constantPattern.getValue(), loBit);
1066 return builder.create<LLVM::ConstantOp>(
1067 op.getLoc(), IntegerAttr::get(intType, memsetVal));
1068 }
1069
1070 // If the output is a single byte, we can return the pattern directly.
1071 if (width == 8)
1072 return op.getVal();
1073
1074 // Otherwise build the memset integer at runtime by repeatedly shifting the
1075 // value and or-ing it with the previous value.
1076 uint64_t coveredBits = 8;
1077 Value currentValue =
1078 builder.create<LLVM::ZExtOp>(op.getLoc(), intType, op.getVal());
1079 while (coveredBits < width) {
1080 Value shiftBy =
1081 builder.create<LLVM::ConstantOp>(op.getLoc(), intType, coveredBits);
1082 Value shifted =
1083 builder.create<LLVM::ShlOp>(op.getLoc(), currentValue, shiftBy);
1084 currentValue =
1085 builder.create<LLVM::OrOp>(op.getLoc(), currentValue, shifted);
1086 coveredBits *= 2;
1087 }
1088
1089 return currentValue;
1090 };
1091 return TypeSwitch<Type, Value>(slot.elemType)
1092 .Case([&](IntegerType type) -> Value {
1093 return buildMemsetValue(type.getWidth());
1094 })
1095 .Case([&](FloatType type) -> Value {
1096 Value intVal = buildMemsetValue(type.getWidth());
1097 return builder.create<LLVM::BitcastOp>(op.getLoc(), type, intVal);
1098 })
1099 .Default([](Type) -> Value {
1100 llvm_unreachable(
1101 "getStored should not be called on memset to unsupported type");
1102 });
1103}
1104
1105template <class MemsetIntr>
1106static bool
1107memsetCanUsesBeRemoved(MemsetIntr op, const MemorySlot &slot,
1108 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1109 SmallVectorImpl<OpOperand *> &newBlockingUses,
1110 const DataLayout &dataLayout) {
1111 bool canConvertType =
1112 TypeSwitch<Type, bool>(slot.elemType)
1113 .Case<IntegerType, FloatType>([](auto type) {
1114 return type.getWidth() % 8 == 0 && type.getWidth() > 0;
1115 })
1116 .Default([](Type) { return false; });
1117 if (!canConvertType)
1118 return false;
1119
1120 if (op.getIsVolatile())
1121 return false;
1122
1123 return getStaticMemIntrLen(op) == dataLayout.getTypeSize(t: slot.elemType);
1124}
1125
1126template <class MemsetIntr>
1127static DeletionKind
1128memsetRewire(MemsetIntr op, const DestructurableMemorySlot &slot,
1129 DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder,
1130 const DataLayout &dataLayout) {
1131
1132 std::optional<DenseMap<Attribute, Type>> types =
1133 cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap();
1134
1135 IntegerAttr memsetLenAttr = createMemsetLenAttr(op);
1136
1137 bool packed = false;
1138 if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType))
1139 packed = structType.isPacked();
1140
1141 Type i32 = IntegerType::get(op.getContext(), 32);
1142 uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue();
1143 uint64_t covered = 0;
1144 for (size_t i = 0; i < types->size(); i++) {
1145 // Create indices on the fly to get elements in the right order.
1146 Attribute index = IntegerAttr::get(i32, i);
1147 Type elemType = types->at(Val: index);
1148 uint64_t typeSize = dataLayout.getTypeSize(t: elemType);
1149
1150 if (!packed)
1151 covered =
1152 llvm::alignTo(Value: covered, Align: dataLayout.getTypeABIAlignment(t: elemType));
1153
1154 if (covered >= memsetLen)
1155 break;
1156
1157 // If this subslot is used, apply a new memset to it.
1158 // Otherwise, only compute its offset within the original memset.
1159 if (subslots.contains(Val: index)) {
1160 uint64_t newMemsetSize = std::min(a: memsetLen - covered, b: typeSize);
1161 createMemsetIntr(builder, op, memsetLenAttr, newMemsetSize, subslots,
1162 index);
1163 }
1164
1165 covered += typeSize;
1166 }
1167
1168 return DeletionKind::Delete;
1169}
1170
1171bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; }
1172
1173bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
1174 return getDst() == slot.ptr;
1175}
1176
1177Value LLVM::MemsetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1178 Value reachingDef,
1179 const DataLayout &dataLayout) {
1180 return memsetGetStored(*this, slot, builder);
1181}
1182
1183bool LLVM::MemsetOp::canUsesBeRemoved(
1184 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1185 SmallVectorImpl<OpOperand *> &newBlockingUses,
1186 const DataLayout &dataLayout) {
1187 return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1188 dataLayout);
1189}
1190
1191DeletionKind LLVM::MemsetOp::removeBlockingUses(
1192 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1193 OpBuilder &builder, Value reachingDefinition,
1194 const DataLayout &dataLayout) {
1195 return DeletionKind::Delete;
1196}
1197
1198LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses(
1199 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1200 const DataLayout &dataLayout) {
1201 return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
1202}
1203
1204bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot,
1205 SmallPtrSetImpl<Attribute> &usedIndices,
1206 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1207 const DataLayout &dataLayout) {
1208 return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1209 dataLayout);
1210}
1211
1212DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot,
1213 DenseMap<Attribute, MemorySlot> &subslots,
1214 OpBuilder &builder,
1215 const DataLayout &dataLayout) {
1216 return memsetRewire(*this, slot, subslots, builder, dataLayout);
1217}
1218
1219bool LLVM::MemsetInlineOp::loadsFrom(const MemorySlot &slot) { return false; }
1220
1221bool LLVM::MemsetInlineOp::storesTo(const MemorySlot &slot) {
1222 return getDst() == slot.ptr;
1223}
1224
1225Value LLVM::MemsetInlineOp::getStored(const MemorySlot &slot,
1226 OpBuilder &builder, Value reachingDef,
1227 const DataLayout &dataLayout) {
1228 return memsetGetStored(*this, slot, builder);
1229}
1230
1231bool LLVM::MemsetInlineOp::canUsesBeRemoved(
1232 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1233 SmallVectorImpl<OpOperand *> &newBlockingUses,
1234 const DataLayout &dataLayout) {
1235 return memsetCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1236 dataLayout);
1237}
1238
1239DeletionKind LLVM::MemsetInlineOp::removeBlockingUses(
1240 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1241 OpBuilder &builder, Value reachingDefinition,
1242 const DataLayout &dataLayout) {
1243 return DeletionKind::Delete;
1244}
1245
1246LogicalResult LLVM::MemsetInlineOp::ensureOnlySafeAccesses(
1247 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1248 const DataLayout &dataLayout) {
1249 return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout));
1250}
1251
1252bool LLVM::MemsetInlineOp::canRewire(
1253 const DestructurableMemorySlot &slot,
1254 SmallPtrSetImpl<Attribute> &usedIndices,
1255 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1256 const DataLayout &dataLayout) {
1257 return memsetCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1258 dataLayout);
1259}
1260
1261DeletionKind
1262LLVM::MemsetInlineOp::rewire(const DestructurableMemorySlot &slot,
1263 DenseMap<Attribute, MemorySlot> &subslots,
1264 OpBuilder &builder, const DataLayout &dataLayout) {
1265 return memsetRewire(*this, slot, subslots, builder, dataLayout);
1266}
1267
1268//===----------------------------------------------------------------------===//
1269// Interfaces for memcpy/memmove
1270//===----------------------------------------------------------------------===//
1271
1272template <class MemcpyLike>
1273static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot) {
1274 return op.getSrc() == slot.ptr;
1275}
1276
1277template <class MemcpyLike>
1278static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) {
1279 return op.getDst() == slot.ptr;
1280}
1281
1282template <class MemcpyLike>
1283static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot,
1284 OpBuilder &builder) {
1285 return builder.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc());
1286}
1287
1288template <class MemcpyLike>
1289static bool
1290memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot,
1291 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1292 SmallVectorImpl<OpOperand *> &newBlockingUses,
1293 const DataLayout &dataLayout) {
1294 // If source and destination are the same, memcpy behavior is undefined and
1295 // memmove is a no-op. Because there is no memory change happening here,
1296 // simplifying such operations is left to canonicalization.
1297 if (op.getDst() == op.getSrc())
1298 return false;
1299
1300 if (op.getIsVolatile())
1301 return false;
1302
1303 return getStaticMemIntrLen(op) == dataLayout.getTypeSize(t: slot.elemType);
1304}
1305
1306template <class MemcpyLike>
1307static DeletionKind
1308memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot,
1309 const SmallPtrSetImpl<OpOperand *> &blockingUses,
1310 OpBuilder &builder, Value reachingDefinition) {
1311 if (op.loadsFrom(slot))
1312 builder.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition, op.getDst());
1313 return DeletionKind::Delete;
1314}
1315
1316template <class MemcpyLike>
1317static LogicalResult
1318memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot,
1319 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) {
1320 DataLayout dataLayout = DataLayout::closest(op);
1321 // While rewiring memcpy-like intrinsics only supports full copies, partial
1322 // copies are still safe accesses so it is enough to only check for writes
1323 // within bounds.
1324 return success(definitelyWritesOnlyWithinSlot(op, slot, dataLayout));
1325}
1326
1327template <class MemcpyLike>
1328static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
1329 SmallPtrSetImpl<Attribute> &usedIndices,
1330 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1331 const DataLayout &dataLayout) {
1332 if (op.getIsVolatile())
1333 return false;
1334
1335 if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap())
1336 return false;
1337
1338 if (!areAllIndicesI32(slot))
1339 return false;
1340
1341 // Only full copies are supported.
1342 if (getStaticMemIntrLen(op) != dataLayout.getTypeSize(t: slot.elemType))
1343 return false;
1344
1345 if (op.getSrc() == slot.ptr)
1346 usedIndices.insert_range(R: llvm::make_first_range(c: slot.subelementTypes));
1347
1348 return true;
1349}
1350
1351namespace {
1352
1353template <class MemcpyLike>
1354void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout,
1355 MemcpyLike toReplace, Value dst, Value src,
1356 Type toCpy, bool isVolatile) {
1357 Value memcpySize = builder.create<LLVM::ConstantOp>(
1358 toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(),
1359 layout.getTypeSize(toCpy)));
1360 builder.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize,
1361 isVolatile);
1362}
1363
1364template <>
1365void createMemcpyLikeToReplace(OpBuilder &builder, const DataLayout &layout,
1366 LLVM::MemcpyInlineOp toReplace, Value dst,
1367 Value src, Type toCpy, bool isVolatile) {
1368 Type lenType = IntegerType::get(toReplace->getContext(),
1369 toReplace.getLen().getBitWidth());
1370 builder.create<LLVM::MemcpyInlineOp>(
1371 toReplace.getLoc(), dst, src,
1372 IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile);
1373}
1374
1375} // namespace
1376
1377/// Rewires a memcpy-like operation. Only copies to or from the full slot are
1378/// supported.
1379template <class MemcpyLike>
1380static DeletionKind
1381memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot,
1382 DenseMap<Attribute, MemorySlot> &subslots, OpBuilder &builder,
1383 const DataLayout &dataLayout) {
1384 if (subslots.empty())
1385 return DeletionKind::Delete;
1386
1387 assert((slot.ptr == op.getDst()) != (slot.ptr == op.getSrc()));
1388 bool isDst = slot.ptr == op.getDst();
1389
1390#ifndef NDEBUG
1391 size_t slotsTreated = 0;
1392#endif
1393
1394 // It was previously checked that index types are consistent, so this type can
1395 // be fetched now.
1396 Type indexType = cast<IntegerAttr>(subslots.begin()->first).getType();
1397 for (size_t i = 0, e = slot.subelementTypes.size(); i != e; i++) {
1398 Attribute index = IntegerAttr::get(indexType, i);
1399 if (!subslots.contains(Val: index))
1400 continue;
1401 const MemorySlot &subslot = subslots.at(Val: index);
1402
1403#ifndef NDEBUG
1404 slotsTreated++;
1405#endif
1406
1407 // First get a pointer to the equivalent of this subslot from the source
1408 // pointer.
1409 SmallVector<LLVM::GEPArg> gepIndices{
1410 0, static_cast<int32_t>(
1411 cast<IntegerAttr>(index).getValue().getZExtValue())};
1412 Value subslotPtrInOther = builder.create<LLVM::GEPOp>(
1413 op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), slot.elemType,
1414 isDst ? op.getSrc() : op.getDst(), gepIndices);
1415
1416 // Then create a new memcpy out of this source pointer.
1417 createMemcpyLikeToReplace(builder, dataLayout, op,
1418 isDst ? subslot.ptr : subslotPtrInOther,
1419 isDst ? subslotPtrInOther : subslot.ptr,
1420 subslot.elemType, op.getIsVolatile());
1421 }
1422
1423 assert(subslots.size() == slotsTreated);
1424
1425 return DeletionKind::Delete;
1426}
1427
1428bool LLVM::MemcpyOp::loadsFrom(const MemorySlot &slot) {
1429 return memcpyLoadsFrom(*this, slot);
1430}
1431
1432bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
1433 return memcpyStoresTo(*this, slot);
1434}
1435
1436Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1437 Value reachingDef,
1438 const DataLayout &dataLayout) {
1439 return memcpyGetStored(*this, slot, builder);
1440}
1441
1442bool LLVM::MemcpyOp::canUsesBeRemoved(
1443 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1444 SmallVectorImpl<OpOperand *> &newBlockingUses,
1445 const DataLayout &dataLayout) {
1446 return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1447 dataLayout);
1448}
1449
1450DeletionKind LLVM::MemcpyOp::removeBlockingUses(
1451 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1452 OpBuilder &builder, Value reachingDefinition,
1453 const DataLayout &dataLayout) {
1454 return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
1455 reachingDefinition);
1456}
1457
1458LogicalResult LLVM::MemcpyOp::ensureOnlySafeAccesses(
1459 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1460 const DataLayout &dataLayout) {
1461 return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1462}
1463
1464bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot,
1465 SmallPtrSetImpl<Attribute> &usedIndices,
1466 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1467 const DataLayout &dataLayout) {
1468 return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1469 dataLayout);
1470}
1471
1472DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot,
1473 DenseMap<Attribute, MemorySlot> &subslots,
1474 OpBuilder &builder,
1475 const DataLayout &dataLayout) {
1476 return memcpyRewire(*this, slot, subslots, builder, dataLayout);
1477}
1478
1479bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) {
1480 return memcpyLoadsFrom(*this, slot);
1481}
1482
1483bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
1484 return memcpyStoresTo(*this, slot);
1485}
1486
1487Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
1488 OpBuilder &builder, Value reachingDef,
1489 const DataLayout &dataLayout) {
1490 return memcpyGetStored(*this, slot, builder);
1491}
1492
1493bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
1494 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1495 SmallVectorImpl<OpOperand *> &newBlockingUses,
1496 const DataLayout &dataLayout) {
1497 return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1498 dataLayout);
1499}
1500
1501DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
1502 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1503 OpBuilder &builder, Value reachingDefinition,
1504 const DataLayout &dataLayout) {
1505 return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
1506 reachingDefinition);
1507}
1508
1509LogicalResult LLVM::MemcpyInlineOp::ensureOnlySafeAccesses(
1510 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1511 const DataLayout &dataLayout) {
1512 return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1513}
1514
1515bool LLVM::MemcpyInlineOp::canRewire(
1516 const DestructurableMemorySlot &slot,
1517 SmallPtrSetImpl<Attribute> &usedIndices,
1518 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1519 const DataLayout &dataLayout) {
1520 return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1521 dataLayout);
1522}
1523
1524DeletionKind
1525LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot,
1526 DenseMap<Attribute, MemorySlot> &subslots,
1527 OpBuilder &builder, const DataLayout &dataLayout) {
1528 return memcpyRewire(*this, slot, subslots, builder, dataLayout);
1529}
1530
1531bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) {
1532 return memcpyLoadsFrom(*this, slot);
1533}
1534
1535bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
1536 return memcpyStoresTo(*this, slot);
1537}
1538
1539Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, OpBuilder &builder,
1540 Value reachingDef,
1541 const DataLayout &dataLayout) {
1542 return memcpyGetStored(*this, slot, builder);
1543}
1544
1545bool LLVM::MemmoveOp::canUsesBeRemoved(
1546 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1547 SmallVectorImpl<OpOperand *> &newBlockingUses,
1548 const DataLayout &dataLayout) {
1549 return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses,
1550 dataLayout);
1551}
1552
1553DeletionKind LLVM::MemmoveOp::removeBlockingUses(
1554 const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1555 OpBuilder &builder, Value reachingDefinition,
1556 const DataLayout &dataLayout) {
1557 return memcpyRemoveBlockingUses(*this, slot, blockingUses, builder,
1558 reachingDefinition);
1559}
1560
1561LogicalResult LLVM::MemmoveOp::ensureOnlySafeAccesses(
1562 const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1563 const DataLayout &dataLayout) {
1564 return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed);
1565}
1566
1567bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot,
1568 SmallPtrSetImpl<Attribute> &usedIndices,
1569 SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
1570 const DataLayout &dataLayout) {
1571 return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed,
1572 dataLayout);
1573}
1574
1575DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot,
1576 DenseMap<Attribute, MemorySlot> &subslots,
1577 OpBuilder &builder,
1578 const DataLayout &dataLayout) {
1579 return memcpyRewire(*this, slot, subslots, builder, dataLayout);
1580}
1581
1582//===----------------------------------------------------------------------===//
1583// Interfaces for destructurable types
1584//===----------------------------------------------------------------------===//
1585
1586std::optional<DenseMap<Attribute, Type>>
1587LLVM::LLVMStructType::getSubelementIndexMap() const {
1588 Type i32 = IntegerType::get(getContext(), 32);
1589 DenseMap<Attribute, Type> destructured;
1590 for (const auto &[index, elemType] : llvm::enumerate(getBody()))
1591 destructured.insert({IntegerAttr::get(i32, index), elemType});
1592 return destructured;
1593}
1594
1595Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) const {
1596 auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
1597 if (!indexAttr || !indexAttr.getType().isInteger(32))
1598 return {};
1599 int32_t indexInt = indexAttr.getInt();
1600 ArrayRef<Type> body = getBody();
1601 if (indexInt < 0 || body.size() <= static_cast<uint32_t>(indexInt))
1602 return {};
1603 return body[indexInt];
1604}
1605
1606std::optional<DenseMap<Attribute, Type>>
1607LLVM::LLVMArrayType::getSubelementIndexMap() const {
1608 constexpr size_t maxArraySizeForDestructuring = 16;
1609 if (getNumElements() > maxArraySizeForDestructuring)
1610 return {};
1611 int32_t numElements = getNumElements();
1612
1613 Type i32 = IntegerType::get(getContext(), 32);
1614 DenseMap<Attribute, Type> destructured;
1615 for (int32_t index = 0; index < numElements; ++index)
1616 destructured.insert({IntegerAttr::get(i32, index), getElementType()});
1617 return destructured;
1618}
1619
1620Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const {
1621 auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
1622 if (!indexAttr || !indexAttr.getType().isInteger(32))
1623 return {};
1624 int32_t indexInt = indexAttr.getInt();
1625 if (indexInt < 0 || getNumElements() <= static_cast<uint32_t>(indexInt))
1626 return {};
1627 return getElementType();
1628}
1629

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp