1 | //===- TransposeConv2D.cpp - Convolution transposition -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
10 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
11 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
12 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
13 | #include "mlir/IR/BuiltinTypes.h" |
14 | #include "mlir/IR/PatternMatch.h" |
15 | #include "mlir/IR/ValueRange.h" |
16 | #include "mlir/Support/LogicalResult.h" |
17 | #include "mlir/Transforms/DialectConversion.h" |
18 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
19 | #include "llvm/ADT/SmallVector.h" |
20 | #include "llvm/Support/ErrorHandling.h" |
21 | #include "llvm/Support/RWMutex.h" |
22 | #include <memory> |
23 | #include <numeric> |
24 | |
25 | namespace mlir { |
26 | namespace linalg { |
27 | namespace { |
28 | // clang-format off |
29 | /// Convolution converter that applies the following rewrite: |
30 | /// |
31 | /// Before: |
32 | /// |
33 | /// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, |
34 | /// strides = dense<2> : tensor<2xi64>} |
35 | /// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>) |
36 | /// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> |
37 | /// |
38 | /// After: |
39 | /// |
40 | /// %cst = arith.constant 0.000000e+00 : f32 |
41 | /// %0 = tensor.empty() : tensor<2x2x6x8xf32> |
42 | /// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32> |
43 | /// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>) |
44 | /// permutation = [1, 2, 3, 0] |
45 | /// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} |
46 | /// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>) |
47 | /// -> tensor<1x2x2x8xf32> |
48 | /// |
49 | /// with an analogous example for the quantized case. |
50 | // clang-format on |
51 | template <typename FHWCConvOp, typename HWCFConvOp> |
52 | FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter, |
53 | FHWCConvOp op) { |
54 | // Construct a permutation of the filter tensor dimensions. For a 2D |
55 | // convolution this will be known statically as [1, 2, 3, 0]. |
56 | SmallVector<int64_t> filterPerm({1, 2, 3, 0}); |
57 | |
58 | // Create the type for the transposed filter tensor. |
59 | auto filter = op->getOperand(1); |
60 | auto filterTy = cast<ShapedType>(filter.getType()); |
61 | SmallVector<int64_t> newFilterShape(filterPerm.size()); |
62 | std::generate(std::begin(cont&: newFilterShape), std::end(cont&: newFilterShape), |
63 | [dim = 0, &filterTy, &filterPerm]() mutable { |
64 | return filterTy.getShape()[filterPerm[dim++]]; |
65 | }); |
66 | |
67 | // Because linalg.transpose expects an "out" parameter we need to pass it a |
68 | // tensor of zeros of the result type so here we construct that tensor. |
69 | auto inputType = op->getOperand(0).getType(); |
70 | auto elementTy = cast<ShapedType>(inputType).getElementType(); |
71 | auto loc = op->getLoc(); |
72 | |
73 | const auto isTensorOp = isa<TensorType>(inputType); |
74 | Value input; |
75 | if (isTensorOp) { |
76 | |
77 | input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy) |
78 | .getResult(); |
79 | } else { |
80 | input = rewriter |
81 | .create<memref::AllocOp>( |
82 | loc, MemRefType::get(newFilterShape, elementTy)) |
83 | .getResult(); |
84 | } |
85 | |
86 | // We can then construct the transposition on our filter. |
87 | auto transpose = |
88 | rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm); |
89 | |
90 | Value newFilter; |
91 | if (isTensorOp) { |
92 | newFilter = transpose.getResult()[0]; |
93 | } else { |
94 | newFilter = input; |
95 | } |
96 | |
97 | SmallVector<Value> newInputs{op.getInputs()}; |
98 | // The filter is always the second input argument, the other inputs can be |
99 | // left as they are. |
100 | newInputs[1] = newFilter; |
101 | // It is possible the convolution doesn't define any results and its |
102 | // out argument is just used instead. |
103 | SmallVector<Type> resultTy; |
104 | if (op.getNumResults()) { |
105 | resultTy.push_back(Elt: op->getResult(0).getType()); |
106 | } |
107 | auto newConv = |
108 | rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(), |
109 | op.getStrides(), op.getDilations()); |
110 | rewriter.replaceOp(op, newConv); |
111 | return newConv.getOperation(); |
112 | } |
113 | |
114 | template <typename FHWCConvOp, typename HWCFConvOp> |
115 | class ConvConverter : public OpRewritePattern<FHWCConvOp> { |
116 | public: |
117 | using OpRewritePattern<FHWCConvOp>::OpRewritePattern; |
118 | LogicalResult matchAndRewrite(FHWCConvOp op, |
119 | PatternRewriter &rewriter) const final { |
120 | if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) { |
121 | return failure(); |
122 | } |
123 | return success(); |
124 | } |
125 | }; |
126 | } // namespace |
127 | |
128 | FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, |
129 | linalg::Conv2DNhwcFhwcOp op) { |
130 | |
131 | return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp, |
132 | linalg::Conv2DNhwcHwcfOp>(rewriter, op); |
133 | } |
134 | |
135 | FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, |
136 | linalg::Conv2DNhwcFhwcQOp op) { |
137 | |
138 | return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp, |
139 | linalg::Conv2DNhwcHwcfQOp>(rewriter, op); |
140 | } |
141 | |
142 | void populateTranposeConv2DPatterns(RewritePatternSet &patterns) { |
143 | MLIRContext *context = patterns.getContext(); |
144 | patterns.insert< |
145 | ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>, |
146 | ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>( |
147 | context); |
148 | } |
149 | } // namespace linalg |
150 | } // namespace mlir |
151 | |