1//===- X86VectorDialect.cpp - MLIR X86Vector ops implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the X86Vector dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/TypeUtilities.h"
17
18using namespace mlir;
19
20#include "mlir/Dialect/X86Vector/X86VectorInterfaces.cpp.inc"
21
22#include "mlir/Dialect/X86Vector/X86VectorDialect.cpp.inc"
23
24void x86vector::X86VectorDialect::initialize() {
25 addOperations<
26#define GET_OP_LIST
27#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
28 >();
29}
30
31static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
32 const LLVMTypeConverter &typeConverter,
33 RewriterBase &rewriter) {
34 MemRefDescriptor memRefDescriptor(buffer);
35 return memRefDescriptor.bufferPtr(builder&: rewriter, loc, converter: typeConverter, type);
36}
37
38LogicalResult x86vector::MaskCompressOp::verify() {
39 if (getSrc() && getConstantSrc())
40 return emitError(message: "cannot use both src and constant_src");
41
42 if (getSrc() && (getSrc().getType() != getDst().getType()))
43 return emitError(message: "failed to verify that src and dst have same type");
44
45 if (getConstantSrc() && (getConstantSrc()->getType() != getDst().getType()))
46 return emitError(
47 message: "failed to verify that constant_src and dst have same type");
48
49 return success();
50}
51
52SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
53 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
54 RewriterBase &rewriter) {
55 auto loc = getLoc();
56 Adaptor adaptor(operands, *this);
57
58 auto opType = adaptor.getA().getType();
59 Value src;
60 if (adaptor.getSrc()) {
61 src = adaptor.getSrc();
62 } else if (adaptor.getConstantSrc()) {
63 src = rewriter.create<LLVM::ConstantOp>(location: loc, args&: opType,
64 args: adaptor.getConstantSrcAttr());
65 } else {
66 auto zeroAttr = rewriter.getZeroAttr(type: opType);
67 src = rewriter.create<LLVM::ConstantOp>(location: loc, args&: opType, args&: zeroAttr);
68 }
69
70 return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()};
71}
72
73SmallVector<Value>
74x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
75 const LLVMTypeConverter &typeConverter,
76 RewriterBase &rewriter) {
77 SmallVector<Value> intrinsicOperands(operands);
78 // Dot product of all elements, broadcasted to all elements.
79 Value scale =
80 rewriter.create<LLVM::ConstantOp>(location: getLoc(), args: rewriter.getI8Type(), args: 0xff);
81 intrinsicOperands.push_back(Elt: scale);
82
83 return intrinsicOperands;
84}
85
86SmallVector<Value> x86vector::DotInt8Op::getIntrinsicOperands(
87 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
88 RewriterBase &rewriter) {
89 SmallVector<Value> intrinsicOprnds;
90 Adaptor adaptor(operands, *this);
91 intrinsicOprnds.push_back(Elt: adaptor.getW());
92 // Bitcast `a` and `b` to i32
93 Value bitcast_a = rewriter.create<LLVM::BitcastOp>(
94 location: getLoc(),
95 args: VectorType::get(shape: (getA().getType().getShape()[0] / 4),
96 elementType: rewriter.getIntegerType(width: 32)),
97 args: adaptor.getA());
98 intrinsicOprnds.push_back(Elt: bitcast_a);
99 Value bitcast_b = rewriter.create<LLVM::BitcastOp>(
100 location: getLoc(),
101 args: VectorType::get(shape: (getB().getType().getShape()[0] / 4),
102 elementType: rewriter.getIntegerType(width: 32)),
103 args: adaptor.getB());
104 intrinsicOprnds.push_back(Elt: bitcast_b);
105
106 return intrinsicOprnds;
107}
108
109SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
110 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
111 RewriterBase &rewriter) {
112 Adaptor adaptor(operands, *this);
113 return {getMemrefBuffPtr(loc: getLoc(), type: getA().getType(), buffer: adaptor.getA(),
114 typeConverter, rewriter)};
115}
116
117SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
118 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
119 RewriterBase &rewriter) {
120 Adaptor adaptor(operands, *this);
121 return {getMemrefBuffPtr(loc: getLoc(), type: getA().getType(), buffer: adaptor.getA(),
122 typeConverter, rewriter)};
123}
124
125SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
126 ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
127 RewriterBase &rewriter) {
128 Adaptor adaptor(operands, *this);
129 return {getMemrefBuffPtr(loc: getLoc(), type: getA().getType(), buffer: adaptor.getA(),
130 typeConverter, rewriter)};
131}
132
133#define GET_OP_CLASSES
134#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
135

source code of mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp