1 | //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
10 | #include "MemRefDescriptor.h" |
11 | #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" |
12 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
13 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
14 | #include "llvm/ADT/ScopeExit.h" |
15 | #include "llvm/Support/Threading.h" |
16 | #include <memory> |
17 | #include <mutex> |
18 | #include <optional> |
19 | |
20 | using namespace mlir; |
21 | |
22 | SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() { |
23 | { |
24 | // Most of the time, the entry already exists in the map. |
25 | std::shared_lock<decltype(callStackMutex)> lock(callStackMutex, |
26 | std::defer_lock); |
27 | if (getContext().isMultithreadingEnabled()) |
28 | lock.lock(); |
29 | auto recursiveStack = conversionCallStack.find(Val: llvm::get_threadid()); |
30 | if (recursiveStack != conversionCallStack.end()) |
31 | return *recursiveStack->second; |
32 | } |
33 | |
34 | // First time this thread gets here, we have to get an exclusive access to |
35 | // inset in the map |
36 | std::unique_lock<decltype(callStackMutex)> lock(callStackMutex); |
37 | auto recursiveStackInserted = conversionCallStack.insert(KV: std::make_pair( |
38 | x: llvm::get_threadid(), y: std::make_unique<SmallVector<Type>>())); |
39 | return *recursiveStackInserted.first->second; |
40 | } |
41 | |
42 | /// Create an LLVMTypeConverter using default LowerToLLVMOptions. |
43 | LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, |
44 | const DataLayoutAnalysis *analysis) |
45 | : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {} |
46 | |
47 | /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. |
48 | LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, |
49 | const LowerToLLVMOptions &options, |
50 | const DataLayoutAnalysis *analysis) |
51 | : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options), |
52 | dataLayoutAnalysis(analysis) { |
53 | assert(llvmDialect && "LLVM IR dialect is not registered" ); |
54 | |
55 | // Register conversions for the builtin types. |
56 | addConversion(callback: [&](ComplexType type) { return convertComplexType(type); }); |
57 | addConversion(callback: [&](FloatType type) { return convertFloatType(type); }); |
58 | addConversion([&](FunctionType type) { return convertFunctionType(type); }); |
59 | addConversion([&](IndexType type) { return convertIndexType(type); }); |
60 | addConversion([&](IntegerType type) { return convertIntegerType(type); }); |
61 | addConversion([&](MemRefType type) { return convertMemRefType(type); }); |
62 | addConversion( |
63 | [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); |
64 | addConversion(callback: [&](VectorType type) -> std::optional<Type> { |
65 | FailureOr<Type> llvmType = convertVectorType(type: type); |
66 | if (failed(result: llvmType)) |
67 | return std::nullopt; |
68 | return llvmType; |
69 | }); |
70 | |
71 | // LLVM-compatible types are legal, so add a pass-through conversion. Do this |
72 | // before the conversions below since conversions are attempted in reverse |
73 | // order and those should take priority. |
74 | addConversion(callback: [](Type type) { |
75 | return LLVM::isCompatibleType(type) ? std::optional<Type>(type) |
76 | : std::nullopt; |
77 | }); |
78 | |
79 | addConversion(callback: [&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results) |
80 | -> std::optional<LogicalResult> { |
81 | // Fastpath for types that won't be converted by this callback anyway. |
82 | if (LLVM::isCompatibleType(type)) { |
83 | results.push_back(type); |
84 | return success(); |
85 | } |
86 | |
87 | if (type.isIdentified()) { |
88 | auto convertedType = LLVM::LLVMStructType::getIdentified( |
89 | context: type.getContext(), name: ("_Converted." + type.getName()).str()); |
90 | |
91 | SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack(); |
92 | if (llvm::count(Range&: recursiveStack, Element: type)) { |
93 | results.push_back(Elt: convertedType); |
94 | return success(); |
95 | } |
96 | recursiveStack.push_back(type); |
97 | auto popConversionCallStack = llvm::make_scope_exit( |
98 | F: [&recursiveStack]() { recursiveStack.pop_back(); }); |
99 | |
100 | SmallVector<Type> convertedElemTypes; |
101 | convertedElemTypes.reserve(N: type.getBody().size()); |
102 | if (failed(result: convertTypes(types: type.getBody(), results&: convertedElemTypes))) |
103 | return std::nullopt; |
104 | |
105 | // If the converted type has not been initialized yet, just set its body |
106 | // to be the converted arguments and return. |
107 | if (!convertedType.isInitialized()) { |
108 | if (failed( |
109 | convertedType.setBody(convertedElemTypes, type.isPacked()))) { |
110 | return failure(); |
111 | } |
112 | results.push_back(Elt: convertedType); |
113 | return success(); |
114 | } |
115 | |
116 | // If it has been initialized, has the same body and packed bit, just use |
117 | // it. This ensures that recursive structs keep being recursive rather |
118 | // than including a non-updated name. |
119 | if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) && |
120 | convertedType.isPacked() == type.isPacked()) { |
121 | results.push_back(Elt: convertedType); |
122 | return success(); |
123 | } |
124 | |
125 | return failure(); |
126 | } |
127 | |
128 | SmallVector<Type> convertedSubtypes; |
129 | convertedSubtypes.reserve(N: type.getBody().size()); |
130 | if (failed(result: convertTypes(types: type.getBody(), results&: convertedSubtypes))) |
131 | return std::nullopt; |
132 | |
133 | results.push_back(Elt: LLVM::LLVMStructType::getLiteral( |
134 | context: type.getContext(), types: convertedSubtypes, isPacked: type.isPacked())); |
135 | return success(); |
136 | }); |
137 | addConversion(callback: [&](LLVM::LLVMArrayType type) -> std::optional<Type> { |
138 | if (auto element = convertType(type.getElementType())) |
139 | return LLVM::LLVMArrayType::get(element, type.getNumElements()); |
140 | return std::nullopt; |
141 | }); |
142 | addConversion(callback: [&](LLVM::LLVMFunctionType type) -> std::optional<Type> { |
143 | Type convertedResType = convertType(type.getReturnType()); |
144 | if (!convertedResType) |
145 | return std::nullopt; |
146 | |
147 | SmallVector<Type> convertedArgTypes; |
148 | convertedArgTypes.reserve(N: type.getNumParams()); |
149 | if (failed(convertTypes(types: type.getParams(), results&: convertedArgTypes))) |
150 | return std::nullopt; |
151 | |
152 | return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes, |
153 | type.isVarArg()); |
154 | }); |
155 | |
156 | // Materialization for memrefs creates descriptor structs from individual |
157 | // values constituting them, when descriptors are used, i.e. more than one |
158 | // value represents a memref. |
159 | addArgumentMaterialization( |
160 | callback: [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, |
161 | Location loc) -> std::optional<Value> { |
162 | if (inputs.size() == 1) |
163 | return std::nullopt; |
164 | return UnrankedMemRefDescriptor::pack(builder, loc, converter: *this, type: resultType, |
165 | values: inputs); |
166 | }); |
167 | addArgumentMaterialization(callback: [&](OpBuilder &builder, MemRefType resultType, |
168 | ValueRange inputs, |
169 | Location loc) -> std::optional<Value> { |
170 | // TODO: bare ptr conversion could be handled here but we would need a way |
171 | // to distinguish between FuncOp and other regions. |
172 | if (inputs.size() == 1) |
173 | return std::nullopt; |
174 | return MemRefDescriptor::pack(builder, loc, converter: *this, type: resultType, values: inputs); |
175 | }); |
176 | // Add generic source and target materializations to handle cases where |
177 | // non-LLVM types persist after an LLVM conversion. |
178 | addSourceMaterialization(callback: [&](OpBuilder &builder, Type resultType, |
179 | ValueRange inputs, |
180 | Location loc) -> std::optional<Value> { |
181 | if (inputs.size() != 1) |
182 | return std::nullopt; |
183 | |
184 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) |
185 | .getResult(0); |
186 | }); |
187 | addTargetMaterialization(callback: [&](OpBuilder &builder, Type resultType, |
188 | ValueRange inputs, |
189 | Location loc) -> std::optional<Value> { |
190 | if (inputs.size() != 1) |
191 | return std::nullopt; |
192 | |
193 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) |
194 | .getResult(0); |
195 | }); |
196 | |
197 | // Integer memory spaces map to themselves. |
198 | addTypeAttributeConversion( |
199 | [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; }); |
200 | } |
201 | |
202 | /// Returns the MLIR context. |
203 | MLIRContext &LLVMTypeConverter::getContext() const { |
204 | return *getDialect()->getContext(); |
205 | } |
206 | |
207 | Type LLVMTypeConverter::getIndexType() const { |
208 | return IntegerType::get(&getContext(), getIndexTypeBitwidth()); |
209 | } |
210 | |
211 | unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const { |
212 | return options.dataLayout.getPointerSizeInBits(AS: addressSpace); |
213 | } |
214 | |
215 | Type LLVMTypeConverter::convertIndexType(IndexType type) const { |
216 | return getIndexType(); |
217 | } |
218 | |
219 | Type LLVMTypeConverter::convertIntegerType(IntegerType type) const { |
220 | return IntegerType::get(&getContext(), type.getWidth()); |
221 | } |
222 | |
223 | Type LLVMTypeConverter::convertFloatType(FloatType type) const { |
224 | if (type.isFloat8E5M2() || type.isFloat8E4M3FN() || type.isFloat8E5M2FNUZ() || |
225 | type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) |
226 | return IntegerType::get(&getContext(), type.getWidth()); |
227 | return type; |
228 | } |
229 | |
230 | // Convert a `ComplexType` to an LLVM type. The result is a complex number |
231 | // struct with entries for the |
232 | // 1. real part and for the |
233 | // 2. imaginary part. |
234 | Type LLVMTypeConverter::convertComplexType(ComplexType type) const { |
235 | auto elementType = convertType(type.getElementType()); |
236 | return LLVM::LLVMStructType::getLiteral(context: &getContext(), |
237 | types: {elementType, elementType}); |
238 | } |
239 | |
240 | // Except for signatures, MLIR function types are converted into LLVM |
241 | // pointer-to-function types. |
242 | Type LLVMTypeConverter::convertFunctionType(FunctionType type) const { |
243 | return LLVM::LLVMPointerType::get(type.getContext()); |
244 | } |
245 | |
246 | // Function types are converted to LLVM Function types by recursively converting |
247 | // argument and result types. If MLIR Function has zero results, the LLVM |
248 | // Function has one VoidType result. If MLIR Function has more than one result, |
249 | // they are into an LLVM StructType in their order of appearance. |
250 | Type LLVMTypeConverter::convertFunctionSignature( |
251 | FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, |
252 | LLVMTypeConverter::SignatureConversion &result) const { |
253 | // Select the argument converter depending on the calling convention. |
254 | useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv; |
255 | auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter |
256 | : structFuncArgTypeConverter; |
257 | // Convert argument types one by one and check for errors. |
258 | for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) { |
259 | SmallVector<Type, 8> converted; |
260 | if (failed(funcArgConverter(*this, type, converted))) |
261 | return {}; |
262 | result.addInputs(idx, converted); |
263 | } |
264 | |
265 | // If function does not return anything, create the void result type, |
266 | // if it returns on element, convert it, otherwise pack the result types into |
267 | // a struct. |
268 | Type resultType = |
269 | funcTy.getNumResults() == 0 |
270 | ? LLVM::LLVMVoidType::get(ctx: &getContext()) |
271 | : packFunctionResults(types: funcTy.getResults(), useBarePointerCallConv: useBarePtrCallConv); |
272 | if (!resultType) |
273 | return {}; |
274 | return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(), |
275 | isVariadic); |
276 | } |
277 | |
278 | /// Converts the function type to a C-compatible format, in particular using |
279 | /// pointers to memref descriptors for arguments. |
280 | std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType> |
281 | LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const { |
282 | SmallVector<Type, 4> inputs; |
283 | |
284 | Type resultType = type.getNumResults() == 0 |
285 | ? LLVM::LLVMVoidType::get(ctx: &getContext()) |
286 | : packFunctionResults(types: type.getResults()); |
287 | if (!resultType) |
288 | return {}; |
289 | |
290 | auto ptrType = LLVM::LLVMPointerType::get(type.getContext()); |
291 | auto structType = dyn_cast<LLVM::LLVMStructType>(Val&: resultType); |
292 | if (structType) { |
293 | // Struct types cannot be safely returned via C interface. Make this a |
294 | // pointer argument, instead. |
295 | inputs.push_back(Elt: ptrType); |
296 | resultType = LLVM::LLVMVoidType::get(ctx: &getContext()); |
297 | } |
298 | |
299 | for (Type t : type.getInputs()) { |
300 | auto converted = convertType(t); |
301 | if (!converted || !LLVM::isCompatibleType(converted)) |
302 | return {}; |
303 | if (isa<MemRefType, UnrankedMemRefType>(t)) |
304 | converted = ptrType; |
305 | inputs.push_back(converted); |
306 | } |
307 | |
308 | return {LLVM::LLVMFunctionType::get(resultType, inputs), structType}; |
309 | } |
310 | |
311 | /// Convert a memref type into a list of LLVM IR types that will form the |
312 | /// memref descriptor. The result contains the following types: |
313 | /// 1. The pointer to the allocated data buffer, followed by |
314 | /// 2. The pointer to the aligned data buffer, followed by |
315 | /// 3. A lowered `index`-type integer containing the distance between the |
316 | /// beginning of the buffer and the first element to be accessed through the |
317 | /// view, followed by |
318 | /// 4. An array containing as many `index`-type integers as the rank of the |
319 | /// MemRef: the array represents the size, in number of elements, of the memref |
320 | /// along the given dimension. For constant MemRef dimensions, the |
321 | /// corresponding size entry is a constant whose runtime value must match the |
322 | /// static value, followed by |
323 | /// 5. A second array containing as many `index`-type integers as the rank of |
324 | /// the MemRef: the second array represents the "stride" (in tensor abstraction |
325 | /// sense), i.e. the number of consecutive elements of the underlying buffer. |
326 | /// TODO: add assertions for the static cases. |
327 | /// |
328 | /// If `unpackAggregates` is set to true, the arrays described in (4) and (5) |
329 | /// are expanded into individual index-type elements. |
330 | /// |
331 | /// template <typename Elem, typename Index, size_t Rank> |
332 | /// struct { |
333 | /// Elem *allocatedPtr; |
334 | /// Elem *alignedPtr; |
335 | /// Index offset; |
336 | /// Index sizes[Rank]; // omitted when rank == 0 |
337 | /// Index strides[Rank]; // omitted when rank == 0 |
338 | /// }; |
339 | SmallVector<Type, 5> |
340 | LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, |
341 | bool unpackAggregates) const { |
342 | if (!isStrided(type)) { |
343 | emitError( |
344 | UnknownLoc::get(type.getContext()), |
345 | "conversion to strided form failed either due to non-strided layout " |
346 | "maps (which should have been normalized away) or other reasons" ); |
347 | return {}; |
348 | } |
349 | |
350 | Type elementType = convertType(type.getElementType()); |
351 | if (!elementType) |
352 | return {}; |
353 | |
354 | FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type: type); |
355 | if (failed(result: addressSpace)) { |
356 | emitError(UnknownLoc::get(type.getContext()), |
357 | "conversion of memref memory space " ) |
358 | << type.getMemorySpace() |
359 | << " to integer address space " |
360 | "failed. Consider adding memory space conversions." ; |
361 | return {}; |
362 | } |
363 | auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace); |
364 | |
365 | auto indexTy = getIndexType(); |
366 | |
367 | SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy}; |
368 | auto rank = type.getRank(); |
369 | if (rank == 0) |
370 | return results; |
371 | |
372 | if (unpackAggregates) |
373 | results.insert(results.end(), 2 * rank, indexTy); |
374 | else |
375 | results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank)); |
376 | return results; |
377 | } |
378 | |
379 | unsigned |
380 | LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type, |
381 | const DataLayout &layout) const { |
382 | // Compute the descriptor size given that of its components indicated above. |
383 | unsigned space = *getMemRefAddressSpace(type: type); |
384 | return 2 * llvm::divideCeil(Numerator: getPointerBitwidth(addressSpace: space), Denominator: 8) + |
385 | (1 + 2 * type.getRank()) * layout.getTypeSize(t: getIndexType()); |
386 | } |
387 | |
388 | /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that |
389 | /// packs the descriptor fields as defined by `getMemRefDescriptorFields`. |
390 | Type LLVMTypeConverter::convertMemRefType(MemRefType type) const { |
391 | // When converting a MemRefType to a struct with descriptor fields, do not |
392 | // unpack the `sizes` and `strides` arrays. |
393 | SmallVector<Type, 5> types = |
394 | getMemRefDescriptorFields(type: type, /*unpackAggregates=*/false); |
395 | if (types.empty()) |
396 | return {}; |
397 | return LLVM::LLVMStructType::getLiteral(context: &getContext(), types); |
398 | } |
399 | |
400 | /// Convert an unranked memref type into a list of non-aggregate LLVM IR types |
401 | /// that will form the unranked memref descriptor. In particular, the fields |
402 | /// for an unranked memref descriptor are: |
403 | /// 1. index-typed rank, the dynamic rank of this MemRef |
404 | /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be |
405 | /// stack allocated (alloca) copy of a MemRef descriptor that got casted to |
406 | /// be unranked. |
407 | SmallVector<Type, 2> |
408 | LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const { |
409 | return {getIndexType(), LLVM::LLVMPointerType::get(&getContext())}; |
410 | } |
411 | |
412 | unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize( |
413 | UnrankedMemRefType type, const DataLayout &layout) const { |
414 | // Compute the descriptor size given that of its components indicated above. |
415 | unsigned space = *getMemRefAddressSpace(type: type); |
416 | return layout.getTypeSize(t: getIndexType()) + |
417 | llvm::divideCeil(Numerator: getPointerBitwidth(addressSpace: space), Denominator: 8); |
418 | } |
419 | |
420 | Type LLVMTypeConverter::convertUnrankedMemRefType( |
421 | UnrankedMemRefType type) const { |
422 | if (!convertType(type.getElementType())) |
423 | return {}; |
424 | return LLVM::LLVMStructType::getLiteral(context: &getContext(), |
425 | types: getUnrankedMemRefDescriptorFields()); |
426 | } |
427 | |
428 | FailureOr<unsigned> |
429 | LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const { |
430 | if (!type.getMemorySpace()) // Default memory space -> 0. |
431 | return 0; |
432 | std::optional<Attribute> converted = |
433 | convertTypeAttribute(type, attr: type.getMemorySpace()); |
434 | if (!converted) |
435 | return failure(); |
436 | if (!(*converted)) // Conversion to default is 0. |
437 | return 0; |
438 | if (auto explicitSpace = llvm::dyn_cast_if_present<IntegerAttr>(*converted)) |
439 | return explicitSpace.getInt(); |
440 | return failure(); |
441 | } |
442 | |
443 | // Check if a memref type can be converted to a bare pointer. |
444 | bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) { |
445 | if (isa<UnrankedMemRefType>(Val: type)) |
446 | // Unranked memref is not supported in the bare pointer calling convention. |
447 | return false; |
448 | |
449 | // Check that the memref has static shape, strides and offset. Otherwise, it |
450 | // cannot be lowered to a bare pointer. |
451 | auto memrefTy = cast<MemRefType>(type); |
452 | if (!memrefTy.hasStaticShape()) |
453 | return false; |
454 | |
455 | int64_t offset = 0; |
456 | SmallVector<int64_t, 4> strides; |
457 | if (failed(getStridesAndOffset(memrefTy, strides, offset))) |
458 | return false; |
459 | |
460 | for (int64_t stride : strides) |
461 | if (ShapedType::isDynamic(stride)) |
462 | return false; |
463 | |
464 | return !ShapedType::isDynamic(offset); |
465 | } |
466 | |
467 | /// Convert a memref type to a bare pointer to the memref element type. |
468 | Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const { |
469 | if (!canConvertToBarePtr(type)) |
470 | return {}; |
471 | Type elementType = convertType(t: type.getElementType()); |
472 | if (!elementType) |
473 | return {}; |
474 | FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type); |
475 | if (failed(result: addressSpace)) |
476 | return {}; |
477 | return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace); |
478 | } |
479 | |
480 | /// Convert an n-D vector type to an LLVM vector type: |
481 | /// * 0-D `vector<T>` are converted to vector<1xT> |
482 | /// * 1-D `vector<axT>` remains as is while, |
483 | /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to |
484 | /// `!llvm.array<ax...array<jxvector<kxT>>>`. |
485 | /// Returns failure for n-D scalable vector types as LLVM does not support |
486 | /// arrays of scalable vectors. |
487 | FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const { |
488 | auto elementType = convertType(type.getElementType()); |
489 | if (!elementType) |
490 | return {}; |
491 | if (type.getShape().empty()) |
492 | return VectorType::get({1}, elementType); |
493 | Type vectorType = VectorType::get(type.getShape().back(), elementType, |
494 | type.getScalableDims().back()); |
495 | assert(LLVM::isCompatibleVectorType(vectorType) && |
496 | "expected vector type compatible with the LLVM dialect" ); |
497 | // Only the trailing dimension can be scalable. |
498 | if (llvm::is_contained(type.getScalableDims().drop_back(), true)) |
499 | return failure(); |
500 | auto shape = type.getShape(); |
501 | for (int i = shape.size() - 2; i >= 0; --i) |
502 | vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); |
503 | return vectorType; |
504 | } |
505 | |
506 | /// Convert a type in the context of the default or bare pointer calling |
507 | /// convention. Calling convention sensitive types, such as MemRefType and |
508 | /// UnrankedMemRefType, are converted following the specific rules for the |
509 | /// calling convention. Calling convention independent types are converted |
510 | /// following the default LLVM type conversions. |
511 | Type LLVMTypeConverter::convertCallingConventionType( |
512 | Type type, bool useBarePtrCallConv) const { |
513 | if (useBarePtrCallConv) |
514 | if (auto memrefTy = dyn_cast<BaseMemRefType>(Val&: type)) |
515 | return convertMemRefToBarePtr(type: memrefTy); |
516 | |
517 | return convertType(t: type); |
518 | } |
519 | |
520 | /// Promote the bare pointers in 'values' that resulted from memrefs to |
521 | /// descriptors. 'stdTypes' holds they types of 'values' before the conversion |
522 | /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). |
523 | void LLVMTypeConverter::promoteBarePtrsToDescriptors( |
524 | ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes, |
525 | SmallVectorImpl<Value> &values) const { |
526 | assert(stdTypes.size() == values.size() && |
527 | "The number of types and values doesn't match" ); |
528 | for (unsigned i = 0, end = values.size(); i < end; ++i) |
529 | if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i])) |
530 | values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, |
531 | memrefTy, values[i]); |
532 | } |
533 | |
534 | /// Convert a non-empty list of types of values produced by an operation into an |
535 | /// LLVM-compatible type. In particular, if more than one value is |
536 | /// produced, create a literal structure with elements that correspond to each |
537 | /// of the types converted with `convertType`. |
538 | Type LLVMTypeConverter::packOperationResults(TypeRange types) const { |
539 | assert(!types.empty() && "expected non-empty list of type" ); |
540 | if (types.size() == 1) |
541 | return convertType(t: types[0]); |
542 | |
543 | SmallVector<Type> resultTypes; |
544 | resultTypes.reserve(N: types.size()); |
545 | for (Type type : types) { |
546 | Type converted = convertType(t: type); |
547 | if (!converted || !LLVM::isCompatibleType(type: converted)) |
548 | return {}; |
549 | resultTypes.push_back(Elt: converted); |
550 | } |
551 | |
552 | return LLVM::LLVMStructType::getLiteral(context: &getContext(), types: resultTypes); |
553 | } |
554 | |
555 | /// Convert a non-empty list of types to be returned from a function into an |
556 | /// LLVM-compatible type. In particular, if more than one value is returned, |
557 | /// create an LLVM dialect structure type with elements that correspond to each |
558 | /// of the types converted with `convertCallingConventionType`. |
559 | Type LLVMTypeConverter::packFunctionResults(TypeRange types, |
560 | bool useBarePtrCallConv) const { |
561 | assert(!types.empty() && "expected non-empty list of type" ); |
562 | |
563 | useBarePtrCallConv |= options.useBarePtrCallConv; |
564 | if (types.size() == 1) |
565 | return convertCallingConventionType(type: types.front(), useBarePtrCallConv); |
566 | |
567 | SmallVector<Type> resultTypes; |
568 | resultTypes.reserve(N: types.size()); |
569 | for (auto t : types) { |
570 | auto converted = convertCallingConventionType(type: t, useBarePtrCallConv); |
571 | if (!converted || !LLVM::isCompatibleType(type: converted)) |
572 | return {}; |
573 | resultTypes.push_back(Elt: converted); |
574 | } |
575 | |
576 | return LLVM::LLVMStructType::getLiteral(context: &getContext(), types: resultTypes); |
577 | } |
578 | |
579 | Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, |
580 | OpBuilder &builder) const { |
581 | // Alloca with proper alignment. We do not expect optimizations of this |
582 | // alloca op and so we omit allocating at the entry block. |
583 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
584 | Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), |
585 | builder.getIndexAttr(1)); |
586 | Value allocated = |
587 | builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one); |
588 | // Store into the alloca'ed descriptor. |
589 | builder.create<LLVM::StoreOp>(loc, operand, allocated); |
590 | return allocated; |
591 | } |
592 | |
593 | SmallVector<Value, 4> |
594 | LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, |
595 | ValueRange operands, OpBuilder &builder, |
596 | bool useBarePtrCallConv) const { |
597 | SmallVector<Value, 4> promotedOperands; |
598 | promotedOperands.reserve(N: operands.size()); |
599 | useBarePtrCallConv |= options.useBarePtrCallConv; |
600 | for (auto it : llvm::zip(t&: opOperands, u&: operands)) { |
601 | auto operand = std::get<0>(t&: it); |
602 | auto llvmOperand = std::get<1>(t&: it); |
603 | |
604 | if (useBarePtrCallConv) { |
605 | // For the bare-ptr calling convention, we only have to extract the |
606 | // aligned pointer of a memref. |
607 | if (dyn_cast<MemRefType>(operand.getType())) { |
608 | MemRefDescriptor desc(llvmOperand); |
609 | llvmOperand = desc.alignedPtr(builder, loc); |
610 | } else if (isa<UnrankedMemRefType>(Val: operand.getType())) { |
611 | llvm_unreachable("Unranked memrefs are not supported" ); |
612 | } |
613 | } else { |
614 | if (isa<UnrankedMemRefType>(Val: operand.getType())) { |
615 | UnrankedMemRefDescriptor::unpack(builder, loc, packed: llvmOperand, |
616 | results&: promotedOperands); |
617 | continue; |
618 | } |
619 | if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) { |
620 | MemRefDescriptor::unpack(builder, loc, packed: llvmOperand, type: memrefType, |
621 | results&: promotedOperands); |
622 | continue; |
623 | } |
624 | } |
625 | |
626 | promotedOperands.push_back(Elt: llvmOperand); |
627 | } |
628 | return promotedOperands; |
629 | } |
630 | |
631 | /// Callback to convert function argument types. It converts a MemRef function |
632 | /// argument to a list of non-aggregate types containing descriptor |
633 | /// information, and an UnrankedmemRef function argument to a list containing |
634 | /// the rank and a pointer to a descriptor struct. |
635 | LogicalResult |
636 | mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, |
637 | SmallVectorImpl<Type> &result) { |
638 | if (auto memref = dyn_cast<MemRefType>(type)) { |
639 | // In signatures, Memref descriptors are expanded into lists of |
640 | // non-aggregate values. |
641 | auto converted = |
642 | converter.getMemRefDescriptorFields(type: memref, /*unpackAggregates=*/true); |
643 | if (converted.empty()) |
644 | return failure(); |
645 | result.append(converted.begin(), converted.end()); |
646 | return success(); |
647 | } |
648 | if (isa<UnrankedMemRefType>(Val: type)) { |
649 | auto converted = converter.getUnrankedMemRefDescriptorFields(); |
650 | if (converted.empty()) |
651 | return failure(); |
652 | result.append(in_start: converted.begin(), in_end: converted.end()); |
653 | return success(); |
654 | } |
655 | auto converted = converter.convertType(t: type); |
656 | if (!converted) |
657 | return failure(); |
658 | result.push_back(Elt: converted); |
659 | return success(); |
660 | } |
661 | |
662 | /// Callback to convert function argument types. It converts MemRef function |
663 | /// arguments to bare pointers to the MemRef element type. |
664 | LogicalResult |
665 | mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, |
666 | SmallVectorImpl<Type> &result) { |
667 | auto llvmTy = converter.convertCallingConventionType( |
668 | type, /*useBarePointerCallConv=*/useBarePtrCallConv: true); |
669 | if (!llvmTy) |
670 | return failure(); |
671 | |
672 | result.push_back(Elt: llvmTy); |
673 | return success(); |
674 | } |
675 | |