1//===- SwapExtractSliceWithProducerPatterns.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Swap a `tensor.extract_slice` with the producer of the source if the producer
10// implements the `TilingInterface`. When used in conjunction with tiling this
11// effectively tiles + fuses the producer with its consumer.
12//
13//===----------------------------------------------------------------------===//
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
18#include "mlir/Dialect/Utils/StaticValueUtils.h"
19#include "mlir/Interfaces/TilingInterface.h"
20
21using namespace mlir;
22
23FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
24 OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producer) {
25 auto producerOp = dyn_cast<TilingInterface>(producer.getOwner());
26 if (!producerOp)
27 return failure();
28
29 // `TilingInterface` currently only supports strides being 1.
30 if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
31 return failure();
32
33 FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
34 builder, producer.getResultNumber(), sliceOp.getMixedOffsets(),
35 sliceOp.getMixedSizes());
36 if (failed(Result: tiledResult))
37 return failure();
38
39 return *tiledResult;
40}
41
42FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
43 OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp,
44 OpOperand &consumer) {
45 auto consumerOp = dyn_cast<TilingInterface>(consumer.getOwner());
46 if (!consumerOp)
47 return failure();
48
49 // `TilingInterface` currently only supports strides being 1.
50 if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
51 return failure();
52
53 FailureOr<TilingResult> tiledResult =
54 consumerOp.getTiledImplementationFromOperandTile(
55 builder, consumer.getOperandNumber(), sliceOp.getMixedOffsets(),
56 sliceOp.getMixedSizes());
57 if (failed(Result: tiledResult))
58 return failure();
59
60 return *tiledResult;
61}
62

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp