1//===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
11#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
13#include "llvm/ADT/ScopeExit.h"
14#include "llvm/Support/Threading.h"
15#include <memory>
16#include <mutex>
17#include <optional>
18
19using namespace mlir;
20
21SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() {
22 {
23 // Most of the time, the entry already exists in the map.
24 std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
25 std::defer_lock);
26 if (getContext().isMultithreadingEnabled())
27 lock.lock();
28 auto recursiveStack = conversionCallStack.find(Val: llvm::get_threadid());
29 if (recursiveStack != conversionCallStack.end())
30 return *recursiveStack->second;
31 }
32
33 // First time this thread gets here, we have to get an exclusive access to
34 // inset in the map
35 std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
36 auto recursiveStackInserted = conversionCallStack.insert(KV: std::make_pair(
37 x: llvm::get_threadid(), y: std::make_unique<SmallVector<Type>>()));
38 return *recursiveStackInserted.first->second;
39}
40
41/// Create an LLVMTypeConverter using default LowerToLLVMOptions.
42LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
43 const DataLayoutAnalysis *analysis)
44 : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
45
46/// Helper function that checks if the given value range is a bare pointer.
47static bool isBarePointer(ValueRange values) {
48 return values.size() == 1 &&
49 isa<LLVM::LLVMPointerType>(Val: values.front().getType());
50}
51
52/// Pack SSA values into an unranked memref descriptor struct.
53static Value packUnrankedMemRefDesc(OpBuilder &builder,
54 UnrankedMemRefType resultType,
55 ValueRange inputs, Location loc,
56 const LLVMTypeConverter &converter) {
57 // Note: Bare pointers are not supported for unranked memrefs because a
58 // memref descriptor cannot be built just from a bare pointer.
59 if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
60 return Value();
61 return UnrankedMemRefDescriptor::pack(builder, loc, converter, type: resultType,
62 values: inputs);
63}
64
65/// Pack SSA values into a ranked memref descriptor struct.
66static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType,
67 ValueRange inputs, Location loc,
68 const LLVMTypeConverter &converter) {
69 assert(resultType && "expected non-null result type");
70 if (isBarePointer(values: inputs))
71 return MemRefDescriptor::fromStaticShape(builder, loc, typeConverter: converter,
72 type: resultType, memory: inputs[0]);
73 if (TypeRange(inputs) ==
74 converter.getMemRefDescriptorFields(type: resultType,
75 /*unpackAggregates=*/true))
76 return MemRefDescriptor::pack(builder, loc, converter, type: resultType, values: inputs);
77 // The inputs are neither a bare pointer nor an unpacked memref descriptor.
78 // This materialization function cannot be used.
79 return Value();
80}
81
82/// MemRef descriptor elements -> UnrankedMemRefType
83static Value unrankedMemRefMaterialization(OpBuilder &builder,
84 UnrankedMemRefType resultType,
85 ValueRange inputs, Location loc,
86 const LLVMTypeConverter &converter) {
87 // A source materialization must return a value of type
88 // `resultType`, so insert a cast from the memref descriptor type
89 // (!llvm.struct) to the original memref type.
90 Value packed =
91 packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter);
92 if (!packed)
93 return Value();
94 return builder.create<UnrealizedConversionCastOp>(location: loc, args&: resultType, args&: packed)
95 .getResult(i: 0);
96}
97
98/// MemRef descriptor elements -> MemRefType
99static Value rankedMemRefMaterialization(OpBuilder &builder,
100 MemRefType resultType,
101 ValueRange inputs, Location loc,
102 const LLVMTypeConverter &converter) {
103 // A source materialization must return a value of type `resultType`,
104 // so insert a cast from the memref descriptor type (!llvm.struct) to the
105 // original memref type.
106 Value packed =
107 packRankedMemRefDesc(builder, resultType, inputs, loc, converter);
108 if (!packed)
109 return Value();
110 return builder.create<UnrealizedConversionCastOp>(location: loc, args&: resultType, args&: packed)
111 .getResult(i: 0);
112}
113
114/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
115LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
116 const LowerToLLVMOptions &options,
117 const DataLayoutAnalysis *analysis)
118 : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
119 dataLayoutAnalysis(analysis) {
120 assert(llvmDialect && "LLVM IR dialect is not registered");
121
122 // Register conversions for the builtin types.
123 addConversion(callback: [&](ComplexType type) { return convertComplexType(type); });
124 addConversion(callback: [&](FloatType type) { return convertFloatType(type); });
125 addConversion(callback: [&](FunctionType type) { return convertFunctionType(type); });
126 addConversion(callback: [&](IndexType type) { return convertIndexType(type); });
127 addConversion(callback: [&](IntegerType type) { return convertIntegerType(type); });
128 addConversion(callback: [&](MemRefType type) { return convertMemRefType(type); });
129 addConversion(
130 callback: [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
131 addConversion(callback: [&](VectorType type) -> std::optional<Type> {
132 FailureOr<Type> llvmType = convertVectorType(type);
133 if (failed(Result: llvmType))
134 return std::nullopt;
135 return llvmType;
136 });
137
138 // LLVM-compatible types are legal, so add a pass-through conversion. Do this
139 // before the conversions below since conversions are attempted in reverse
140 // order and those should take priority.
141 addConversion(callback: [](Type type) {
142 return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
143 : std::nullopt;
144 });
145
146 addConversion(callback: [&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
147 -> std::optional<LogicalResult> {
148 // Fastpath for types that won't be converted by this callback anyway.
149 if (LLVM::isCompatibleType(type)) {
150 results.push_back(Elt: type);
151 return success();
152 }
153
154 if (type.isIdentified()) {
155 auto convertedType = LLVM::LLVMStructType::getIdentified(
156 context: type.getContext(), name: ("_Converted." + type.getName()).str());
157
158 SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack();
159 if (llvm::count(Range&: recursiveStack, Element: type)) {
160 results.push_back(Elt: convertedType);
161 return success();
162 }
163 recursiveStack.push_back(Elt: type);
164 auto popConversionCallStack = llvm::make_scope_exit(
165 F: [&recursiveStack]() { recursiveStack.pop_back(); });
166
167 SmallVector<Type> convertedElemTypes;
168 convertedElemTypes.reserve(N: type.getBody().size());
169 if (failed(Result: convertTypes(types: type.getBody(), results&: convertedElemTypes)))
170 return std::nullopt;
171
172 // If the converted type has not been initialized yet, just set its body
173 // to be the converted arguments and return.
174 if (!convertedType.isInitialized()) {
175 if (failed(
176 Result: convertedType.setBody(types: convertedElemTypes, isPacked: type.isPacked()))) {
177 return failure();
178 }
179 results.push_back(Elt: convertedType);
180 return success();
181 }
182
183 // If it has been initialized, has the same body and packed bit, just use
184 // it. This ensures that recursive structs keep being recursive rather
185 // than including a non-updated name.
186 if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
187 convertedType.isPacked() == type.isPacked()) {
188 results.push_back(Elt: convertedType);
189 return success();
190 }
191
192 return failure();
193 }
194
195 SmallVector<Type> convertedSubtypes;
196 convertedSubtypes.reserve(N: type.getBody().size());
197 if (failed(Result: convertTypes(types: type.getBody(), results&: convertedSubtypes)))
198 return std::nullopt;
199
200 results.push_back(Elt: LLVM::LLVMStructType::getLiteral(
201 context: type.getContext(), types: convertedSubtypes, isPacked: type.isPacked()));
202 return success();
203 });
204 addConversion(callback: [&](LLVM::LLVMArrayType type) -> std::optional<Type> {
205 if (auto element = convertType(t: type.getElementType()))
206 return LLVM::LLVMArrayType::get(elementType: element, numElements: type.getNumElements());
207 return std::nullopt;
208 });
209 addConversion(callback: [&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
210 Type convertedResType = convertType(t: type.getReturnType());
211 if (!convertedResType)
212 return std::nullopt;
213
214 SmallVector<Type> convertedArgTypes;
215 convertedArgTypes.reserve(N: type.getNumParams());
216 if (failed(Result: convertTypes(types: type.getParams(), results&: convertedArgTypes)))
217 return std::nullopt;
218
219 return LLVM::LLVMFunctionType::get(result: convertedResType, arguments: convertedArgTypes,
220 isVarArg: type.isVarArg());
221 });
222
223 // Add generic source and target materializations to handle cases where
224 // non-LLVM types persist after an LLVM conversion.
225 addSourceMaterialization(callback: [&](OpBuilder &builder, Type resultType,
226 ValueRange inputs, Location loc) {
227 return builder.create<UnrealizedConversionCastOp>(location: loc, args&: resultType, args&: inputs)
228 .getResult(i: 0);
229 });
230 addTargetMaterialization(callback: [&](OpBuilder &builder, Type resultType,
231 ValueRange inputs, Location loc) {
232 return builder.create<UnrealizedConversionCastOp>(location: loc, args&: resultType, args&: inputs)
233 .getResult(i: 0);
234 });
235
236 // Source materializations convert from the new block argument types
237 // (multiple SSA values that make up a memref descriptor) back to the
238 // original block argument type.
239 addSourceMaterialization(callback: [&](OpBuilder &builder,
240 UnrankedMemRefType resultType, ValueRange inputs,
241 Location loc) {
242 return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
243 converter: *this);
244 });
245 addSourceMaterialization(callback: [&](OpBuilder &builder, MemRefType resultType,
246 ValueRange inputs, Location loc) {
247 return rankedMemRefMaterialization(builder, resultType, inputs, loc, converter: *this);
248 });
249
250 // Bare pointer -> Packed MemRef descriptor
251 addTargetMaterialization(callback: [&](OpBuilder &builder, Type resultType,
252 ValueRange inputs, Location loc,
253 Type originalType) -> Value {
254 // The original MemRef type is required to build a MemRef descriptor
255 // because the sizes/strides of the MemRef cannot be inferred from just the
256 // bare pointer.
257 if (!originalType)
258 return Value();
259 if (resultType != convertType(t: originalType))
260 return Value();
261 if (auto memrefType = dyn_cast<MemRefType>(Val&: originalType))
262 return packRankedMemRefDesc(builder, resultType: memrefType, inputs, loc, converter: *this);
263 if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(Val&: originalType))
264 return packUnrankedMemRefDesc(builder, resultType: unrankedMemrefType, inputs, loc,
265 converter: *this);
266 return Value();
267 });
268
269 // Integer memory spaces map to themselves.
270 addTypeAttributeConversion(
271 callback: [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
272}
273
274/// Returns the MLIR context.
275MLIRContext &LLVMTypeConverter::getContext() const {
276 return *getDialect()->getContext();
277}
278
279Type LLVMTypeConverter::getIndexType() const {
280 return IntegerType::get(context: &getContext(), width: getIndexTypeBitwidth());
281}
282
283unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
284 return options.dataLayout.getPointerSizeInBits(AS: addressSpace);
285}
286
287Type LLVMTypeConverter::convertIndexType(IndexType type) const {
288 return getIndexType();
289}
290
291Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
292 return IntegerType::get(context: &getContext(), width: type.getWidth());
293}
294
295Type LLVMTypeConverter::convertFloatType(FloatType type) const {
296 // Valid LLVM float types are used directly.
297 if (LLVM::isCompatibleType(type))
298 return type;
299
300 // F4, F6, F8 types are converted to integer types with the same bit width.
301 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
302 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
303 Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
304 Float8E8M0FNUType>(Val: type))
305 return IntegerType::get(context: &getContext(), width: type.getWidth());
306
307 // Other floating-point types: A custom type conversion rule must be
308 // specified by the user.
309 return Type();
310}
311
312// Convert a `ComplexType` to an LLVM type. The result is a complex number
313// struct with entries for the
314// 1. real part and for the
315// 2. imaginary part.
316Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
317 auto elementType = convertType(t: type.getElementType());
318 return LLVM::LLVMStructType::getLiteral(context: &getContext(),
319 types: {elementType, elementType});
320}
321
322// Except for signatures, MLIR function types are converted into LLVM
323// pointer-to-function types.
324Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
325 return LLVM::LLVMPointerType::get(context: type.getContext());
326}
327
328/// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the
329/// function arguments. Returns an empty container if none of these attributes
330/// are found in any of the arguments.
331static void
332filterByValRefArgAttrs(FunctionOpInterface funcOp,
333 SmallVectorImpl<std::optional<NamedAttribute>> &result) {
334 assert(result.empty() && "Unexpected non-empty output");
335 result.resize(N: funcOp.getNumArguments(), NV: std::nullopt);
336 bool foundByValByRefAttrs = false;
337 for (int argIdx : llvm::seq(Size: funcOp.getNumArguments())) {
338 for (NamedAttribute namedAttr : funcOp.getArgAttrs(index: argIdx)) {
339 if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() ||
340 namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) {
341 foundByValByRefAttrs = true;
342 result[argIdx] = namedAttr;
343 break;
344 }
345 }
346 }
347
348 if (!foundByValByRefAttrs)
349 result.clear();
350}
351
352// Function types are converted to LLVM Function types by recursively converting
353// argument and result types. If MLIR Function has zero results, the LLVM
354// Function has one VoidType result. If MLIR Function has more than one result,
355// they are into an LLVM StructType in their order of appearance.
356// If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and
357// `llvm.byref` function arguments which are not LLVM pointers are overridden
358// with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already
359// converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`.
360Type LLVMTypeConverter::convertFunctionSignatureImpl(
361 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
362 LLVMTypeConverter::SignatureConversion &result,
363 SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const {
364 // Select the argument converter depending on the calling convention.
365 useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
366 auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
367 : structFuncArgTypeConverter;
368 // Convert argument types one by one and check for errors.
369 for (auto [idx, type] : llvm::enumerate(First: funcTy.getInputs())) {
370 SmallVector<Type, 8> converted;
371 if (failed(Result: funcArgConverter(*this, type, converted)))
372 return {};
373
374 // Rewrite converted type of `llvm.byval` or `llvm.byref` function
375 // argument that was not converted to an LLVM pointer types.
376 if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() &&
377 converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) {
378 // If the argument was already converted to an LLVM pointer type, we stop
379 // tracking it as it doesn't need more processing.
380 if (isa<LLVM::LLVMPointerType>(Val: converted[0]))
381 (*byValRefNonPtrAttrs)[idx] = std::nullopt;
382 else
383 converted[0] = LLVM::LLVMPointerType::get(context: &getContext());
384 }
385
386 result.addInputs(origInputNo: idx, types: converted);
387 }
388
389 // If function does not return anything, create the void result type,
390 // if it returns on element, convert it, otherwise pack the result types into
391 // a struct.
392 Type resultType =
393 funcTy.getNumResults() == 0
394 ? LLVM::LLVMVoidType::get(ctx: &getContext())
395 : packFunctionResults(types: funcTy.getResults(), useBarePointerCallConv: useBarePtrCallConv);
396 if (!resultType)
397 return {};
398 return LLVM::LLVMFunctionType::get(result: resultType, arguments: result.getConvertedTypes(),
399 isVarArg: isVariadic);
400}
401
402Type LLVMTypeConverter::convertFunctionSignature(
403 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
404 LLVMTypeConverter::SignatureConversion &result) const {
405 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
406 result,
407 /*byValRefNonPtrAttrs=*/nullptr);
408}
409
410Type LLVMTypeConverter::convertFunctionSignature(
411 FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv,
412 LLVMTypeConverter::SignatureConversion &result,
413 SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const {
414 // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those
415 // that were not converted to LLVM pointer types will be returned for further
416 // processing.
417 filterByValRefArgAttrs(funcOp, result&: byValRefNonPtrAttrs);
418 auto funcTy = cast<FunctionType>(Val: funcOp.getFunctionType());
419 return convertFunctionSignatureImpl(funcTy, isVariadic, useBarePtrCallConv,
420 result, byValRefNonPtrAttrs: &byValRefNonPtrAttrs);
421}
422
423/// Converts the function type to a C-compatible format, in particular using
424/// pointers to memref descriptors for arguments.
425std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
426LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
427 SmallVector<Type, 4> inputs;
428
429 Type resultType = type.getNumResults() == 0
430 ? LLVM::LLVMVoidType::get(ctx: &getContext())
431 : packFunctionResults(types: type.getResults());
432 if (!resultType)
433 return {};
434
435 auto ptrType = LLVM::LLVMPointerType::get(context: type.getContext());
436 auto structType = dyn_cast<LLVM::LLVMStructType>(Val&: resultType);
437 if (structType) {
438 // Struct types cannot be safely returned via C interface. Make this a
439 // pointer argument, instead.
440 inputs.push_back(Elt: ptrType);
441 resultType = LLVM::LLVMVoidType::get(ctx: &getContext());
442 }
443
444 for (Type t : type.getInputs()) {
445 auto converted = convertType(t);
446 if (!converted || !LLVM::isCompatibleType(type: converted))
447 return {};
448 if (isa<MemRefType, UnrankedMemRefType>(Val: t))
449 converted = ptrType;
450 inputs.push_back(Elt: converted);
451 }
452
453 return {LLVM::LLVMFunctionType::get(result: resultType, arguments: inputs), structType};
454}
455
456/// Convert a memref type into a list of LLVM IR types that will form the
457/// memref descriptor. The result contains the following types:
458/// 1. The pointer to the allocated data buffer, followed by
459/// 2. The pointer to the aligned data buffer, followed by
460/// 3. A lowered `index`-type integer containing the distance between the
461/// beginning of the buffer and the first element to be accessed through the
462/// view, followed by
463/// 4. An array containing as many `index`-type integers as the rank of the
464/// MemRef: the array represents the size, in number of elements, of the memref
465/// along the given dimension. For constant MemRef dimensions, the
466/// corresponding size entry is a constant whose runtime value must match the
467/// static value, followed by
468/// 5. A second array containing as many `index`-type integers as the rank of
469/// the MemRef: the second array represents the "stride" (in tensor abstraction
470/// sense), i.e. the number of consecutive elements of the underlying buffer.
471/// TODO: add assertions for the static cases.
472///
473/// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
474/// are expanded into individual index-type elements.
475///
476/// template <typename Elem, typename Index, size_t Rank>
477/// struct {
478/// Elem *allocatedPtr;
479/// Elem *alignedPtr;
480/// Index offset;
481/// Index sizes[Rank]; // omitted when rank == 0
482/// Index strides[Rank]; // omitted when rank == 0
483/// };
484SmallVector<Type, 5>
485LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
486 bool unpackAggregates) const {
487 if (!type.isStrided()) {
488 emitError(
489 loc: UnknownLoc::get(context: type.getContext()),
490 message: "conversion to strided form failed either due to non-strided layout "
491 "maps (which should have been normalized away) or other reasons");
492 return {};
493 }
494
495 Type elementType = convertType(t: type.getElementType());
496 if (!elementType)
497 return {};
498
499 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
500 if (failed(Result: addressSpace)) {
501 emitError(loc: UnknownLoc::get(context: type.getContext()),
502 message: "conversion of memref memory space ")
503 << type.getMemorySpace()
504 << " to integer address space "
505 "failed. Consider adding memory space conversions.";
506 return {};
507 }
508 auto ptrTy = LLVM::LLVMPointerType::get(context: type.getContext(), addressSpace: *addressSpace);
509
510 auto indexTy = getIndexType();
511
512 SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
513 auto rank = type.getRank();
514 if (rank == 0)
515 return results;
516
517 if (unpackAggregates)
518 results.insert(I: results.end(), NumToInsert: 2 * rank, Elt: indexTy);
519 else
520 results.insert(I: results.end(), NumToInsert: 2, Elt: LLVM::LLVMArrayType::get(elementType: indexTy, numElements: rank));
521 return results;
522}
523
524unsigned
525LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
526 const DataLayout &layout) const {
527 // Compute the descriptor size given that of its components indicated above.
528 unsigned space = *getMemRefAddressSpace(type);
529 return 2 * llvm::divideCeil(Numerator: getPointerBitwidth(addressSpace: space), Denominator: 8) +
530 (1 + 2 * type.getRank()) * layout.getTypeSize(t: getIndexType());
531}
532
533/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
534/// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
535Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
536 // When converting a MemRefType to a struct with descriptor fields, do not
537 // unpack the `sizes` and `strides` arrays.
538 SmallVector<Type, 5> types =
539 getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
540 if (types.empty())
541 return {};
542 return LLVM::LLVMStructType::getLiteral(context: &getContext(), types);
543}
544
545/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
546/// that will form the unranked memref descriptor. In particular, the fields
547/// for an unranked memref descriptor are:
548/// 1. index-typed rank, the dynamic rank of this MemRef
549/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
550/// stack allocated (alloca) copy of a MemRef descriptor that got casted to
551/// be unranked.
552SmallVector<Type, 2>
553LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
554 return {getIndexType(), LLVM::LLVMPointerType::get(context: &getContext())};
555}
556
557unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize(
558 UnrankedMemRefType type, const DataLayout &layout) const {
559 // Compute the descriptor size given that of its components indicated above.
560 unsigned space = *getMemRefAddressSpace(type);
561 return layout.getTypeSize(t: getIndexType()) +
562 llvm::divideCeil(Numerator: getPointerBitwidth(addressSpace: space), Denominator: 8);
563}
564
565Type LLVMTypeConverter::convertUnrankedMemRefType(
566 UnrankedMemRefType type) const {
567 if (!convertType(t: type.getElementType()))
568 return {};
569 return LLVM::LLVMStructType::getLiteral(context: &getContext(),
570 types: getUnrankedMemRefDescriptorFields());
571}
572
573FailureOr<unsigned>
574LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const {
575 if (!type.getMemorySpace()) // Default memory space -> 0.
576 return 0;
577 std::optional<Attribute> converted =
578 convertTypeAttribute(type, attr: type.getMemorySpace());
579 if (!converted)
580 return failure();
581 if (!(*converted)) // Conversion to default is 0.
582 return 0;
583 if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(Val&: *converted)) {
584 if (explicitSpace.getType().isIndex() ||
585 explicitSpace.getType().isSignlessInteger())
586 return explicitSpace.getInt();
587 }
588 return failure();
589}
590
591// Check if a memref type can be converted to a bare pointer.
592bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
593 if (isa<UnrankedMemRefType>(Val: type))
594 // Unranked memref is not supported in the bare pointer calling convention.
595 return false;
596
597 // Check that the memref has static shape, strides and offset. Otherwise, it
598 // cannot be lowered to a bare pointer.
599 auto memrefTy = cast<MemRefType>(Val&: type);
600 if (!memrefTy.hasStaticShape())
601 return false;
602
603 int64_t offset = 0;
604 SmallVector<int64_t, 4> strides;
605 if (failed(Result: memrefTy.getStridesAndOffset(strides, offset)))
606 return false;
607
608 for (int64_t stride : strides)
609 if (ShapedType::isDynamic(dValue: stride))
610 return false;
611
612 return ShapedType::isStatic(dValue: offset);
613}
614
615/// Convert a memref type to a bare pointer to the memref element type.
616Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
617 if (!canConvertToBarePtr(type))
618 return {};
619 Type elementType = convertType(t: type.getElementType());
620 if (!elementType)
621 return {};
622 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
623 if (failed(Result: addressSpace))
624 return {};
625 return LLVM::LLVMPointerType::get(context: type.getContext(), addressSpace: *addressSpace);
626}
627
628/// Convert an n-D vector type to an LLVM vector type:
629/// * 0-D `vector<T>` are converted to vector<1xT>
630/// * 1-D `vector<axT>` remains as is while,
631/// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
632/// `!llvm.array<ax...array<jxvector<kxT>>>`.
633/// As LLVM supports arrays of scalable vectors, this method will also convert
634/// n-D scalable vectors provided that only the trailing dim is scalable.
635FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
636 auto elementType = convertType(t: type.getElementType());
637 if (!elementType)
638 return {};
639 if (type.getShape().empty())
640 return VectorType::get(shape: {1}, elementType);
641 Type vectorType = VectorType::get(shape: type.getShape().back(), elementType,
642 scalableDims: type.getScalableDims().back());
643 assert(LLVM::isCompatibleVectorType(vectorType) &&
644 "expected vector type compatible with the LLVM dialect");
645 // For n-D vector types for which a _non-trailing_ dim is scalable,
646 // return a failure. Supporting such cases would require LLVM
647 // to support something akin "scalable arrays" of vectors.
648 if (llvm::is_contained(Range: type.getScalableDims().drop_back(), Element: true))
649 return failure();
650 auto shape = type.getShape();
651 for (int i = shape.size() - 2; i >= 0; --i)
652 vectorType = LLVM::LLVMArrayType::get(elementType: vectorType, numElements: shape[i]);
653 return vectorType;
654}
655
656/// Convert a type in the context of the default or bare pointer calling
657/// convention. Calling convention sensitive types, such as MemRefType and
658/// UnrankedMemRefType, are converted following the specific rules for the
659/// calling convention. Calling convention independent types are converted
660/// following the default LLVM type conversions.
661Type LLVMTypeConverter::convertCallingConventionType(
662 Type type, bool useBarePtrCallConv) const {
663 if (useBarePtrCallConv)
664 if (auto memrefTy = dyn_cast<BaseMemRefType>(Val&: type))
665 return convertMemRefToBarePtr(type: memrefTy);
666
667 return convertType(t: type);
668}
669
670/// Promote the bare pointers in 'values' that resulted from memrefs to
671/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
672/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
673void LLVMTypeConverter::promoteBarePtrsToDescriptors(
674 ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
675 SmallVectorImpl<Value> &values) const {
676 assert(stdTypes.size() == values.size() &&
677 "The number of types and values doesn't match");
678 for (unsigned i = 0, end = values.size(); i < end; ++i)
679 if (auto memrefTy = dyn_cast<MemRefType>(Val: stdTypes[i]))
680 values[i] = MemRefDescriptor::fromStaticShape(builder&: rewriter, loc, typeConverter: *this,
681 type: memrefTy, memory: values[i]);
682}
683
684/// Convert a non-empty list of types of values produced by an operation into an
685/// LLVM-compatible type. In particular, if more than one value is
686/// produced, create a literal structure with elements that correspond to each
687/// of the types converted with `convertType`.
688Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
689 assert(!types.empty() && "expected non-empty list of type");
690 if (types.size() == 1)
691 return convertType(t: types[0]);
692
693 SmallVector<Type> resultTypes;
694 resultTypes.reserve(N: types.size());
695 for (Type type : types) {
696 Type converted = convertType(t: type);
697 if (!converted || !LLVM::isCompatibleType(type: converted))
698 return {};
699 resultTypes.push_back(Elt: converted);
700 }
701
702 return LLVM::LLVMStructType::getLiteral(context: &getContext(), types: resultTypes);
703}
704
705/// Convert a non-empty list of types to be returned from a function into an
706/// LLVM-compatible type. In particular, if more than one value is returned,
707/// create an LLVM dialect structure type with elements that correspond to each
708/// of the types converted with `convertCallingConventionType`.
709Type LLVMTypeConverter::packFunctionResults(TypeRange types,
710 bool useBarePtrCallConv) const {
711 assert(!types.empty() && "expected non-empty list of type");
712
713 useBarePtrCallConv |= options.useBarePtrCallConv;
714 if (types.size() == 1)
715 return convertCallingConventionType(type: types.front(), useBarePtrCallConv);
716
717 SmallVector<Type> resultTypes;
718 resultTypes.reserve(N: types.size());
719 for (auto t : types) {
720 auto converted = convertCallingConventionType(type: t, useBarePtrCallConv);
721 if (!converted || !LLVM::isCompatibleType(type: converted))
722 return {};
723 resultTypes.push_back(Elt: converted);
724 }
725
726 return LLVM::LLVMStructType::getLiteral(context: &getContext(), types: resultTypes);
727}
728
729Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
730 OpBuilder &builder) const {
731 // Alloca with proper alignment. We do not expect optimizations of this
732 // alloca op and so we omit allocating at the entry block.
733 auto ptrType = LLVM::LLVMPointerType::get(context: builder.getContext());
734 Value one = builder.create<LLVM::ConstantOp>(location: loc, args: builder.getI64Type(),
735 args: builder.getIndexAttr(value: 1));
736 Value allocated =
737 builder.create<LLVM::AllocaOp>(location: loc, args&: ptrType, args: operand.getType(), args&: one);
738 // Store into the alloca'ed descriptor.
739 builder.create<LLVM::StoreOp>(location: loc, args&: operand, args&: allocated);
740 return allocated;
741}
742
743SmallVector<Value, 4>
744LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
745 ValueRange operands, OpBuilder &builder,
746 bool useBarePtrCallConv) const {
747 SmallVector<Value, 4> promotedOperands;
748 promotedOperands.reserve(N: operands.size());
749 useBarePtrCallConv |= options.useBarePtrCallConv;
750 for (auto it : llvm::zip(t&: opOperands, u&: operands)) {
751 auto operand = std::get<0>(t&: it);
752 auto llvmOperand = std::get<1>(t&: it);
753
754 if (useBarePtrCallConv) {
755 // For the bare-ptr calling convention, we only have to extract the
756 // aligned pointer of a memref.
757 if (isa<MemRefType>(Val: operand.getType())) {
758 MemRefDescriptor desc(llvmOperand);
759 llvmOperand = desc.alignedPtr(builder, loc);
760 } else if (isa<UnrankedMemRefType>(Val: operand.getType())) {
761 llvm_unreachable("Unranked memrefs are not supported");
762 }
763 } else {
764 if (isa<UnrankedMemRefType>(Val: operand.getType())) {
765 UnrankedMemRefDescriptor::unpack(builder, loc, packed: llvmOperand,
766 results&: promotedOperands);
767 continue;
768 }
769 if (auto memrefType = dyn_cast<MemRefType>(Val: operand.getType())) {
770 MemRefDescriptor::unpack(builder, loc, packed: llvmOperand, type: memrefType,
771 results&: promotedOperands);
772 continue;
773 }
774 }
775
776 promotedOperands.push_back(Elt: llvmOperand);
777 }
778 return promotedOperands;
779}
780
781/// Callback to convert function argument types. It converts a MemRef function
782/// argument to a list of non-aggregate types containing descriptor
783/// information, and an UnrankedmemRef function argument to a list containing
784/// the rank and a pointer to a descriptor struct.
785LogicalResult
786mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
787 SmallVectorImpl<Type> &result) {
788 if (auto memref = dyn_cast<MemRefType>(Val&: type)) {
789 // In signatures, Memref descriptors are expanded into lists of
790 // non-aggregate values.
791 auto converted =
792 converter.getMemRefDescriptorFields(type: memref, /*unpackAggregates=*/true);
793 if (converted.empty())
794 return failure();
795 result.append(in_start: converted.begin(), in_end: converted.end());
796 return success();
797 }
798 if (isa<UnrankedMemRefType>(Val: type)) {
799 auto converted = converter.getUnrankedMemRefDescriptorFields();
800 if (converted.empty())
801 return failure();
802 result.append(in_start: converted.begin(), in_end: converted.end());
803 return success();
804 }
805 auto converted = converter.convertType(t: type);
806 if (!converted)
807 return failure();
808 result.push_back(Elt: converted);
809 return success();
810}
811
812/// Callback to convert function argument types. It converts MemRef function
813/// arguments to bare pointers to the MemRef element type.
814LogicalResult
815mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
816 SmallVectorImpl<Type> &result) {
817 auto llvmTy = converter.convertCallingConventionType(
818 type, /*useBarePointerCallConv=*/useBarePtrCallConv: true);
819 if (!llvmTy)
820 return failure();
821
822 result.push_back(Elt: llvmTy);
823 return success();
824}
825

source code of mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp