1//===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10#include "MemRefDescriptor.h"
11#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
12#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14#include "llvm/ADT/ScopeExit.h"
15#include "llvm/Support/Threading.h"
16#include <memory>
17#include <mutex>
18#include <optional>
19
20using namespace mlir;
21
22SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() {
23 {
24 // Most of the time, the entry already exists in the map.
25 std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
26 std::defer_lock);
27 if (getContext().isMultithreadingEnabled())
28 lock.lock();
29 auto recursiveStack = conversionCallStack.find(Val: llvm::get_threadid());
30 if (recursiveStack != conversionCallStack.end())
31 return *recursiveStack->second;
32 }
33
34 // First time this thread gets here, we have to get an exclusive access to
35 // inset in the map
36 std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
37 auto recursiveStackInserted = conversionCallStack.insert(KV: std::make_pair(
38 x: llvm::get_threadid(), y: std::make_unique<SmallVector<Type>>()));
39 return *recursiveStackInserted.first->second;
40}
41
42/// Create an LLVMTypeConverter using default LowerToLLVMOptions.
43LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
44 const DataLayoutAnalysis *analysis)
45 : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
46
47/// Helper function that checks if the given value range is a bare pointer.
48static bool isBarePointer(ValueRange values) {
49 return values.size() == 1 &&
50 isa<LLVM::LLVMPointerType>(Val: values.front().getType());
51}
52
53/// Pack SSA values into an unranked memref descriptor struct.
54static Value packUnrankedMemRefDesc(OpBuilder &builder,
55 UnrankedMemRefType resultType,
56 ValueRange inputs, Location loc,
57 const LLVMTypeConverter &converter) {
58 // Note: Bare pointers are not supported for unranked memrefs because a
59 // memref descriptor cannot be built just from a bare pointer.
60 if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
61 return Value();
62 return UnrankedMemRefDescriptor::pack(builder, loc, converter, type: resultType,
63 values: inputs);
64}
65
66/// Pack SSA values into a ranked memref descriptor struct.
67static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
68 ValueRange inputs, Location loc,
69 const LLVMTypeConverter &converter) {
70 assert(resultType && "expected non-null result type");
71 if (isBarePointer(values: inputs))
72 return MemRefDescriptor::fromStaticShape(builder, loc, converter,
73 resultType, inputs[0]);
74 if (TypeRange(inputs) ==
75 converter.getMemRefDescriptorFields(type: resultType,
76 /*unpackAggregates=*/true))
77 return MemRefDescriptor::pack(builder, loc, converter, type: resultType, values: inputs);
78 // The inputs are neither a bare pointer nor an unpacked memref descriptor.
79 // This materialization function cannot be used.
80 return Value();
81}
82
83/// MemRef descriptor elements -> UnrankedMemRefType
84static Value unrankedMemRefMaterialization(OpBuilder &builder,
85 UnrankedMemRefType resultType,
86 ValueRange inputs, Location loc,
87 const LLVMTypeConverter &converter) {
88 // A source materialization must return a value of type
89 // `resultType`, so insert a cast from the memref descriptor type
90 // (!llvm.struct) to the original memref type.
91 Value packed =
92 packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
93 if (!packed)
94 return Value();
95 return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
96 .getResult(0);
97}
98
99/// MemRef descriptor elements -> MemRefType
100static Value rankedMemRefMaterialization(OpBuilder &builder,
101 MemRefType resultType,
102 ValueRange inputs, Location loc,
103 const LLVMTypeConverter &converter) {
104 // A source materialization must return a value of type `resultType`,
105 // so insert a cast from the memref descriptor type (!llvm.struct) to the
106 // original memref type.
107 Value packed =
108 packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
109 if (!packed)
110 return Value();
111 return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
112 .getResult(0);
113}
114
115/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
116LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
117 const LowerToLLVMOptions &options,
118 const DataLayoutAnalysis *analysis)
119 : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
120 dataLayoutAnalysis(analysis) {
121 assert(llvmDialect && "LLVM IR dialect is not registered");
122
123 // Register conversions for the builtin types.
124 addConversion(callback: [&](ComplexType type) { return convertComplexType(type); });
125 addConversion([&](FloatType type) { return convertFloatType(type); });
126 addConversion([&](FunctionType type) { return convertFunctionType(type); });
127 addConversion([&](IndexType type) { return convertIndexType(type); });
128 addConversion([&](IntegerType type) { return convertIntegerType(type); });
129 addConversion([&](MemRefType type) { return convertMemRefType(type); });
130 addConversion(
131 [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
132 addConversion(callback: [&](VectorType type) -> std::optional<Type> {
133 FailureOr<Type> llvmType = convertVectorType(type: type);
134 if (failed(Result: llvmType))
135 return std::nullopt;
136 return llvmType;
137 });
138
139 // LLVM-compatible types are legal, so add a pass-through conversion. Do this
140 // before the conversions below since conversions are attempted in reverse
141 // order and those should take priority.
142 addConversion(callback: [](Type type) {
143 return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
144 : std::nullopt;
145 });
146
147 addConversion(callback: [&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
148 -> std::optional<LogicalResult> {
149 // Fastpath for types that won't be converted by this callback anyway.
150 if (LLVM::isCompatibleType(type: type)) {
151 results.push_back(Elt: type);
152 return success();
153 }
154
155 if (type.isIdentified()) {
156 auto convertedType = LLVM::LLVMStructType::getIdentified(
157 type.getContext(), ("_Converted." + type.getName()).str());
158
159 SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack();
160 if (llvm::count(recursiveStack, type)) {
161 results.push_back(Elt: convertedType);
162 return success();
163 }
164 recursiveStack.push_back(Elt: type);
165 auto popConversionCallStack = llvm::make_scope_exit(
166 F: [&recursiveStack]() { recursiveStack.pop_back(); });
167
168 SmallVector<Type> convertedElemTypes;
169 convertedElemTypes.reserve(N: type.getBody().size());
170 if (failed(convertTypes(types: type.getBody(), results&: convertedElemTypes)))
171 return std::nullopt;
172
173 // If the converted type has not been initialized yet, just set its body
174 // to be the converted arguments and return.
175 if (!convertedType.isInitialized()) {
176 if (failed(
177 convertedType.setBody(convertedElemTypes, type.isPacked()))) {
178 return failure();
179 }
180 results.push_back(Elt: convertedType);
181 return success();
182 }
183
184 // If it has been initialized, has the same body and packed bit, just use
185 // it. This ensures that recursive structs keep being recursive rather
186 // than including a non-updated name.
187 if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
188 convertedType.isPacked() == type.isPacked()) {
189 results.push_back(Elt: convertedType);
190 return success();
191 }
192
193 return failure();
194 }
195
196 SmallVector<Type> convertedSubtypes;
197 convertedSubtypes.reserve(N: type.getBody().size());
198 if (failed(convertTypes(types: type.getBody(), results&: convertedSubtypes)))
199 return std::nullopt;
200
201 results.push_back(LLVM::LLVMStructType::getLiteral(
202 type.getContext(), convertedSubtypes, type.isPacked()));
203 return success();
204 });
205 addConversion(callback: [&](LLVM::LLVMArrayType type) -> std::optional<Type> {
206 if (auto element = convertType(type.getElementType()))
207 return LLVM::LLVMArrayType::get(element, type.getNumElements());
208 return std::nullopt;
209 });
210 addConversion(callback: [&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
211 Type convertedResType = convertType(type.getReturnType());
212 if (!convertedResType)
213 return std::nullopt;
214
215 SmallVector<Type> convertedArgTypes;
216 convertedArgTypes.reserve(N: type.getNumParams());
217 if (failed(convertTypes(types: type.getParams(), results&: convertedArgTypes)))
218 return std::nullopt;
219
220 return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
221 type.isVarArg());
222 });
223
224 // Add generic source and target materializations to handle cases where
225 // non-LLVM types persist after an LLVM conversion.
226 addSourceMaterialization(callback: [&](OpBuilder &builder, Type resultType,
227 ValueRange inputs, Location loc) {
228 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
229 .getResult(0);
230 });
231 addTargetMaterialization(callback: [&](OpBuilder &builder, Type resultType,
232 ValueRange inputs, Location loc) {
233 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
234 .getResult(0);
235 });
236
237 // Source materializations convert from the new block argument types
238 // (multiple SSA values that make up a memref descriptor) back to the
239 // original block argument type.
240 addSourceMaterialization([&](OpBuilder &builder,
241 UnrankedMemRefType resultType, ValueRange inputs,
242 Location loc) {
243 return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
244 *this);
245 });
246 addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
247 ValueRange inputs, Location loc) {
248 return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
249 });
250
251 // Bare pointer -> Packed MemRef descriptor
252 addTargetMaterialization(callback: [&](OpBuilder &builder, Type resultType,
253 ValueRange inputs, Location loc,
254 Type originalType) -> Value {
255 // The original MemRef type is required to build a MemRef descriptor
256 // because the sizes/strides of the MemRef cannot be inferred from just the
257 // bare pointer.
258 if (!originalType)
259 return Value();
260 if (resultType != convertType(t: originalType))
261 return Value();
262 if (auto memrefType = dyn_cast<MemRefType>(originalType))
263 return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
264 if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
265 return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
266 *this);
267 return Value();
268 });
269
270 // Integer memory spaces map to themselves.
271 addTypeAttributeConversion(
272 [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
273}
274
275/// Returns the MLIR context.
276MLIRContext &LLVMTypeConverter::getContext() const {
277 return *getDialect()->getContext();
278}
279
280Type LLVMTypeConverter::getIndexType() const {
281 return IntegerType::get(&getContext(), getIndexTypeBitwidth());
282}
283
284unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
285 return options.dataLayout.getPointerSizeInBits(AS: addressSpace);
286}
287
288Type LLVMTypeConverter::convertIndexType(IndexType type) const {
289 return getIndexType();
290}
291
292Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
293 return IntegerType::get(&getContext(), type.getWidth());
294}
295
296Type LLVMTypeConverter::convertFloatType(FloatType type) const {
297 // Valid LLVM float types are used directly.
298 if (LLVM::isCompatibleType(type: type))
299 return type;
300
301 // F4, F6, F8 types are converted to integer types with the same bit width.
302 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
303 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
304 Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
305 Float8E8M0FNUType>(type))
306 return IntegerType::get(&getContext(), type.getWidth());
307
308 // Other floating-point types: A custom type conversion rule must be
309 // specified by the user.
310 return Type();
311}
312
313// Convert a `ComplexType` to an LLVM type. The result is a complex number
314// struct with entries for the
315// 1. real part and for the
316// 2. imaginary part.
317Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
318 auto elementType = convertType(type.getElementType());
319 return LLVM::LLVMStructType::getLiteral(&getContext(),
320 {elementType, elementType});
321}
322
323// Except for signatures, MLIR function types are converted into LLVM
324// pointer-to-function types.
325Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
326 return LLVM::LLVMPointerType::get(type.getContext());
327}
328
329/// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
330/// function arguments. Returns an empty container if none of these attributes
331/// are found in any of the arguments.
332static void
333filterByValRefArgAttrs(FunctionOpInterface funcOp,
334 SmallVectorImpl<std::optional<NamedAttribute>> &result) {
335 assert(result.empty() && "Unexpected non-empty output");
336 result.resize(funcOp.getNumArguments(), std::nullopt);
337 bool foundByValByRefAttrs = false;
338 for (int argIdx : llvm::seq(funcOp.getNumArguments())) {
339 for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) {
340 if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
341 namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
342 foundByValByRefAttrs = true;
343 result[argIdx] = namedAttr;
344 break;
345 }
346 }
347 }
348
349 if (!foundByValByRefAttrs)
350 result.clear();
351}
352
353// Function types are converted to LLVM Function types by recursively converting
354// argument and result types. If MLIR Function has zero results, the LLVM
355// Function has one VoidType result. If MLIR Function has more than one result,
356// they are into an LLVM StructType in their order of appearance.
357// If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
358// `llvm.byref` function arguments which are not LLVM pointers are overridden
359// with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
360// converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
361Type LLVMTypeConverter::convertFunctionSignatureImpl(
362 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
363 LLVMTypeConverter::SignatureConversion &result,
364 SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
365 // Select the argument converter depending on the calling convention.
366 useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
367 auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
368 : structFuncArgTypeConverter;
369 // Convert argument types one by one and check for errors.
370 for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
371 SmallVector<Type, 8> converted;
372 if (failed(funcArgConverter(*this, type, converted)))
373 return {};
374
375 // Rewrite converted type of `llvm.byval` or `llvm.byref` function
376 // argument that was not converted to an LLVM pointer types.
377 if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
378 converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
379 // If the argument was already converted to an LLVM pointer type, we stop
380 // tracking it as it doesn't need more processing.
381 if (isa<LLVM::LLVMPointerType>(converted[0]))
382 (*byValRefNonPtrAttrs)[idx] = std::nullopt;
383 else
384 converted[0] = LLVM::LLVMPointerType::get(&getContext());
385 }
386
387 result.addInputs(idx, converted);
388 }
389
390 // If function does not return anything, create the void result type,
391 // if it returns on element, convert it, otherwise pack the result types into
392 // a struct.
393 Type resultType =
394 funcTy.getNumResults() == 0
395 ? LLVM::LLVMVoidType::get(ctx: &getContext())
396 : packFunctionResults(types: funcTy.getResults(), useBarePointerCallConv: useBarePtrCallConv);
397 if (!resultType)
398 return {};
399 return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
400 isVariadic);
401}
402
403Type LLVMTypeConverter::convertFunctionSignature(
404 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
405 LLVMTypeConverter::SignatureConversion &result) const {
406 return convertFunctionSignatureImpl(funcTy: funcTy, isVariadic, useBarePtrCallConv,
407 result,
408 /*byValRefNonPtrAttrs=*/nullptr);
409}
410
411Type LLVMTypeConverter::convertFunctionSignature(
412 FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
413 LLVMTypeConverter::SignatureConversion &result,
414 SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
415 // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
416 // that were not converted to LLVM pointer types will be returned for further
417 // processing.
418 filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs);
419 auto funcTy = cast<FunctionType>(funcOp.getFunctionType());
420 return convertFunctionSignatureImpl(funcTy: funcTy, isVariadic, useBarePtrCallConv,
421 result, byValRefNonPtrAttrs: &byValRefNonPtrAttrs);
422}
423
424/// Converts the function type to a C-compatible format, in particular using
425/// pointers to memref descriptors for arguments.
426std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
427LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
428 SmallVector<Type, 4> inputs;
429
430 Type resultType = type.getNumResults() == 0
431 ? LLVM::LLVMVoidType::get(ctx: &getContext())
432 : packFunctionResults(types: type.getResults());
433 if (!resultType)
434 return {};
435
436 auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
437 auto structType = dyn_cast<LLVM::LLVMStructType>(resultType);
438 if (structType) {
439 // Struct types cannot be safely returned via C interface. Make this a
440 // pointer argument, instead.
441 inputs.push_back(Elt: ptrType);
442 resultType = LLVM::LLVMVoidType::get(ctx: &getContext());
443 }
444
445 for (Type t : type.getInputs()) {
446 auto converted = convertType(t);
447 if (!converted || !LLVM::isCompatibleType(converted))
448 return {};
449 if (isa<MemRefType, UnrankedMemRefType>(t))
450 converted = ptrType;
451 inputs.push_back(converted);
452 }
453
454 return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
455}
456
457/// Convert a memref type into a list of LLVM IR types that will form the
458/// memref descriptor. The result contains the following types:
459/// 1. The pointer to the allocated data buffer, followed by
460/// 2. The pointer to the aligned data buffer, followed by
461/// 3. A lowered `index`-type integer containing the distance between the
462/// beginning of the buffer and the first element to be accessed through the
463/// view, followed by
464/// 4. An array containing as many `index`-type integers as the rank of the
465/// MemRef: the array represents the size, in number of elements, of the memref
466/// along the given dimension. For constant MemRef dimensions, the
467/// corresponding size entry is a constant whose runtime value must match the
468/// static value, followed by
469/// 5. A second array containing as many `index`-type integers as the rank of
470/// the MemRef: the second array represents the "stride" (in tensor abstraction
471/// sense), i.e. the number of consecutive elements of the underlying buffer.
472/// TODO: add assertions for the static cases.
473///
474/// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
475/// are expanded into individual index-type elements.
476///
477/// template <typename Elem, typename Index, size_t Rank>
478/// struct {
479/// Elem *allocatedPtr;
480/// Elem *alignedPtr;
481/// Index offset;
482/// Index sizes[Rank]; // omitted when rank == 0
483/// Index strides[Rank]; // omitted when rank == 0
484/// };
485SmallVector<Type, 5>
486LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
487 bool unpackAggregates) const {
488 if (!type.isStrided()) {
489 emitError(
490 UnknownLoc::get(type.getContext()),
491 "conversion to strided form failed either due to non-strided layout "
492 "maps (which should have been normalized away) or other reasons");
493 return {};
494 }
495
496 Type elementType = convertType(type.getElementType());
497 if (!elementType)
498 return {};
499
500 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type: type);
501 if (failed(Result: addressSpace)) {
502 emitError(UnknownLoc::get(type.getContext()),
503 "conversion of memref memory space ")
504 << type.getMemorySpace()
505 << " to integer address space "
506 "failed. Consider adding memory space conversions.";
507 return {};
508 }
509 auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
510
511 auto indexTy = getIndexType();
512
513 SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
514 auto rank = type.getRank();
515 if (rank == 0)
516 return results;
517
518 if (unpackAggregates)
519 results.insert(results.end(), 2 * rank, indexTy);
520 else
521 results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
522 return results;
523}
524
525unsigned
526LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
527 const DataLayout &layout) const {
528 // Compute the descriptor size given that of its components indicated above.
529 unsigned space = *getMemRefAddressSpace(type: type);
530 return 2 * llvm::divideCeil(Numerator: getPointerBitwidth(addressSpace: space), Denominator: 8) +
531 (1 + 2 * type.getRank()) * layout.getTypeSize(t: getIndexType());
532}
533
534/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
535/// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
536Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
537 // When converting a MemRefType to a struct with descriptor fields, do not
538 // unpack the `sizes` and `strides` arrays.
539 SmallVector<Type, 5> types =
540 getMemRefDescriptorFields(type: type, /*unpackAggregates=*/false);
541 if (types.empty())
542 return {};
543 return LLVM::LLVMStructType::getLiteral(&getContext(), types);
544}
545
546/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
547/// that will form the unranked memref descriptor. In particular, the fields
548/// for an unranked memref descriptor are:
549/// 1. index-typed rank, the dynamic rank of this MemRef
550/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
551/// stack allocated (alloca) copy of a MemRef descriptor that got casted to
552/// be unranked.
553SmallVector<Type, 2>
554LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
555 return {getIndexType(), LLVM::LLVMPointerType::get(&getContext())};
556}
557
558unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize(
559 UnrankedMemRefType type, const DataLayout &layout) const {
560 // Compute the descriptor size given that of its components indicated above.
561 unsigned space = *getMemRefAddressSpace(type: type);
562 return layout.getTypeSize(t: getIndexType()) +
563 llvm::divideCeil(Numerator: getPointerBitwidth(addressSpace: space), Denominator: 8);
564}
565
566Type LLVMTypeConverter::convertUnrankedMemRefType(
567 UnrankedMemRefType type) const {
568 if (!convertType(type.getElementType()))
569 return {};
570 return LLVM::LLVMStructType::getLiteral(&getContext(),
571 getUnrankedMemRefDescriptorFields());
572}
573
574FailureOr<unsigned>
575LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const {
576 if (!type.getMemorySpace()) // Default memory space -> 0.
577 return 0;
578 std::optional<Attribute> converted =
579 convertTypeAttribute(type, attr: type.getMemorySpace());
580 if (!converted)
581 return failure();
582 if (!(*converted)) // Conversion to default is 0.
583 return 0;
584 if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
585 if (explicitSpace.getType().isIndex() ||
586 explicitSpace.getType().isSignlessInteger())
587 return explicitSpace.getInt();
588 }
589 return failure();
590}
591
592// Check if a memref type can be converted to a bare pointer.
593bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
594 if (isa<UnrankedMemRefType>(Val: type))
595 // Unranked memref is not supported in the bare pointer calling convention.
596 return false;
597
598 // Check that the memref has static shape, strides and offset. Otherwise, it
599 // cannot be lowered to a bare pointer.
600 auto memrefTy = cast<MemRefType>(type);
601 if (!memrefTy.hasStaticShape())
602 return false;
603
604 int64_t offset = 0;
605 SmallVector<int64_t, 4> strides;
606 if (failed(memrefTy.getStridesAndOffset(strides, offset)))
607 return false;
608
609 for (int64_t stride : strides)
610 if (ShapedType::isDynamic(stride))
611 return false;
612
613 return !ShapedType::isDynamic(offset);
614}
615
616/// Convert a memref type to a bare pointer to the memref element type.
617Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
618 if (!canConvertToBarePtr(type))
619 return {};
620 Type elementType = convertType(t: type.getElementType());
621 if (!elementType)
622 return {};
623 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
624 if (failed(Result: addressSpace))
625 return {};
626 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
627}
628
629/// Convert an n-D vector type to an LLVM vector type:
630/// * 0-D `vector<T>` are converted to vector<1xT>
631/// * 1-D `vector<axT>` remains as is while,
632/// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
633/// `!llvm.array<ax...array<jxvector<kxT>>>`.
634/// As LLVM supports arrays of scalable vectors, this method will also convert
635/// n-D scalable vectors provided that only the trailing dim is scalable.
636FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
637 auto elementType = convertType(type.getElementType());
638 if (!elementType)
639 return {};
640 if (type.getShape().empty())
641 return VectorType::get({1}, elementType);
642 Type vectorType = VectorType::get(type.getShape().back(), elementType,
643 type.getScalableDims().back());
644 assert(LLVM::isCompatibleVectorType(vectorType) &&
645 "expected vector type compatible with the LLVM dialect");
646 // For n-D vector types for which a _non-trailing_ dim is scalable,
647 // return a failure. Supporting such cases would require LLVM
648 // to support something akin "scalable arrays" of vectors.
649 if (llvm::is_contained(type.getScalableDims().drop_back(), true))
650 return failure();
651 auto shape = type.getShape();
652 for (int i = shape.size() - 2; i >= 0; --i)
653 vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
654 return vectorType;
655}
656
657/// Convert a type in the context of the default or bare pointer calling
658/// convention. Calling convention sensitive types, such as MemRefType and
659/// UnrankedMemRefType, are converted following the specific rules for the
660/// calling convention. Calling convention independent types are converted
661/// following the default LLVM type conversions.
662Type LLVMTypeConverter::convertCallingConventionType(
663 Type type, bool useBarePtrCallConv) const {
664 if (useBarePtrCallConv)
665 if (auto memrefTy = dyn_cast<BaseMemRefType>(Val&: type))
666 return convertMemRefToBarePtr(type: memrefTy);
667
668 return convertType(t: type);
669}
670
671/// Promote the bare pointers in 'values' that resulted from memrefs to
672/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
673/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
674void LLVMTypeConverter::promoteBarePtrsToDescriptors(
675 ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
676 SmallVectorImpl<Value> &values) const {
677 assert(stdTypes.size() == values.size() &&
678 "The number of types and values doesn't match");
679 for (unsigned i = 0, end = values.size(); i < end; ++i)
680 if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
681 values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
682 memrefTy, values[i]);
683}
684
685/// Convert a non-empty list of types of values produced by an operation into an
686/// LLVM-compatible type. In particular, if more than one value is
687/// produced, create a literal structure with elements that correspond to each
688/// of the types converted with `convertType`.
689Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
690 assert(!types.empty() && "expected non-empty list of type");
691 if (types.size() == 1)
692 return convertType(t: types[0]);
693
694 SmallVector<Type> resultTypes;
695 resultTypes.reserve(N: types.size());
696 for (Type type : types) {
697 Type converted = convertType(t: type);
698 if (!converted || !LLVM::isCompatibleType(type: converted))
699 return {};
700 resultTypes.push_back(Elt: converted);
701 }
702
703 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
704}
705
706/// Convert a non-empty list of types to be returned from a function into an
707/// LLVM-compatible type. In particular, if more than one value is returned,
708/// create an LLVM dialect structure type with elements that correspond to each
709/// of the types converted with `convertCallingConventionType`.
710Type LLVMTypeConverter::packFunctionResults(TypeRange types,
711 bool useBarePtrCallConv) const {
712 assert(!types.empty() && "expected non-empty list of type");
713
714 useBarePtrCallConv |= options.useBarePtrCallConv;
715 if (types.size() == 1)
716 return convertCallingConventionType(type: types.front(), useBarePtrCallConv);
717
718 SmallVector<Type> resultTypes;
719 resultTypes.reserve(N: types.size());
720 for (auto t : types) {
721 auto converted = convertCallingConventionType(type: t, useBarePtrCallConv);
722 if (!converted || !LLVM::isCompatibleType(type: converted))
723 return {};
724 resultTypes.push_back(Elt: converted);
725 }
726
727 return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
728}
729
730Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
731 OpBuilder &builder) const {
732 // Alloca with proper alignment. We do not expect optimizations of this
733 // alloca op and so we omit allocating at the entry block.
734 auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
735 Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
736 builder.getIndexAttr(1));
737 Value allocated =
738 builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
739 // Store into the alloca'ed descriptor.
740 builder.create<LLVM::StoreOp>(loc, operand, allocated);
741 return allocated;
742}
743
744SmallVector<Value, 4>
745LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
746 ValueRange operands, OpBuilder &builder,
747 bool useBarePtrCallConv) const {
748 SmallVector<Value, 4> promotedOperands;
749 promotedOperands.reserve(N: operands.size());
750 useBarePtrCallConv |= options.useBarePtrCallConv;
751 for (auto it : llvm::zip(t&: opOperands, u&: operands)) {
752 auto operand = std::get<0>(t&: it);
753 auto llvmOperand = std::get<1>(t&: it);
754
755 if (useBarePtrCallConv) {
756 // For the bare-ptr calling convention, we only have to extract the
757 // aligned pointer of a memref.
758 if (isa<MemRefType>(Val: operand.getType())) {
759 MemRefDescriptor desc(llvmOperand);
760 llvmOperand = desc.alignedPtr(builder, loc);
761 } else if (isa<UnrankedMemRefType>(Val: operand.getType())) {
762 llvm_unreachable("Unranked memrefs are not supported");
763 }
764 } else {
765 if (isa<UnrankedMemRefType>(Val: operand.getType())) {
766 UnrankedMemRefDescriptor::unpack(builder, loc, packed: llvmOperand,
767 results&: promotedOperands);
768 continue;
769 }
770 if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
771 MemRefDescriptor::unpack(builder, loc, packed: llvmOperand, type: memrefType,
772 results&: promotedOperands);
773 continue;
774 }
775 }
776
777 promotedOperands.push_back(Elt: llvmOperand);
778 }
779 return promotedOperands;
780}
781
782/// Callback to convert function argument types. It converts a MemRef function
783/// argument to a list of non-aggregate types containing descriptor
784/// information, and an UnrankedmemRef function argument to a list containing
785/// the rank and a pointer to a descriptor struct.
786LogicalResult
787mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
788 SmallVectorImpl<Type> &result) {
789 if (auto memref = dyn_cast<MemRefType>(type)) {
790 // In signatures, Memref descriptors are expanded into lists of
791 // non-aggregate values.
792 auto converted =
793 converter.getMemRefDescriptorFields(type: memref, /*unpackAggregates=*/true);
794 if (converted.empty())
795 return failure();
796 result.append(converted.begin(), converted.end());
797 return success();
798 }
799 if (isa<UnrankedMemRefType>(Val: type)) {
800 auto converted = converter.getUnrankedMemRefDescriptorFields();
801 if (converted.empty())
802 return failure();
803 result.append(in_start: converted.begin(), in_end: converted.end());
804 return success();
805 }
806 auto converted = converter.convertType(t: type);
807 if (!converted)
808 return failure();
809 result.push_back(Elt: converted);
810 return success();
811}
812
813/// Callback to convert function argument types. It converts MemRef function
814/// arguments to bare pointers to the MemRef element type.
815LogicalResult
816mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
817 SmallVectorImpl<Type> &result) {
818 auto llvmTy = converter.convertCallingConventionType(
819 type, /*useBarePointerCallConv=*/useBarePtrCallConv: true);
820 if (!llvmTy)
821 return failure();
822
823 result.push_back(Elt: llvmTy);
824 return success();
825}
826

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp