1 | //===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
10 | #include "MemRefDescriptor.h" |
11 | #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" |
12 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
13 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
14 | #include "llvm/ADT/ScopeExit.h" |
15 | #include "llvm/Support/Threading.h" |
16 | #include <memory> |
17 | #include <mutex> |
18 | #include <optional> |
19 | |
20 | using namespace mlir; |
21 | |
22 | SmallVector<Type> &LLVMTypeConverter::getCurrentThreadRecursiveStack() { |
23 | { |
24 | // Most of the time, the entry already exists in the map. |
25 | std::shared_lock<decltype(callStackMutex)> lock(callStackMutex, |
26 | std::defer_lock); |
27 | if (getContext().isMultithreadingEnabled()) |
28 | lock.lock(); |
29 | auto recursiveStack = conversionCallStack.find(Val: llvm::get_threadid()); |
30 | if (recursiveStack != conversionCallStack.end()) |
31 | return *recursiveStack->second; |
32 | } |
33 | |
34 | // First time this thread gets here, we have to get an exclusive access to |
35 | // inset in the map |
36 | std::unique_lock<decltype(callStackMutex)> lock(callStackMutex); |
37 | auto recursiveStackInserted = conversionCallStack.insert(KV: std::make_pair( |
38 | x: llvm::get_threadid(), y: std::make_unique<SmallVector<Type>>())); |
39 | return *recursiveStackInserted.first->second; |
40 | } |
41 | |
42 | /// Create an LLVMTypeConverter using default LowerToLLVMOptions. |
43 | LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, |
44 | const DataLayoutAnalysis *analysis) |
45 | : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {} |
46 | |
47 | /// Helper function that checks if the given value range is a bare pointer. |
48 | static bool isBarePointer(ValueRange values) { |
49 | return values.size() == 1 && |
50 | isa<LLVM::LLVMPointerType>(Val: values.front().getType()); |
51 | } |
52 | |
53 | /// Pack SSA values into an unranked memref descriptor struct. |
54 | static Value packUnrankedMemRefDesc(OpBuilder &builder, |
55 | UnrankedMemRefType resultType, |
56 | ValueRange inputs, Location loc, |
57 | const LLVMTypeConverter &converter) { |
58 | // Note: Bare pointers are not supported for unranked memrefs because a |
59 | // memref descriptor cannot be built just from a bare pointer. |
60 | if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields()) |
61 | return Value(); |
62 | return UnrankedMemRefDescriptor::pack(builder, loc, converter, type: resultType, |
63 | values: inputs); |
64 | } |
65 | |
66 | /// Pack SSA values into a ranked memref descriptor struct. |
67 | static Value packRankedMemRefDesc(OpBuilder &builder, MemRefType resultType, |
68 | ValueRange inputs, Location loc, |
69 | const LLVMTypeConverter &converter) { |
70 | assert(resultType && "expected non-null result type"); |
71 | if (isBarePointer(values: inputs)) |
72 | return MemRefDescriptor::fromStaticShape(builder, loc, converter, |
73 | resultType, inputs[0]); |
74 | if (TypeRange(inputs) == |
75 | converter.getMemRefDescriptorFields(type: resultType, |
76 | /*unpackAggregates=*/true)) |
77 | return MemRefDescriptor::pack(builder, loc, converter, type: resultType, values: inputs); |
78 | // The inputs are neither a bare pointer nor an unpacked memref descriptor. |
79 | // This materialization function cannot be used. |
80 | return Value(); |
81 | } |
82 | |
83 | /// MemRef descriptor elements -> UnrankedMemRefType |
84 | static Value unrankedMemRefMaterialization(OpBuilder &builder, |
85 | UnrankedMemRefType resultType, |
86 | ValueRange inputs, Location loc, |
87 | const LLVMTypeConverter &converter) { |
88 | // A source materialization must return a value of type |
89 | // `resultType`, so insert a cast from the memref descriptor type |
90 | // (!llvm.struct) to the original memref type. |
91 | Value packed = |
92 | packUnrankedMemRefDesc(builder, resultType, inputs, loc, converter); |
93 | if (!packed) |
94 | return Value(); |
95 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed) |
96 | .getResult(0); |
97 | } |
98 | |
99 | /// MemRef descriptor elements -> MemRefType |
100 | static Value rankedMemRefMaterialization(OpBuilder &builder, |
101 | MemRefType resultType, |
102 | ValueRange inputs, Location loc, |
103 | const LLVMTypeConverter &converter) { |
104 | // A source materialization must return a value of type `resultType`, |
105 | // so insert a cast from the memref descriptor type (!llvm.struct) to the |
106 | // original memref type. |
107 | Value packed = |
108 | packRankedMemRefDesc(builder, resultType, inputs, loc, converter); |
109 | if (!packed) |
110 | return Value(); |
111 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed) |
112 | .getResult(0); |
113 | } |
114 | |
115 | /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. |
116 | LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, |
117 | const LowerToLLVMOptions &options, |
118 | const DataLayoutAnalysis *analysis) |
119 | : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()), options(options), |
120 | dataLayoutAnalysis(analysis) { |
121 | assert(llvmDialect && "LLVM IR dialect is not registered"); |
122 | |
123 | // Register conversions for the builtin types. |
124 | addConversion(callback: [&](ComplexType type) { return convertComplexType(type); }); |
125 | addConversion([&](FloatType type) { return convertFloatType(type); }); |
126 | addConversion([&](FunctionType type) { return convertFunctionType(type); }); |
127 | addConversion([&](IndexType type) { return convertIndexType(type); }); |
128 | addConversion([&](IntegerType type) { return convertIntegerType(type); }); |
129 | addConversion([&](MemRefType type) { return convertMemRefType(type); }); |
130 | addConversion( |
131 | [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); |
132 | addConversion(callback: [&](VectorType type) -> std::optional<Type> { |
133 | FailureOr<Type> llvmType = convertVectorType(type: type); |
134 | if (failed(Result: llvmType)) |
135 | return std::nullopt; |
136 | return llvmType; |
137 | }); |
138 | |
139 | // LLVM-compatible types are legal, so add a pass-through conversion. Do this |
140 | // before the conversions below since conversions are attempted in reverse |
141 | // order and those should take priority. |
142 | addConversion(callback: [](Type type) { |
143 | return LLVM::isCompatibleType(type) ? std::optional<Type>(type) |
144 | : std::nullopt; |
145 | }); |
146 | |
147 | addConversion(callback: [&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results) |
148 | -> std::optional<LogicalResult> { |
149 | // Fastpath for types that won't be converted by this callback anyway. |
150 | if (LLVM::isCompatibleType(type: type)) { |
151 | results.push_back(Elt: type); |
152 | return success(); |
153 | } |
154 | |
155 | if (type.isIdentified()) { |
156 | auto convertedType = LLVM::LLVMStructType::getIdentified( |
157 | type.getContext(), ("_Converted."+ type.getName()).str()); |
158 | |
159 | SmallVectorImpl<Type> &recursiveStack = getCurrentThreadRecursiveStack(); |
160 | if (llvm::count(recursiveStack, type)) { |
161 | results.push_back(Elt: convertedType); |
162 | return success(); |
163 | } |
164 | recursiveStack.push_back(Elt: type); |
165 | auto popConversionCallStack = llvm::make_scope_exit( |
166 | F: [&recursiveStack]() { recursiveStack.pop_back(); }); |
167 | |
168 | SmallVector<Type> convertedElemTypes; |
169 | convertedElemTypes.reserve(N: type.getBody().size()); |
170 | if (failed(convertTypes(types: type.getBody(), results&: convertedElemTypes))) |
171 | return std::nullopt; |
172 | |
173 | // If the converted type has not been initialized yet, just set its body |
174 | // to be the converted arguments and return. |
175 | if (!convertedType.isInitialized()) { |
176 | if (failed( |
177 | convertedType.setBody(convertedElemTypes, type.isPacked()))) { |
178 | return failure(); |
179 | } |
180 | results.push_back(Elt: convertedType); |
181 | return success(); |
182 | } |
183 | |
184 | // If it has been initialized, has the same body and packed bit, just use |
185 | // it. This ensures that recursive structs keep being recursive rather |
186 | // than including a non-updated name. |
187 | if (TypeRange(convertedType.getBody()) == TypeRange(convertedElemTypes) && |
188 | convertedType.isPacked() == type.isPacked()) { |
189 | results.push_back(Elt: convertedType); |
190 | return success(); |
191 | } |
192 | |
193 | return failure(); |
194 | } |
195 | |
196 | SmallVector<Type> convertedSubtypes; |
197 | convertedSubtypes.reserve(N: type.getBody().size()); |
198 | if (failed(convertTypes(types: type.getBody(), results&: convertedSubtypes))) |
199 | return std::nullopt; |
200 | |
201 | results.push_back(LLVM::LLVMStructType::getLiteral( |
202 | type.getContext(), convertedSubtypes, type.isPacked())); |
203 | return success(); |
204 | }); |
205 | addConversion(callback: [&](LLVM::LLVMArrayType type) -> std::optional<Type> { |
206 | if (auto element = convertType(type.getElementType())) |
207 | return LLVM::LLVMArrayType::get(element, type.getNumElements()); |
208 | return std::nullopt; |
209 | }); |
210 | addConversion(callback: [&](LLVM::LLVMFunctionType type) -> std::optional<Type> { |
211 | Type convertedResType = convertType(type.getReturnType()); |
212 | if (!convertedResType) |
213 | return std::nullopt; |
214 | |
215 | SmallVector<Type> convertedArgTypes; |
216 | convertedArgTypes.reserve(N: type.getNumParams()); |
217 | if (failed(convertTypes(types: type.getParams(), results&: convertedArgTypes))) |
218 | return std::nullopt; |
219 | |
220 | return LLVM::LLVMFunctionType::get(convertedResType, convertedArgTypes, |
221 | type.isVarArg()); |
222 | }); |
223 | |
224 | // Add generic source and target materializations to handle cases where |
225 | // non-LLVM types persist after an LLVM conversion. |
226 | addSourceMaterialization(callback: [&](OpBuilder &builder, Type resultType, |
227 | ValueRange inputs, Location loc) { |
228 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) |
229 | .getResult(0); |
230 | }); |
231 | addTargetMaterialization(callback: [&](OpBuilder &builder, Type resultType, |
232 | ValueRange inputs, Location loc) { |
233 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) |
234 | .getResult(0); |
235 | }); |
236 | |
237 | // Source materializations convert from the new block argument types |
238 | // (multiple SSA values that make up a memref descriptor) back to the |
239 | // original block argument type. |
240 | addSourceMaterialization([&](OpBuilder &builder, |
241 | UnrankedMemRefType resultType, ValueRange inputs, |
242 | Location loc) { |
243 | return unrankedMemRefMaterialization(builder, resultType, inputs, loc, |
244 | *this); |
245 | }); |
246 | addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType, |
247 | ValueRange inputs, Location loc) { |
248 | return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this); |
249 | }); |
250 | |
251 | // Bare pointer -> Packed MemRef descriptor |
252 | addTargetMaterialization(callback: [&](OpBuilder &builder, Type resultType, |
253 | ValueRange inputs, Location loc, |
254 | Type originalType) -> Value { |
255 | // The original MemRef type is required to build a MemRef descriptor |
256 | // because the sizes/strides of the MemRef cannot be inferred from just the |
257 | // bare pointer. |
258 | if (!originalType) |
259 | return Value(); |
260 | if (resultType != convertType(t: originalType)) |
261 | return Value(); |
262 | if (auto memrefType = dyn_cast<MemRefType>(originalType)) |
263 | return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this); |
264 | if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType)) |
265 | return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc, |
266 | *this); |
267 | return Value(); |
268 | }); |
269 | |
270 | // Integer memory spaces map to themselves. |
271 | addTypeAttributeConversion( |
272 | [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; }); |
273 | } |
274 | |
275 | /// Returns the MLIR context. |
276 | MLIRContext &LLVMTypeConverter::getContext() const { |
277 | return *getDialect()->getContext(); |
278 | } |
279 | |
280 | Type LLVMTypeConverter::getIndexType() const { |
281 | return IntegerType::get(&getContext(), getIndexTypeBitwidth()); |
282 | } |
283 | |
284 | unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) const { |
285 | return options.dataLayout.getPointerSizeInBits(AS: addressSpace); |
286 | } |
287 | |
288 | Type LLVMTypeConverter::convertIndexType(IndexType type) const { |
289 | return getIndexType(); |
290 | } |
291 | |
292 | Type LLVMTypeConverter::convertIntegerType(IntegerType type) const { |
293 | return IntegerType::get(&getContext(), type.getWidth()); |
294 | } |
295 | |
296 | Type LLVMTypeConverter::convertFloatType(FloatType type) const { |
297 | // Valid LLVM float types are used directly. |
298 | if (LLVM::isCompatibleType(type: type)) |
299 | return type; |
300 | |
301 | // F4, F6, F8 types are converted to integer types with the same bit width. |
302 | if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, |
303 | Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, |
304 | Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType, |
305 | Float8E8M0FNUType>(type)) |
306 | return IntegerType::get(&getContext(), type.getWidth()); |
307 | |
308 | // Other floating-point types: A custom type conversion rule must be |
309 | // specified by the user. |
310 | return Type(); |
311 | } |
312 | |
313 | // Convert a `ComplexType` to an LLVM type. The result is a complex number |
314 | // struct with entries for the |
315 | // 1. real part and for the |
316 | // 2. imaginary part. |
317 | Type LLVMTypeConverter::convertComplexType(ComplexType type) const { |
318 | auto elementType = convertType(type.getElementType()); |
319 | return LLVM::LLVMStructType::getLiteral(&getContext(), |
320 | {elementType, elementType}); |
321 | } |
322 | |
323 | // Except for signatures, MLIR function types are converted into LLVM |
324 | // pointer-to-function types. |
325 | Type LLVMTypeConverter::convertFunctionType(FunctionType type) const { |
326 | return LLVM::LLVMPointerType::get(type.getContext()); |
327 | } |
328 | |
329 | /// Returns the `llvm.byval` or `llvm.byref` attributes that are present in the |
330 | /// function arguments. Returns an empty container if none of these attributes |
331 | /// are found in any of the arguments. |
332 | static void |
333 | filterByValRefArgAttrs(FunctionOpInterface funcOp, |
334 | SmallVectorImpl<std::optional<NamedAttribute>> &result) { |
335 | assert(result.empty() && "Unexpected non-empty output"); |
336 | result.resize(funcOp.getNumArguments(), std::nullopt); |
337 | bool foundByValByRefAttrs = false; |
338 | for (int argIdx : llvm::seq(funcOp.getNumArguments())) { |
339 | for (NamedAttribute namedAttr : funcOp.getArgAttrs(argIdx)) { |
340 | if ((namedAttr.getName() == LLVM::LLVMDialect::getByValAttrName() || |
341 | namedAttr.getName() == LLVM::LLVMDialect::getByRefAttrName())) { |
342 | foundByValByRefAttrs = true; |
343 | result[argIdx] = namedAttr; |
344 | break; |
345 | } |
346 | } |
347 | } |
348 | |
349 | if (!foundByValByRefAttrs) |
350 | result.clear(); |
351 | } |
352 | |
353 | // Function types are converted to LLVM Function types by recursively converting |
354 | // argument and result types. If MLIR Function has zero results, the LLVM |
355 | // Function has one VoidType result. If MLIR Function has more than one result, |
356 | // they are into an LLVM StructType in their order of appearance. |
357 | // If `byValRefNonPtrAttrs` is provided, converted types of `llvm.byval` and |
358 | // `llvm.byref` function arguments which are not LLVM pointers are overridden |
359 | // with LLVM pointers. `llvm.byval` and `llvm.byref` arguments that were already |
360 | // converted to LLVM pointer types are removed from 'byValRefNonPtrAttrs`. |
361 | Type LLVMTypeConverter::convertFunctionSignatureImpl( |
362 | FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, |
363 | LLVMTypeConverter::SignatureConversion &result, |
364 | SmallVectorImpl<std::optional<NamedAttribute>> *byValRefNonPtrAttrs) const { |
365 | // Select the argument converter depending on the calling convention. |
366 | useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv; |
367 | auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter |
368 | : structFuncArgTypeConverter; |
369 | // Convert argument types one by one and check for errors. |
370 | for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) { |
371 | SmallVector<Type, 8> converted; |
372 | if (failed(funcArgConverter(*this, type, converted))) |
373 | return {}; |
374 | |
375 | // Rewrite converted type of `llvm.byval` or `llvm.byref` function |
376 | // argument that was not converted to an LLVM pointer types. |
377 | if (byValRefNonPtrAttrs != nullptr && !byValRefNonPtrAttrs->empty() && |
378 | converted.size() == 1 && (*byValRefNonPtrAttrs)[idx].has_value()) { |
379 | // If the argument was already converted to an LLVM pointer type, we stop |
380 | // tracking it as it doesn't need more processing. |
381 | if (isa<LLVM::LLVMPointerType>(converted[0])) |
382 | (*byValRefNonPtrAttrs)[idx] = std::nullopt; |
383 | else |
384 | converted[0] = LLVM::LLVMPointerType::get(&getContext()); |
385 | } |
386 | |
387 | result.addInputs(idx, converted); |
388 | } |
389 | |
390 | // If function does not return anything, create the void result type, |
391 | // if it returns on element, convert it, otherwise pack the result types into |
392 | // a struct. |
393 | Type resultType = |
394 | funcTy.getNumResults() == 0 |
395 | ? LLVM::LLVMVoidType::get(ctx: &getContext()) |
396 | : packFunctionResults(types: funcTy.getResults(), useBarePointerCallConv: useBarePtrCallConv); |
397 | if (!resultType) |
398 | return {}; |
399 | return LLVM::LLVMFunctionType::get(resultType, result.getConvertedTypes(), |
400 | isVariadic); |
401 | } |
402 | |
403 | Type LLVMTypeConverter::convertFunctionSignature( |
404 | FunctionType funcTy, bool isVariadic, bool useBarePtrCallConv, |
405 | LLVMTypeConverter::SignatureConversion &result) const { |
406 | return convertFunctionSignatureImpl(funcTy: funcTy, isVariadic, useBarePtrCallConv, |
407 | result, |
408 | /*byValRefNonPtrAttrs=*/nullptr); |
409 | } |
410 | |
411 | Type LLVMTypeConverter::convertFunctionSignature( |
412 | FunctionOpInterface funcOp, bool isVariadic, bool useBarePtrCallConv, |
413 | LLVMTypeConverter::SignatureConversion &result, |
414 | SmallVectorImpl<std::optional<NamedAttribute>> &byValRefNonPtrAttrs) const { |
415 | // Gather all `llvm.byval` and `llvm.byref` function arguments. Only those |
416 | // that were not converted to LLVM pointer types will be returned for further |
417 | // processing. |
418 | filterByValRefArgAttrs(funcOp, byValRefNonPtrAttrs); |
419 | auto funcTy = cast<FunctionType>(funcOp.getFunctionType()); |
420 | return convertFunctionSignatureImpl(funcTy: funcTy, isVariadic, useBarePtrCallConv, |
421 | result, byValRefNonPtrAttrs: &byValRefNonPtrAttrs); |
422 | } |
423 | |
424 | /// Converts the function type to a C-compatible format, in particular using |
425 | /// pointers to memref descriptors for arguments. |
426 | std::pair<LLVM::LLVMFunctionType, LLVM::LLVMStructType> |
427 | LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const { |
428 | SmallVector<Type, 4> inputs; |
429 | |
430 | Type resultType = type.getNumResults() == 0 |
431 | ? LLVM::LLVMVoidType::get(ctx: &getContext()) |
432 | : packFunctionResults(types: type.getResults()); |
433 | if (!resultType) |
434 | return {}; |
435 | |
436 | auto ptrType = LLVM::LLVMPointerType::get(type.getContext()); |
437 | auto structType = dyn_cast<LLVM::LLVMStructType>(resultType); |
438 | if (structType) { |
439 | // Struct types cannot be safely returned via C interface. Make this a |
440 | // pointer argument, instead. |
441 | inputs.push_back(Elt: ptrType); |
442 | resultType = LLVM::LLVMVoidType::get(ctx: &getContext()); |
443 | } |
444 | |
445 | for (Type t : type.getInputs()) { |
446 | auto converted = convertType(t); |
447 | if (!converted || !LLVM::isCompatibleType(converted)) |
448 | return {}; |
449 | if (isa<MemRefType, UnrankedMemRefType>(t)) |
450 | converted = ptrType; |
451 | inputs.push_back(converted); |
452 | } |
453 | |
454 | return {LLVM::LLVMFunctionType::get(resultType, inputs), structType}; |
455 | } |
456 | |
457 | /// Convert a memref type into a list of LLVM IR types that will form the |
458 | /// memref descriptor. The result contains the following types: |
459 | /// 1. The pointer to the allocated data buffer, followed by |
460 | /// 2. The pointer to the aligned data buffer, followed by |
461 | /// 3. A lowered `index`-type integer containing the distance between the |
462 | /// beginning of the buffer and the first element to be accessed through the |
463 | /// view, followed by |
464 | /// 4. An array containing as many `index`-type integers as the rank of the |
465 | /// MemRef: the array represents the size, in number of elements, of the memref |
466 | /// along the given dimension. For constant MemRef dimensions, the |
467 | /// corresponding size entry is a constant whose runtime value must match the |
468 | /// static value, followed by |
469 | /// 5. A second array containing as many `index`-type integers as the rank of |
470 | /// the MemRef: the second array represents the "stride" (in tensor abstraction |
471 | /// sense), i.e. the number of consecutive elements of the underlying buffer. |
472 | /// TODO: add assertions for the static cases. |
473 | /// |
474 | /// If `unpackAggregates` is set to true, the arrays described in (4) and (5) |
475 | /// are expanded into individual index-type elements. |
476 | /// |
477 | /// template <typename Elem, typename Index, size_t Rank> |
478 | /// struct { |
479 | /// Elem *allocatedPtr; |
480 | /// Elem *alignedPtr; |
481 | /// Index offset; |
482 | /// Index sizes[Rank]; // omitted when rank == 0 |
483 | /// Index strides[Rank]; // omitted when rank == 0 |
484 | /// }; |
485 | SmallVector<Type, 5> |
486 | LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, |
487 | bool unpackAggregates) const { |
488 | if (!type.isStrided()) { |
489 | emitError( |
490 | UnknownLoc::get(type.getContext()), |
491 | "conversion to strided form failed either due to non-strided layout " |
492 | "maps (which should have been normalized away) or other reasons"); |
493 | return {}; |
494 | } |
495 | |
496 | Type elementType = convertType(type.getElementType()); |
497 | if (!elementType) |
498 | return {}; |
499 | |
500 | FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type: type); |
501 | if (failed(Result: addressSpace)) { |
502 | emitError(UnknownLoc::get(type.getContext()), |
503 | "conversion of memref memory space ") |
504 | << type.getMemorySpace() |
505 | << " to integer address space " |
506 | "failed. Consider adding memory space conversions."; |
507 | return {}; |
508 | } |
509 | auto ptrTy = LLVM::LLVMPointerType::get(type.getContext(), *addressSpace); |
510 | |
511 | auto indexTy = getIndexType(); |
512 | |
513 | SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy}; |
514 | auto rank = type.getRank(); |
515 | if (rank == 0) |
516 | return results; |
517 | |
518 | if (unpackAggregates) |
519 | results.insert(results.end(), 2 * rank, indexTy); |
520 | else |
521 | results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank)); |
522 | return results; |
523 | } |
524 | |
525 | unsigned |
526 | LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type, |
527 | const DataLayout &layout) const { |
528 | // Compute the descriptor size given that of its components indicated above. |
529 | unsigned space = *getMemRefAddressSpace(type: type); |
530 | return 2 * llvm::divideCeil(Numerator: getPointerBitwidth(addressSpace: space), Denominator: 8) + |
531 | (1 + 2 * type.getRank()) * layout.getTypeSize(t: getIndexType()); |
532 | } |
533 | |
534 | /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that |
535 | /// packs the descriptor fields as defined by `getMemRefDescriptorFields`. |
536 | Type LLVMTypeConverter::convertMemRefType(MemRefType type) const { |
537 | // When converting a MemRefType to a struct with descriptor fields, do not |
538 | // unpack the `sizes` and `strides` arrays. |
539 | SmallVector<Type, 5> types = |
540 | getMemRefDescriptorFields(type: type, /*unpackAggregates=*/false); |
541 | if (types.empty()) |
542 | return {}; |
543 | return LLVM::LLVMStructType::getLiteral(&getContext(), types); |
544 | } |
545 | |
546 | /// Convert an unranked memref type into a list of non-aggregate LLVM IR types |
547 | /// that will form the unranked memref descriptor. In particular, the fields |
548 | /// for an unranked memref descriptor are: |
549 | /// 1. index-typed rank, the dynamic rank of this MemRef |
550 | /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be |
551 | /// stack allocated (alloca) copy of a MemRef descriptor that got casted to |
552 | /// be unranked. |
553 | SmallVector<Type, 2> |
554 | LLVMTypeConverter::getUnrankedMemRefDescriptorFields() const { |
555 | return {getIndexType(), LLVM::LLVMPointerType::get(&getContext())}; |
556 | } |
557 | |
558 | unsigned LLVMTypeConverter::getUnrankedMemRefDescriptorSize( |
559 | UnrankedMemRefType type, const DataLayout &layout) const { |
560 | // Compute the descriptor size given that of its components indicated above. |
561 | unsigned space = *getMemRefAddressSpace(type: type); |
562 | return layout.getTypeSize(t: getIndexType()) + |
563 | llvm::divideCeil(Numerator: getPointerBitwidth(addressSpace: space), Denominator: 8); |
564 | } |
565 | |
566 | Type LLVMTypeConverter::convertUnrankedMemRefType( |
567 | UnrankedMemRefType type) const { |
568 | if (!convertType(type.getElementType())) |
569 | return {}; |
570 | return LLVM::LLVMStructType::getLiteral(&getContext(), |
571 | getUnrankedMemRefDescriptorFields()); |
572 | } |
573 | |
574 | FailureOr<unsigned> |
575 | LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const { |
576 | if (!type.getMemorySpace()) // Default memory space -> 0. |
577 | return 0; |
578 | std::optional<Attribute> converted = |
579 | convertTypeAttribute(type, attr: type.getMemorySpace()); |
580 | if (!converted) |
581 | return failure(); |
582 | if (!(*converted)) // Conversion to default is 0. |
583 | return 0; |
584 | if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) { |
585 | if (explicitSpace.getType().isIndex() || |
586 | explicitSpace.getType().isSignlessInteger()) |
587 | return explicitSpace.getInt(); |
588 | } |
589 | return failure(); |
590 | } |
591 | |
592 | // Check if a memref type can be converted to a bare pointer. |
593 | bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) { |
594 | if (isa<UnrankedMemRefType>(Val: type)) |
595 | // Unranked memref is not supported in the bare pointer calling convention. |
596 | return false; |
597 | |
598 | // Check that the memref has static shape, strides and offset. Otherwise, it |
599 | // cannot be lowered to a bare pointer. |
600 | auto memrefTy = cast<MemRefType>(type); |
601 | if (!memrefTy.hasStaticShape()) |
602 | return false; |
603 | |
604 | int64_t offset = 0; |
605 | SmallVector<int64_t, 4> strides; |
606 | if (failed(memrefTy.getStridesAndOffset(strides, offset))) |
607 | return false; |
608 | |
609 | for (int64_t stride : strides) |
610 | if (ShapedType::isDynamic(stride)) |
611 | return false; |
612 | |
613 | return !ShapedType::isDynamic(offset); |
614 | } |
615 | |
616 | /// Convert a memref type to a bare pointer to the memref element type. |
617 | Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) const { |
618 | if (!canConvertToBarePtr(type)) |
619 | return {}; |
620 | Type elementType = convertType(t: type.getElementType()); |
621 | if (!elementType) |
622 | return {}; |
623 | FailureOr<unsigned> addressSpace = getMemRefAddressSpace(type); |
624 | if (failed(Result: addressSpace)) |
625 | return {}; |
626 | return LLVM::LLVMPointerType::get(type.getContext(), *addressSpace); |
627 | } |
628 | |
629 | /// Convert an n-D vector type to an LLVM vector type: |
630 | /// * 0-D `vector<T>` are converted to vector<1xT> |
631 | /// * 1-D `vector<axT>` remains as is while, |
632 | /// * n>1 `vector<ax...xkxT>` convert via an (n-1)-D array type to |
633 | /// `!llvm.array<ax...array<jxvector<kxT>>>`. |
634 | /// As LLVM supports arrays of scalable vectors, this method will also convert |
635 | /// n-D scalable vectors provided that only the trailing dim is scalable. |
636 | FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const { |
637 | auto elementType = convertType(type.getElementType()); |
638 | if (!elementType) |
639 | return {}; |
640 | if (type.getShape().empty()) |
641 | return VectorType::get({1}, elementType); |
642 | Type vectorType = VectorType::get(type.getShape().back(), elementType, |
643 | type.getScalableDims().back()); |
644 | assert(LLVM::isCompatibleVectorType(vectorType) && |
645 | "expected vector type compatible with the LLVM dialect"); |
646 | // For n-D vector types for which a _non-trailing_ dim is scalable, |
647 | // return a failure. Supporting such cases would require LLVM |
648 | // to support something akin "scalable arrays" of vectors. |
649 | if (llvm::is_contained(type.getScalableDims().drop_back(), true)) |
650 | return failure(); |
651 | auto shape = type.getShape(); |
652 | for (int i = shape.size() - 2; i >= 0; --i) |
653 | vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); |
654 | return vectorType; |
655 | } |
656 | |
657 | /// Convert a type in the context of the default or bare pointer calling |
658 | /// convention. Calling convention sensitive types, such as MemRefType and |
659 | /// UnrankedMemRefType, are converted following the specific rules for the |
660 | /// calling convention. Calling convention independent types are converted |
661 | /// following the default LLVM type conversions. |
662 | Type LLVMTypeConverter::convertCallingConventionType( |
663 | Type type, bool useBarePtrCallConv) const { |
664 | if (useBarePtrCallConv) |
665 | if (auto memrefTy = dyn_cast<BaseMemRefType>(Val&: type)) |
666 | return convertMemRefToBarePtr(type: memrefTy); |
667 | |
668 | return convertType(t: type); |
669 | } |
670 | |
671 | /// Promote the bare pointers in 'values' that resulted from memrefs to |
672 | /// descriptors. 'stdTypes' holds they types of 'values' before the conversion |
673 | /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). |
674 | void LLVMTypeConverter::promoteBarePtrsToDescriptors( |
675 | ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes, |
676 | SmallVectorImpl<Value> &values) const { |
677 | assert(stdTypes.size() == values.size() && |
678 | "The number of types and values doesn't match"); |
679 | for (unsigned i = 0, end = values.size(); i < end; ++i) |
680 | if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i])) |
681 | values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, |
682 | memrefTy, values[i]); |
683 | } |
684 | |
685 | /// Convert a non-empty list of types of values produced by an operation into an |
686 | /// LLVM-compatible type. In particular, if more than one value is |
687 | /// produced, create a literal structure with elements that correspond to each |
688 | /// of the types converted with `convertType`. |
689 | Type LLVMTypeConverter::packOperationResults(TypeRange types) const { |
690 | assert(!types.empty() && "expected non-empty list of type"); |
691 | if (types.size() == 1) |
692 | return convertType(t: types[0]); |
693 | |
694 | SmallVector<Type> resultTypes; |
695 | resultTypes.reserve(N: types.size()); |
696 | for (Type type : types) { |
697 | Type converted = convertType(t: type); |
698 | if (!converted || !LLVM::isCompatibleType(type: converted)) |
699 | return {}; |
700 | resultTypes.push_back(Elt: converted); |
701 | } |
702 | |
703 | return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes); |
704 | } |
705 | |
706 | /// Convert a non-empty list of types to be returned from a function into an |
707 | /// LLVM-compatible type. In particular, if more than one value is returned, |
708 | /// create an LLVM dialect structure type with elements that correspond to each |
709 | /// of the types converted with `convertCallingConventionType`. |
710 | Type LLVMTypeConverter::packFunctionResults(TypeRange types, |
711 | bool useBarePtrCallConv) const { |
712 | assert(!types.empty() && "expected non-empty list of type"); |
713 | |
714 | useBarePtrCallConv |= options.useBarePtrCallConv; |
715 | if (types.size() == 1) |
716 | return convertCallingConventionType(type: types.front(), useBarePtrCallConv); |
717 | |
718 | SmallVector<Type> resultTypes; |
719 | resultTypes.reserve(N: types.size()); |
720 | for (auto t : types) { |
721 | auto converted = convertCallingConventionType(type: t, useBarePtrCallConv); |
722 | if (!converted || !LLVM::isCompatibleType(type: converted)) |
723 | return {}; |
724 | resultTypes.push_back(Elt: converted); |
725 | } |
726 | |
727 | return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes); |
728 | } |
729 | |
730 | Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, |
731 | OpBuilder &builder) const { |
732 | // Alloca with proper alignment. We do not expect optimizations of this |
733 | // alloca op and so we omit allocating at the entry block. |
734 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
735 | Value one = builder.create<LLVM::ConstantOp>(loc, builder.getI64Type(), |
736 | builder.getIndexAttr(1)); |
737 | Value allocated = |
738 | builder.create<LLVM::AllocaOp>(loc, ptrType, operand.getType(), one); |
739 | // Store into the alloca'ed descriptor. |
740 | builder.create<LLVM::StoreOp>(loc, operand, allocated); |
741 | return allocated; |
742 | } |
743 | |
744 | SmallVector<Value, 4> |
745 | LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, |
746 | ValueRange operands, OpBuilder &builder, |
747 | bool useBarePtrCallConv) const { |
748 | SmallVector<Value, 4> promotedOperands; |
749 | promotedOperands.reserve(N: operands.size()); |
750 | useBarePtrCallConv |= options.useBarePtrCallConv; |
751 | for (auto it : llvm::zip(t&: opOperands, u&: operands)) { |
752 | auto operand = std::get<0>(t&: it); |
753 | auto llvmOperand = std::get<1>(t&: it); |
754 | |
755 | if (useBarePtrCallConv) { |
756 | // For the bare-ptr calling convention, we only have to extract the |
757 | // aligned pointer of a memref. |
758 | if (isa<MemRefType>(Val: operand.getType())) { |
759 | MemRefDescriptor desc(llvmOperand); |
760 | llvmOperand = desc.alignedPtr(builder, loc); |
761 | } else if (isa<UnrankedMemRefType>(Val: operand.getType())) { |
762 | llvm_unreachable("Unranked memrefs are not supported"); |
763 | } |
764 | } else { |
765 | if (isa<UnrankedMemRefType>(Val: operand.getType())) { |
766 | UnrankedMemRefDescriptor::unpack(builder, loc, packed: llvmOperand, |
767 | results&: promotedOperands); |
768 | continue; |
769 | } |
770 | if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) { |
771 | MemRefDescriptor::unpack(builder, loc, packed: llvmOperand, type: memrefType, |
772 | results&: promotedOperands); |
773 | continue; |
774 | } |
775 | } |
776 | |
777 | promotedOperands.push_back(Elt: llvmOperand); |
778 | } |
779 | return promotedOperands; |
780 | } |
781 | |
782 | /// Callback to convert function argument types. It converts a MemRef function |
783 | /// argument to a list of non-aggregate types containing descriptor |
784 | /// information, and an UnrankedmemRef function argument to a list containing |
785 | /// the rank and a pointer to a descriptor struct. |
786 | LogicalResult |
787 | mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, |
788 | SmallVectorImpl<Type> &result) { |
789 | if (auto memref = dyn_cast<MemRefType>(type)) { |
790 | // In signatures, Memref descriptors are expanded into lists of |
791 | // non-aggregate values. |
792 | auto converted = |
793 | converter.getMemRefDescriptorFields(type: memref, /*unpackAggregates=*/true); |
794 | if (converted.empty()) |
795 | return failure(); |
796 | result.append(converted.begin(), converted.end()); |
797 | return success(); |
798 | } |
799 | if (isa<UnrankedMemRefType>(Val: type)) { |
800 | auto converted = converter.getUnrankedMemRefDescriptorFields(); |
801 | if (converted.empty()) |
802 | return failure(); |
803 | result.append(in_start: converted.begin(), in_end: converted.end()); |
804 | return success(); |
805 | } |
806 | auto converted = converter.convertType(t: type); |
807 | if (!converted) |
808 | return failure(); |
809 | result.push_back(Elt: converted); |
810 | return success(); |
811 | } |
812 | |
813 | /// Callback to convert function argument types. It converts MemRef function |
814 | /// arguments to bare pointers to the MemRef element type. |
815 | LogicalResult |
816 | mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, |
817 | SmallVectorImpl<Type> &result) { |
818 | auto llvmTy = converter.convertCallingConventionType( |
819 | type, /*useBarePointerCallConv=*/useBarePtrCallConv: true); |
820 | if (!llvmTy) |
821 | return failure(); |
822 | |
823 | result.push_back(Elt: llvmTy); |
824 | return success(); |
825 | } |
826 |
Definitions
- getCurrentThreadRecursiveStack
- LLVMTypeConverter
- isBarePointer
- packUnrankedMemRefDesc
- packRankedMemRefDesc
- unrankedMemRefMaterialization
- rankedMemRefMaterialization
- LLVMTypeConverter
- getContext
- getIndexType
- getPointerBitwidth
- convertIndexType
- convertIntegerType
- convertFloatType
- convertComplexType
- convertFunctionType
- filterByValRefArgAttrs
- convertFunctionSignatureImpl
- convertFunctionSignature
- convertFunctionSignature
- convertFunctionTypeCWrapper
- getMemRefDescriptorFields
- getMemRefDescriptorSize
- convertMemRefType
- getUnrankedMemRefDescriptorFields
- getUnrankedMemRefDescriptorSize
- convertUnrankedMemRefType
- getMemRefAddressSpace
- canConvertToBarePtr
- convertMemRefToBarePtr
- convertVectorType
- convertCallingConventionType
- promoteBarePtrsToDescriptors
- packOperationResults
- packFunctionResults
- promoteOneMemRefDescriptor
- promoteOperands
- structFuncArgTypeConverter
Learn to use CMake with our Intro Training
Find out more