1//===- SwapExtractSliceWithFillPatterns.cpp -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10#include "mlir/IR/PatternMatch.h"
11
12using namespace mlir;
13using namespace mlir::linalg;
14
15/// swaps:
16/// `tensor.extract_slice(linalg.fill(%cst, %init))`
17/// with:
18/// `linalg.fill(%cst, tensor.extract_slice(%init))`
19///
20/// when the linalg.fill op have no other users.
21/// This helps to reduce the fill footprint.
22struct SwapExtractSliceOfFill final
23 : public OpRewritePattern<tensor::ExtractSliceOp> {
24 using OpRewritePattern::OpRewritePattern;
25
26 LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp,
27 PatternRewriter &rewriter) const override {
28 auto fillOp = extractOp.getSource().getDefiningOp<FillOp>();
29 if (!fillOp || !fillOp->hasOneUse())
30 return failure();
31
32 auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
33 extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0],
34 extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
35 extractOp.getMixedStrides());
36 rewriter.replaceOpWithNewOp<FillOp>(extractOp, fillOp.getInputs(),
37 ValueRange{newExtractOp.getResult()});
38 return success();
39 }
40};
41
42void mlir::linalg::populateSwapExtractSliceWithFillPatterns(
43 RewritePatternSet &patterns) {
44 patterns.add<SwapExtractSliceOfFill>(arg: patterns.getContext());
45}
46

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp