1//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
10#include "mlir/IR/Builders.h"
11#include "mlir/IR/DialectImplementation.h"
12#include "llvm/ADT/TypeSwitch.h"
13
14namespace mlir {
15namespace xegpu {
16
17void XeGPUDialect::initialize() {
18 addTypes<
19#define GET_TYPEDEF_LIST
20#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
21 >();
22 addOperations<
23#define GET_OP_LIST
24#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
25 >();
26 addAttributes<
27#define GET_ATTRDEF_LIST
28#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
29 >();
30}
31
32//===----------------------------------------------------------------------===//
33// XeGPU_TensorDescAttr
34//===----------------------------------------------------------------------===//
35TensorDescAttr TensorDescAttr::get(mlir::MLIRContext *context,
36 xegpu::MemoryScope memory_scope,
37 int array_length, bool boundary_check,
38 bool scattered) {
39 auto scopeAttr = MemoryScopeAttr::get(context, memory_scope);
40 auto lengthAttr =
41 IntegerAttr::get(IntegerType::get(context, 64), array_length);
42 auto boundaryAttr = BoolAttr::get(context, boundary_check);
43 auto scatteredAttr = BoolAttr::get(context, scattered);
44 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr, scatteredAttr);
45}
46
47//===----------------------------------------------------------------------===//
48// XeGPU_TensorDescType
49//===----------------------------------------------------------------------===//
50mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
51 llvm::SmallVector<int64_t> shape;
52 mlir::Type elementType;
53 mlir::FailureOr<mlir::Attribute> encoding;
54
55 // Parse literal '<'
56 if (parser.parseLess())
57 return {};
58
59 auto shapeLoc = parser.getCurrentLocation();
60 if (mlir::failed(parser.parseDimensionList(shape))) {
61 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
62 return {};
63 }
64
65 auto elemTypeLoc = parser.getCurrentLocation();
66 if (mlir::failed(parser.parseType(elementType))) {
67 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
68 return {};
69 }
70
71 // parse optional attributes
72 if (mlir::succeeded(parser.parseOptionalComma())) {
73 encoding = mlir::FieldParser<mlir::Attribute>::parse(parser);
74 if (mlir::failed(encoding)) {
75 parser.emitError(
76 parser.getCurrentLocation(),
77 "Failed to parse the attribute field for TensorDescType.\n");
78 return {};
79 }
80 }
81
82 // Parse literal '>'
83 if (parser.parseGreater())
84 return {};
85
86 return TensorDescType::get(parser.getContext(), shape, elementType,
87 encoding.value_or(mlir::Attribute()));
88}
89
90void TensorDescType::print(::mlir::AsmPrinter &printer) const {
91 printer << "<";
92
93 auto shape = getShape();
94 for (int64_t dim : shape) {
95 if (mlir::ShapedType::isDynamic(dim))
96 printer << '?';
97 else
98 printer << dim;
99 printer << 'x';
100 }
101
102 printer << getElementType();
103
104 if (auto encoding = getEncoding())
105 printer << ", " << encoding;
106
107 printer << ">";
108}
109
110TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
111 mlir::Type elementType, bool scattered,
112 int array_length, MemoryScope memory_scope,
113 bool boundary_check) {
114 auto context = elementType.getContext();
115 auto attr = TensorDescAttr::get(context, memory_scope, array_length,
116 boundary_check, scattered);
117 return Base::get(context, shape, elementType, attr);
118}
119
120} // namespace xegpu
121} // namespace mlir
122
123#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
124#define GET_ATTRDEF_CLASSES
125#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
126#define GET_TYPEDEF_CLASSES
127#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
128

source code of mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp