1//===- TransposeConv2D.cpp - Convolution transposition -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10#include "mlir/Dialect/Linalg/IR/Linalg.h"
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/Dialect/Tensor/IR/Tensor.h"
13#include "mlir/IR/BuiltinTypes.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/IR/ValueRange.h"
16#include "mlir/Transforms/DialectConversion.h"
17#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/Support/ErrorHandling.h"
20#include "llvm/Support/RWMutex.h"
21#include <memory>
22#include <numeric>
23
24namespace mlir {
25namespace linalg {
26namespace {
27// clang-format off
28/// Convolution converter that applies the following rewrite:
29///
30/// Before:
31///
32/// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
33/// strides = dense<2> : tensor<2xi64>}
34/// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
35/// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
36///
37/// After:
38///
39/// %cst = arith.constant 0.000000e+00 : f32
40/// %0 = tensor.empty() : tensor<2x2x6x8xf32>
41/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
42/// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
43/// permutation = [1, 2, 3, 0]
44/// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
45/// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
46/// -> tensor<1x2x2x8xf32>
47///
48/// with an analogous example for the quantized case.
49// clang-format on
50template <typename FHWCConvOp, typename HWCFConvOp>
51FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
52 FHWCConvOp op) {
53 // Construct a permutation of the filter tensor dimensions. For a 2D
54 // convolution this will be known statically as [1, 2, 3, 0].
55 SmallVector<int64_t> filterPerm = {1, 2, 3, 0};
56
57 // Create the type for the transposed filter tensor.
58 auto filter = op->getOperand(1);
59 auto filterTy = cast<ShapedType>(filter.getType());
60 SmallVector<int64_t> newFilterShape(filterPerm.size());
61 std::generate(std::begin(cont&: newFilterShape), std::end(cont&: newFilterShape),
62 [dim = 0, &filterTy, &filterPerm]() mutable {
63 return filterTy.getShape()[filterPerm[dim++]];
64 });
65
66 // Because linalg.transpose expects an "out" parameter we need to pass it a
67 // tensor of zeros of the result type so here we construct that tensor.
68 auto inputType = op->getOperand(0).getType();
69 auto elementTy = cast<ShapedType>(inputType).getElementType();
70 auto loc = op->getLoc();
71
72 const auto isTensorOp = isa<TensorType>(inputType);
73 Value input;
74 if (isTensorOp) {
75
76 input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy)
77 .getResult();
78 } else {
79 input = rewriter
80 .create<memref::AllocOp>(
81 loc, MemRefType::get(newFilterShape, elementTy))
82 .getResult();
83 }
84
85 // We can then construct the transposition on our filter.
86 auto transpose =
87 rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm);
88
89 Value newFilter;
90 if (isTensorOp) {
91 newFilter = transpose.getResult()[0];
92 } else {
93 newFilter = input;
94 }
95
96 SmallVector<Value> newInputs{op.getInputs()};
97 // The filter is always the second input argument, the other inputs can be
98 // left as they are.
99 newInputs[1] = newFilter;
100 // It is possible the convolution doesn't define any results and its
101 // out argument is just used instead.
102 SmallVector<Type> resultTy;
103 if (op.getNumResults()) {
104 resultTy.push_back(Elt: op->getResult(0).getType());
105 }
106 auto newConv =
107 rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(),
108 op.getStrides(), op.getDilations());
109 rewriter.replaceOp(op, newConv);
110 return newConv.getOperation();
111}
112
113template <typename FHWCConvOp, typename HWCFConvOp>
114class ConvConverter : public OpRewritePattern<FHWCConvOp> {
115public:
116 using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
117 LogicalResult matchAndRewrite(FHWCConvOp op,
118 PatternRewriter &rewriter) const final {
119 if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) {
120 return failure();
121 }
122 return success();
123 }
124};
125} // namespace
126
127FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
128 linalg::Conv2DNhwcFhwcOp op) {
129
130 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp,
131 linalg::Conv2DNhwcHwcfOp>(rewriter, op);
132}
133
134FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
135 linalg::Conv2DNhwcFhwcQOp op) {
136
137 return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp,
138 linalg::Conv2DNhwcHwcfQOp>(rewriter, op);
139}
140
141void populateTransposeConv2DPatterns(RewritePatternSet &patterns) {
142 MLIRContext *context = patterns.getContext();
143 patterns.insert<
144 ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
145 ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
146 context);
147}
148} // namespace linalg
149} // namespace mlir
150

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp