1//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h"
10
11#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13#include "mlir/Dialect/AMX/AMXDialect.h"
14#include "mlir/Dialect/AMX/Transforms.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
17#include "mlir/Dialect/ArmNeon/Transforms.h"
18#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
19#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/Tensor/IR/Tensor.h"
24#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
25#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
26#include "mlir/Dialect/X86Vector/Transforms.h"
27#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
28#include "mlir/Pass/Pass.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30
31namespace mlir {
32#define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS
33#include "mlir/Conversion/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37using namespace mlir::vector;
38
39namespace {
40struct ConvertVectorToLLVMPass
41 : public impl::ConvertVectorToLLVMPassBase<ConvertVectorToLLVMPass> {
42
43 using Base::Base;
44
45 // Override explicitly to allow conditional dialect dependence.
46 void getDependentDialects(DialectRegistry &registry) const override {
47 registry.insert<LLVM::LLVMDialect>();
48 registry.insert<arith::ArithDialect>();
49 registry.insert<memref::MemRefDialect>();
50 registry.insert<tensor::TensorDialect>();
51 if (armNeon)
52 registry.insert<arm_neon::ArmNeonDialect>();
53 if (armSVE)
54 registry.insert<arm_sve::ArmSVEDialect>();
55 if (amx)
56 registry.insert<amx::AMXDialect>();
57 if (x86Vector)
58 registry.insert<x86vector::X86VectorDialect>();
59 }
60 void runOnOperation() override;
61};
62} // namespace
63
64void ConvertVectorToLLVMPass::runOnOperation() {
65 // Perform progressive lowering of operations on slices and all contraction
66 // operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
67 // applies folding and DCE.
68 {
69 RewritePatternSet patterns(&getContext());
70 populateVectorToVectorCanonicalizationPatterns(patterns);
71 populateVectorBitCastLoweringPatterns(patterns);
72 populateVectorBroadcastLoweringPatterns(patterns);
73 populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
74 populateVectorMaskOpLoweringPatterns(patterns);
75 populateVectorShapeCastLoweringPatterns(patterns);
76 populateVectorInterleaveLoweringPatterns(patterns);
77 populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
78 // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
79 populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
80 populateVectorMaskMaterializationPatterns(patterns,
81 force32BitVectorIndices);
82 populateVectorInsertExtractStridedSliceTransforms(patterns);
83 populateVectorStepLoweringPatterns(patterns);
84 populateVectorRankReducingFMAPattern(patterns);
85 populateVectorGatherLoweringPatterns(patterns);
86 if (armI8MM) {
87 if (armNeon)
88 arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
89 if (armSVE)
90 populateLowerContractionToSVEI8MMPatternPatterns(patterns);
91 }
92 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
93 }
94
95 // Convert to the LLVM IR dialect.
96 LowerToLLVMOptions options(&getContext());
97 LLVMTypeConverter converter(&getContext(), options);
98 RewritePatternSet patterns(&getContext());
99 populateVectorTransferLoweringPatterns(patterns);
100 populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
101 populateVectorToLLVMConversionPatterns(
102 converter, patterns, reassociateFPReductions, force32BitVectorIndices,
103 useVectorAlignment);
104 populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
105
106 // Architecture specific augmentations.
107 LLVMConversionTarget target(getContext());
108 target.addLegalDialect<arith::ArithDialect>();
109 target.addLegalDialect<memref::MemRefDialect>();
110 target.addLegalOp<UnrealizedConversionCastOp>();
111
112 if (armNeon) {
113 // TODO: we may or may not want to include in-dialect lowering to
114 // LLVM-compatible operations here. So far, all operations in the dialect
115 // can be translated to LLVM IR so there is no conversion necessary.
116 target.addLegalDialect<arm_neon::ArmNeonDialect>();
117 }
118 if (armSVE) {
119 configureArmSVELegalizeForExportTarget(target);
120 populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
121 }
122 if (amx) {
123 configureAMXLegalizeForExportTarget(target);
124 populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
125 }
126 if (x86Vector) {
127 configureX86VectorLegalizeForExportTarget(target);
128 populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns);
129 }
130
131 if (failed(
132 applyPartialConversion(getOperation(), target, std::move(patterns))))
133 signalPassFailure();
134}
135

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp