1 | //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
10 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
11 | |
12 | using namespace mlir; |
13 | |
14 | // For >1-D vector types, extracts the necessary information to iterate over all |
15 | // 1-D subvectors in the underlying llrepresentation of the n-D vector |
16 | // Iterates on the llvm array type until we hit a non-array type (which is |
17 | // asserted to be an llvm vector type). |
18 | LLVM::detail::NDVectorTypeInfo |
19 | LLVM::detail::(VectorType vectorType, |
20 | const LLVMTypeConverter &converter) { |
21 | assert(vectorType.getRank() > 1 && "expected >1D vector type" ); |
22 | NDVectorTypeInfo info; |
23 | info.llvmNDVectorTy = converter.convertType(vectorType); |
24 | if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(type: info.llvmNDVectorTy)) { |
25 | info.llvmNDVectorTy = nullptr; |
26 | return info; |
27 | } |
28 | info.arraySizes.reserve(N: vectorType.getRank() - 1); |
29 | auto llvmTy = info.llvmNDVectorTy; |
30 | while (isa<LLVM::LLVMArrayType>(llvmTy)) { |
31 | info.arraySizes.push_back( |
32 | cast<LLVM::LLVMArrayType>(llvmTy).getNumElements()); |
33 | llvmTy = cast<LLVM::LLVMArrayType>(llvmTy).getElementType(); |
34 | } |
35 | if (!LLVM::isCompatibleVectorType(type: llvmTy)) |
36 | return info; |
37 | info.llvm1DVectorTy = llvmTy; |
38 | return info; |
39 | } |
40 | |
41 | // Express `linearIndex` in terms of coordinates of `basis`. |
42 | // Returns the empty vector when linearIndex is out of the range [0, P] where |
43 | // P is the product of all the basis coordinates. |
44 | // |
45 | // Prerequisites: |
46 | // Basis is an array of nonnegative integers (signed type inherited from |
47 | // vector shape type). |
48 | SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis, |
49 | unsigned linearIndex) { |
50 | SmallVector<int64_t, 4> res; |
51 | res.reserve(N: basis.size()); |
52 | for (unsigned basisElement : llvm::reverse(C&: basis)) { |
53 | res.push_back(Elt: linearIndex % basisElement); |
54 | linearIndex = linearIndex / basisElement; |
55 | } |
56 | if (linearIndex > 0) |
57 | return {}; |
58 | std::reverse(first: res.begin(), last: res.end()); |
59 | return res; |
60 | } |
61 | |
62 | // Iterate of linear index, convert to coords space and insert splatted 1-D |
63 | // vector in each position. |
64 | void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info, |
65 | OpBuilder &builder, |
66 | function_ref<void(ArrayRef<int64_t>)> fun) { |
67 | unsigned ub = 1; |
68 | for (auto s : info.arraySizes) |
69 | ub *= s; |
70 | for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { |
71 | auto coords = getCoordinates(basis: info.arraySizes, linearIndex); |
72 | // Linear index is out of bounds, we are done. |
73 | if (coords.empty()) |
74 | break; |
75 | assert(coords.size() == info.arraySizes.size()); |
76 | fun(coords); |
77 | } |
78 | } |
79 | |
80 | LogicalResult LLVM::detail::handleMultidimensionalVectors( |
81 | Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, |
82 | std::function<Value(Type, ValueRange)> createOperand, |
83 | ConversionPatternRewriter &rewriter) { |
84 | auto resultNDVectorType = cast<VectorType>(op->getResult(idx: 0).getType()); |
85 | auto resultTypeInfo = |
86 | extractNDVectorTypeInfo(resultNDVectorType, typeConverter); |
87 | auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; |
88 | auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; |
89 | auto loc = op->getLoc(); |
90 | Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy); |
91 | nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { |
92 | // For this unrolled `position` corresponding to the `linearIndex`^th |
93 | // element, extract operand vectors |
94 | SmallVector<Value, 4> extractedOperands; |
95 | for (const auto &operand : llvm::enumerate(First&: operands)) { |
96 | extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( |
97 | loc, operand.value(), position)); |
98 | } |
99 | Value newVal = createOperand(result1DVectorTy, extractedOperands); |
100 | desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position); |
101 | }); |
102 | rewriter.replaceOp(op, newValues: desc); |
103 | return success(); |
104 | } |
105 | |
106 | LogicalResult LLVM::detail::vectorOneToOneRewrite( |
107 | Operation *op, StringRef targetOp, ValueRange operands, |
108 | ArrayRef<NamedAttribute> targetAttrs, |
109 | const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, |
110 | IntegerOverflowFlags overflowFlags) { |
111 | assert(!operands.empty()); |
112 | |
113 | // Cannot convert ops if their operands are not of LLVM type. |
114 | if (!llvm::all_of(Range: operands.getTypes(), P: isCompatibleType)) |
115 | return failure(); |
116 | |
117 | auto llvmNDVectorTy = operands[0].getType(); |
118 | if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) |
119 | return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, |
120 | rewriter, overflowFlags); |
121 | |
122 | auto callback = [op, targetOp, targetAttrs, overflowFlags, |
123 | &rewriter](Type llvm1DVectorTy, ValueRange operands) { |
124 | Operation *newOp = |
125 | rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), |
126 | operands, llvm1DVectorTy, targetAttrs); |
127 | LLVM::detail::setNativeProperties(newOp, overflowFlags); |
128 | return newOp->getResult(0); |
129 | }; |
130 | |
131 | return handleMultidimensionalVectors(op, operands, typeConverter, callback, |
132 | rewriter); |
133 | } |
134 | |