1//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
10
11#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13#include "mlir/Dialect/AMX/AMXDialect.h"
14#include "mlir/Dialect/AMX/Transforms.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
17#include "mlir/Dialect/ArmNeon/Transforms.h"
18#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
19#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
20#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/Tensor/IR/Tensor.h"
23#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
24#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
25#include "mlir/Dialect/X86Vector/Transforms.h"
26#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29
30namespace mlir {
31#define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS
32#include "mlir/Conversion/Passes.h.inc"
33} // namespace mlir
34
35using namespace mlir;
36using namespace mlir::vector;
37
38namespace {
39struct ConvertVectorToLLVMPass
40 : public impl::ConvertVectorToLLVMPassBase<ConvertVectorToLLVMPass> {
41
42 using Base::Base;
43
44 // Override explicitly to allow conditional dialect dependence.
45 void getDependentDialects(DialectRegistry &registry) const override {
46 registry.insert<LLVM::LLVMDialect>();
47 registry.insert<arith::ArithDialect>();
48 registry.insert<memref::MemRefDialect>();
49 registry.insert<tensor::TensorDialect>();
50 if (armNeon)
51 registry.insert<arm_neon::ArmNeonDialect>();
52 if (armSVE)
53 registry.insert<arm_sve::ArmSVEDialect>();
54 if (amx)
55 registry.insert<amx::AMXDialect>();
56 if (x86Vector)
57 registry.insert<x86vector::X86VectorDialect>();
58 }
59 void runOnOperation() override;
60};
61} // namespace
62
63void ConvertVectorToLLVMPass::runOnOperation() {
64 // Perform progressive lowering of operations on slices and all contraction
65 // operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
66 // applies folding and DCE.
67 {
68 RewritePatternSet patterns(&getContext());
69 populateVectorToVectorCanonicalizationPatterns(patterns);
70 populateVectorBitCastLoweringPatterns(patterns);
71 populateVectorBroadcastLoweringPatterns(patterns);
72 populateVectorContractLoweringPatterns(patterns, vectorContractLoweringOption: vectorContractLowering);
73 populateVectorMaskOpLoweringPatterns(patterns);
74 populateVectorShapeCastLoweringPatterns(patterns);
75 populateVectorInterleaveLoweringPatterns(patterns);
76 populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
77 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
78 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
79 populateVectorMaskMaterializationPatterns(patterns,
80 force32BitVectorIndices);
81 populateVectorInsertExtractStridedSliceTransforms(patterns);
82 populateVectorStepLoweringPatterns(patterns);
83 populateVectorRankReducingFMAPattern(patterns);
84 populateVectorGatherLoweringPatterns(patterns);
85 if (armI8MM) {
86 if (armNeon)
87 arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
88 if (armSVE)
89 populateLowerContractionToSVEI8MMPatternPatterns(patterns);
90 }
91 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
92 }
93
94 // Convert to the LLVM IR dialect.
95 LowerToLLVMOptions options(&getContext());
96 LLVMTypeConverter converter(&getContext(), options);
97 RewritePatternSet patterns(&getContext());
98 populateVectorTransferLoweringPatterns(patterns);
99 populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
100 populateVectorToLLVMConversionPatterns(
101 converter, patterns, reassociateFPReductions, force32BitVectorIndices,
102 useVectorAlignment);
103 populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
104
105 // Architecture specific augmentations.
106 LLVMConversionTarget target(getContext());
107 target.addLegalDialect<arith::ArithDialect>();
108 target.addLegalDialect<memref::MemRefDialect>();
109 target.addLegalOp<UnrealizedConversionCastOp>();
110
111 if (armNeon) {
112 // TODO: we may or may not want to include in-dialect lowering to
113 // LLVM-compatible operations here. So far, all operations in the dialect
114 // can be translated to LLVM IR so there is no conversion necessary.
115 target.addLegalDialect<arm_neon::ArmNeonDialect>();
116 }
117 if (armSVE) {
118 configureArmSVELegalizeForExportTarget(target);
119 populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
120 }
121 if (amx) {
122 configureAMXLegalizeForExportTarget(target);
123 populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
124 }
125 if (x86Vector) {
126 configureX86VectorLegalizeForExportTarget(target);
127 populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns);
128 }
129
130 if (failed(
131 Result: applyPartialConversion(op: getOperation(), target, patterns: std::move(patterns))))
132 signalPassFailure();
133}
134

source code of mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp