1//===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
10#include "MemRefDescriptor.h"
11#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
12#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14#include "llvm/ADT/ScopeExit.h"
15#include "llvm/Support/Threading.h"
16#include <memory>
17#include <mutex>
18#include <optional>
19
20using namespace mlir;
21
22SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() {
23 {
24 // Most of the time, the entry already exists in the map.
25 std::shared_lock<decltype(callStackMutex)> lock(callStackMutex,
26 std::defer_lock);
27 if (getContext().isMultithreadingEnabled())
28 lock.lock();
29 auto recursiveStack = conversionCallStack.find(Val: llvm::get_threadid());
30 if (recursiveStack != conversionCallStack.end())
31 return *recursiveStack->second;
32 }
33
34 // First time this thread gets here, we have to get an exclusive access to
35 // inset in the map
36 std::unique_lock<decltype(callStackMutex)> lock(callStackMutex);
37 auto recursiveStackInserted = conversionCallStack.insert(KV: std::make_pair(
38 x: llvm::get_threadid(), y: std::make_unique<SmallVector<Type>>()));
39 return *recursiveStackInserted.first->second;
40}
41
42/// Create an LLVMTypeConverter using default LowerToLLVMOptions.
43LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
44 const DataLayoutAnalysis *analysis)
45 : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {}
46
47/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
48LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
49 const LowerToLLVMOptions &options,
50 const DataLayoutAnalysis *analysis)
51 : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options),
52 dataLayoutAnalysis(analysis) {
53 assert(llvmDialect && "LLVM IR dialect is not registered");
54
55 // Register conversions for the builtin types.
56 addConversion(callback: [&](ComplexType type) { return convertComplexType(type); });
57 addConversion(callback: [&](FloatType type) { return convertFloatType(type); });
58 addConversion([&](FunctionType type) { return convertFunctionType(type); });
59 addConversion([&](IndexType type) { return convertIndexType(type); });
60 addConversion([&](IntegerType type) { return convertIntegerType(type); });
61 addConversion([&](MemRefType type) { return convertMemRefType(type); });
62 addConversion(
63 [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
64 addConversion(callback: [&](VectorType type) -> std::optional<Type> {
65 FailureOr<Type> llvmType = convertVectorType(type: type);
66 if (failed(result: llvmType))
67 return std::nullopt;
68 return llvmType;
69 });
70
71 // LLVM-compatible types are legal, so add a pass-through conversion. Do this
72 // before the conversions below since conversions are attempted in reverse
73 // order and those should take priority.
74 addConversion(callback: [](Type type) {
75 return LLVM::isCompatibleType(type) ? std::optional<Type>(type)
76 : std::nullopt;
77 });
78
79 addConversion(callback: [&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results)
80 -> std::optional<LogicalResult> {
81 // Fastpath for types that won't be converted by this callback anyway.
82 if (LLVM::isCompatibleType(type)) {
83 results.push_back(type);
84 return success();
85 }
86
87 if (type.isIdentified()) {
88 auto convertedType = LLVM::LLVMStructType::getIdentified(
89 context: type.getContext(), name: ("_Converted." + type.getName()).str());
90
91 SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack();
92 if (llvm::count(Range&: recursiveStack, Element: type)) {
93 results.push_back(Elt: convertedType);
94 return success();
95 }
96 recursiveStack.push_back(type);
97 auto popConversionCallStack = llvm::make_scope_exit(
98 F: [&recursiveStack]() { recursiveStack.pop_back(); });
99
100 SmallVector<Type> convertedElemTypes;
101 convertedElemTypes.reserve(N: type.getBody().size());
102 if (failed(result: convertTypes(types: type.getBody(), results&: convertedElemTypes)))
103 return std::nullopt;
104
105 // If the converted type has not been initialized yet, just set its body
106 // to be the converted arguments and return.
107 if (!convertedType.isInitialized()) {
108 if (failed(
109 convertedType.setBody(convertedElemTypes, type.isPacked()))) {
110 return failure();
111 }
112 results.push_back(Elt: convertedType);
113 return success();
114 }
115
116 // If it has been initialized, has the same body and packed bit, just use
117 // it. This ensures that recursive structs keep being recursive rather
118 // than including a non-updated name.
119 if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) &&
120 convertedType.isPacked() == type.isPacked()) {
121 results.push_back(Elt: convertedType);
122 return success();
123 }
124
125 return failure();
126 }
127
128 SmallVector<Type> convertedSubtypes;
129 convertedSubtypes.reserve(N: type.getBody().size());
130 if (failed(result: convertTypes(types: type.getBody(), results&: convertedSubtypes)))
131 return std::nullopt;
132
133 results.push_back(Elt: LLVM::LLVMStructType::getLiteral(
134 context: type.getContext(), types: convertedSubtypes, isPacked: type.isPacked()));
135 return success();
136 });
137 addConversion(callback: [&](LLVM::LLVMArrayType type) -> std::optional<Type> {
138 if (auto element = convertType(type.getElementType()))
139 return LLVM::LLVMArrayType::get(element, type.getNumElements());
140 return std::nullopt;
141 });
142 addConversion(callback: [&](LLVM::LLVMFunctionType type) -> std::optional<Type> {
143 Type convertedResType = convertType(type.getReturnType());
144 if (!convertedResType)
145 return std::nullopt;
146
147 SmallVector<Type> convertedArgTypes;
148 convertedArgTypes.reserve(N: type.getNumParams());
149 if (failed(convertTypes(types: type.getParams(), results&: convertedArgTypes)))
150 return std::nullopt;
151
152 return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes,
153 type.isVarArg());
154 });
155
156 // Materialization for memrefs creates descriptor structs from individual
157 // values constituting them, when descriptors are used, i.e. more than one
158 // value represents a memref.
159 addArgumentMaterialization(
160 callback: [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
161 Location loc) -> std::optional<Value> {
162 if (inputs.size() == 1)
163 return std::nullopt;
164 return UnrankedMemRefDescriptor::pack(builder, loc, converter: *this, type: resultType,
165 values: inputs);
166 });
167 addArgumentMaterialization(callback: [&](OpBuilder &builder, MemRefType resultType,
168 ValueRange inputs,
169 Location loc) -> std::optional<Value> {
170 // TODO: bare ptr conversion could be handled here but we would need a way
171 // to distinguish between FuncOp and other regions.
172 if (inputs.size() == 1)
173 return std::nullopt;
174 return MemRefDescriptor::pack(builder, loc, converter: *this, type: resultType, values: inputs);
175 });
176 // Add generic source and target materializations to handle cases where
177 // non-LLVM types persist after an LLVM conversion.
178 addSourceMaterialization(callback: [&](OpBuilder &builder, Type resultType,
179 ValueRange inputs,
180 Location loc) -> std::optional<Value> {
181 if (inputs.size() != 1)
182 return std::nullopt;
183
184 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
185 .getResult(0);
186 });
187 addTargetMaterialization(callback: [&](OpBuilder &builder, Type resultType,
188 ValueRange inputs,
189 Location loc) -> std::optional<Value> {
190 if (inputs.size() != 1)
191 return std::nullopt;
192
193 return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
194 .getResult(0);
195 });
196
197 // Integer memory spaces map to themselves.
198 addTypeAttributeConversion(
199 [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
200}
201
202/// Returns the MLIR context.
203MLIRContext &LLVMTypeConverter::getContext() const {
204 return *getDialect()->getContext();
205}
206
207Type LLVMTypeConverter::getIndexType() const {
208 return IntegerType::get(&getContext(), getIndexTypeBitwidth());
209}
210
211unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const {
212 return options.dataLayout.getPointerSizeInBits(AS: addressSpace);
213}
214
215Type LLVMTypeConverter::convertIndexType(IndexType type) const {
216 return getIndexType();
217}
218
219Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
220 return IntegerType::get(&getContext(), type.getWidth());
221}
222
223Type LLVMTypeConverter::convertFloatType(FloatType type) const {
224 if (type.isFloat8E5M2() || type.isFloat8E4M3FN() || type.isFloat8E5M2FNUZ() ||
225 type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ())
226 return IntegerType::get(&getContext(), type.getWidth());
227 return type;
228}
229
230// Convert a `ComplexType` to an LLVM type. The result is a complex number
231// struct with entries for the
232// 1. real part and for the
233// 2. imaginary part.
234Type LLVMTypeConverter::convertComplexType(ComplexType type) const {
235 auto elementType = convertType(type.getElementType());
236 return LLVM::LLVMStructType::getLiteral(context: &getContext(),
237 types: {elementType, elementType});
238}
239
240// Except for signatures, MLIR function types are converted into LLVM
241// pointer-to-function types.
242Type LLVMTypeConverter::convertFunctionType(FunctionType type) const {
243 return LLVM::LLVMPointerType::get(type.getContext());
244}
245
246// Function types are converted to LLVM Function types by recursively converting
247// argument and result types. If MLIR Function has zero results, the LLVM
248// Function has one VoidType result. If MLIR Function has more than one result,
249// they are into an LLVM StructType in their order of appearance.
250Type LLVMTypeConverter::convertFunctionSignature(
251 FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv,
252 LLVMTypeConverter::SignatureConversion &result) const {
253 // Select the argument converter depending on the calling convention.
254 useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
255 auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
256 : structFuncArgTypeConverter;
257 // Convert argument types one by one and check for errors.
258 for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
259 SmallVector<Type, 8> converted;
260 if (failed(funcArgConverter(*this, type, converted)))
261 return {};
262 result.addInputs(idx, converted);
263 }
264
265 // If function does not return anything, create the void result type,
266 // if it returns on element, convert it, otherwise pack the result types into
267 // a struct.
268 Type resultType =
269 funcTy.getNumResults() == 0
270 ? LLVM::LLVMVoidType::get(ctx: &getContext())
271 : packFunctionResults(types: funcTy.getResults(), useBarePointerCallConv: useBarePtrCallConv);
272 if (!resultType)
273 return {};
274 return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(),
275 isVariadic);
276}
277
278/// Converts the function type to a C-compatible format, in particular using
279/// pointers to memref descriptors for arguments.
280std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType>
281LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
282 SmallVector<Type, 4> inputs;
283
284 Type resultType = type.getNumResults() == 0
285 ? LLVM::LLVMVoidType::get(ctx: &getContext())
286 : packFunctionResults(types: type.getResults());
287 if (!resultType)
288 return {};
289
290 auto ptrType = LLVM::LLVMPointerType::get(type.getContext());
291 auto structType = dyn_cast<LLVM::LLVMStructType>(Val&: resultType);
292 if (structType) {
293 // Struct types cannot be safely returned via C interface. Make this a
294 // pointer argument, instead.
295 inputs.push_back(Elt: ptrType);
296 resultType = LLVM::LLVMVoidType::get(ctx: &getContext());
297 }
298
299 for (Type t : type.getInputs()) {
300 auto converted = convertType(t);
301 if (!converted || !LLVM::isCompatibleType(converted))
302 return {};
303 if (isa<MemRefType, UnrankedMemRefType>(t))
304 converted = ptrType;
305 inputs.push_back(converted);
306 }
307
308 return {LLVM::LLVMFunctionType::get(resultType, inputs), structType};
309}
310
311/// Convert a memref type into a list of LLVM IR types that will form the
312/// memref descriptor. The result contains the following types:
313/// 1. The pointer to the allocated data buffer, followed by
314/// 2. The pointer to the aligned data buffer, followed by
315/// 3. A lowered `index`-type integer containing the distance between the
316/// beginning of the buffer and the first element to be accessed through the
317/// view, followed by
318/// 4. An array containing as many `index`-type integers as the rank of the
319/// MemRef: the array represents the size, in number of elements, of the memref
320/// along the given dimension. For constant MemRef dimensions, the
321/// corresponding size entry is a constant whose runtime value must match the
322/// static value, followed by
323/// 5. A second array containing as many `index`-type integers as the rank of
324/// the MemRef: the second array represents the "stride" (in tensor abstraction
325/// sense), i.e. the number of consecutive elements of the underlying buffer.
326/// TODO: add assertions for the static cases.
327///
328/// If `unpackAggregates` is set to true, the arrays described in (4) and (5)
329/// are expanded into individual index-type elements.
330///
331/// template <typename Elem, typename Index, size_t Rank>
332/// struct {
333/// Elem *allocatedPtr;
334/// Elem *alignedPtr;
335/// Index offset;
336/// Index sizes[Rank]; // omitted when rank == 0
337/// Index strides[Rank]; // omitted when rank == 0
338/// };
339SmallVector<Type, 5>
340LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
341 bool unpackAggregates) const {
342 if (!isStrided(type)) {
343 emitError(
344 UnknownLoc::get(type.getContext()),
345 "conversion to strided form failed either due to non-strided layout "
346 "maps (which should have been normalized away) or other reasons");
347 return {};
348 }
349
350 Type elementType = convertType(type.getElementType());
351 if (!elementType)
352 return {};
353
354 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type: type);
355 if (failed(result: addressSpace)) {
356 emitError(UnknownLoc::get(type.getContext()),
357 "conversion of memref memory space ")
358 << type.getMemorySpace()
359 << " to integer address space "
360 "failed. Consider adding memory space conversions.";
361 return {};
362 }
363 auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
364
365 auto indexTy = getIndexType();
366
367 SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
368 auto rank = type.getRank();
369 if (rank == 0)
370 return results;
371
372 if (unpackAggregates)
373 results.insert(results.end(), 2 * rank, indexTy);
374 else
375 results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank));
376 return results;
377}
378
379unsigned
380LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type,
381 const DataLayout &layout) const {
382 // Compute the descriptor size given that of its components indicated above.
383 unsigned space = *getMemRefAddressSpace(type: type);
384 return 2 * llvm::divideCeil(Numerator: getPointerBitwidth(addressSpace: space), Denominator: 8) +
385 (1 + 2 * type.getRank()) * layout.getTypeSize(t: getIndexType());
386}
387
388/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that
389/// packs the descriptor fields as defined by `getMemRefDescriptorFields`.
390Type LLVMTypeConverter::convertMemRefType(MemRefType type) const {
391 // When converting a MemRefType to a struct with descriptor fields, do not
392 // unpack the `sizes` and `strides` arrays.
393 SmallVector<Type, 5> types =
394 getMemRefDescriptorFields(type: type, /*unpackAggregates=*/false);
395 if (types.empty())
396 return {};
397 return LLVM::LLVMStructType::getLiteral(context: &getContext(), types);
398}
399
400/// Convert an unranked memref type into a list of non-aggregate LLVM IR types
401/// that will form the unranked memref descriptor. In particular, the fields
402/// for an unranked memref descriptor are:
403/// 1. index-typed rank, the dynamic rank of this MemRef
404/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
405/// stack allocated (alloca) copy of a MemRef descriptor that got casted to
406/// be unranked.
407SmallVector<Type, 2>
408LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const {
409 return {getIndexType(), LLVM::LLVMPointerType::get(&getContext())};
410}
411
412unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize(
413 UnrankedMemRefType type, const DataLayout &layout) const {
414 // Compute the descriptor size given that of its components indicated above.
415 unsigned space = *getMemRefAddressSpace(type: type);
416 return layout.getTypeSize(t: getIndexType()) +
417 llvm::divideCeil(Numerator: getPointerBitwidth(addressSpace: space), Denominator: 8);
418}
419
420Type LLVMTypeConverter::convertUnrankedMemRefType(
421 UnrankedMemRefType type) const {
422 if (!convertType(type.getElementType()))
423 return {};
424 return LLVM::LLVMStructType::getLiteral(context: &getContext(),
425 types: getUnrankedMemRefDescriptorFields());
426}
427
428FailureOr<unsigned>
429LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const {
430 if (!type.getMemorySpace()) // Default memory space -> 0.
431 return 0;
432 std::optional<Attribute> converted =
433 convertTypeAttribute(type, attr: type.getMemorySpace());
434 if (!converted)
435 return failure();
436 if (!(*converted)) // Conversion to default is 0.
437 return 0;
438 if (auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted))
439 return explicitSpace.getInt();
440 return failure();
441}
442
443// Check if a memref type can be converted to a bare pointer.
444bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
445 if (isa<UnrankedMemRefType>(Val: type))
446 // Unranked memref is not supported in the bare pointer calling convention.
447 return false;
448
449 // Check that the memref has static shape, strides and offset. Otherwise, it
450 // cannot be lowered to a bare pointer.
451 auto memrefTy = cast<MemRefType>(type);
452 if (!memrefTy.hasStaticShape())
453 return false;
454
455 int64_t offset = 0;
456 SmallVector<int64_t, 4> strides;
457 if (failed(getStridesAndOffset(memrefTy, strides, offset)))
458 return false;
459
460 for (int64_t stride : strides)
461 if (ShapedType::isDynamic(stride))
462 return false;
463
464 return !ShapedType::isDynamic(offset);
465}
466
467/// Convert a memref type to a bare pointer to the memref element type.
468Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const {
469 if (!canConvertToBarePtr(type))
470 return {};
471 Type elementType = convertType(t: type.getElementType());
472 if (!elementType)
473 return {};
474 FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type);
475 if (failed(result: addressSpace))
476 return {};
477 return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace);
478}
479
480/// Convert an n-D vector type to an LLVM vector type:
481/// * 0-D `vector<T>` are converted to vector<1xT>
482/// * 1-D `vector<axT>` remains as is while,
483/// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to
484/// `!llvm.array<ax...array<jxvector<kxT>>>`.
485/// Returns failure for n-D scalable vector types as LLVM does not support
486/// arrays of scalable vectors.
487FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
488 auto elementType = convertType(type.getElementType());
489 if (!elementType)
490 return {};
491 if (type.getShape().empty())
492 return VectorType::get({1}, elementType);
493 Type vectorType = VectorType::get(type.getShape().back(), elementType,
494 type.getScalableDims().back());
495 assert(LLVM::isCompatibleVectorType(vectorType) &&
496 "expected vector type compatible with the LLVM dialect");
497 // Only the trailing dimension can be scalable.
498 if (llvm::is_contained(type.getScalableDims().drop_back(), true))
499 return failure();
500 auto shape = type.getShape();
501 for (int i = shape.size() - 2; i >= 0; --i)
502 vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
503 return vectorType;
504}
505
506/// Convert a type in the context of the default or bare pointer calling
507/// convention. Calling convention sensitive types, such as MemRefType and
508/// UnrankedMemRefType, are converted following the specific rules for the
509/// calling convention. Calling convention independent types are converted
510/// following the default LLVM type conversions.
511Type LLVMTypeConverter::convertCallingConventionType(
512 Type type, bool useBarePtrCallConv) const {
513 if (useBarePtrCallConv)
514 if (auto memrefTy = dyn_cast<BaseMemRefType>(Val&: type))
515 return convertMemRefToBarePtr(type: memrefTy);
516
517 return convertType(t: type);
518}
519
520/// Promote the bare pointers in 'values' that resulted from memrefs to
521/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
522/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
523void LLVMTypeConverter::promoteBarePtrsToDescriptors(
524 ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
525 SmallVectorImpl<Value> &values) const {
526 assert(stdTypes.size() == values.size() &&
527 "The number of types and values doesn't match");
528 for (unsigned i = 0, end = values.size(); i < end; ++i)
529 if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
530 values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
531 memrefTy, values[i]);
532}
533
534/// Convert a non-empty list of types of values produced by an operation into an
535/// LLVM-compatible type. In particular, if more than one value is
536/// produced, create a literal structure with elements that correspond to each
537/// of the types converted with `convertType`.
538Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
539 assert(!types.empty() && "expected non-empty list of type");
540 if (types.size() == 1)
541 return convertType(t: types[0]);
542
543 SmallVector<Type> resultTypes;
544 resultTypes.reserve(N: types.size());
545 for (Type type : types) {
546 Type converted = convertType(t: type);
547 if (!converted || !LLVM::isCompatibleType(type: converted))
548 return {};
549 resultTypes.push_back(Elt: converted);
550 }
551
552 return LLVM::LLVMStructType::getLiteral(context: &getContext(), types: resultTypes);
553}
554
555/// Convert a non-empty list of types to be returned from a function into an
556/// LLVM-compatible type. In particular, if more than one value is returned,
557/// create an LLVM dialect structure type with elements that correspond to each
558/// of the types converted with `convertCallingConventionType`.
559Type LLVMTypeConverter::packFunctionResults(TypeRange types,
560 bool useBarePtrCallConv) const {
561 assert(!types.empty() && "expected non-empty list of type");
562
563 useBarePtrCallConv |= options.useBarePtrCallConv;
564 if (types.size() == 1)
565 return convertCallingConventionType(type: types.front(), useBarePtrCallConv);
566
567 SmallVector<Type> resultTypes;
568 resultTypes.reserve(N: types.size());
569 for (auto t : types) {
570 auto converted = convertCallingConventionType(type: t, useBarePtrCallConv);
571 if (!converted || !LLVM::isCompatibleType(type: converted))
572 return {};
573 resultTypes.push_back(Elt: converted);
574 }
575
576 return LLVM::LLVMStructType::getLiteral(context: &getContext(), types: resultTypes);
577}
578
579Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
580 OpBuilder &builder) const {
581 // Alloca with proper alignment. We do not expect optimizations of this
582 // alloca op and so we omit allocating at the entry block.
583 auto ptrType = LLVM::LLVMPointerType::get(builder.getContext());
584 Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(),
585 builder.getIndexAttr(1));
586 Value allocated =
587 builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one);
588 // Store into the alloca'ed descriptor.
589 builder.create<LLVM::StoreOp>(loc, operand, allocated);
590 return allocated;
591}
592
593SmallVector<Value, 4>
594LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
595 ValueRange operands, OpBuilder &builder,
596 bool useBarePtrCallConv) const {
597 SmallVector<Value, 4> promotedOperands;
598 promotedOperands.reserve(N: operands.size());
599 useBarePtrCallConv |= options.useBarePtrCallConv;
600 for (auto it : llvm::zip(t&: opOperands, u&: operands)) {
601 auto operand = std::get<0>(t&: it);
602 auto llvmOperand = std::get<1>(t&: it);
603
604 if (useBarePtrCallConv) {
605 // For the bare-ptr calling convention, we only have to extract the
606 // aligned pointer of a memref.
607 if (dyn_cast<MemRefType>(operand.getType())) {
608 MemRefDescriptor desc(llvmOperand);
609 llvmOperand = desc.alignedPtr(builder, loc);
610 } else if (isa<UnrankedMemRefType>(Val: operand.getType())) {
611 llvm_unreachable("Unranked memrefs are not supported");
612 }
613 } else {
614 if (isa<UnrankedMemRefType>(Val: operand.getType())) {
615 UnrankedMemRefDescriptor::unpack(builder, loc, packed: llvmOperand,
616 results&: promotedOperands);
617 continue;
618 }
619 if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
620 MemRefDescriptor::unpack(builder, loc, packed: llvmOperand, type: memrefType,
621 results&: promotedOperands);
622 continue;
623 }
624 }
625
626 promotedOperands.push_back(Elt: llvmOperand);
627 }
628 return promotedOperands;
629}
630
631/// Callback to convert function argument types. It converts a MemRef function
632/// argument to a list of non-aggregate types containing descriptor
633/// information, and an UnrankedmemRef function argument to a list containing
634/// the rank and a pointer to a descriptor struct.
635LogicalResult
636mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
637 SmallVectorImpl<Type> &result) {
638 if (auto memref = dyn_cast<MemRefType>(type)) {
639 // In signatures, Memref descriptors are expanded into lists of
640 // non-aggregate values.
641 auto converted =
642 converter.getMemRefDescriptorFields(type: memref, /*unpackAggregates=*/true);
643 if (converted.empty())
644 return failure();
645 result.append(converted.begin(), converted.end());
646 return success();
647 }
648 if (isa<UnrankedMemRefType>(Val: type)) {
649 auto converted = converter.getUnrankedMemRefDescriptorFields();
650 if (converted.empty())
651 return failure();
652 result.append(in_start: converted.begin(), in_end: converted.end());
653 return success();
654 }
655 auto converted = converter.convertType(t: type);
656 if (!converted)
657 return failure();
658 result.push_back(Elt: converted);
659 return success();
660}
661
662/// Callback to convert function argument types. It converts MemRef function
663/// arguments to bare pointers to the MemRef element type.
664LogicalResult
665mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
666 SmallVectorImpl<Type> &result) {
667 auto llvmTy = converter.convertCallingConventionType(
668 type, /*useBarePointerCallConv=*/useBarePtrCallConv: true);
669 if (!llvmTy)
670 return failure();
671
672 result.push_back(Elt: llvmTy);
673 return success();
674}
675

source code of mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp