1 | //===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Swap a `tensor.extract_slice` with the producer of the source if the producer |
10 | // implements the `TilingInterface`. When used in conjunction with tiling this |
11 | // effectively tiles + fuses the producer with its consumer. |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
18 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
19 | #include "mlir/Interfaces/TilingInterface.h" |
20 | |
21 | using namespace mlir; |
22 | |
23 | FailureOr<TilingResult> tensor::( |
24 | OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) { |
25 | auto producerOp = dyn_cast<TilingInterface>(producer.getOwner()); |
26 | if (!producerOp) |
27 | return failure(); |
28 | |
29 | // `TilingInterface` currently only supports strides being 1. |
30 | if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) |
31 | return failure(); |
32 | |
33 | FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue( |
34 | builder, producer.getResultNumber(), sliceOp.getMixedOffsets(), |
35 | sliceOp.getMixedSizes()); |
36 | if (failed(Result: tiledResult)) |
37 | return failure(); |
38 | |
39 | return *tiledResult; |
40 | } |
41 | |
42 | FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer( |
43 | OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, |
44 | OpOperand &consumer) { |
45 | auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner()); |
46 | if (!consumerOp) |
47 | return failure(); |
48 | |
49 | // `TilingInterface` currently only supports strides being 1. |
50 | if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) |
51 | return failure(); |
52 | |
53 | FailureOr<TilingResult> tiledResult = |
54 | consumerOp.getTiledImplementationFromOperandTile( |
55 | builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(), |
56 | sliceOp.getMixedSizes()); |
57 | if (failed(Result: tiledResult)) |
58 | return failure(); |
59 | |
60 | return *tiledResult; |
61 | } |
62 | |