1//===- TransposeConv2D.cpp - Convolution transposition -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10#include "mlir/Dialect/Linalg/IR/Linalg.h"
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/Dialect/Tensor/IR/Tensor.h"
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/IR/ValueRange.h"
16#include "mlir/Support/LogicalResult.h"
17#include "mlir/Transforms/DialectConversion.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/Support/ErrorHandling.h"
21#include "llvm/Support/RWMutex.h"
22#include <memory>
23#include <numeric>
24
25namespace mlir {
26namespace linalg {
27namespace {
28// clang-format off
29/// Convolution converter that applies the following rewrite:
30///
31/// Before:
32///
33/// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
34/// strides = dense<2> : tensor<2xi64>}
35/// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
36/// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
37///
38/// After:
39///
40/// %cst = arith.constant 0.000000e+00 : f32
41/// %0 = tensor.empty() : tensor<2x2x6x8xf32>
42/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
43/// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
44/// permutation = [1, 2, 3, 0]
45/// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
46/// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
47/// -> tensor<1x2x2x8xf32>
48///
49/// with an analogous example for the quantized case.
50// clang-format on
51template <typename FHWCConvOp, typename HWCFConvOp>
52FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
53 FHWCConvOp op) {
54 // Construct a permutation of the filter tensor dimensions. For a 2D
55 // convolution this will be known statically as [1, 2, 3, 0].
56 SmallVector<int64_t> filterPerm({1, 2, 3, 0});
57
58 // Create the type for the transposed filter tensor.
59 auto filter = op->getOperand(1);
60 auto filterTy = cast<ShapedType>(filter.getType());
61 SmallVector<int64_t> newFilterShape(filterPerm.size());
62 std::generate(std::begin(cont&: newFilterShape), std::end(cont&: newFilterShape),
63 [dim = 0, &filterTy, &filterPerm]() mutable {
64 return filterTy.getShape()[filterPerm[dim++]];
65 });
66
67 // Because linalg.transpose expects an "out" parameter we need to pass it a
68 // tensor of zeros of the result type so here we construct that tensor.
69 auto inputType = op->getOperand(0).getType();
70 auto elementTy = cast<ShapedType>(inputType).getElementType();
71 auto loc = op->getLoc();
72
73 const auto isTensorOp = isa<TensorType>(inputType);
74 Value input;
75 if (isTensorOp) {
76
77 input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy)
78 .getResult();
79 } else {
80 input = rewriter
81 .create<memref::AllocOp>(
82 loc, MemRefType::get(newFilterShape, elementTy))
83 .getResult();
84 }
85
86 // We can then construct the transposition on our filter.
87 auto transpose =
88 rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm);
89
90 Value newFilter;
91 if (isTensorOp) {
92 newFilter = transpose.getResult()[0];
93 } else {
94 newFilter = input;
95 }
96
97 SmallVector<Value> newInputs{op.getInputs()};
98 // The filter is always the second input argument, the other inputs can be
99 // left as they are.
100 newInputs[1] = newFilter;
101 // It is possible the convolution doesn't define any results and its
102 // out argument is just used instead.
103 SmallVector<Type> resultTy;
104 if (op.getNumResults()) {
105 resultTy.push_back(Elt: op->getResult(0).getType());
106 }
107 auto newConv =
108 rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(),
109 op.getStrides(), op.getDilations());
110 rewriter.replaceOp(op, newConv);
111 return newConv.getOperation();
112}
113
114template <typename FHWCConvOp, typename HWCFConvOp>
115class ConvConverter : public OpRewritePattern<FHWCConvOp> {
116public:
117 using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
118 LogicalResult matchAndRewrite(FHWCConvOp op,
119 PatternRewriter &rewriter) const final {
120 if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) {
121 return failure();
122 }
123 return success();
124 }
125};
126} // namespace
127
128FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
129 linalg::Conv2DNhwcFhwcOp op) {
130
131 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp,
132 linalg::Conv2DNhwcHwcfOp>(rewriter, op);
133}
134
135FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
136 linalg::Conv2DNhwcFhwcQOp op) {
137
138 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp,
139 linalg::Conv2DNhwcHwcfQOp>(rewriter, op);
140}
141
142void populateTranposeConv2DPatterns(RewritePatternSet &patterns) {
143 MLIRContext *context = patterns.getContext();
144 patterns.insert<
145 ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
146 ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
147 context);
148}
149} // namespace linalg
150} // namespace mlir
151

source code of mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp