1 | //===- TransposeConv2D.cpp - Convolution transposition -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
10 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
11 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
12 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
13 | #include "mlir/IR/BuiltinTypes.h" |
14 | #include "mlir/IR/PatternMatch.h" |
15 | #include "mlir/IR/ValueRange.h" |
16 | #include "mlir/Transforms/DialectConversion.h" |
17 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
18 | #include "llvm/ADT/SmallVector.h" |
19 | #include "llvm/Support/ErrorHandling.h" |
20 | #include "llvm/Support/RWMutex.h" |
21 | #include <memory> |
22 | #include <numeric> |
23 | |
24 | namespace mlir { |
25 | namespace linalg { |
26 | namespace { |
27 | // clang-format off |
28 | /// Convolution converter that applies the following rewrite: |
29 | /// |
30 | /// Before: |
31 | /// |
32 | /// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, |
33 | /// strides = dense<2> : tensor<2xi64>} |
34 | /// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) |
35 | /// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> |
36 | /// |
37 | /// After: |
38 | /// |
39 | /// %cst = arith.constant 0.000000e+00 : f32 |
40 | /// %0 = tensor.empty() : tensor<2x2x6x8xf32> |
41 | /// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32> |
42 | /// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>) |
43 | /// permutation = [1, 2, 3, 0] |
44 | /// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} |
45 | /// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>) |
46 | /// -> tensor<1x2x2x8xf32> |
47 | /// |
48 | /// with an analogous example for the quantized case. |
49 | // clang-format on |
50 | template <typename FHWCConvOp, typename HWCFConvOp> |
51 | FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter, |
52 | FHWCConvOp op) { |
53 | // Construct a permutation of the filter tensor dimensions. For a 2D |
54 | // convolution this will be known statically as [1, 2, 3, 0]. |
55 | SmallVector<int64_t> filterPerm = {1, 2, 3, 0}; |
56 | |
57 | // Create the type for the transposed filter tensor. |
58 | auto filter = op->getOperand(1); |
59 | auto filterTy = cast<ShapedType>(filter.getType()); |
60 | SmallVector<int64_t> newFilterShape(filterPerm.size()); |
61 | std::generate(std::begin(cont&: newFilterShape), std::end(cont&: newFilterShape), |
62 | [dim = 0, &filterTy, &filterPerm]() mutable { |
63 | return filterTy.getShape()[filterPerm[dim++]]; |
64 | }); |
65 | |
66 | // Because linalg.transpose expects an "out" parameter we need to pass it a |
67 | // tensor of zeros of the result type so here we construct that tensor. |
68 | auto inputType = op->getOperand(0).getType(); |
69 | auto elementTy = cast<ShapedType>(inputType).getElementType(); |
70 | auto loc = op->getLoc(); |
71 | |
72 | const auto isTensorOp = isa<TensorType>(inputType); |
73 | Value input; |
74 | if (isTensorOp) { |
75 | |
76 | input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy) |
77 | .getResult(); |
78 | } else { |
79 | input = rewriter |
80 | .create<memref::AllocOp>( |
81 | loc, MemRefType::get(newFilterShape, elementTy)) |
82 | .getResult(); |
83 | } |
84 | |
85 | // We can then construct the transposition on our filter. |
86 | auto transpose = |
87 | rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm); |
88 | |
89 | Value newFilter; |
90 | if (isTensorOp) { |
91 | newFilter = transpose.getResult()[0]; |
92 | } else { |
93 | newFilter = input; |
94 | } |
95 | |
96 | SmallVector<Value> newInputs{op.getInputs()}; |
97 | // The filter is always the second input argument, the other inputs can be |
98 | // left as they are. |
99 | newInputs[1] = newFilter; |
100 | // It is possible the convolution doesn't define any results and its |
101 | // out argument is just used instead. |
102 | SmallVector<Type> resultTy; |
103 | if (op.getNumResults()) { |
104 | resultTy.push_back(Elt: op->getResult(0).getType()); |
105 | } |
106 | auto newConv = |
107 | rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(), |
108 | op.getStrides(), op.getDilations()); |
109 | rewriter.replaceOp(op, newConv); |
110 | return newConv.getOperation(); |
111 | } |
112 | |
113 | template <typename FHWCConvOp, typename HWCFConvOp> |
114 | class ConvConverter : public OpRewritePattern<FHWCConvOp> { |
115 | public: |
116 | using OpRewritePattern<FHWCConvOp>::OpRewritePattern; |
117 | LogicalResult matchAndRewrite(FHWCConvOp op, |
118 | PatternRewriter &rewriter) const final { |
119 | if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) { |
120 | return failure(); |
121 | } |
122 | return success(); |
123 | } |
124 | }; |
125 | } // namespace |
126 | |
127 | FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, |
128 | linalg::Conv2DNhwcFhwcOp op) { |
129 | |
130 | return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp, |
131 | linalg::Conv2DNhwcHwcfOp>(rewriter, op); |
132 | } |
133 | |
134 | FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, |
135 | linalg::Conv2DNhwcFhwcQOp op) { |
136 | |
137 | return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp, |
138 | linalg::Conv2DNhwcHwcfQOp>(rewriter, op); |
139 | } |
140 | |
141 | void populateTransposeConv2DPatterns(RewritePatternSet &patterns) { |
142 | MLIRContext *context = patterns.getContext(); |
143 | patterns.insert< |
144 | ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>, |
145 | ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>( |
146 | context); |
147 | } |
148 | } // namespace linalg |
149 | } // namespace mlir |
150 | |