1//===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/X86Vector/Transforms.h"
10
11#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12#include "mlir/Conversion/LLVMCommon/Pattern.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
16#include "mlir/IR/BuiltinOps.h"
17#include "mlir/IR/PatternMatch.h"
18
19using namespace mlir;
20using namespace mlir::x86vector;
21
22/// Extracts the "main" vector element type from the given X86Vector operation.
23template <typename OpTy>
24static Type getSrcVectorElementType(OpTy op) {
25 return cast<VectorType>(op.getSrc().getType()).getElementType();
26}
27template <>
28Type getSrcVectorElementType(Vp2IntersectOp op) {
29 return cast<VectorType>(op.getA().getType()).getElementType();
30}
31
32namespace {
33
34/// Base conversion for AVX512 ops that can be lowered to one of the two
35/// intrinsics based on the bitwidth of their "main" vector element type. This
36/// relies on the to-LLVM-dialect conversion helpers to correctly pack the
37/// results of multi-result intrinsic ops.
38template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
39struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
40 explicit LowerToIntrinsic(LLVMTypeConverter &converter)
41 : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
42
43 const LLVMTypeConverter &getTypeConverter() const {
44 return *static_cast<const LLVMTypeConverter *>(
45 OpConversionPattern<OpTy>::getTypeConverter());
46 }
47
48 LogicalResult
49 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
50 ConversionPatternRewriter &rewriter) const override {
51 Type elementType = getSrcVectorElementType<OpTy>(op);
52 unsigned bitwidth = elementType.getIntOrFloatBitWidth();
53 if (bitwidth == 32)
54 return LLVM::detail::oneToOneRewrite(
55 op, Intr32OpTy::getOperationName(), adaptor.getOperands(),
56 op->getAttrs(), getTypeConverter(), rewriter);
57 if (bitwidth == 64)
58 return LLVM::detail::oneToOneRewrite(
59 op, Intr64OpTy::getOperationName(), adaptor.getOperands(),
60 op->getAttrs(), getTypeConverter(), rewriter);
61 return rewriter.notifyMatchFailure(
62 op, "expected 'src' to be either f32 or f64");
63 }
64};
65
66struct MaskCompressOpConversion
67 : public ConvertOpToLLVMPattern<MaskCompressOp> {
68 using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
69
70 LogicalResult
71 matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter) const override {
73 auto opType = adaptor.getA().getType();
74
75 Value src;
76 if (op.getSrc()) {
77 src = adaptor.getSrc();
78 } else if (op.getConstantSrc()) {
79 src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
80 op.getConstantSrcAttr());
81 } else {
82 auto zeroAttr = rewriter.getZeroAttr(type: opType);
83 src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
84 }
85
86 rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(),
87 src, adaptor.getK());
88
89 return success();
90 }
91};
92
93struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
94 using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;
95
96 LogicalResult
97 matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
98 ConversionPatternRewriter &rewriter) const override {
99 auto opType = adaptor.getA().getType();
100 rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA());
101 return success();
102 }
103};
104
105struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
106 using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;
107
108 LogicalResult
109 matchAndRewrite(DotOp op, OpAdaptor adaptor,
110 ConversionPatternRewriter &rewriter) const override {
111 auto opType = adaptor.getA().getType();
112 Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
113 // Dot product of all elements, broadcasted to all elements.
114 auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
115 Value scale =
116 rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
117 rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(),
118 adaptor.getB(), scale);
119 return success();
120 }
121};
122
123/// An entry associating the "main" AVX512 op with its instantiations for
124/// vectors of 32-bit and 64-bit elements.
125template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
126struct RegEntry {
127 using MainOp = OpTy;
128 using Intr32Op = Intr32OpTy;
129 using Intr64Op = Intr64OpTy;
130};
131
132/// A container for op association entries facilitating the configuration of
133/// dialect conversion.
134template <typename... Args>
135struct RegistryImpl {
136 /// Registers the patterns specializing the "main" op to one of the
137 /// "intrinsic" ops depending on elemental type.
138 static void registerPatterns(LLVMTypeConverter &converter,
139 RewritePatternSet &patterns) {
140 patterns
141 .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
142 typename Args::Intr64Op>...>(converter);
143 }
144
145 /// Configures the conversion target to lower out "main" ops.
146 static void configureTarget(LLVMConversionTarget &target) {
147 target.addIllegalOp<typename Args::MainOp...>();
148 target.addLegalOp<typename Args::Intr32Op...>();
149 target.addLegalOp<typename Args::Intr64Op...>();
150 }
151};
152
153using Registry = RegistryImpl<
154 RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
155 RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
156 RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
157
158} // namespace
159
160/// Populate the given list with patterns that convert from X86Vector to LLVM.
161void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
162 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
163 Registry::registerPatterns(converter, patterns);
164 patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
165 arg&: converter);
166}
167
168void mlir::configureX86VectorLegalizeForExportTarget(
169 LLVMConversionTarget &target) {
170 Registry::configureTarget(target);
171 target.addLegalOp<MaskCompressIntrOp>();
172 target.addIllegalOp<MaskCompressOp>();
173 target.addLegalOp<RsqrtIntrOp>();
174 target.addIllegalOp<RsqrtOp>();
175 target.addLegalOp<DotIntrOp>();
176 target.addIllegalOp<DotOp>();
177}
178

source code of mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp