1 | //===- XeGPUFoldAliasOps.cpp - XeGPU alias ops folders ----------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" |
12 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
13 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
14 | #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" |
15 | #include "mlir/Pass/Pass.h" |
16 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
17 | #include "llvm/Support/Debug.h" |
18 | |
19 | namespace mlir { |
20 | namespace xegpu { |
21 | #define GEN_PASS_DEF_XEGPUFOLDALIASOPS |
22 | #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" |
23 | } // namespace xegpu |
24 | } // namespace mlir |
25 | |
26 | #define DEBUG_TYPE "xegpu-fold-alias-ops" |
27 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
28 | |
29 | using namespace mlir; |
30 | |
31 | namespace { |
32 | /// Merges subview operation with xegpu.create_nd_tdesc operation. |
33 | class XegpuCreateNdDescOpSubViewOpFolder final |
34 | : public OpRewritePattern<xegpu::CreateNdDescOp> { |
35 | public: |
36 | using OpRewritePattern<xegpu::CreateNdDescOp>::OpRewritePattern; |
37 | |
38 | LogicalResult matchAndRewrite(xegpu::CreateNdDescOp descOp, |
39 | PatternRewriter &rewriter) const override; |
40 | }; |
41 | } // namespace |
42 | |
43 | LogicalResult XegpuCreateNdDescOpSubViewOpFolder::matchAndRewrite( |
44 | xegpu::CreateNdDescOp descOp, PatternRewriter &rewriter) const { |
45 | auto subViewOp = descOp.getSource().getDefiningOp<memref::SubViewOp>(); |
46 | |
47 | if (!subViewOp) |
48 | return rewriter.notifyMatchFailure(descOp, "not a subview producer" ); |
49 | if (!subViewOp.hasUnitStride()) |
50 | return rewriter.notifyMatchFailure(descOp, "requires unit strides" ); |
51 | |
52 | SmallVector<Value> resolvedOffsets; |
53 | affine::resolveIndicesIntoOpWithOffsetsAndStrides( |
54 | rewriter, descOp.getLoc(), subViewOp.getMixedOffsets(), |
55 | subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), |
56 | descOp.getMixedOffsets(), resolvedOffsets); |
57 | |
58 | rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>( |
59 | descOp, descOp.getTensorDesc().getType(), subViewOp.getSource(), |
60 | getAsOpFoldResult(resolvedOffsets)); |
61 | |
62 | return success(); |
63 | } |
64 | |
65 | void xegpu::populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns) { |
66 | patterns.add<XegpuCreateNdDescOpSubViewOpFolder>(arg: patterns.getContext()); |
67 | } |
68 | |
69 | namespace { |
70 | |
71 | struct XeGPUFoldAliasOpsPass final |
72 | : public xegpu::impl::XeGPUFoldAliasOpsBase<XeGPUFoldAliasOpsPass> { |
73 | void runOnOperation() override; |
74 | }; |
75 | |
76 | } // namespace |
77 | |
78 | void XeGPUFoldAliasOpsPass::runOnOperation() { |
79 | RewritePatternSet patterns(&getContext()); |
80 | xegpu::populateXeGPUFoldAliasOpsPatterns(patterns); |
81 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
82 | } |
83 | |