1 | //===- TestXeGPUTransforms.cpp -- Test Vector transforms and lowerings ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
10 | #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" |
11 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
12 | #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" |
13 | #include "mlir/Pass/Pass.h" |
14 | #include "mlir/Pass/PassManager.h" |
15 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
16 | |
17 | using namespace mlir; |
18 | using namespace mlir::xegpu; |
19 | |
20 | namespace { |
21 | |
22 | struct TestXeGPUUnrollingPatterns |
23 | : public PassWrapper<TestXeGPUUnrollingPatterns, |
24 | OperationPass<gpu::GPUModuleOp>> { |
25 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUUnrollingPatterns) |
26 | |
27 | StringRef getArgument() const final { |
28 | return "test-xegpu-unrolling-patterns" ; |
29 | } |
30 | |
31 | StringRef getDescription() const final { |
32 | return "Test lowering patterns to unroll ops in the xegpu dialect" ; |
33 | } |
34 | |
35 | void getDependentDialects(::mlir::DialectRegistry ®istry) const override { |
36 | registry.insert<memref::MemRefDialect>(); |
37 | registry.insert<xegpu::XeGPUDialect>(); |
38 | registry.insert<vector::VectorDialect>(); |
39 | } |
40 | |
41 | TestXeGPUUnrollingPatterns() = default; |
42 | TestXeGPUUnrollingPatterns(const TestXeGPUUnrollingPatterns &pass) |
43 | : PassWrapper(pass) {} |
44 | |
45 | void runOnOperation() override { |
46 | MLIRContext *ctx = &getContext(); |
47 | xegpu::UnrollOptions options; |
48 | options.setNativeShapeFn( |
49 | [&](Operation *op) -> std::optional<SmallVector<int64_t>> { |
50 | if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, |
51 | xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) { |
52 | xegpu::TensorDescType tdescTy; |
53 | if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) { |
54 | tdescTy = createNdOp.getType(); |
55 | } else if (auto updateNdOp = |
56 | dyn_cast<xegpu::UpdateNdOffsetOp>(op)) { |
57 | tdescTy = updateNdOp.getTensorDescType(); |
58 | } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) { |
59 | tdescTy = prefetchNdOp.getTensorDescType(); |
60 | } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) { |
61 | tdescTy = loadNdOp.getTensorDescType(); |
62 | } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) { |
63 | tdescTy = storeNdOp.getTensorDescType(); |
64 | } |
65 | |
66 | if (auto layout = tdescTy.getLayoutAttr()) { |
67 | auto inst_data = layout.getInstData(); |
68 | if (inst_data && layout.isSgLayout()) |
69 | return SmallVector<int64_t>(inst_data.asArrayRef().begin(), |
70 | inst_data.asArrayRef().end()); |
71 | } |
72 | } |
73 | |
74 | if (isa<xegpu::DpasOp>(op)) |
75 | return SmallVector<int64_t>{8, 16, 16}; |
76 | |
77 | return std::nullopt; |
78 | }); |
79 | |
80 | options.setUnrolledTypesFn( |
81 | [&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> { |
82 | Type elemTy = type.getElementType(); |
83 | Type newTy; |
84 | |
85 | // TensorDescType needs to drop the inst_data field in the layout |
86 | // attribute |
87 | if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) { |
88 | Attribute encoding = tdescTy.getEncoding(); |
89 | auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>( |
90 | tdescTy.getLayout()); |
91 | if (layout) { |
92 | if (layout.getLaneLayout() == nullptr) |
93 | layout = xegpu::LayoutAttr(); |
94 | else |
95 | layout = layout.dropInstData(); |
96 | } |
97 | newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, |
98 | layout); |
99 | } else { |
100 | newTy = type.clone(tileShape, elemTy); |
101 | } |
102 | |
103 | std::optional<SmallVector<int64_t>> ratio = |
104 | computeShapeRatio(type.getShape(), tileShape); |
105 | assert(ratio && "Expecting the ratio to be valid." ); |
106 | return SmallVector<Type>(computeProduct(basis: *ratio), newTy); |
107 | }); |
108 | |
109 | RewritePatternSet patterns(ctx); |
110 | |
111 | populateXeGPUUnrollPatterns(patterns, options); |
112 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
113 | } |
114 | }; |
115 | |
116 | } // namespace |
117 | |
118 | namespace mlir { |
119 | namespace test { |
120 | void registerTestXeGPULowerings() { |
121 | PassRegistration<TestXeGPUUnrollingPatterns>(); |
122 | } |
123 | } // namespace test |
124 | } // namespace mlir |
125 | |