1//===- TypeConsistency.cpp - Rewrites to improve type consistency ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h"
10#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
11#include "llvm/ADT/TypeSwitch.h"
12
13namespace mlir {
14namespace LLVM {
15#define GEN_PASS_DEF_LLVMTYPECONSISTENCY
16#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc"
17} // namespace LLVM
18} // namespace mlir
19
20using namespace mlir;
21using namespace LLVM;
22
23//===----------------------------------------------------------------------===//
24// Utils
25//===----------------------------------------------------------------------===//
26
27/// Checks that a pointer value has a pointee type hint consistent with the
28/// expected type. Returns the type it actually hints to if it differs, or
29/// nullptr if the type is consistent or impossible to analyze.
30static Type isElementTypeInconsistent(Value addr, Type expectedType) {
31 auto defOp = dyn_cast_or_null<GetResultPtrElementType>(addr.getDefiningOp());
32 if (!defOp)
33 return nullptr;
34
35 Type elemType = defOp.getResultPtrElementType();
36 if (!elemType)
37 return nullptr;
38
39 if (elemType == expectedType)
40 return nullptr;
41
42 return elemType;
43}
44
45//===----------------------------------------------------------------------===//
46// CanonicalizeAlignedGep
47//===----------------------------------------------------------------------===//
48
49/// Returns the amount of bytes the provided GEP elements will offset the
50/// pointer by. Returns nullopt if the offset could not be computed.
51static std::optional<uint64_t> gepToByteOffset(DataLayout &layout, GEPOp gep) {
52
53 SmallVector<uint32_t> indices;
54 // Ensures all indices are static and fetches them.
55 for (auto index : gep.getIndices()) {
56 IntegerAttr indexInt = llvm::dyn_cast_if_present<IntegerAttr>(index);
57 if (!indexInt)
58 return std::nullopt;
59 int32_t gepIndex = indexInt.getInt();
60 if (gepIndex < 0)
61 return std::nullopt;
62 indices.push_back(static_cast<uint32_t>(gepIndex));
63 }
64
65 uint64_t offset = indices[0] * layout.getTypeSize(t: gep.getElemType());
66
67 Type currentType = gep.getElemType();
68 for (uint32_t index : llvm::drop_begin(RangeOrContainer&: indices)) {
69 bool shouldCancel =
70 TypeSwitch<Type, bool>(currentType)
71 .Case(caseFn: [&](LLVMArrayType arrayType) {
72 if (arrayType.getNumElements() <= index)
73 return true;
74 offset += index * layout.getTypeSize(t: arrayType.getElementType());
75 currentType = arrayType.getElementType();
76 return false;
77 })
78 .Case(caseFn: [&](LLVMStructType structType) {
79 ArrayRef<Type> body = structType.getBody();
80 if (body.size() <= index)
81 return true;
82 for (uint32_t i = 0; i < index; i++) {
83 if (!structType.isPacked())
84 offset = llvm::alignTo(Value: offset,
85 Align: layout.getTypeABIAlignment(t: body[i]));
86 offset += layout.getTypeSize(t: body[i]);
87 }
88 currentType = body[index];
89 return false;
90 })
91 .Default(defaultFn: [](Type) { return true; });
92
93 if (shouldCancel)
94 return std::nullopt;
95 }
96
97 return offset;
98}
99
100/// Fills in `equivalentIndicesOut` with GEP indices that would be equivalent to
101/// offsetting a pointer by `offset` bytes, assuming the GEP has `base` as base
102/// type.
103static LogicalResult
104findIndicesForOffset(DataLayout &layout, Type base, uint64_t offset,
105 SmallVectorImpl<GEPArg> &equivalentIndicesOut) {
106
107 uint64_t baseSize = layout.getTypeSize(t: base);
108 uint64_t rootIndex = offset / baseSize;
109 if (rootIndex > std::numeric_limits<uint32_t>::max())
110 return failure();
111 equivalentIndicesOut.push_back(Elt: rootIndex);
112
113 uint64_t distanceToStart = rootIndex * baseSize;
114
115#ifndef NDEBUG
116 auto isWithinCurrentType = [&](Type currentType) {
117 return offset < distanceToStart + layout.getTypeSize(t: currentType);
118 };
119#endif
120
121 Type currentType = base;
122 while (distanceToStart < offset) {
123 // While an index that does not perfectly align with offset has not been
124 // reached...
125
126 assert(isWithinCurrentType(currentType));
127
128 bool shouldCancel =
129 TypeSwitch<Type, bool>(currentType)
130 .Case(caseFn: [&](LLVMArrayType arrayType) {
131 // Find which element of the array contains the offset.
132 uint64_t elemSize =
133 layout.getTypeSize(t: arrayType.getElementType());
134 uint64_t index = (offset - distanceToStart) / elemSize;
135 equivalentIndicesOut.push_back(Elt: index);
136 distanceToStart += index * elemSize;
137
138 // Then, try to find where in the element the offset is. If the
139 // offset is exactly the beginning of the element, the loop is
140 // complete.
141 currentType = arrayType.getElementType();
142
143 // Only continue if the element in question can be indexed using
144 // an i32.
145 return index > std::numeric_limits<uint32_t>::max();
146 })
147 .Case(caseFn: [&](LLVMStructType structType) {
148 ArrayRef<Type> body = structType.getBody();
149 uint32_t index = 0;
150
151 // Walk over the elements of the struct to find in which of them
152 // the offset is.
153 for (Type elem : body) {
154 uint64_t elemSize = layout.getTypeSize(t: elem);
155 if (!structType.isPacked()) {
156 distanceToStart = llvm::alignTo(
157 Value: distanceToStart, Align: layout.getTypeABIAlignment(t: elem));
158 // If the offset is in padding, cancel the rewrite.
159 if (offset < distanceToStart)
160 return true;
161 }
162
163 if (offset < distanceToStart + elemSize) {
164 // The offset is within this element, stop iterating the
165 // struct and look within the current element.
166 equivalentIndicesOut.push_back(Elt: index);
167 currentType = elem;
168 return false;
169 }
170
171 // The offset is not within this element, continue walking over
172 // the struct.
173 distanceToStart += elemSize;
174 index++;
175 }
176
177 // The offset was supposed to be within this struct but is not.
178 // This can happen if the offset points into final padding.
179 // Anyway, nothing can be done.
180 return true;
181 })
182 .Default(defaultFn: [](Type) {
183 // If the offset is within a type that cannot be split, no indices
184 // will yield this offset. This can happen if the offset is not
185 // perfectly aligned with a leaf type.
186 // TODO: support vectors.
187 return true;
188 });
189
190 if (shouldCancel)
191 return failure();
192 }
193
194 return success();
195}
196
197/// Returns the consistent type for the GEP if the GEP is not type-consistent.
198/// Returns failure if the GEP is already consistent.
199static FailureOr<Type> getRequiredConsistentGEPType(GEPOp gep) {
200 // GEP of typed pointers are not supported.
201 if (!gep.getElemType())
202 return failure();
203
204 std::optional<Type> maybeBaseType = gep.getElemType();
205 if (!maybeBaseType)
206 return failure();
207 Type baseType = *maybeBaseType;
208
209 Type typeHint = isElementTypeInconsistent(gep.getBase(), baseType);
210 if (!typeHint)
211 return failure();
212 return typeHint;
213}
214
215LogicalResult
216CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep,
217 PatternRewriter &rewriter) const {
218 FailureOr<Type> typeHint = getRequiredConsistentGEPType(gep);
219 if (failed(result: typeHint)) {
220 // GEP is already canonical, nothing to do here.
221 return failure();
222 }
223
224 DataLayout layout = DataLayout::closest(op: gep);
225 std::optional<uint64_t> desiredOffset = gepToByteOffset(layout, gep);
226 if (!desiredOffset)
227 return failure();
228
229 SmallVector<GEPArg> newIndices;
230 if (failed(
231 result: findIndicesForOffset(layout, base: *typeHint, offset: *desiredOffset, equivalentIndicesOut&: newIndices)))
232 return failure();
233
234 rewriter.replaceOpWithNewOp<GEPOp>(
235 gep, LLVM::LLVMPointerType::get(getContext()), *typeHint, gep.getBase(),
236 newIndices, gep.getInbounds());
237
238 return success();
239}
240
241namespace {
242/// Class abstracting over both array and struct types, turning each into ranges
243/// of their sub-types.
244class DestructurableTypeRange
245 : public llvm::indexed_accessor_range<DestructurableTypeRange,
246 DestructurableTypeInterface, Type,
247 Type *, Type> {
248
249 using Base = llvm::indexed_accessor_range<
250 DestructurableTypeRange, DestructurableTypeInterface, Type, Type *, Type>;
251
252public:
253 using Base::Base;
254
255 /// Constructs a DestructurableTypeRange from either a LLVMStructType or
256 /// LLVMArrayType.
257 explicit DestructurableTypeRange(DestructurableTypeInterface base)
258 : Base(base, 0, [&]() -> ptrdiff_t {
259 return TypeSwitch<DestructurableTypeInterface, ptrdiff_t>(base)
260 .Case([](LLVMStructType structType) {
261 return structType.getBody().size();
262 })
263 .Case([](LLVMArrayType arrayType) {
264 return arrayType.getNumElements();
265 })
266 .Default([](auto) -> ptrdiff_t {
267 llvm_unreachable(
268 "Only LLVMStructType or LLVMArrayType supported");
269 });
270 }()) {}
271
272 /// Returns true if this is a range over a packed struct.
273 bool isPacked() const {
274 if (auto structType = dyn_cast<LLVMStructType>(getBase()))
275 return structType.isPacked();
276 return false;
277 }
278
279private:
280 static Type dereference(DestructurableTypeInterface base, ptrdiff_t index) {
281 // i32 chosen because the implementations of ArrayType and StructType
282 // specifically expect it to be 32 bit. They will fail otherwise.
283 Type result = base.getTypeAtIndex(
284 IntegerAttr::get(IntegerType::get(base.getContext(), 32), index));
285 assert(result && "Should always succeed");
286 return result;
287 }
288
289 friend Base;
290};
291} // namespace
292
293/// Returns the list of elements of `destructurableType` that are written to by
294/// a store operation writing `storeSize` bytes at `storeOffset`.
295/// `storeOffset` is required to cleanly point to an immediate element within
296/// the type. If the write operation were to write to any padding, write beyond
297/// the aggregate or partially write to a non-aggregate, failure is returned.
298static FailureOr<DestructurableTypeRange>
299getWrittenToFields(const DataLayout &dataLayout,
300 DestructurableTypeInterface destructurableType,
301 unsigned storeSize, unsigned storeOffset) {
302 DestructurableTypeRange destructurableTypeRange(destructurableType);
303
304 unsigned currentOffset = 0;
305 for (; !destructurableTypeRange.empty();
306 destructurableTypeRange = destructurableTypeRange.drop_front()) {
307 Type type = destructurableTypeRange.front();
308 if (!destructurableTypeRange.isPacked()) {
309 unsigned alignment = dataLayout.getTypeABIAlignment(t: type);
310 currentOffset = llvm::alignTo(Value: currentOffset, Align: alignment);
311 }
312
313 // currentOffset is guaranteed to be equal to offset since offset is either
314 // 0 or stems from a type-consistent GEP indexing into just a single
315 // aggregate.
316 if (currentOffset == storeOffset)
317 break;
318
319 assert(currentOffset < storeOffset &&
320 "storeOffset should cleanly point into an immediate field");
321
322 currentOffset += dataLayout.getTypeSize(t: type);
323 }
324
325 size_t exclusiveEnd = 0;
326 for (; exclusiveEnd < destructurableTypeRange.size() && storeSize > 0;
327 exclusiveEnd++) {
328 if (!destructurableTypeRange.isPacked()) {
329 unsigned alignment =
330 dataLayout.getTypeABIAlignment(destructurableTypeRange[exclusiveEnd]);
331 // No padding allowed inbetween fields at this point in time.
332 if (!llvm::isAligned(Lhs: llvm::Align(alignment), SizeInBytes: currentOffset))
333 return failure();
334 }
335
336 unsigned fieldSize =
337 dataLayout.getTypeSize(destructurableTypeRange[exclusiveEnd]);
338 if (fieldSize > storeSize) {
339 // Partial writes into an aggregate are okay since subsequent pattern
340 // applications can further split these up into writes into the
341 // sub-elements.
342 auto subAggregate = dyn_cast<DestructurableTypeInterface>(
343 destructurableTypeRange[exclusiveEnd]);
344 if (!subAggregate)
345 return failure();
346
347 // Avoid splitting redundantly by making sure the store into the
348 // aggregate can actually be split.
349 if (failed(getWrittenToFields(dataLayout, subAggregate, storeSize,
350 /*storeOffset=*/0)))
351 return failure();
352
353 return destructurableTypeRange.take_front(exclusiveEnd + 1);
354 }
355 currentOffset += fieldSize;
356 storeSize -= fieldSize;
357 }
358
359 // If the storeSize is not 0 at this point we are writing past the aggregate
360 // as a whole. Abort.
361 if (storeSize > 0)
362 return failure();
363 return destructurableTypeRange.take_front(exclusiveEnd);
364}
365
366/// Splits a store of the vector `value` into `address` at `storeOffset` into
367/// multiple stores of each element with the goal of each generated store
368/// becoming type-consistent through subsequent pattern applications.
369static void splitVectorStore(const DataLayout &dataLayout, Location loc,
370 RewriterBase &rewriter, Value address,
371 TypedValue<VectorType> value,
372 unsigned storeOffset) {
373 VectorType vectorType = value.getType();
374 unsigned elementSize = dataLayout.getTypeSize(t: vectorType.getElementType());
375
376 // Extract every element in the vector and store it in the given address.
377 for (size_t index : llvm::seq<size_t>(0, vectorType.getNumElements())) {
378 auto pos =
379 rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(index));
380 auto extractOp = rewriter.create<ExtractElementOp>(loc, value, pos);
381
382 // For convenience, we do indexing by calculating the final byte offset.
383 // Other patterns will turn this into a type-consistent GEP.
384 auto gepOp = rewriter.create<GEPOp>(
385 loc, address.getType(), rewriter.getI8Type(), address,
386 ArrayRef<GEPArg>{
387 static_cast<int32_t>(storeOffset + index * elementSize)});
388
389 rewriter.create<StoreOp>(loc, extractOp, gepOp);
390 }
391}
392
393/// Splits a store of the integer `value` into `address` at `storeOffset` into
394/// multiple stores to each 'writtenToFields', making each store operation
395/// type-consistent.
396static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
397 RewriterBase &rewriter, Value address,
398 Value value, unsigned storeSize,
399 unsigned storeOffset,
400 DestructurableTypeRange writtenToFields) {
401 unsigned currentOffset = storeOffset;
402 for (Type type : writtenToFields) {
403 unsigned fieldSize = dataLayout.getTypeSize(type);
404
405 // Extract the data out of the integer by first shifting right and then
406 // truncating it.
407 auto pos = rewriter.create<ConstantOp>(
408 loc, rewriter.getIntegerAttr(value.getType(),
409 (currentOffset - storeOffset) * 8));
410
411 auto shrOp = rewriter.create<LShrOp>(loc, value, pos);
412
413 // If we are doing a partial write into a direct field the remaining
414 // `storeSize` will be less than the size of the field. We have to truncate
415 // to the `storeSize` to avoid creating a store that wasn't in the original
416 // code.
417 IntegerType fieldIntType =
418 rewriter.getIntegerType(std::min(fieldSize, storeSize) * 8);
419 Value valueToStore = rewriter.create<TruncOp>(loc, fieldIntType, shrOp);
420
421 // We create an `i8` indexed GEP here as that is the easiest (offset is
422 // already known). Other patterns turn this into a type-consistent GEP.
423 auto gepOp = rewriter.create<GEPOp>(
424 loc, address.getType(), rewriter.getI8Type(), address,
425 ArrayRef<GEPArg>{static_cast<int32_t>(currentOffset)});
426 rewriter.create<StoreOp>(loc, valueToStore, gepOp);
427
428 // No need to care about padding here since we already checked previously
429 // that no padding exists in this range.
430 currentOffset += fieldSize;
431 storeSize -= fieldSize;
432 }
433}
434
435LogicalResult SplitStores::matchAndRewrite(StoreOp store,
436 PatternRewriter &rewriter) const {
437 Type sourceType = store.getValue().getType();
438 if (!isa<IntegerType, VectorType>(Val: sourceType)) {
439 // We currently only support integer and vector sources.
440 return failure();
441 }
442
443 Type typeHint = isElementTypeInconsistent(store.getAddr(), sourceType);
444 if (!typeHint) {
445 // Nothing to do, since it is already consistent.
446 return failure();
447 }
448
449 auto dataLayout = DataLayout::closest(op: store);
450
451 unsigned storeSize = dataLayout.getTypeSize(sourceType);
452 unsigned offset = 0;
453 Value address = store.getAddr();
454 if (auto gepOp = address.getDefiningOp<GEPOp>()) {
455 // Currently only handle canonical GEPs with exactly two indices,
456 // indexing a single aggregate deep.
457 // If the GEP is not canonical we have to fail, otherwise we would not
458 // create type-consistent IR.
459 if (gepOp.getIndices().size() != 2 ||
460 succeeded(getRequiredConsistentGEPType(gepOp)))
461 return failure();
462
463 // If the size of the element indexed by the GEP is smaller than the store
464 // size, it is pointing into the middle of an aggregate with the store
465 // storing into multiple adjacent elements. Destructure into the base
466 // address of the aggregate with a store offset.
467 if (storeSize > dataLayout.getTypeSize(gepOp.getResultPtrElementType())) {
468 std::optional<uint64_t> byteOffset = gepToByteOffset(dataLayout, gepOp);
469 if (!byteOffset)
470 return failure();
471
472 offset = *byteOffset;
473 typeHint = gepOp.getElemType();
474 address = gepOp.getBase();
475 }
476 }
477
478 auto destructurableType = dyn_cast<DestructurableTypeInterface>(typeHint);
479 if (!destructurableType)
480 return failure();
481
482 FailureOr<DestructurableTypeRange> writtenToElements =
483 getWrittenToFields(dataLayout, destructurableType, storeSize, offset);
484 if (failed(result: writtenToElements))
485 return failure();
486
487 if (writtenToElements->size() <= 1) {
488 // Other patterns should take care of this case, we are only interested in
489 // splitting element stores.
490 return failure();
491 }
492
493 if (isa<IntegerType>(Val: sourceType)) {
494 splitIntegerStore(dataLayout, store.getLoc(), rewriter, address,
495 store.getValue(), storeSize, offset, *writtenToElements);
496 rewriter.eraseOp(op: store);
497 return success();
498 }
499
500 // Add a reasonable bound to not split very large vectors that would end up
501 // generating lots of code.
502 if (dataLayout.getTypeSizeInBits(sourceType) > maxVectorSplitSize)
503 return failure();
504
505 // Vector types are simply split into its elements and new stores generated
506 // with those. Subsequent pattern applications will split these stores further
507 // if required.
508 splitVectorStore(dataLayout, store.getLoc(), rewriter, address,
509 cast<TypedValue<VectorType>>(store.getValue()), offset);
510 rewriter.eraseOp(op: store);
511 return success();
512}
513
514LogicalResult SplitGEP::matchAndRewrite(GEPOp gepOp,
515 PatternRewriter &rewriter) const {
516 FailureOr<Type> typeHint = getRequiredConsistentGEPType(gepOp);
517 if (succeeded(result: typeHint) || gepOp.getIndices().size() <= 2) {
518 // GEP is not canonical or a single aggregate deep, nothing to do here.
519 return failure();
520 }
521
522 auto indexToGEPArg =
523 [](GEPIndicesAdaptor<ValueRange>::value_type index) -> GEPArg {
524 if (auto integerAttr = dyn_cast<IntegerAttr>(index))
525 return integerAttr.getValue().getSExtValue();
526 return cast<Value>(index);
527 };
528
529 GEPIndicesAdaptor<ValueRange> indices = gepOp.getIndices();
530
531 auto splitIter = std::next(x: indices.begin(), n: 2);
532
533 // Split of the first GEP using the first two indices.
534 auto subGepOp = rewriter.create<GEPOp>(
535 gepOp.getLoc(), gepOp.getType(), gepOp.getElemType(), gepOp.getBase(),
536 llvm::map_to_vector(llvm::make_range(indices.begin(), splitIter),
537 indexToGEPArg),
538 gepOp.getInbounds());
539
540 // The second GEP indexes on the result pointer element type of the previous
541 // with all the remaining indices and a zero upfront. If this GEP has more
542 // than two indices remaining it'll be further split in subsequent pattern
543 // applications.
544 SmallVector<GEPArg> newIndices = {0};
545 llvm::transform(llvm::make_range(splitIter, indices.end()),
546 std::back_inserter(x&: newIndices), indexToGEPArg);
547 rewriter.replaceOpWithNewOp<GEPOp>(gepOp, gepOp.getType(),
548 subGepOp.getResultPtrElementType(),
549 subGepOp, newIndices, gepOp.getInbounds());
550 return success();
551}
552
553//===----------------------------------------------------------------------===//
554// Type consistency pass
555//===----------------------------------------------------------------------===//
556
557namespace {
558struct LLVMTypeConsistencyPass
559 : public LLVM::impl::LLVMTypeConsistencyBase<LLVMTypeConsistencyPass> {
560 void runOnOperation() override {
561 RewritePatternSet rewritePatterns(&getContext());
562 rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
563 rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
564 rewritePatterns.add<SplitGEP>(&getContext());
565 FrozenRewritePatternSet frozen(std::move(rewritePatterns));
566
567 if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen)))
568 signalPassFailure();
569 }
570};
571} // namespace
572
573std::unique_ptr<Pass> LLVM::createTypeConsistencyPass() {
574 return std::make_unique<LLVMTypeConsistencyPass>();
575}
576

source code of mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp