| 1 | //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
| 10 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 11 | |
| 12 | using namespace mlir; |
| 13 | |
| 14 | // For >1-D vector types, extracts the necessary information to iterate over all |
| 15 | // 1-D subvectors in the underlying llrepresentation of the n-D vector |
| 16 | // Iterates on the llvm array type until we hit a non-array type (which is |
| 17 | // asserted to be an llvm vector type). |
| 18 | LLVM::detail::NDVectorTypeInfo |
| 19 | LLVM::detail::(VectorType vectorType, |
| 20 | const LLVMTypeConverter &converter) { |
| 21 | assert(vectorType.getRank() > 1 && "expected >1D vector type" ); |
| 22 | NDVectorTypeInfo info; |
| 23 | info.llvmNDVectorTy = converter.convertType(vectorType); |
| 24 | if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(type: info.llvmNDVectorTy)) { |
| 25 | info.llvmNDVectorTy = nullptr; |
| 26 | return info; |
| 27 | } |
| 28 | info.arraySizes.reserve(N: vectorType.getRank() - 1); |
| 29 | auto llvmTy = info.llvmNDVectorTy; |
| 30 | while (isa<LLVM::LLVMArrayType>(llvmTy)) { |
| 31 | info.arraySizes.push_back( |
| 32 | cast<LLVM::LLVMArrayType>(llvmTy).getNumElements()); |
| 33 | llvmTy = cast<LLVM::LLVMArrayType>(llvmTy).getElementType(); |
| 34 | } |
| 35 | if (!LLVM::isCompatibleVectorType(type: llvmTy)) |
| 36 | return info; |
| 37 | info.llvm1DVectorTy = llvmTy; |
| 38 | return info; |
| 39 | } |
| 40 | |
| 41 | // Express `linearIndex` in terms of coordinates of `basis`. |
| 42 | // Returns the empty vector when linearIndex is out of the range [0, P] where |
| 43 | // P is the product of all the basis coordinates. |
| 44 | // |
| 45 | // Prerequisites: |
| 46 | // Basis is an array of nonnegative integers (signed type inherited from |
| 47 | // vector shape type). |
| 48 | SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis, |
| 49 | unsigned linearIndex) { |
| 50 | SmallVector<int64_t, 4> res; |
| 51 | res.reserve(N: basis.size()); |
| 52 | for (unsigned basisElement : llvm::reverse(C&: basis)) { |
| 53 | res.push_back(Elt: linearIndex % basisElement); |
| 54 | linearIndex = linearIndex / basisElement; |
| 55 | } |
| 56 | if (linearIndex > 0) |
| 57 | return {}; |
| 58 | std::reverse(first: res.begin(), last: res.end()); |
| 59 | return res; |
| 60 | } |
| 61 | |
| 62 | // Iterate of linear index, convert to coords space and insert splatted 1-D |
| 63 | // vector in each position. |
| 64 | void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info, |
| 65 | OpBuilder &builder, |
| 66 | function_ref<void(ArrayRef<int64_t>)> fun) { |
| 67 | unsigned ub = 1; |
| 68 | for (auto s : info.arraySizes) |
| 69 | ub *= s; |
| 70 | for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { |
| 71 | auto coords = getCoordinates(basis: info.arraySizes, linearIndex); |
| 72 | // Linear index is out of bounds, we are done. |
| 73 | if (coords.empty()) |
| 74 | break; |
| 75 | assert(coords.size() == info.arraySizes.size()); |
| 76 | fun(coords); |
| 77 | } |
| 78 | } |
| 79 | |
| 80 | LogicalResult LLVM::detail::handleMultidimensionalVectors( |
| 81 | Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, |
| 82 | std::function<Value(Type, ValueRange)> createOperand, |
| 83 | ConversionPatternRewriter &rewriter) { |
| 84 | auto resultNDVectorType = cast<VectorType>(op->getResult(idx: 0).getType()); |
| 85 | auto resultTypeInfo = |
| 86 | extractNDVectorTypeInfo(resultNDVectorType, typeConverter); |
| 87 | auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; |
| 88 | auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; |
| 89 | auto loc = op->getLoc(); |
| 90 | Value desc = rewriter.create<LLVM::PoisonOp>(loc, resultNDVectoryTy); |
| 91 | nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { |
| 92 | // For this unrolled `position` corresponding to the `linearIndex`^th |
| 93 | // element, extract operand vectors |
| 94 | SmallVector<Value, 4> extractedOperands; |
| 95 | for (const auto &operand : llvm::enumerate(First&: operands)) { |
| 96 | extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( |
| 97 | loc, operand.value(), position)); |
| 98 | } |
| 99 | Value newVal = createOperand(result1DVectorTy, extractedOperands); |
| 100 | desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position); |
| 101 | }); |
| 102 | rewriter.replaceOp(op, newValues: desc); |
| 103 | return success(); |
| 104 | } |
| 105 | |
| 106 | LogicalResult LLVM::detail::vectorOneToOneRewrite( |
| 107 | Operation *op, StringRef targetOp, ValueRange operands, |
| 108 | ArrayRef<NamedAttribute> targetAttrs, |
| 109 | const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, |
| 110 | IntegerOverflowFlags overflowFlags) { |
| 111 | assert(!operands.empty()); |
| 112 | |
| 113 | // Cannot convert ops if their operands are not of LLVM type. |
| 114 | if (!llvm::all_of(Range: operands.getTypes(), P: isCompatibleType)) |
| 115 | return failure(); |
| 116 | |
| 117 | auto llvmNDVectorTy = operands[0].getType(); |
| 118 | if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy)) |
| 119 | return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, |
| 120 | rewriter, overflowFlags); |
| 121 | |
| 122 | auto callback = [op, targetOp, targetAttrs, overflowFlags, |
| 123 | &rewriter](Type llvm1DVectorTy, ValueRange operands) { |
| 124 | Operation *newOp = |
| 125 | rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), |
| 126 | operands, llvm1DVectorTy, targetAttrs); |
| 127 | LLVM::detail::setNativeProperties(newOp, overflowFlags); |
| 128 | return newOp->getResult(0); |
| 129 | }; |
| 130 | |
| 131 | return handleMultidimensionalVectors(op, operands, typeConverter, callback, |
| 132 | rewriter); |
| 133 | } |
| 134 | |