1 | //===- X86VectorDialect.cpp - MLIR X86Vector ops implementation -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the X86Vector dialect and its operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/X86Vector/X86VectorDialect.h" |
14 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
15 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
16 | #include "mlir/IR/Builders.h" |
17 | #include "mlir/IR/OpImplementation.h" |
18 | #include "mlir/IR/TypeUtilities.h" |
19 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
20 | |
21 | using namespace mlir; |
22 | |
23 | #include "mlir/Dialect/X86Vector/X86VectorInterfaces.cpp.inc" |
24 | |
25 | #include "mlir/Dialect/X86Vector/X86VectorDialect.cpp.inc" |
26 | |
27 | void x86vector::X86VectorDialect::initialize() { |
28 | addOperations< |
29 | #define GET_OP_LIST |
30 | #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc" |
31 | >(); |
32 | } |
33 | |
34 | static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer, |
35 | const LLVMTypeConverter &typeConverter, |
36 | RewriterBase &rewriter) { |
37 | MemRefDescriptor memRefDescriptor(buffer); |
38 | return memRefDescriptor.bufferPtr(builder&: rewriter, loc, converter: typeConverter, type: type); |
39 | } |
40 | |
41 | LogicalResult x86vector::MaskCompressOp::verify() { |
42 | if (getSrc() && getConstantSrc()) |
43 | return emitError("cannot use both src and constant_src" ); |
44 | |
45 | if (getSrc() && (getSrc().getType() != getDst().getType())) |
46 | return emitError("failed to verify that src and dst have same type" ); |
47 | |
48 | if (getConstantSrc() && (getConstantSrc()->getType() != getDst().getType())) |
49 | return emitError( |
50 | "failed to verify that constant_src and dst have same type" ); |
51 | |
52 | return success(); |
53 | } |
54 | |
55 | SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands( |
56 | ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, |
57 | RewriterBase &rewriter) { |
58 | auto loc = getLoc(); |
59 | Adaptor adaptor(operands, *this); |
60 | |
61 | auto opType = adaptor.getA().getType(); |
62 | Value src; |
63 | if (adaptor.getSrc()) { |
64 | src = adaptor.getSrc(); |
65 | } else if (adaptor.getConstantSrc()) { |
66 | src = rewriter.create<LLVM::ConstantOp>(loc, opType, |
67 | adaptor.getConstantSrcAttr()); |
68 | } else { |
69 | auto zeroAttr = rewriter.getZeroAttr(opType); |
70 | src = rewriter.create<LLVM::ConstantOp>(loc, opType, zeroAttr); |
71 | } |
72 | |
73 | return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()}; |
74 | } |
75 | |
76 | SmallVector<Value> |
77 | x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands, |
78 | const LLVMTypeConverter &typeConverter, |
79 | RewriterBase &rewriter) { |
80 | SmallVector<Value> intrinsicOperands(operands); |
81 | // Dot product of all elements, broadcasted to all elements. |
82 | Value scale = |
83 | rewriter.create<LLVM::ConstantOp>(getLoc(), rewriter.getI8Type(), 0xff); |
84 | intrinsicOperands.push_back(scale); |
85 | |
86 | return intrinsicOperands; |
87 | } |
88 | |
89 | SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands( |
90 | ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, |
91 | RewriterBase &rewriter) { |
92 | Adaptor adaptor(operands, *this); |
93 | return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(), |
94 | typeConverter, rewriter)}; |
95 | } |
96 | |
97 | SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands( |
98 | ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, |
99 | RewriterBase &rewriter) { |
100 | Adaptor adaptor(operands, *this); |
101 | return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(), |
102 | typeConverter, rewriter)}; |
103 | } |
104 | |
105 | SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands( |
106 | ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter, |
107 | RewriterBase &rewriter) { |
108 | Adaptor adaptor(operands, *this); |
109 | return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(), |
110 | typeConverter, rewriter)}; |
111 | } |
112 | |
113 | #define GET_OP_CLASSES |
114 | #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc" |
115 | |