1 | //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/X86Vector/Transforms.h" |
10 | |
11 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
12 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
15 | #include "mlir/Dialect/X86Vector/X86VectorDialect.h" |
16 | #include "mlir/IR/BuiltinOps.h" |
17 | #include "mlir/IR/PatternMatch.h" |
18 | |
19 | using namespace mlir; |
20 | using namespace mlir::x86vector; |
21 | |
22 | /// Extracts the "main" vector element type from the given X86Vector operation. |
23 | template <typename OpTy> |
24 | static Type getSrcVectorElementType(OpTy op) { |
25 | return cast<VectorType>(op.getSrc().getType()).getElementType(); |
26 | } |
27 | template <> |
28 | Type getSrcVectorElementType(Vp2IntersectOp op) { |
29 | return cast<VectorType>(op.getA().getType()).getElementType(); |
30 | } |
31 | |
32 | namespace { |
33 | |
34 | /// Base conversion for AVX512 ops that can be lowered to one of the two |
35 | /// intrinsics based on the bitwidth of their "main" vector element type. This |
36 | /// relies on the to-LLVM-dialect conversion helpers to correctly pack the |
37 | /// results of multi-result intrinsic ops. |
38 | template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> |
39 | struct LowerToIntrinsic : public OpConversionPattern<OpTy> { |
40 | explicit LowerToIntrinsic(LLVMTypeConverter &converter) |
41 | : OpConversionPattern<OpTy>(converter, &converter.getContext()) {} |
42 | |
43 | const LLVMTypeConverter &getTypeConverter() const { |
44 | return *static_cast<const LLVMTypeConverter *>( |
45 | OpConversionPattern<OpTy>::getTypeConverter()); |
46 | } |
47 | |
48 | LogicalResult |
49 | matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, |
50 | ConversionPatternRewriter &rewriter) const override { |
51 | Type elementType = getSrcVectorElementType<OpTy>(op); |
52 | unsigned bitwidth = elementType.getIntOrFloatBitWidth(); |
53 | if (bitwidth == 32) |
54 | return LLVM::detail::oneToOneRewrite( |
55 | op, Intr32OpTy::getOperationName(), adaptor.getOperands(), |
56 | op->getAttrs(), getTypeConverter(), rewriter); |
57 | if (bitwidth == 64) |
58 | return LLVM::detail::oneToOneRewrite( |
59 | op, Intr64OpTy::getOperationName(), adaptor.getOperands(), |
60 | op->getAttrs(), getTypeConverter(), rewriter); |
61 | return rewriter.notifyMatchFailure( |
62 | op, "expected 'src' to be either f32 or f64" ); |
63 | } |
64 | }; |
65 | |
66 | struct MaskCompressOpConversion |
67 | : public ConvertOpToLLVMPattern<MaskCompressOp> { |
68 | using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern; |
69 | |
70 | LogicalResult |
71 | matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor, |
72 | ConversionPatternRewriter &rewriter) const override { |
73 | auto opType = adaptor.getA().getType(); |
74 | |
75 | Value src; |
76 | if (op.getSrc()) { |
77 | src = adaptor.getSrc(); |
78 | } else if (op.getConstantSrc()) { |
79 | src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType, |
80 | op.getConstantSrcAttr()); |
81 | } else { |
82 | auto zeroAttr = rewriter.getZeroAttr(type: opType); |
83 | src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr); |
84 | } |
85 | |
86 | rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(), |
87 | src, adaptor.getK()); |
88 | |
89 | return success(); |
90 | } |
91 | }; |
92 | |
93 | struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> { |
94 | using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern; |
95 | |
96 | LogicalResult |
97 | matchAndRewrite(RsqrtOp op, OpAdaptor adaptor, |
98 | ConversionPatternRewriter &rewriter) const override { |
99 | auto opType = adaptor.getA().getType(); |
100 | rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA()); |
101 | return success(); |
102 | } |
103 | }; |
104 | |
105 | struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> { |
106 | using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern; |
107 | |
108 | LogicalResult |
109 | matchAndRewrite(DotOp op, OpAdaptor adaptor, |
110 | ConversionPatternRewriter &rewriter) const override { |
111 | auto opType = adaptor.getA().getType(); |
112 | Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8); |
113 | // Dot product of all elements, broadcasted to all elements. |
114 | auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff)); |
115 | Value scale = |
116 | rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr); |
117 | rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(), |
118 | adaptor.getB(), scale); |
119 | return success(); |
120 | } |
121 | }; |
122 | |
123 | /// An entry associating the "main" AVX512 op with its instantiations for |
124 | /// vectors of 32-bit and 64-bit elements. |
125 | template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy> |
126 | struct RegEntry { |
127 | using MainOp = OpTy; |
128 | using Intr32Op = Intr32OpTy; |
129 | using Intr64Op = Intr64OpTy; |
130 | }; |
131 | |
132 | /// A container for op association entries facilitating the configuration of |
133 | /// dialect conversion. |
134 | template <typename... Args> |
135 | struct RegistryImpl { |
136 | /// Registers the patterns specializing the "main" op to one of the |
137 | /// "intrinsic" ops depending on elemental type. |
138 | static void registerPatterns(LLVMTypeConverter &converter, |
139 | RewritePatternSet &patterns) { |
140 | patterns |
141 | .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op, |
142 | typename Args::Intr64Op>...>(converter); |
143 | } |
144 | |
145 | /// Configures the conversion target to lower out "main" ops. |
146 | static void configureTarget(LLVMConversionTarget &target) { |
147 | target.addIllegalOp<typename Args::MainOp...>(); |
148 | target.addLegalOp<typename Args::Intr32Op...>(); |
149 | target.addLegalOp<typename Args::Intr64Op...>(); |
150 | } |
151 | }; |
152 | |
153 | using Registry = RegistryImpl< |
154 | RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>, |
155 | RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>, |
156 | RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>; |
157 | |
158 | } // namespace |
159 | |
160 | /// Populate the given list with patterns that convert from X86Vector to LLVM. |
161 | void mlir::populateX86VectorLegalizeForLLVMExportPatterns( |
162 | LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
163 | Registry::registerPatterns(converter, patterns); |
164 | patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>( |
165 | arg&: converter); |
166 | } |
167 | |
168 | void mlir::configureX86VectorLegalizeForExportTarget( |
169 | LLVMConversionTarget &target) { |
170 | Registry::configureTarget(target); |
171 | target.addLegalOp<MaskCompressIntrOp>(); |
172 | target.addIllegalOp<MaskCompressOp>(); |
173 | target.addLegalOp<RsqrtIntrOp>(); |
174 | target.addIllegalOp<RsqrtOp>(); |
175 | target.addLegalOp<DotIntrOp>(); |
176 | target.addIllegalOp<DotOp>(); |
177 | } |
178 | |