| 1 | //===- SwapExtractSliceWithFillPatterns.cpp -------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 10 | #include "mlir/IR/PatternMatch.h" |
| 11 | |
| 12 | using namespace mlir; |
| 13 | using namespace mlir::linalg; |
| 14 | |
| 15 | /// swaps: |
| 16 | /// `tensor.extract_slice(linalg.fill(%cst, %init))` |
| 17 | /// with: |
| 18 | /// `linalg.fill(%cst, tensor.extract_slice(%init))` |
| 19 | /// |
| 20 | /// when the linalg.fill op have no other users. |
| 21 | /// This helps to reduce the fill footprint. |
| 22 | struct final |
| 23 | : public OpRewritePattern<tensor::ExtractSliceOp> { |
| 24 | using OpRewritePattern::OpRewritePattern; |
| 25 | |
| 26 | LogicalResult matchAndRewrite(tensor::ExtractSliceOp , |
| 27 | PatternRewriter &rewriter) const override { |
| 28 | auto fillOp = extractOp.getSource().getDefiningOp<FillOp>(); |
| 29 | if (!fillOp || !fillOp->hasOneUse()) |
| 30 | return failure(); |
| 31 | |
| 32 | auto = rewriter.create<tensor::ExtractSliceOp>( |
| 33 | extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0], |
| 34 | extractOp.getMixedOffsets(), extractOp.getMixedSizes(), |
| 35 | extractOp.getMixedStrides()); |
| 36 | rewriter.replaceOpWithNewOp<FillOp>(extractOp, fillOp.getInputs(), |
| 37 | ValueRange{newExtractOp.getResult()}); |
| 38 | return success(); |
| 39 | } |
| 40 | }; |
| 41 | |
| 42 | void mlir::linalg::( |
| 43 | RewritePatternSet &patterns) { |
| 44 | patterns.add<SwapExtractSliceOfFill>(arg: patterns.getContext()); |
| 45 | } |
| 46 | |