1 | //===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
10 | #include "mlir/IR/Builders.h" |
11 | #include "mlir/IR/DialectImplementation.h" |
12 | #include "llvm/ADT/TypeSwitch.h" |
13 | |
14 | namespace mlir { |
15 | namespace xegpu { |
16 | |
17 | void XeGPUDialect::initialize() { |
18 | addTypes< |
19 | #define GET_TYPEDEF_LIST |
20 | #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc> |
21 | >(); |
22 | addOperations< |
23 | #define GET_OP_LIST |
24 | #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc> |
25 | >(); |
26 | addAttributes< |
27 | #define GET_ATTRDEF_LIST |
28 | #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc> |
29 | >(); |
30 | } |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // XeGPU_TensorDescAttr |
34 | //===----------------------------------------------------------------------===// |
35 | TensorDescAttr TensorDescAttr::get(mlir::MLIRContext *context, |
36 | xegpu::MemoryScope memory_scope, |
37 | int array_length, bool boundary_check, |
38 | bool scattered) { |
39 | auto scopeAttr = MemoryScopeAttr::get(context, memory_scope); |
40 | auto lengthAttr = |
41 | IntegerAttr::get(IntegerType::get(context, 64), array_length); |
42 | auto boundaryAttr = BoolAttr::get(context, boundary_check); |
43 | auto scatteredAttr = BoolAttr::get(context, scattered); |
44 | return Base::get(context, scopeAttr, lengthAttr, boundaryAttr, scatteredAttr); |
45 | } |
46 | |
47 | //===----------------------------------------------------------------------===// |
48 | // XeGPU_TensorDescType |
49 | //===----------------------------------------------------------------------===// |
50 | mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { |
51 | llvm::SmallVector<int64_t> shape; |
52 | mlir::Type elementType; |
53 | mlir::FailureOr<mlir::Attribute> encoding; |
54 | |
55 | // Parse literal '<' |
56 | if (parser.parseLess()) |
57 | return {}; |
58 | |
59 | auto shapeLoc = parser.getCurrentLocation(); |
60 | if (mlir::failed(parser.parseDimensionList(shape))) { |
61 | parser.emitError(shapeLoc, "failed to parse parameter 'shape'" ); |
62 | return {}; |
63 | } |
64 | |
65 | auto elemTypeLoc = parser.getCurrentLocation(); |
66 | if (mlir::failed(parser.parseType(elementType))) { |
67 | parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'" ); |
68 | return {}; |
69 | } |
70 | |
71 | // parse optional attributes |
72 | if (mlir::succeeded(parser.parseOptionalComma())) { |
73 | encoding = mlir::FieldParser<mlir::Attribute>::parse(parser); |
74 | if (mlir::failed(encoding)) { |
75 | parser.emitError( |
76 | parser.getCurrentLocation(), |
77 | "Failed to parse the attribute field for TensorDescType.\n" ); |
78 | return {}; |
79 | } |
80 | } |
81 | |
82 | // Parse literal '>' |
83 | if (parser.parseGreater()) |
84 | return {}; |
85 | |
86 | return TensorDescType::get(parser.getContext(), shape, elementType, |
87 | encoding.value_or(mlir::Attribute())); |
88 | } |
89 | |
90 | void TensorDescType::print(::mlir::AsmPrinter &printer) const { |
91 | printer << "<" ; |
92 | |
93 | auto shape = getShape(); |
94 | for (int64_t dim : shape) { |
95 | if (mlir::ShapedType::isDynamic(dim)) |
96 | printer << '?'; |
97 | else |
98 | printer << dim; |
99 | printer << 'x'; |
100 | } |
101 | |
102 | printer << getElementType(); |
103 | |
104 | if (auto encoding = getEncoding()) |
105 | printer << ", " << encoding; |
106 | |
107 | printer << ">" ; |
108 | } |
109 | |
110 | TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape, |
111 | mlir::Type elementType, bool scattered, |
112 | int array_length, MemoryScope memory_scope, |
113 | bool boundary_check) { |
114 | auto context = elementType.getContext(); |
115 | auto attr = TensorDescAttr::get(context, memory_scope, array_length, |
116 | boundary_check, scattered); |
117 | return Base::get(context, shape, elementType, attr); |
118 | } |
119 | |
120 | } // namespace xegpu |
121 | } // namespace mlir |
122 | |
123 | #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc> |
124 | #define GET_ATTRDEF_CLASSES |
125 | #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc> |
126 | #define GET_TYPEDEF_CLASSES |
127 | #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc> |
128 | |