1 | //===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Swap a `tensor.extract_slice` with the producer of the source if the producer |
10 | // implements the `TilingInterface`. When used in conjunction with tiling this |
11 | // effectively tiles + fuses the producer with its consumer. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
18 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
19 | #include "mlir/Interfaces/TilingInterface.h" |
20 | |
21 | using namespace mlir; |
22 | |
23 | FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer( |
24 | OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) { |
25 | auto producerOp = dyn_cast<TilingInterface>(producer.getOwner()); |
26 | if (!producerOp) |
27 | return failure(); |
28 | |
29 | // `TilingInterface` currently only supports strides being 1. |
30 | if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { |
31 | return !isConstantIntValue(ofr, value: 1); |
32 | })) |
33 | return failure(); |
34 | |
35 | FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue( |
36 | builder, producer.getResultNumber(), sliceOp.getMixedOffsets(), |
37 | sliceOp.getMixedSizes()); |
38 | if (failed(result: tiledResult)) |
39 | return failure(); |
40 | |
41 | return *tiledResult; |
42 | } |
43 |