1 | //===- TosaMakeBroadcastable.cpp ------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Insert reshape to binary op's input if needed to match rank |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
14 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
15 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
16 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
17 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
18 | #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" |
19 | #include "mlir/Pass/Pass.h" |
20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
21 | |
22 | namespace mlir { |
23 | namespace tosa { |
24 | #define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE |
25 | #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" |
26 | } // namespace tosa |
27 | } // namespace mlir |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::tosa; |
31 | |
32 | namespace { |
33 | |
34 | /// Common code to create the reshape op where necessary to make the rank of the |
35 | /// operations equal. input1 and input2 will be updated when the rank has |
36 | /// changed. The caller is expected to use these to rewrite the original |
37 | /// operator with the RESHAPE now in the graph. |
38 | /// return failure when (1) no reshape needed, or (2) output_type is specified |
39 | /// and it has different rank |
40 | LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, |
41 | RankedTensorType outputType, Value &input1, |
42 | Value &input2) { |
43 | auto input1Ty = dyn_cast<RankedTensorType>(input1.getType()); |
44 | auto input2Ty = dyn_cast<RankedTensorType>(input2.getType()); |
45 | |
46 | if (!input1Ty || !input2Ty) { |
47 | return rewriter.notifyMatchFailure(loc, "input not a ranked tensor" ); |
48 | } |
49 | |
50 | int64_t input1Rank = input1Ty.getRank(); |
51 | int64_t input2Rank = input2Ty.getRank(); |
52 | |
53 | if (input1Rank == input2Rank) |
54 | return rewriter.notifyMatchFailure(loc, |
55 | "cannot rewrite as its already correct" ); |
56 | |
57 | Value input1Copy = input1; |
58 | Value input2Copy = input2; |
59 | if (EqualizeRanks(rewriter, loc, input1&: input1Copy, input2&: input2Copy).failed()) { |
60 | return rewriter.notifyMatchFailure(loc, "failed to reshape inputs" ); |
61 | } |
62 | |
63 | // Verify the rank agrees with the output type if the output type is ranked. |
64 | if (outputType) { |
65 | if (outputType.getRank() != |
66 | llvm::cast<RankedTensorType>(input1Copy.getType()).getRank() || |
67 | outputType.getRank() != |
68 | llvm::cast<RankedTensorType>(input2Copy.getType()).getRank()) |
69 | return rewriter.notifyMatchFailure( |
70 | loc, "the reshaped type doesn't agrees with the ranked output type" ); |
71 | } |
72 | |
73 | input1 = input1Copy; |
74 | input2 = input2Copy; |
75 | |
76 | return success(); |
77 | } |
78 | |
79 | template <typename OpTy> |
80 | struct ConvertTosaOp : public OpRewritePattern<OpTy> { |
81 | using OpRewritePattern<OpTy>::OpRewritePattern; |
82 | |
83 | LogicalResult matchAndRewrite(OpTy tosaBinaryOp, |
84 | PatternRewriter &rewriter) const override { |
85 | |
86 | Value input1 = tosaBinaryOp.getInput1(); |
87 | Value input2 = tosaBinaryOp.getInput2(); |
88 | Value output = tosaBinaryOp.getResult(); |
89 | |
90 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
91 | if (!outputType) |
92 | return failure(); |
93 | |
94 | if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, |
95 | input1, input2) |
96 | .failed()) |
97 | return failure(); |
98 | |
99 | rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, input1, input2); |
100 | |
101 | return success(); |
102 | } |
103 | }; |
104 | |
105 | // The MulOp has an extra parameter 'shift' not present in other elementwise |
106 | // binary ops, that necessitates special handling of its builder. |
107 | template <> |
108 | struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> { |
109 | using OpRewritePattern<tosa::MulOp>::OpRewritePattern; |
110 | |
111 | LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp, |
112 | PatternRewriter &rewriter) const override { |
113 | |
114 | Value input1 = tosaBinaryOp.getInput1(); |
115 | Value input2 = tosaBinaryOp.getInput2(); |
116 | int32_t shift = tosaBinaryOp.getShift(); |
117 | Value output = tosaBinaryOp.getResult(); |
118 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
119 | if (!outputType) |
120 | return failure(); |
121 | |
122 | if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, |
123 | input1, input2) |
124 | .failed()) |
125 | return failure(); |
126 | |
127 | rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, input1, |
128 | input2, shift); |
129 | |
130 | return success(); |
131 | } |
132 | }; |
133 | |
134 | // The ArithmeticRightShiftOp has an extra parameter 'round' not present in |
135 | // other elementwise binary ops, that necessitates special handling of its |
136 | // builder. |
137 | template <> |
138 | struct ConvertTosaOp<tosa::ArithmeticRightShiftOp> |
139 | : public OpRewritePattern<tosa::ArithmeticRightShiftOp> { |
140 | using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern; |
141 | |
142 | LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp, |
143 | PatternRewriter &rewriter) const override { |
144 | |
145 | Value input1 = tosaBinaryOp.getInput1(); |
146 | Value input2 = tosaBinaryOp.getInput2(); |
147 | int32_t round = tosaBinaryOp.getRound(); |
148 | Value output = tosaBinaryOp.getResult(); |
149 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
150 | if (!outputType) |
151 | return failure(); |
152 | |
153 | if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, |
154 | input1, input2) |
155 | .failed()) |
156 | return failure(); |
157 | |
158 | rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>( |
159 | tosaBinaryOp, outputType, input1, input2, round); |
160 | |
161 | return success(); |
162 | } |
163 | }; |
164 | |
165 | template <> |
166 | struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> { |
167 | using OpRewritePattern<tosa::SelectOp>::OpRewritePattern; |
168 | |
169 | LogicalResult matchAndRewrite(tosa::SelectOp tosaOp, |
170 | PatternRewriter &rewriter) const override { |
171 | |
172 | Value input1 = tosaOp.getPred(); |
173 | Value input2 = tosaOp.getOnTrue(); |
174 | Value input3 = tosaOp.getOnFalse(); |
175 | Value output = tosaOp.getResult(); |
176 | |
177 | auto outputType = dyn_cast<RankedTensorType>(output.getType()); |
178 | if (!outputType) |
179 | return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor" ); |
180 | |
181 | // Apply broadcasting to each pair of inputs separately, and chain them as |
182 | // compound as below so that the broadcasting happens all at once. |
183 | bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, |
184 | input1, input2) |
185 | .succeeded(); |
186 | |
187 | bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, |
188 | input1, input3) |
189 | .succeeded(); |
190 | |
191 | bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, |
192 | input2, input3) |
193 | .succeeded(); |
194 | |
195 | if (!reshaped1 && !reshaped2 && !reshaped3) |
196 | return rewriter.notifyMatchFailure( |
197 | tosaOp, |
198 | "cannot rewrite as the rank of all operands is already aligned" ); |
199 | |
200 | int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank(); |
201 | int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank(); |
202 | int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank(); |
203 | int32_t outputRank = outputType.getRank(); |
204 | |
205 | if ((result1Rank != result2Rank) || (result2Rank != result3Rank) || |
206 | (result1Rank != outputRank)) |
207 | return rewriter.notifyMatchFailure( |
208 | tosaOp, "not all ranks are aligned with each other" ); |
209 | |
210 | rewriter.replaceOpWithNewOp<tosa::SelectOp>(tosaOp, outputType, input1, |
211 | input2, input3); |
212 | |
213 | return success(); |
214 | } |
215 | }; |
216 | } // namespace |
217 | |
218 | namespace { |
219 | /// Pass that enables broadcast by making all input arrays have the same |
220 | /// number of dimensions. Insert RESHAPE operations to lower rank operand |
221 | struct TosaMakeBroadcastable |
222 | : public tosa::impl::TosaMakeBroadcastableBase<TosaMakeBroadcastable> { |
223 | public: |
224 | void runOnOperation() override { |
225 | auto func = getOperation(); |
226 | RewritePatternSet patterns(func.getContext()); |
227 | MLIRContext *ctx = func.getContext(); |
228 | // Add the generated patterns to the list. |
229 | patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx); |
230 | patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx); |
231 | patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx); |
232 | patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx); |
233 | patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx); |
234 | patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx); |
235 | patterns.add<ConvertTosaOp<tosa::DivOp>>(ctx); |
236 | patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx); |
237 | patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx); |
238 | patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx); |
239 | patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx); |
240 | patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx); |
241 | patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx); |
242 | patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx); |
243 | patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx); |
244 | patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx); |
245 | patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx); |
246 | patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx); |
247 | patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx); |
248 | patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx); |
249 | (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); |
250 | } |
251 | }; |
252 | } // namespace |
253 | |
254 | std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() { |
255 | return std::make_unique<TosaMakeBroadcastable>(); |
256 | } |
257 | |