1//===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Insert reshape to binary op's input if needed to match rank
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/Tensor/IR/Tensor.h"
15#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16#include "mlir/Dialect/Tosa/Transforms/Passes.h"
17#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
18#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22namespace mlir {
23namespace tosa {
24#define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE
25#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
26} // namespace tosa
27} // namespace mlir
28
29using namespace mlir;
30using namespace mlir::tosa;
31
32namespace {
33
34/// Common code to create the reshape op where necessary to make the rank of the
35/// operations equal. input1 and input2 will be updated when the rank has
36/// changed. The caller is expected to use these to rewrite the original
37/// operator with the RESHAPE now in the graph.
38/// return failure when (1) no reshape needed, or (2) output_type is specified
39/// and it has different rank
40LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
41 RankedTensorType outputType, Value &input1,
42 Value &input2) {
43 auto input1Ty = dyn_cast<RankedTensorType>(input1.getType());
44 auto input2Ty = dyn_cast<RankedTensorType>(input2.getType());
45
46 if (!input1Ty || !input2Ty) {
47 return rewriter.notifyMatchFailure(loc, "input not a ranked tensor");
48 }
49
50 int64_t input1Rank = input1Ty.getRank();
51 int64_t input2Rank = input2Ty.getRank();
52
53 if (input1Rank == input2Rank)
54 return rewriter.notifyMatchFailure(loc,
55 "cannot rewrite as its already correct");
56
57 Value input1Copy = input1;
58 Value input2Copy = input2;
59 if (EqualizeRanks(rewriter, loc, input1&: input1Copy, input2&: input2Copy).failed()) {
60 return rewriter.notifyMatchFailure(loc, "failed to reshape inputs");
61 }
62
63 // Verify the rank agrees with the output type if the output type is ranked.
64 if (outputType) {
65 if (outputType.getRank() !=
66 llvm::cast<RankedTensorType>(input1Copy.getType()).getRank() ||
67 outputType.getRank() !=
68 llvm::cast<RankedTensorType>(input2Copy.getType()).getRank())
69 return rewriter.notifyMatchFailure(
70 loc, "the reshaped type doesn't agrees with the ranked output type");
71 }
72
73 input1 = input1Copy;
74 input2 = input2Copy;
75
76 return success();
77}
78
79template <typename OpTy>
80struct ConvertTosaOp : public OpRewritePattern<OpTy> {
81 using OpRewritePattern<OpTy>::OpRewritePattern;
82
83 LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
84 PatternRewriter &rewriter) const override {
85
86 Value input1 = tosaBinaryOp.getInput1();
87 Value input2 = tosaBinaryOp.getInput2();
88 Value output = tosaBinaryOp.getResult();
89
90 auto outputType = dyn_cast<RankedTensorType>(output.getType());
91 if (!outputType)
92 return failure();
93
94 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
95 input1, input2)
96 .failed())
97 return failure();
98
99 rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, input1, input2);
100
101 return success();
102 }
103};
104
105// The MulOp has an extra parameter 'shift' not present in other elementwise
106// binary ops, that necessitates special handling of its builder.
107template <>
108struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
109 using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
110
111 LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
112 PatternRewriter &rewriter) const override {
113
114 Value input1 = tosaBinaryOp.getInput1();
115 Value input2 = tosaBinaryOp.getInput2();
116 int32_t shift = tosaBinaryOp.getShift();
117 Value output = tosaBinaryOp.getResult();
118 auto outputType = dyn_cast<RankedTensorType>(output.getType());
119 if (!outputType)
120 return failure();
121
122 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
123 input1, input2)
124 .failed())
125 return failure();
126
127 rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, input1,
128 input2, shift);
129
130 return success();
131 }
132};
133
134// The ArithmeticRightShiftOp has an extra parameter 'round' not present in
135// other elementwise binary ops, that necessitates special handling of its
136// builder.
137template <>
138struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
139 : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
140 using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
141
142 LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
143 PatternRewriter &rewriter) const override {
144
145 Value input1 = tosaBinaryOp.getInput1();
146 Value input2 = tosaBinaryOp.getInput2();
147 int32_t round = tosaBinaryOp.getRound();
148 Value output = tosaBinaryOp.getResult();
149 auto outputType = dyn_cast<RankedTensorType>(output.getType());
150 if (!outputType)
151 return failure();
152
153 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
154 input1, input2)
155 .failed())
156 return failure();
157
158 rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
159 tosaBinaryOp, outputType, input1, input2, round);
160
161 return success();
162 }
163};
164
165template <>
166struct ConvertTosaOp<tosa::SelectOp> : public OpRewritePattern<tosa::SelectOp> {
167 using OpRewritePattern<tosa::SelectOp>::OpRewritePattern;
168
169 LogicalResult matchAndRewrite(tosa::SelectOp tosaOp,
170 PatternRewriter &rewriter) const override {
171
172 Value input1 = tosaOp.getPred();
173 Value input2 = tosaOp.getOnTrue();
174 Value input3 = tosaOp.getOnFalse();
175 Value output = tosaOp.getResult();
176
177 auto outputType = dyn_cast<RankedTensorType>(output.getType());
178 if (!outputType)
179 return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor");
180
181 // Apply broadcasting to each pair of inputs separately, and chain them as
182 // compound as below so that the broadcasting happens all at once.
183 bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
184 input1, input2)
185 .succeeded();
186
187 bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
188 input1, input3)
189 .succeeded();
190
191 bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType,
192 input2, input3)
193 .succeeded();
194
195 if (!reshaped1 && !reshaped2 && !reshaped3)
196 return rewriter.notifyMatchFailure(
197 tosaOp,
198 "cannot rewrite as the rank of all operands is already aligned");
199
200 int32_t result1Rank = cast<RankedTensorType>(input1.getType()).getRank();
201 int32_t result2Rank = cast<RankedTensorType>(input2.getType()).getRank();
202 int32_t result3Rank = cast<RankedTensorType>(input3.getType()).getRank();
203 int32_t outputRank = outputType.getRank();
204
205 if ((result1Rank != result2Rank) || (result2Rank != result3Rank) ||
206 (result1Rank != outputRank))
207 return rewriter.notifyMatchFailure(
208 tosaOp, "not all ranks are aligned with each other");
209
210 rewriter.replaceOpWithNewOp<tosa::SelectOp>(tosaOp, outputType, input1,
211 input2, input3);
212
213 return success();
214 }
215};
216} // namespace
217
218namespace {
219/// Pass that enables broadcast by making all input arrays have the same
220/// number of dimensions. Insert RESHAPE operations to lower rank operand
221struct TosaMakeBroadcastable
222 : public tosa::impl::TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
223public:
224 void runOnOperation() override {
225 auto func = getOperation();
226 RewritePatternSet patterns(func.getContext());
227 MLIRContext *ctx = func.getContext();
228 // Add the generated patterns to the list.
229 patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx);
230 patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx);
231 patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx);
232 patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
233 patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
234 patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
235 patterns.add<ConvertTosaOp<tosa::DivOp>>(ctx);
236 patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
237 patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
238 patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
239 patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
240 patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
241 patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
242 patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
243 patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
244 patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
245 patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
246 patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
247 patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
248 patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
249 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
250 }
251};
252} // namespace
253
254std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {
255 return std::make_unique<TosaMakeBroadcastable>();
256}
257

source code of mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp