1//===- TestXeGPUTransforms.cpp -- Test Vector transforms and lowerings ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/GPU/IR/GPUDialect.h"
10#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
11#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
12#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Pass/PassManager.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16
17using namespace mlir;
18using namespace mlir::xegpu;
19
20namespace {
21
22struct TestXeGPUUnrollingPatterns
23 : public PassWrapper<TestXeGPUUnrollingPatterns,
24 OperationPass<gpu::GPUModuleOp>> {
25 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUUnrollingPatterns)
26
27 StringRef getArgument() const final {
28 return "test-xegpu-unrolling-patterns";
29 }
30
31 StringRef getDescription() const final {
32 return "Test lowering patterns to unroll ops in the xegpu dialect";
33 }
34
35 void getDependentDialects(::mlir::DialectRegistry &registry) const override {
36 registry.insert<memref::MemRefDialect>();
37 registry.insert<xegpu::XeGPUDialect>();
38 registry.insert<vector::VectorDialect>();
39 }
40
41 TestXeGPUUnrollingPatterns() = default;
42 TestXeGPUUnrollingPatterns(const TestXeGPUUnrollingPatterns &pass)
43 : PassWrapper(pass) {}
44
45 void runOnOperation() override {
46 MLIRContext *ctx = &getContext();
47 xegpu::UnrollOptions options;
48 options.setNativeShapeFn(
49 [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
50 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
51 xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
52 xegpu::TensorDescType tdescTy;
53 if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
54 tdescTy = createNdOp.getType();
55 } else if (auto updateNdOp =
56 dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
57 tdescTy = updateNdOp.getTensorDescType();
58 } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
59 tdescTy = prefetchNdOp.getTensorDescType();
60 } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
61 tdescTy = loadNdOp.getTensorDescType();
62 } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
63 tdescTy = storeNdOp.getTensorDescType();
64 }
65
66 if (auto layout = tdescTy.getLayoutAttr()) {
67 auto inst_data = layout.getInstData();
68 if (inst_data && layout.isSgLayout())
69 return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
70 inst_data.asArrayRef().end());
71 }
72 }
73
74 if (isa<xegpu::DpasOp>(op))
75 return SmallVector<int64_t>{8, 16, 16};
76
77 return std::nullopt;
78 });
79
80 options.setUnrolledTypesFn(
81 [&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
82 Type elemTy = type.getElementType();
83 Type newTy;
84
85 // TensorDescType needs to drop the inst_data field in the layout
86 // attribute
87 if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
88 Attribute encoding = tdescTy.getEncoding();
89 auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
90 tdescTy.getLayout());
91 if (layout) {
92 if (layout.getLaneLayout() == nullptr)
93 layout = xegpu::LayoutAttr();
94 else
95 layout = layout.dropInstData();
96 }
97 newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
98 layout);
99 } else {
100 newTy = type.clone(tileShape, elemTy);
101 }
102
103 std::optional<SmallVector<int64_t>> ratio =
104 computeShapeRatio(type.getShape(), tileShape);
105 assert(ratio && "Expecting the ratio to be valid.");
106 return SmallVector<Type>(computeProduct(basis: *ratio), newTy);
107 });
108
109 RewritePatternSet patterns(ctx);
110
111 populateXeGPUUnrollPatterns(patterns, options);
112 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
113 }
114};
115
116} // namespace
117
118namespace mlir {
119namespace test {
120void registerTestXeGPULowerings() {
121 PassRegistration<TestXeGPUUnrollingPatterns>();
122}
123} // namespace test
124} // namespace mlir
125

source code of mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp