1//===- TestXeGPUTransforms.cpp -- Test Vector transforms and lowerings ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/GPU/IR/GPUDialect.h"
10#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
11#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
12#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Pass/PassManager.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16
17using namespace mlir;
18using namespace mlir::xegpu;
19
20namespace {
21
22#define DEBUG_TYPE "test-xegpu-unroll"
23#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
24#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
25
26struct TestXeGPUUnrollingPatterns
27 : public PassWrapper<TestXeGPUUnrollingPatterns,
28 OperationPass<gpu::GPUModuleOp>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUUnrollingPatterns)
30
31 StringRef getArgument() const final {
32 return "test-xegpu-unrolling-patterns";
33 }
34
35 StringRef getDescription() const final {
36 return "Test lowering patterns to unroll ops in the xegpu dialect";
37 }
38
39 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
40 registry.insert<memref::MemRefDialect>();
41 registry.insert<xegpu::XeGPUDialect>();
42 registry.insert<vector::VectorDialect>();
43 }
44
45 TestXeGPUUnrollingPatterns() = default;
46 TestXeGPUUnrollingPatterns(const TestXeGPUUnrollingPatterns &pass)
47 : PassWrapper(pass) {}
48
49 void runOnOperation() override {
50 MLIRContext *ctx = &getContext();
51 xegpu::UnrollOptions options;
52 options.setNativeShapeFn(
53 [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
54 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
55 xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
56 xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
57 xegpu::LoadGatherOp, xegpu::StoreScatterOp>(Val: op)) {
58 xegpu::TensorDescType tdescTy;
59 if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(Val: op)) {
60 tdescTy = createNdOp.getType();
61 } else if (auto updateNdOp =
62 dyn_cast<xegpu::UpdateNdOffsetOp>(Val: op)) {
63 tdescTy = updateNdOp.getTensorDescType();
64 } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(Val: op)) {
65 tdescTy = prefetchNdOp.getTensorDescType();
66 } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(Val: op)) {
67 tdescTy = loadNdOp.getTensorDescType();
68 } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(Val: op)) {
69 tdescTy = storeNdOp.getTensorDescType();
70 } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(Val: op)) {
71 tdescTy = createOp.getType();
72 } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(Val: op)) {
73 tdescTy = updateOp.getTensorDescType();
74 } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(Val: op)) {
75 tdescTy = prefetchOp.getTensorDescType();
76 } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(Val: op)) {
77 tdescTy = loadOp.getTensorDescType();
78 } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(Val: op)) {
79 tdescTy = storeOp.getTensorDescType();
80 }
81
82 if (auto layout = tdescTy.getLayoutAttr()) {
83 auto inst_data = layout.getInstData();
84 if (inst_data && layout.isSgLayout())
85 return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
86 inst_data.asArrayRef().end());
87 }
88 }
89
90 if (isa<xegpu::DpasOp>(Val: op))
91 return SmallVector<int64_t>{8, 16, 16};
92
93 return std::nullopt;
94 });
95
96 options.setUnrolledTypesFn(
97 [&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
98 Type elemTy = type.getElementType();
99 Type newTy;
100
101 // TensorDescType needs to drop the inst_data field in the layout
102 // attribute
103 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(Val&: type)) {
104 Attribute encoding = tdescTy.getEncoding();
105 auto layout = tdescTy.getLayoutAttr();
106
107 // If the encoding is a ScatterTensorDescAttr, we need to
108 // potentially adjust the chunk size based on the inst_data.
109 if (tdescTy.isScattered()) {
110 int64_t chunkSize = tdescTy.getChunkSizeAsInt();
111
112 if (chunkSize > 1) {
113 int64_t blockedChunkSize = chunkSize;
114 auto instData = layout.getInstData();
115 if (!instData.empty())
116 blockedChunkSize = instData.asArrayRef().back();
117
118 // To create a new attribute with a different chunk_size:
119 auto newEncoding = xegpu::ScatterTensorDescAttr::get(
120 context: ctx, memory_space: tdescTy.getMemorySpace(), chunk_size: blockedChunkSize);
121
122 encoding = newEncoding;
123 }
124 }
125 if (layout) {
126 if (layout.getLaneLayout() == nullptr)
127 layout = xegpu::LayoutAttr();
128 else
129 layout = layout.dropInstData();
130 }
131
132 newTy = xegpu::TensorDescType::get(context: ctx, shape: tileShape, elementType: elemTy, encoding,
133 layout);
134
135 } else {
136 newTy = type.clone(shape: tileShape, elementType: elemTy);
137 }
138
139 std::optional<SmallVector<int64_t>> ratio =
140 computeShapeRatio(shape: type.getShape(), subShape: tileShape);
141 assert(ratio && "Expecting the ratio to be valid.");
142 return SmallVector<Type>(computeProduct(basis: *ratio), newTy);
143 });
144
145 RewritePatternSet patterns(ctx);
146
147 populateXeGPUUnrollPatterns(patterns, options);
148 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
149 }
150};
151
152} // namespace
153
154namespace mlir {
155namespace test {
156void registerTestXeGPULowerings() {
157 PassRegistration<TestXeGPUUnrollingPatterns>();
158}
159} // namespace test
160} // namespace mlir
161

source code of mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp