1 | //===- LLVMMemorySlot.cpp - MemorySlot interfaces ---------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements MemorySlot-related interfaces for LLVM dialect |
10 | // operations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
15 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
16 | #include "mlir/IR/Matchers.h" |
17 | #include "mlir/IR/PatternMatch.h" |
18 | #include "mlir/Interfaces/DataLayoutInterfaces.h" |
19 | #include "mlir/Interfaces/MemorySlotInterfaces.h" |
20 | #include "llvm/ADT/STLExtras.h" |
21 | #include "llvm/ADT/TypeSwitch.h" |
22 | |
23 | #define DEBUG_TYPE "sroa" |
24 | |
25 | using namespace mlir; |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | // Interfaces for AllocaOp |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() { |
32 | if (!getOperation()->getBlock()->isEntryBlock()) |
33 | return {}; |
34 | |
35 | return {MemorySlot{getResult(), getElemType()}}; |
36 | } |
37 | |
38 | Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot, |
39 | RewriterBase &rewriter) { |
40 | return rewriter.create<LLVM::UndefOp>(getLoc(), slot.elemType); |
41 | } |
42 | |
43 | void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot, |
44 | BlockArgument argument, |
45 | RewriterBase &rewriter) { |
46 | for (Operation *user : getOperation()->getUsers()) |
47 | if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user)) |
48 | rewriter.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument, |
49 | declareOp.getVarInfo(), |
50 | declareOp.getLocationExpr()); |
51 | } |
52 | |
53 | void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot, |
54 | Value defaultValue, |
55 | RewriterBase &rewriter) { |
56 | if (defaultValue && defaultValue.use_empty()) |
57 | rewriter.eraseOp(defaultValue.getDefiningOp()); |
58 | rewriter.eraseOp(*this); |
59 | } |
60 | |
61 | SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() { |
62 | if (!mlir::matchPattern(getArraySize(), m_One())) |
63 | return {}; |
64 | |
65 | auto destructurable = dyn_cast<DestructurableTypeInterface>(getElemType()); |
66 | if (!destructurable) |
67 | return {}; |
68 | |
69 | std::optional<DenseMap<Attribute, Type>> destructuredType = |
70 | destructurable.getSubelementIndexMap(); |
71 | if (!destructuredType) |
72 | return {}; |
73 | |
74 | return {DestructurableMemorySlot{{getResult(), getElemType()}, |
75 | *destructuredType}}; |
76 | } |
77 | |
78 | DenseMap<Attribute, MemorySlot> |
79 | LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot, |
80 | const SmallPtrSetImpl<Attribute> &usedIndices, |
81 | RewriterBase &rewriter) { |
82 | assert(slot.ptr == getResult()); |
83 | rewriter.setInsertionPointAfter(*this); |
84 | |
85 | auto destructurableType = cast<DestructurableTypeInterface>(getElemType()); |
86 | DenseMap<Attribute, MemorySlot> slotMap; |
87 | for (Attribute index : usedIndices) { |
88 | Type elemType = destructurableType.getTypeAtIndex(index); |
89 | assert(elemType && "used index must exist" ); |
90 | auto subAlloca = rewriter.create<LLVM::AllocaOp>( |
91 | getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType, |
92 | getArraySize()); |
93 | slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType}); |
94 | } |
95 | |
96 | return slotMap; |
97 | } |
98 | |
99 | void LLVM::AllocaOp::handleDestructuringComplete( |
100 | const DestructurableMemorySlot &slot, RewriterBase &rewriter) { |
101 | assert(slot.ptr == getResult()); |
102 | rewriter.eraseOp(*this); |
103 | } |
104 | |
105 | //===----------------------------------------------------------------------===// |
106 | // Interfaces for LoadOp/StoreOp |
107 | //===----------------------------------------------------------------------===// |
108 | |
109 | bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) { |
110 | return getAddr() == slot.ptr; |
111 | } |
112 | |
113 | bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; } |
114 | |
115 | Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter, |
116 | Value reachingDef, const DataLayout &dataLayout) { |
117 | llvm_unreachable("getStored should not be called on LoadOp" ); |
118 | } |
119 | |
120 | bool LLVM::StoreOp::loadsFrom(const MemorySlot &slot) { return false; } |
121 | |
122 | bool LLVM::StoreOp::storesTo(const MemorySlot &slot) { |
123 | return getAddr() == slot.ptr; |
124 | } |
125 | |
126 | /// Checks if `type` can be used in any kind of conversion sequences. |
127 | static bool isSupportedTypeForConversion(Type type) { |
128 | // Aggregate types are not bitcastable. |
129 | if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type)) |
130 | return false; |
131 | |
132 | // LLVM vector types are only used for either pointers or target specific |
133 | // types. These types cannot be casted in the general case, thus the memory |
134 | // optimizations do not support them. |
135 | if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type)) |
136 | return false; |
137 | |
138 | // Scalable types are not supported. |
139 | if (auto vectorType = dyn_cast<VectorType>(type)) |
140 | return !vectorType.isScalable(); |
141 | return true; |
142 | } |
143 | |
144 | /// Checks that `rhs` can be converted to `lhs` by a sequence of casts and |
145 | /// truncations. Checks for narrowing or widening conversion compatibility |
146 | /// depending on `narrowingConversion`. |
147 | static bool areConversionCompatible(const DataLayout &layout, Type targetType, |
148 | Type srcType, bool narrowingConversion) { |
149 | if (targetType == srcType) |
150 | return true; |
151 | |
152 | if (!isSupportedTypeForConversion(type: targetType) || |
153 | !isSupportedTypeForConversion(type: srcType)) |
154 | return false; |
155 | |
156 | uint64_t targetSize = layout.getTypeSize(t: targetType); |
157 | uint64_t srcSize = layout.getTypeSize(t: srcType); |
158 | |
159 | // Pointer casts will only be sane when the bitsize of both pointer types is |
160 | // the same. |
161 | if (isa<LLVM::LLVMPointerType>(targetType) && |
162 | isa<LLVM::LLVMPointerType>(srcType)) |
163 | return targetSize == srcSize; |
164 | |
165 | if (narrowingConversion) |
166 | return targetSize <= srcSize; |
167 | return targetSize >= srcSize; |
168 | } |
169 | |
170 | /// Checks if `dataLayout` describes a little endian layout. |
171 | static bool isBigEndian(const DataLayout &dataLayout) { |
172 | auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness()); |
173 | return endiannessStr && endiannessStr == "big" ; |
174 | } |
175 | |
176 | /// Converts a value to an integer type of the same size. |
177 | /// Assumes that the type can be converted. |
178 | static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val, |
179 | const DataLayout &dataLayout) { |
180 | Type type = val.getType(); |
181 | assert(isSupportedTypeForConversion(type) && |
182 | "expected value to have a convertible type" ); |
183 | |
184 | if (isa<IntegerType>(Val: type)) |
185 | return val; |
186 | |
187 | uint64_t typeBitSize = dataLayout.getTypeSizeInBits(t: type); |
188 | IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize); |
189 | |
190 | if (isa<LLVM::LLVMPointerType>(type)) |
191 | return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val); |
192 | return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val); |
193 | } |
194 | |
195 | /// Converts a value with an integer type to `targetType`. |
196 | static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc, |
197 | Value val, Type targetType) { |
198 | assert(isa<IntegerType>(val.getType()) && |
199 | "expected value to have an integer type" ); |
200 | assert(isSupportedTypeForConversion(targetType) && |
201 | "expected the target type to be supported for conversions" ); |
202 | if (val.getType() == targetType) |
203 | return val; |
204 | if (isa<LLVM::LLVMPointerType>(targetType)) |
205 | return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val); |
206 | return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val); |
207 | } |
208 | |
209 | /// Constructs operations that convert `srcValue` into a new value of type |
210 | /// `targetType`. Assumes the types have the same bitsize. |
211 | static Value castSameSizedTypes(RewriterBase &rewriter, Location loc, |
212 | Value srcValue, Type targetType, |
213 | const DataLayout &dataLayout) { |
214 | Type srcType = srcValue.getType(); |
215 | assert(areConversionCompatible(dataLayout, targetType, srcType, |
216 | /*narrowingConversion=*/true) && |
217 | "expected that the compatibility was checked before" ); |
218 | |
219 | // Nothing has to be done if the types are already the same. |
220 | if (srcType == targetType) |
221 | return srcValue; |
222 | |
223 | // In the special case of casting one pointer to another, we want to generate |
224 | // an address space cast. Bitcasts of pointers are not allowed and using |
225 | // pointer to integer conversions are not equivalent due to the loss of |
226 | // provenance. |
227 | if (isa<LLVM::LLVMPointerType>(targetType) && |
228 | isa<LLVM::LLVMPointerType>(srcType)) |
229 | return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType, |
230 | srcValue); |
231 | |
232 | // For all other castable types, casting through integers is necessary. |
233 | Value replacement = castToSameSizedInt(rewriter, loc, val: srcValue, dataLayout); |
234 | return castIntValueToSameSizedType(rewriter, loc, val: replacement, targetType); |
235 | } |
236 | |
237 | /// Constructs operations that convert `srcValue` into a new value of type |
238 | /// `targetType`. Performs bit-level extraction if the source type is larger |
239 | /// than the target type. Assumes that this conversion is possible. |
240 | static Value createExtractAndCast(RewriterBase &rewriter, Location loc, |
241 | Value srcValue, Type targetType, |
242 | const DataLayout &dataLayout) { |
243 | // Get the types of the source and target values. |
244 | Type srcType = srcValue.getType(); |
245 | assert(areConversionCompatible(dataLayout, targetType, srcType, |
246 | /*narrowingConversion=*/true) && |
247 | "expected that the compatibility was checked before" ); |
248 | |
249 | uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(t: srcType); |
250 | uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(t: targetType); |
251 | if (srcTypeSize == targetTypeSize) |
252 | return castSameSizedTypes(rewriter, loc, srcValue, targetType, dataLayout); |
253 | |
254 | // First, cast the value to a same-sized integer type. |
255 | Value replacement = castToSameSizedInt(rewriter, loc, val: srcValue, dataLayout); |
256 | |
257 | // Truncate the integer if the size of the target is less than the value. |
258 | if (isBigEndian(dataLayout)) { |
259 | uint64_t shiftAmount = srcTypeSize - targetTypeSize; |
260 | auto shiftConstant = rewriter.create<LLVM::ConstantOp>( |
261 | loc, rewriter.getIntegerAttr(srcType, shiftAmount)); |
262 | replacement = |
263 | rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant); |
264 | } |
265 | |
266 | replacement = rewriter.create<LLVM::TruncOp>( |
267 | loc, rewriter.getIntegerType(targetTypeSize), replacement); |
268 | |
269 | // Now cast the integer to the actual target type if required. |
270 | return castIntValueToSameSizedType(rewriter, loc, val: replacement, targetType); |
271 | } |
272 | |
273 | /// Constructs operations that insert the bits of `srcValue` into the |
274 | /// "beginning" of `reachingDef` (beginning is endianness dependent). |
275 | /// Assumes that this conversion is possible. |
276 | static Value createInsertAndCast(RewriterBase &rewriter, Location loc, |
277 | Value srcValue, Value reachingDef, |
278 | const DataLayout &dataLayout) { |
279 | |
280 | assert(areConversionCompatible(dataLayout, reachingDef.getType(), |
281 | srcValue.getType(), |
282 | /*narrowingConversion=*/false) && |
283 | "expected that the compatibility was checked before" ); |
284 | uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(t: srcValue.getType()); |
285 | uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(t: reachingDef.getType()); |
286 | if (slotTypeSize == valueTypeSize) |
287 | return castSameSizedTypes(rewriter, loc, srcValue, targetType: reachingDef.getType(), |
288 | dataLayout); |
289 | |
290 | // In the case where the store only overwrites parts of the memory, |
291 | // bit fiddling is required to construct the new value. |
292 | |
293 | // First convert both values to integers of the same size. |
294 | Value defAsInt = castToSameSizedInt(rewriter, loc, val: reachingDef, dataLayout); |
295 | Value valueAsInt = castToSameSizedInt(rewriter, loc, val: srcValue, dataLayout); |
296 | // Extend the value to the size of the reaching definition. |
297 | valueAsInt = |
298 | rewriter.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt); |
299 | uint64_t sizeDifference = slotTypeSize - valueTypeSize; |
300 | if (isBigEndian(dataLayout)) { |
301 | // On big endian systems, a store to the base pointer overwrites the most |
302 | // significant bits. To accomodate for this, the stored value needs to be |
303 | // shifted into the according position. |
304 | Value bigEndianShift = rewriter.create<LLVM::ConstantOp>( |
305 | loc, rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference)); |
306 | valueAsInt = |
307 | rewriter.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift); |
308 | } |
309 | |
310 | // Construct the mask that is used to erase the bits that are overwritten by |
311 | // the store. |
312 | APInt maskValue; |
313 | if (isBigEndian(dataLayout)) { |
314 | // Build a mask that has the most significant bits set to zero. |
315 | // Note: This is the same as 2^sizeDifference - 1 |
316 | maskValue = APInt::getAllOnes(numBits: sizeDifference).zext(width: slotTypeSize); |
317 | } else { |
318 | // Build a mask that has the least significant bits set to zero. |
319 | // Note: This is the same as -(2^valueTypeSize) |
320 | maskValue = APInt::getAllOnes(numBits: valueTypeSize).zext(width: slotTypeSize); |
321 | maskValue.flipAllBits(); |
322 | } |
323 | |
324 | // Mask out the affected bits ... |
325 | Value mask = rewriter.create<LLVM::ConstantOp>( |
326 | loc, rewriter.getIntegerAttr(defAsInt.getType(), maskValue)); |
327 | Value masked = rewriter.createOrFold<LLVM::AndOp>(loc, defAsInt, mask); |
328 | |
329 | // ... and combine the result with the new value. |
330 | Value combined = rewriter.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt); |
331 | |
332 | return castIntValueToSameSizedType(rewriter, loc, val: combined, |
333 | targetType: reachingDef.getType()); |
334 | } |
335 | |
336 | Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter, |
337 | Value reachingDef, |
338 | const DataLayout &dataLayout) { |
339 | assert(reachingDef && reachingDef.getType() == slot.elemType && |
340 | "expected the reaching definition's type to match the slot's type" ); |
341 | return createInsertAndCast(rewriter, getLoc(), getValue(), reachingDef, |
342 | dataLayout); |
343 | } |
344 | |
345 | bool LLVM::LoadOp::canUsesBeRemoved( |
346 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
347 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
348 | const DataLayout &dataLayout) { |
349 | if (blockingUses.size() != 1) |
350 | return false; |
351 | Value blockingUse = (*blockingUses.begin())->get(); |
352 | // If the blocking use is the slot ptr itself, there will be enough |
353 | // context to reconstruct the result of the load at removal time, so it can |
354 | // be removed (provided it is not volatile). |
355 | return blockingUse == slot.ptr && getAddr() == slot.ptr && |
356 | areConversionCompatible(dataLayout, getResult().getType(), |
357 | slot.elemType, /*narrowingConversion=*/true) && |
358 | !getVolatile_(); |
359 | } |
360 | |
361 | DeletionKind LLVM::LoadOp::removeBlockingUses( |
362 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
363 | RewriterBase &rewriter, Value reachingDefinition, |
364 | const DataLayout &dataLayout) { |
365 | // `canUsesBeRemoved` checked this blocking use must be the loaded slot |
366 | // pointer. |
367 | Value newResult = createExtractAndCast(rewriter, getLoc(), reachingDefinition, |
368 | getResult().getType(), dataLayout); |
369 | rewriter.replaceAllUsesWith(getResult(), newResult); |
370 | return DeletionKind::Delete; |
371 | } |
372 | |
373 | bool LLVM::StoreOp::canUsesBeRemoved( |
374 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
375 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
376 | const DataLayout &dataLayout) { |
377 | if (blockingUses.size() != 1) |
378 | return false; |
379 | Value blockingUse = (*blockingUses.begin())->get(); |
380 | // If the blocking use is the slot ptr itself, dropping the store is |
381 | // fine, provided we are currently promoting its target value. Don't allow a |
382 | // store OF the slot pointer, only INTO the slot pointer. |
383 | return blockingUse == slot.ptr && getAddr() == slot.ptr && |
384 | getValue() != slot.ptr && |
385 | areConversionCompatible(dataLayout, slot.elemType, |
386 | getValue().getType(), |
387 | /*narrowingConversion=*/false) && |
388 | !getVolatile_(); |
389 | } |
390 | |
391 | DeletionKind LLVM::StoreOp::removeBlockingUses( |
392 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
393 | RewriterBase &rewriter, Value reachingDefinition, |
394 | const DataLayout &dataLayout) { |
395 | return DeletionKind::Delete; |
396 | } |
397 | |
398 | /// Checks if `slot` can be accessed through the provided access type. |
399 | static bool isValidAccessType(const MemorySlot &slot, Type accessType, |
400 | const DataLayout &dataLayout) { |
401 | return dataLayout.getTypeSize(t: accessType) <= |
402 | dataLayout.getTypeSize(t: slot.elemType); |
403 | } |
404 | |
405 | LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses( |
406 | const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
407 | const DataLayout &dataLayout) { |
408 | return success(getAddr() != slot.ptr || |
409 | isValidAccessType(slot, getType(), dataLayout)); |
410 | } |
411 | |
412 | LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses( |
413 | const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
414 | const DataLayout &dataLayout) { |
415 | return success(getAddr() != slot.ptr || |
416 | isValidAccessType(slot, getValue().getType(), dataLayout)); |
417 | } |
418 | |
419 | /// Returns the subslot's type at the requested index. |
420 | static Type getTypeAtIndex(const DestructurableMemorySlot &slot, |
421 | Attribute index) { |
422 | auto subelementIndexMap = |
423 | cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap(); |
424 | if (!subelementIndexMap) |
425 | return {}; |
426 | assert(!subelementIndexMap->empty()); |
427 | |
428 | // Note: Returns a null-type when no entry was found. |
429 | return subelementIndexMap->lookup(index); |
430 | } |
431 | |
432 | bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot, |
433 | SmallPtrSetImpl<Attribute> &usedIndices, |
434 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
435 | const DataLayout &dataLayout) { |
436 | if (getVolatile_()) |
437 | return false; |
438 | |
439 | // A load always accesses the first element of the destructured slot. |
440 | auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); |
441 | Type subslotType = getTypeAtIndex(slot, index); |
442 | if (!subslotType) |
443 | return false; |
444 | |
445 | // The access can only be replaced when the subslot is read within its bounds. |
446 | if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType)) |
447 | return false; |
448 | |
449 | usedIndices.insert(index); |
450 | return true; |
451 | } |
452 | |
453 | DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot, |
454 | DenseMap<Attribute, MemorySlot> &subslots, |
455 | RewriterBase &rewriter, |
456 | const DataLayout &dataLayout) { |
457 | auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); |
458 | auto it = subslots.find(index); |
459 | assert(it != subslots.end()); |
460 | |
461 | rewriter.modifyOpInPlace( |
462 | *this, [&]() { getAddrMutable().set(it->getSecond().ptr); }); |
463 | return DeletionKind::Keep; |
464 | } |
465 | |
466 | bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot, |
467 | SmallPtrSetImpl<Attribute> &usedIndices, |
468 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
469 | const DataLayout &dataLayout) { |
470 | if (getVolatile_()) |
471 | return false; |
472 | |
473 | // Storing the pointer to memory cannot be dealt with. |
474 | if (getValue() == slot.ptr) |
475 | return false; |
476 | |
477 | // A store always accesses the first element of the destructured slot. |
478 | auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); |
479 | Type subslotType = getTypeAtIndex(slot, index); |
480 | if (!subslotType) |
481 | return false; |
482 | |
483 | // The access can only be replaced when the subslot is read within its bounds. |
484 | if (dataLayout.getTypeSize(getValue().getType()) > |
485 | dataLayout.getTypeSize(subslotType)) |
486 | return false; |
487 | |
488 | usedIndices.insert(index); |
489 | return true; |
490 | } |
491 | |
492 | DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot, |
493 | DenseMap<Attribute, MemorySlot> &subslots, |
494 | RewriterBase &rewriter, |
495 | const DataLayout &dataLayout) { |
496 | auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); |
497 | auto it = subslots.find(index); |
498 | assert(it != subslots.end()); |
499 | |
500 | rewriter.modifyOpInPlace( |
501 | *this, [&]() { getAddrMutable().set(it->getSecond().ptr); }); |
502 | return DeletionKind::Keep; |
503 | } |
504 | |
505 | //===----------------------------------------------------------------------===// |
506 | // Interfaces for discardable OPs |
507 | //===----------------------------------------------------------------------===// |
508 | |
509 | /// Conditions the deletion of the operation to the removal of all its uses. |
510 | static bool forwardToUsers(Operation *op, |
511 | SmallVectorImpl<OpOperand *> &newBlockingUses) { |
512 | for (Value result : op->getResults()) |
513 | for (OpOperand &use : result.getUses()) |
514 | newBlockingUses.push_back(Elt: &use); |
515 | return true; |
516 | } |
517 | |
518 | bool LLVM::BitcastOp::canUsesBeRemoved( |
519 | const SmallPtrSetImpl<OpOperand *> &blockingUses, |
520 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
521 | const DataLayout &dataLayout) { |
522 | return forwardToUsers(*this, newBlockingUses); |
523 | } |
524 | |
525 | DeletionKind LLVM::BitcastOp::removeBlockingUses( |
526 | const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) { |
527 | return DeletionKind::Delete; |
528 | } |
529 | |
530 | bool LLVM::AddrSpaceCastOp::canUsesBeRemoved( |
531 | const SmallPtrSetImpl<OpOperand *> &blockingUses, |
532 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
533 | const DataLayout &dataLayout) { |
534 | return forwardToUsers(*this, newBlockingUses); |
535 | } |
536 | |
537 | DeletionKind LLVM::AddrSpaceCastOp::removeBlockingUses( |
538 | const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) { |
539 | return DeletionKind::Delete; |
540 | } |
541 | |
542 | bool LLVM::LifetimeStartOp::canUsesBeRemoved( |
543 | const SmallPtrSetImpl<OpOperand *> &blockingUses, |
544 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
545 | const DataLayout &dataLayout) { |
546 | return true; |
547 | } |
548 | |
549 | DeletionKind LLVM::LifetimeStartOp::removeBlockingUses( |
550 | const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) { |
551 | return DeletionKind::Delete; |
552 | } |
553 | |
554 | bool LLVM::LifetimeEndOp::canUsesBeRemoved( |
555 | const SmallPtrSetImpl<OpOperand *> &blockingUses, |
556 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
557 | const DataLayout &dataLayout) { |
558 | return true; |
559 | } |
560 | |
561 | DeletionKind LLVM::LifetimeEndOp::removeBlockingUses( |
562 | const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) { |
563 | return DeletionKind::Delete; |
564 | } |
565 | |
566 | bool LLVM::InvariantStartOp::canUsesBeRemoved( |
567 | const SmallPtrSetImpl<OpOperand *> &blockingUses, |
568 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
569 | const DataLayout &dataLayout) { |
570 | return true; |
571 | } |
572 | |
573 | DeletionKind LLVM::InvariantStartOp::removeBlockingUses( |
574 | const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) { |
575 | return DeletionKind::Delete; |
576 | } |
577 | |
578 | bool LLVM::InvariantEndOp::canUsesBeRemoved( |
579 | const SmallPtrSetImpl<OpOperand *> &blockingUses, |
580 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
581 | const DataLayout &dataLayout) { |
582 | return true; |
583 | } |
584 | |
585 | DeletionKind LLVM::InvariantEndOp::removeBlockingUses( |
586 | const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) { |
587 | return DeletionKind::Delete; |
588 | } |
589 | |
590 | bool LLVM::DbgDeclareOp::canUsesBeRemoved( |
591 | const SmallPtrSetImpl<OpOperand *> &blockingUses, |
592 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
593 | const DataLayout &dataLayout) { |
594 | return true; |
595 | } |
596 | |
597 | DeletionKind LLVM::DbgDeclareOp::removeBlockingUses( |
598 | const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) { |
599 | return DeletionKind::Delete; |
600 | } |
601 | |
602 | bool LLVM::DbgValueOp::canUsesBeRemoved( |
603 | const SmallPtrSetImpl<OpOperand *> &blockingUses, |
604 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
605 | const DataLayout &dataLayout) { |
606 | // There is only one operand that we can remove the use of. |
607 | if (blockingUses.size() != 1) |
608 | return false; |
609 | |
610 | return (*blockingUses.begin())->get() == getValue(); |
611 | } |
612 | |
613 | DeletionKind LLVM::DbgValueOp::removeBlockingUses( |
614 | const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) { |
615 | // Rewriter by default is after '*this', but we need it before '*this'. |
616 | rewriter.setInsertionPoint(*this); |
617 | |
618 | // Rather than dropping the debug value, replace it with undef to preserve the |
619 | // debug local variable info. This allows the debugger to inform the user that |
620 | // the variable has been optimized out. |
621 | auto undef = |
622 | rewriter.create<UndefOp>(getValue().getLoc(), getValue().getType()); |
623 | rewriter.modifyOpInPlace(*this, [&] { getValueMutable().assign(undef); }); |
624 | return DeletionKind::Keep; |
625 | } |
626 | |
627 | bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; } |
628 | |
629 | void LLVM::DbgDeclareOp::visitReplacedValues( |
630 | ArrayRef<std::pair<Operation *, Value>> definitions, |
631 | RewriterBase &rewriter) { |
632 | for (auto [op, value] : definitions) { |
633 | rewriter.setInsertionPointAfter(op); |
634 | rewriter.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(), |
635 | getLocationExpr()); |
636 | } |
637 | } |
638 | |
639 | //===----------------------------------------------------------------------===// |
640 | // Interfaces for GEPOp |
641 | //===----------------------------------------------------------------------===// |
642 | |
643 | static bool hasAllZeroIndices(LLVM::GEPOp gepOp) { |
644 | return llvm::all_of(gepOp.getIndices(), [](auto index) { |
645 | auto indexAttr = llvm::dyn_cast_if_present<IntegerAttr>(index); |
646 | return indexAttr && indexAttr.getValue() == 0; |
647 | }); |
648 | } |
649 | |
650 | bool LLVM::GEPOp::canUsesBeRemoved( |
651 | const SmallPtrSetImpl<OpOperand *> &blockingUses, |
652 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
653 | const DataLayout &dataLayout) { |
654 | // GEP can be removed as long as it is a no-op and its users can be removed. |
655 | if (!hasAllZeroIndices(*this)) |
656 | return false; |
657 | return forwardToUsers(*this, newBlockingUses); |
658 | } |
659 | |
660 | DeletionKind LLVM::GEPOp::removeBlockingUses( |
661 | const SmallPtrSetImpl<OpOperand *> &blockingUses, RewriterBase &rewriter) { |
662 | return DeletionKind::Delete; |
663 | } |
664 | |
665 | /// Returns the amount of bytes the provided GEP elements will offset the |
666 | /// pointer by. Returns nullopt if no constant offset could be computed. |
667 | static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout, |
668 | LLVM::GEPOp gep) { |
669 | // Collects all indices. |
670 | SmallVector<uint64_t> indices; |
671 | for (auto index : gep.getIndices()) { |
672 | auto constIndex = dyn_cast<IntegerAttr>(index); |
673 | if (!constIndex) |
674 | return {}; |
675 | int64_t gepIndex = constIndex.getInt(); |
676 | // Negative indices are not supported. |
677 | if (gepIndex < 0) |
678 | return {}; |
679 | indices.push_back(gepIndex); |
680 | } |
681 | |
682 | Type currentType = gep.getElemType(); |
683 | uint64_t offset = indices[0] * dataLayout.getTypeSize(t: currentType); |
684 | |
685 | for (uint64_t index : llvm::drop_begin(RangeOrContainer&: indices)) { |
686 | bool shouldCancel = |
687 | TypeSwitch<Type, bool>(currentType) |
688 | .Case(caseFn: [&](LLVM::LLVMArrayType arrayType) { |
689 | offset += |
690 | index * dataLayout.getTypeSize(t: arrayType.getElementType()); |
691 | currentType = arrayType.getElementType(); |
692 | return false; |
693 | }) |
694 | .Case(caseFn: [&](LLVM::LLVMStructType structType) { |
695 | ArrayRef<Type> body = structType.getBody(); |
696 | assert(index < body.size() && "expected valid struct indexing" ); |
697 | for (uint32_t i : llvm::seq(Size: index)) { |
698 | if (!structType.isPacked()) |
699 | offset = llvm::alignTo( |
700 | Value: offset, Align: dataLayout.getTypeABIAlignment(t: body[i])); |
701 | offset += dataLayout.getTypeSize(t: body[i]); |
702 | } |
703 | |
704 | // Align for the current type as well. |
705 | if (!structType.isPacked()) |
706 | offset = llvm::alignTo( |
707 | Value: offset, Align: dataLayout.getTypeABIAlignment(t: body[index])); |
708 | currentType = body[index]; |
709 | return false; |
710 | }) |
711 | .Default(defaultFn: [&](Type type) { |
712 | LLVM_DEBUG(llvm::dbgs() |
713 | << "[sroa] Unsupported type for offset computations" |
714 | << type << "\n" ); |
715 | return true; |
716 | }); |
717 | |
718 | if (shouldCancel) |
719 | return std::nullopt; |
720 | } |
721 | |
722 | return offset; |
723 | } |
724 | |
725 | namespace { |
726 | /// A struct that stores both the index into the aggregate type of the slot as |
727 | /// well as the corresponding byte offset in memory. |
728 | struct SubslotAccessInfo { |
729 | /// The parent slot's index that the access falls into. |
730 | uint32_t index; |
731 | /// The offset into the subslot of the access. |
732 | uint64_t subslotOffset; |
733 | }; |
734 | } // namespace |
735 | |
736 | /// Computes subslot access information for an access into `slot` with the given |
737 | /// offset. |
738 | /// Returns nullopt when the offset is out-of-bounds or when the access is into |
739 | /// the padding of `slot`. |
740 | static std::optional<SubslotAccessInfo> |
741 | getSubslotAccessInfo(const DestructurableMemorySlot &slot, |
742 | const DataLayout &dataLayout, LLVM::GEPOp gep) { |
743 | std::optional<uint64_t> offset = gepToByteOffset(dataLayout, gep); |
744 | if (!offset) |
745 | return {}; |
746 | |
747 | // Helper to check that a constant index is in the bounds of the GEP index |
748 | // representation. LLVM dialects's GEP arguments have a limited bitwidth, thus |
749 | // this additional check is necessary. |
750 | auto isOutOfBoundsGEPIndex = [](uint64_t index) { |
751 | return index >= (1 << LLVM::kGEPConstantBitWidth); |
752 | }; |
753 | |
754 | Type type = slot.elemType; |
755 | if (*offset >= dataLayout.getTypeSize(t: type)) |
756 | return {}; |
757 | return TypeSwitch<Type, std::optional<SubslotAccessInfo>>(type) |
758 | .Case(caseFn: [&](LLVM::LLVMArrayType arrayType) |
759 | -> std::optional<SubslotAccessInfo> { |
760 | // Find which element of the array contains the offset. |
761 | uint64_t elemSize = dataLayout.getTypeSize(t: arrayType.getElementType()); |
762 | uint64_t index = *offset / elemSize; |
763 | if (isOutOfBoundsGEPIndex(index)) |
764 | return {}; |
765 | return SubslotAccessInfo{.index: static_cast<uint32_t>(index), |
766 | .subslotOffset: *offset - (index * elemSize)}; |
767 | }) |
768 | .Case(caseFn: [&](LLVM::LLVMStructType structType) |
769 | -> std::optional<SubslotAccessInfo> { |
770 | uint64_t distanceToStart = 0; |
771 | // Walk over the elements of the struct to find in which of |
772 | // them the offset is. |
773 | for (auto [index, elem] : llvm::enumerate(First: structType.getBody())) { |
774 | uint64_t elemSize = dataLayout.getTypeSize(t: elem); |
775 | if (!structType.isPacked()) { |
776 | distanceToStart = llvm::alignTo( |
777 | Value: distanceToStart, Align: dataLayout.getTypeABIAlignment(t: elem)); |
778 | // If the offset is in padding, cancel the rewrite. |
779 | if (offset < distanceToStart) |
780 | return {}; |
781 | } |
782 | |
783 | if (offset < distanceToStart + elemSize) { |
784 | if (isOutOfBoundsGEPIndex(index)) |
785 | return {}; |
786 | // The offset is within this element, stop iterating the |
787 | // struct and return the index. |
788 | return SubslotAccessInfo{.index: static_cast<uint32_t>(index), |
789 | .subslotOffset: *offset - distanceToStart}; |
790 | } |
791 | |
792 | // The offset is not within this element, continue walking |
793 | // over the struct. |
794 | distanceToStart += elemSize; |
795 | } |
796 | |
797 | return {}; |
798 | }); |
799 | } |
800 | |
801 | /// Constructs a byte array type of the given size. |
802 | static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context, |
803 | unsigned size) { |
804 | auto byteType = IntegerType::get(context, 8); |
805 | return LLVM::LLVMArrayType::get(context, byteType, size); |
806 | } |
807 | |
808 | LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses( |
809 | const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
810 | const DataLayout &dataLayout) { |
811 | if (getBase() != slot.ptr) |
812 | return success(); |
813 | std::optional<uint64_t> gepOffset = gepToByteOffset(dataLayout, *this); |
814 | if (!gepOffset) |
815 | return failure(); |
816 | uint64_t slotSize = dataLayout.getTypeSize(slot.elemType); |
817 | // Check that the access is strictly inside the slot. |
818 | if (*gepOffset >= slotSize) |
819 | return failure(); |
820 | // Every access that remains in bounds of the remaining slot is considered |
821 | // legal. |
822 | mustBeSafelyUsed.emplace_back<MemorySlot>( |
823 | {getRes(), getByteArrayType(getContext(), slotSize - *gepOffset)}); |
824 | return success(); |
825 | } |
826 | |
827 | bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot, |
828 | SmallPtrSetImpl<Attribute> &usedIndices, |
829 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
830 | const DataLayout &dataLayout) { |
831 | if (!isa<LLVM::LLVMPointerType>(getBase().getType())) |
832 | return false; |
833 | |
834 | if (getBase() != slot.ptr) |
835 | return false; |
836 | std::optional<SubslotAccessInfo> accessInfo = |
837 | getSubslotAccessInfo(slot, dataLayout, *this); |
838 | if (!accessInfo) |
839 | return false; |
840 | auto indexAttr = |
841 | IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index); |
842 | assert(slot.elementPtrs.contains(indexAttr)); |
843 | usedIndices.insert(indexAttr); |
844 | |
845 | // The remainder of the subslot should be accesses in-bounds. Thus, we create |
846 | // a dummy slot with the size of the remainder. |
847 | Type subslotType = slot.elementPtrs.lookup(indexAttr); |
848 | uint64_t slotSize = dataLayout.getTypeSize(subslotType); |
849 | LLVM::LLVMArrayType remainingSlotType = |
850 | getByteArrayType(getContext(), slotSize - accessInfo->subslotOffset); |
851 | mustBeSafelyUsed.emplace_back<MemorySlot>({getRes(), remainingSlotType}); |
852 | |
853 | return true; |
854 | } |
855 | |
856 | DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot, |
857 | DenseMap<Attribute, MemorySlot> &subslots, |
858 | RewriterBase &rewriter, |
859 | const DataLayout &dataLayout) { |
860 | std::optional<SubslotAccessInfo> accessInfo = |
861 | getSubslotAccessInfo(slot, dataLayout, *this); |
862 | assert(accessInfo && "expected access info to be checked before" ); |
863 | auto indexAttr = |
864 | IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index); |
865 | const MemorySlot &newSlot = subslots.at(indexAttr); |
866 | |
867 | auto byteType = IntegerType::get(rewriter.getContext(), 8); |
868 | auto newPtr = rewriter.createOrFold<LLVM::GEPOp>( |
869 | getLoc(), getResult().getType(), byteType, newSlot.ptr, |
870 | ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds()); |
871 | rewriter.replaceAllUsesWith(getResult(), newPtr); |
872 | return DeletionKind::Delete; |
873 | } |
874 | |
875 | //===----------------------------------------------------------------------===// |
876 | // Utilities for memory intrinsics |
877 | //===----------------------------------------------------------------------===// |
878 | |
879 | namespace { |
880 | |
881 | /// Returns the length of the given memory intrinsic in bytes if it can be known |
882 | /// at compile-time on a best-effort basis, nothing otherwise. |
883 | template <class MemIntr> |
884 | std::optional<uint64_t> getStaticMemIntrLen(MemIntr op) { |
885 | APInt memIntrLen; |
886 | if (!matchPattern(op.getLen(), m_ConstantInt(&memIntrLen))) |
887 | return {}; |
888 | if (memIntrLen.getBitWidth() > 64) |
889 | return {}; |
890 | return memIntrLen.getZExtValue(); |
891 | } |
892 | |
893 | /// Returns the length of the given memory intrinsic in bytes if it can be known |
894 | /// at compile-time on a best-effort basis, nothing otherwise. |
895 | /// Because MemcpyInlineOp has its length encoded as an attribute, this requires |
896 | /// specialized handling. |
897 | template <> |
898 | std::optional<uint64_t> getStaticMemIntrLen(LLVM::MemcpyInlineOp op) { |
899 | APInt memIntrLen = op.getLen(); |
900 | if (memIntrLen.getBitWidth() > 64) |
901 | return {}; |
902 | return memIntrLen.getZExtValue(); |
903 | } |
904 | |
905 | } // namespace |
906 | |
907 | /// Returns whether one can be sure the memory intrinsic does not write outside |
908 | /// of the bounds of the given slot, on a best-effort basis. |
909 | template <class MemIntr> |
910 | static bool definitelyWritesOnlyWithinSlot(MemIntr op, const MemorySlot &slot, |
911 | const DataLayout &dataLayout) { |
912 | if (!isa<LLVM::LLVMPointerType>(slot.ptr.getType()) || |
913 | op.getDst() != slot.ptr) |
914 | return false; |
915 | |
916 | std::optional<uint64_t> memIntrLen = getStaticMemIntrLen(op); |
917 | return memIntrLen && *memIntrLen <= dataLayout.getTypeSize(t: slot.elemType); |
918 | } |
919 | |
920 | /// Checks whether all indices are i32. This is used to check GEPs can index |
921 | /// into them. |
922 | static bool areAllIndicesI32(const DestructurableMemorySlot &slot) { |
923 | Type i32 = IntegerType::get(slot.ptr.getContext(), 32); |
924 | return llvm::all_of(Range: llvm::make_first_range(c: slot.elementPtrs), |
925 | P: [&](Attribute index) { |
926 | auto intIndex = dyn_cast<IntegerAttr>(index); |
927 | return intIndex && intIndex.getType() == i32; |
928 | }); |
929 | } |
930 | |
931 | //===----------------------------------------------------------------------===// |
932 | // Interfaces for memset |
933 | //===----------------------------------------------------------------------===// |
934 | |
935 | bool LLVM::MemsetOp::loadsFrom(const MemorySlot &slot) { return false; } |
936 | |
937 | bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) { |
938 | return getDst() == slot.ptr; |
939 | } |
940 | |
941 | Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter, |
942 | Value reachingDef, |
943 | const DataLayout &dataLayout) { |
944 | // TODO: Support non-integer types. |
945 | return TypeSwitch<Type, Value>(slot.elemType) |
946 | .Case([&](IntegerType intType) -> Value { |
947 | if (intType.getWidth() == 8) |
948 | return getVal(); |
949 | |
950 | assert(intType.getWidth() % 8 == 0); |
951 | |
952 | // Build the memset integer by repeatedly shifting the value and |
953 | // or-ing it with the previous value. |
954 | uint64_t coveredBits = 8; |
955 | Value currentValue = |
956 | rewriter.create<LLVM::ZExtOp>(getLoc(), intType, getVal()); |
957 | while (coveredBits < intType.getWidth()) { |
958 | Value shiftBy = |
959 | rewriter.create<LLVM::ConstantOp>(getLoc(), intType, coveredBits); |
960 | Value shifted = |
961 | rewriter.create<LLVM::ShlOp>(getLoc(), currentValue, shiftBy); |
962 | currentValue = |
963 | rewriter.create<LLVM::OrOp>(getLoc(), currentValue, shifted); |
964 | coveredBits *= 2; |
965 | } |
966 | |
967 | return currentValue; |
968 | }) |
969 | .Default([](Type) -> Value { |
970 | llvm_unreachable( |
971 | "getStored should not be called on memset to unsupported type" ); |
972 | }); |
973 | } |
974 | |
975 | bool LLVM::MemsetOp::canUsesBeRemoved( |
976 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
977 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
978 | const DataLayout &dataLayout) { |
979 | // TODO: Support non-integer types. |
980 | bool canConvertType = |
981 | TypeSwitch<Type, bool>(slot.elemType) |
982 | .Case([](IntegerType intType) { |
983 | return intType.getWidth() % 8 == 0 && intType.getWidth() > 0; |
984 | }) |
985 | .Default([](Type) { return false; }); |
986 | if (!canConvertType) |
987 | return false; |
988 | |
989 | if (getIsVolatile()) |
990 | return false; |
991 | |
992 | return getStaticMemIntrLen(*this) == dataLayout.getTypeSize(slot.elemType); |
993 | } |
994 | |
995 | DeletionKind LLVM::MemsetOp::removeBlockingUses( |
996 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
997 | RewriterBase &rewriter, Value reachingDefinition, |
998 | const DataLayout &dataLayout) { |
999 | return DeletionKind::Delete; |
1000 | } |
1001 | |
1002 | LogicalResult LLVM::MemsetOp::ensureOnlySafeAccesses( |
1003 | const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
1004 | const DataLayout &dataLayout) { |
1005 | return success(definitelyWritesOnlyWithinSlot(*this, slot, dataLayout)); |
1006 | } |
1007 | |
1008 | bool LLVM::MemsetOp::canRewire(const DestructurableMemorySlot &slot, |
1009 | SmallPtrSetImpl<Attribute> &usedIndices, |
1010 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
1011 | const DataLayout &dataLayout) { |
1012 | if (&slot.elemType.getDialect() != getOperation()->getDialect()) |
1013 | return false; |
1014 | |
1015 | if (getIsVolatile()) |
1016 | return false; |
1017 | |
1018 | if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap()) |
1019 | return false; |
1020 | |
1021 | if (!areAllIndicesI32(slot)) |
1022 | return false; |
1023 | |
1024 | return definitelyWritesOnlyWithinSlot(*this, slot, dataLayout); |
1025 | } |
1026 | |
1027 | DeletionKind LLVM::MemsetOp::rewire(const DestructurableMemorySlot &slot, |
1028 | DenseMap<Attribute, MemorySlot> &subslots, |
1029 | RewriterBase &rewriter, |
1030 | const DataLayout &dataLayout) { |
1031 | std::optional<DenseMap<Attribute, Type>> types = |
1032 | cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap(); |
1033 | |
1034 | IntegerAttr memsetLenAttr; |
1035 | bool successfulMatch = |
1036 | matchPattern(getLen(), m_Constant<IntegerAttr>(&memsetLenAttr)); |
1037 | (void)successfulMatch; |
1038 | assert(successfulMatch); |
1039 | |
1040 | bool packed = false; |
1041 | if (auto structType = dyn_cast<LLVM::LLVMStructType>(slot.elemType)) |
1042 | packed = structType.isPacked(); |
1043 | |
1044 | Type i32 = IntegerType::get(getContext(), 32); |
1045 | uint64_t memsetLen = memsetLenAttr.getValue().getZExtValue(); |
1046 | uint64_t covered = 0; |
1047 | for (size_t i = 0; i < types->size(); i++) { |
1048 | // Create indices on the fly to get elements in the right order. |
1049 | Attribute index = IntegerAttr::get(i32, i); |
1050 | Type elemType = types->at(index); |
1051 | uint64_t typeSize = dataLayout.getTypeSize(elemType); |
1052 | |
1053 | if (!packed) |
1054 | covered = |
1055 | llvm::alignTo(covered, dataLayout.getTypeABIAlignment(elemType)); |
1056 | |
1057 | if (covered >= memsetLen) |
1058 | break; |
1059 | |
1060 | // If this subslot is used, apply a new memset to it. |
1061 | // Otherwise, only compute its offset within the original memset. |
1062 | if (subslots.contains(index)) { |
1063 | uint64_t newMemsetSize = std::min(memsetLen - covered, typeSize); |
1064 | |
1065 | Value newMemsetSizeValue = |
1066 | rewriter |
1067 | .create<LLVM::ConstantOp>( |
1068 | getLen().getLoc(), |
1069 | IntegerAttr::get(memsetLenAttr.getType(), newMemsetSize)) |
1070 | .getResult(); |
1071 | |
1072 | rewriter.create<LLVM::MemsetOp>(getLoc(), subslots.at(index).ptr, |
1073 | getVal(), newMemsetSizeValue, |
1074 | getIsVolatile()); |
1075 | } |
1076 | |
1077 | covered += typeSize; |
1078 | } |
1079 | |
1080 | return DeletionKind::Delete; |
1081 | } |
1082 | |
1083 | //===----------------------------------------------------------------------===// |
1084 | // Interfaces for memcpy/memmove |
1085 | //===----------------------------------------------------------------------===// |
1086 | |
1087 | template <class MemcpyLike> |
1088 | static bool memcpyLoadsFrom(MemcpyLike op, const MemorySlot &slot) { |
1089 | return op.getSrc() == slot.ptr; |
1090 | } |
1091 | |
1092 | template <class MemcpyLike> |
1093 | static bool memcpyStoresTo(MemcpyLike op, const MemorySlot &slot) { |
1094 | return op.getDst() == slot.ptr; |
1095 | } |
1096 | |
1097 | template <class MemcpyLike> |
1098 | static Value memcpyGetStored(MemcpyLike op, const MemorySlot &slot, |
1099 | RewriterBase &rewriter) { |
1100 | return rewriter.create<LLVM::LoadOp>(op.getLoc(), slot.elemType, op.getSrc()); |
1101 | } |
1102 | |
1103 | template <class MemcpyLike> |
1104 | static bool |
1105 | memcpyCanUsesBeRemoved(MemcpyLike op, const MemorySlot &slot, |
1106 | const SmallPtrSetImpl<OpOperand *> &blockingUses, |
1107 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
1108 | const DataLayout &dataLayout) { |
1109 | // If source and destination are the same, memcpy behavior is undefined and |
1110 | // memmove is a no-op. Because there is no memory change happening here, |
1111 | // simplifying such operations is left to canonicalization. |
1112 | if (op.getDst() == op.getSrc()) |
1113 | return false; |
1114 | |
1115 | if (op.getIsVolatile()) |
1116 | return false; |
1117 | |
1118 | return getStaticMemIntrLen(op) == dataLayout.getTypeSize(t: slot.elemType); |
1119 | } |
1120 | |
1121 | template <class MemcpyLike> |
1122 | static DeletionKind |
1123 | memcpyRemoveBlockingUses(MemcpyLike op, const MemorySlot &slot, |
1124 | const SmallPtrSetImpl<OpOperand *> &blockingUses, |
1125 | RewriterBase &rewriter, Value reachingDefinition) { |
1126 | if (op.loadsFrom(slot)) |
1127 | rewriter.create<LLVM::StoreOp>(op.getLoc(), reachingDefinition, |
1128 | op.getDst()); |
1129 | return DeletionKind::Delete; |
1130 | } |
1131 | |
1132 | template <class MemcpyLike> |
1133 | static LogicalResult |
1134 | memcpyEnsureOnlySafeAccesses(MemcpyLike op, const MemorySlot &slot, |
1135 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed) { |
1136 | DataLayout dataLayout = DataLayout::closest(op); |
1137 | // While rewiring memcpy-like intrinsics only supports full copies, partial |
1138 | // copies are still safe accesses so it is enough to only check for writes |
1139 | // within bounds. |
1140 | return success(definitelyWritesOnlyWithinSlot(op, slot, dataLayout)); |
1141 | } |
1142 | |
1143 | template <class MemcpyLike> |
1144 | static bool memcpyCanRewire(MemcpyLike op, const DestructurableMemorySlot &slot, |
1145 | SmallPtrSetImpl<Attribute> &usedIndices, |
1146 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
1147 | const DataLayout &dataLayout) { |
1148 | if (op.getIsVolatile()) |
1149 | return false; |
1150 | |
1151 | if (!cast<DestructurableTypeInterface>(slot.elemType).getSubelementIndexMap()) |
1152 | return false; |
1153 | |
1154 | if (!areAllIndicesI32(slot)) |
1155 | return false; |
1156 | |
1157 | // Only full copies are supported. |
1158 | if (getStaticMemIntrLen(op) != dataLayout.getTypeSize(t: slot.elemType)) |
1159 | return false; |
1160 | |
1161 | if (op.getSrc() == slot.ptr) |
1162 | for (Attribute index : llvm::make_first_range(c: slot.elementPtrs)) |
1163 | usedIndices.insert(Ptr: index); |
1164 | |
1165 | return true; |
1166 | } |
1167 | |
1168 | namespace { |
1169 | |
1170 | template <class MemcpyLike> |
1171 | void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout, |
1172 | MemcpyLike toReplace, Value dst, Value src, |
1173 | Type toCpy, bool isVolatile) { |
1174 | Value memcpySize = rewriter.create<LLVM::ConstantOp>( |
1175 | toReplace.getLoc(), IntegerAttr::get(toReplace.getLen().getType(), |
1176 | layout.getTypeSize(toCpy))); |
1177 | rewriter.create<MemcpyLike>(toReplace.getLoc(), dst, src, memcpySize, |
1178 | isVolatile); |
1179 | } |
1180 | |
1181 | template <> |
1182 | void createMemcpyLikeToReplace(RewriterBase &rewriter, const DataLayout &layout, |
1183 | LLVM::MemcpyInlineOp toReplace, Value dst, |
1184 | Value src, Type toCpy, bool isVolatile) { |
1185 | Type lenType = IntegerType::get(toReplace->getContext(), |
1186 | toReplace.getLen().getBitWidth()); |
1187 | rewriter.create<LLVM::MemcpyInlineOp>( |
1188 | toReplace.getLoc(), dst, src, |
1189 | IntegerAttr::get(lenType, layout.getTypeSize(toCpy)), isVolatile); |
1190 | } |
1191 | |
1192 | } // namespace |
1193 | |
1194 | /// Rewires a memcpy-like operation. Only copies to or from the full slot are |
1195 | /// supported. |
1196 | template <class MemcpyLike> |
1197 | static DeletionKind |
1198 | memcpyRewire(MemcpyLike op, const DestructurableMemorySlot &slot, |
1199 | DenseMap<Attribute, MemorySlot> &subslots, RewriterBase &rewriter, |
1200 | const DataLayout &dataLayout) { |
1201 | if (subslots.empty()) |
1202 | return DeletionKind::Delete; |
1203 | |
1204 | assert((slot.ptr == op.getDst()) != (slot.ptr == op.getSrc())); |
1205 | bool isDst = slot.ptr == op.getDst(); |
1206 | |
1207 | #ifndef NDEBUG |
1208 | size_t slotsTreated = 0; |
1209 | #endif |
1210 | |
1211 | // It was previously checked that index types are consistent, so this type can |
1212 | // be fetched now. |
1213 | Type indexType = cast<IntegerAttr>(subslots.begin()->first).getType(); |
1214 | for (size_t i = 0, e = slot.elementPtrs.size(); i != e; i++) { |
1215 | Attribute index = IntegerAttr::get(indexType, i); |
1216 | if (!subslots.contains(Val: index)) |
1217 | continue; |
1218 | const MemorySlot &subslot = subslots.at(Val: index); |
1219 | |
1220 | #ifndef NDEBUG |
1221 | slotsTreated++; |
1222 | #endif |
1223 | |
1224 | // First get a pointer to the equivalent of this subslot from the source |
1225 | // pointer. |
1226 | SmallVector<LLVM::GEPArg> gepIndices{ |
1227 | 0, static_cast<int32_t>( |
1228 | cast<IntegerAttr>(index).getValue().getZExtValue())}; |
1229 | Value subslotPtrInOther = rewriter.create<LLVM::GEPOp>( |
1230 | op.getLoc(), LLVM::LLVMPointerType::get(op.getContext()), slot.elemType, |
1231 | isDst ? op.getSrc() : op.getDst(), gepIndices); |
1232 | |
1233 | // Then create a new memcpy out of this source pointer. |
1234 | createMemcpyLikeToReplace(rewriter, dataLayout, op, |
1235 | isDst ? subslot.ptr : subslotPtrInOther, |
1236 | isDst ? subslotPtrInOther : subslot.ptr, |
1237 | subslot.elemType, op.getIsVolatile()); |
1238 | } |
1239 | |
1240 | assert(subslots.size() == slotsTreated); |
1241 | |
1242 | return DeletionKind::Delete; |
1243 | } |
1244 | |
1245 | bool LLVM::MemcpyOp::loadsFrom(const MemorySlot &slot) { |
1246 | return memcpyLoadsFrom(*this, slot); |
1247 | } |
1248 | |
1249 | bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) { |
1250 | return memcpyStoresTo(*this, slot); |
1251 | } |
1252 | |
1253 | Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter, |
1254 | Value reachingDef, |
1255 | const DataLayout &dataLayout) { |
1256 | return memcpyGetStored(*this, slot, rewriter); |
1257 | } |
1258 | |
1259 | bool LLVM::MemcpyOp::canUsesBeRemoved( |
1260 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
1261 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
1262 | const DataLayout &dataLayout) { |
1263 | return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses, |
1264 | dataLayout); |
1265 | } |
1266 | |
1267 | DeletionKind LLVM::MemcpyOp::removeBlockingUses( |
1268 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
1269 | RewriterBase &rewriter, Value reachingDefinition, |
1270 | const DataLayout &dataLayout) { |
1271 | return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter, |
1272 | reachingDefinition); |
1273 | } |
1274 | |
1275 | LogicalResult LLVM::MemcpyOp::ensureOnlySafeAccesses( |
1276 | const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
1277 | const DataLayout &dataLayout) { |
1278 | return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed); |
1279 | } |
1280 | |
1281 | bool LLVM::MemcpyOp::canRewire(const DestructurableMemorySlot &slot, |
1282 | SmallPtrSetImpl<Attribute> &usedIndices, |
1283 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
1284 | const DataLayout &dataLayout) { |
1285 | return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed, |
1286 | dataLayout); |
1287 | } |
1288 | |
1289 | DeletionKind LLVM::MemcpyOp::rewire(const DestructurableMemorySlot &slot, |
1290 | DenseMap<Attribute, MemorySlot> &subslots, |
1291 | RewriterBase &rewriter, |
1292 | const DataLayout &dataLayout) { |
1293 | return memcpyRewire(*this, slot, subslots, rewriter, dataLayout); |
1294 | } |
1295 | |
1296 | bool LLVM::MemcpyInlineOp::loadsFrom(const MemorySlot &slot) { |
1297 | return memcpyLoadsFrom(*this, slot); |
1298 | } |
1299 | |
1300 | bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) { |
1301 | return memcpyStoresTo(*this, slot); |
1302 | } |
1303 | |
1304 | Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot, |
1305 | RewriterBase &rewriter, Value reachingDef, |
1306 | const DataLayout &dataLayout) { |
1307 | return memcpyGetStored(*this, slot, rewriter); |
1308 | } |
1309 | |
1310 | bool LLVM::MemcpyInlineOp::canUsesBeRemoved( |
1311 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
1312 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
1313 | const DataLayout &dataLayout) { |
1314 | return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses, |
1315 | dataLayout); |
1316 | } |
1317 | |
1318 | DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses( |
1319 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
1320 | RewriterBase &rewriter, Value reachingDefinition, |
1321 | const DataLayout &dataLayout) { |
1322 | return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter, |
1323 | reachingDefinition); |
1324 | } |
1325 | |
1326 | LogicalResult LLVM::MemcpyInlineOp::ensureOnlySafeAccesses( |
1327 | const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
1328 | const DataLayout &dataLayout) { |
1329 | return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed); |
1330 | } |
1331 | |
1332 | bool LLVM::MemcpyInlineOp::canRewire( |
1333 | const DestructurableMemorySlot &slot, |
1334 | SmallPtrSetImpl<Attribute> &usedIndices, |
1335 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
1336 | const DataLayout &dataLayout) { |
1337 | return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed, |
1338 | dataLayout); |
1339 | } |
1340 | |
1341 | DeletionKind |
1342 | LLVM::MemcpyInlineOp::rewire(const DestructurableMemorySlot &slot, |
1343 | DenseMap<Attribute, MemorySlot> &subslots, |
1344 | RewriterBase &rewriter, |
1345 | const DataLayout &dataLayout) { |
1346 | return memcpyRewire(*this, slot, subslots, rewriter, dataLayout); |
1347 | } |
1348 | |
1349 | bool LLVM::MemmoveOp::loadsFrom(const MemorySlot &slot) { |
1350 | return memcpyLoadsFrom(*this, slot); |
1351 | } |
1352 | |
1353 | bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) { |
1354 | return memcpyStoresTo(*this, slot); |
1355 | } |
1356 | |
1357 | Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter, |
1358 | Value reachingDef, |
1359 | const DataLayout &dataLayout) { |
1360 | return memcpyGetStored(*this, slot, rewriter); |
1361 | } |
1362 | |
1363 | bool LLVM::MemmoveOp::canUsesBeRemoved( |
1364 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
1365 | SmallVectorImpl<OpOperand *> &newBlockingUses, |
1366 | const DataLayout &dataLayout) { |
1367 | return memcpyCanUsesBeRemoved(*this, slot, blockingUses, newBlockingUses, |
1368 | dataLayout); |
1369 | } |
1370 | |
1371 | DeletionKind LLVM::MemmoveOp::removeBlockingUses( |
1372 | const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses, |
1373 | RewriterBase &rewriter, Value reachingDefinition, |
1374 | const DataLayout &dataLayout) { |
1375 | return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter, |
1376 | reachingDefinition); |
1377 | } |
1378 | |
1379 | LogicalResult LLVM::MemmoveOp::ensureOnlySafeAccesses( |
1380 | const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
1381 | const DataLayout &dataLayout) { |
1382 | return memcpyEnsureOnlySafeAccesses(*this, slot, mustBeSafelyUsed); |
1383 | } |
1384 | |
1385 | bool LLVM::MemmoveOp::canRewire(const DestructurableMemorySlot &slot, |
1386 | SmallPtrSetImpl<Attribute> &usedIndices, |
1387 | SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
1388 | const DataLayout &dataLayout) { |
1389 | return memcpyCanRewire(*this, slot, usedIndices, mustBeSafelyUsed, |
1390 | dataLayout); |
1391 | } |
1392 | |
1393 | DeletionKind LLVM::MemmoveOp::rewire(const DestructurableMemorySlot &slot, |
1394 | DenseMap<Attribute, MemorySlot> &subslots, |
1395 | RewriterBase &rewriter, |
1396 | const DataLayout &dataLayout) { |
1397 | return memcpyRewire(*this, slot, subslots, rewriter, dataLayout); |
1398 | } |
1399 | |
1400 | //===----------------------------------------------------------------------===// |
1401 | // Interfaces for destructurable types |
1402 | //===----------------------------------------------------------------------===// |
1403 | |
1404 | std::optional<DenseMap<Attribute, Type>> |
1405 | LLVM::LLVMStructType::getSubelementIndexMap() { |
1406 | Type i32 = IntegerType::get(getContext(), 32); |
1407 | DenseMap<Attribute, Type> destructured; |
1408 | for (const auto &[index, elemType] : llvm::enumerate(getBody())) |
1409 | destructured.insert({IntegerAttr::get(i32, index), elemType}); |
1410 | return destructured; |
1411 | } |
1412 | |
1413 | Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) { |
1414 | auto indexAttr = llvm::dyn_cast<IntegerAttr>(index); |
1415 | if (!indexAttr || !indexAttr.getType().isInteger(32)) |
1416 | return {}; |
1417 | int32_t indexInt = indexAttr.getInt(); |
1418 | ArrayRef<Type> body = getBody(); |
1419 | if (indexInt < 0 || body.size() <= static_cast<uint32_t>(indexInt)) |
1420 | return {}; |
1421 | return body[indexInt]; |
1422 | } |
1423 | |
1424 | std::optional<DenseMap<Attribute, Type>> |
1425 | LLVM::LLVMArrayType::getSubelementIndexMap() const { |
1426 | constexpr size_t maxArraySizeForDestructuring = 16; |
1427 | if (getNumElements() > maxArraySizeForDestructuring) |
1428 | return {}; |
1429 | int32_t numElements = getNumElements(); |
1430 | |
1431 | Type i32 = IntegerType::get(getContext(), 32); |
1432 | DenseMap<Attribute, Type> destructured; |
1433 | for (int32_t index = 0; index < numElements; ++index) |
1434 | destructured.insert({IntegerAttr::get(i32, index), getElementType()}); |
1435 | return destructured; |
1436 | } |
1437 | |
1438 | Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const { |
1439 | auto indexAttr = llvm::dyn_cast<IntegerAttr>(index); |
1440 | if (!indexAttr || !indexAttr.getType().isInteger(32)) |
1441 | return {}; |
1442 | int32_t indexInt = indexAttr.getInt(); |
1443 | if (indexInt < 0 || getNumElements() <= static_cast<uint32_t>(indexInt)) |
1444 | return {}; |
1445 | return getElementType(); |
1446 | } |
1447 | |