1//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/Index/IR/IndexOps.h"
17#include "mlir/Dialect/Linalg/IR/Linalg.h"
18#include "mlir/Dialect/Math/IR/Math.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/Tensor/IR/Tensor.h"
21#include "mlir/Dialect/Tosa/IR/TosaOps.h"
22#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
23#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24#include "mlir/Dialect/Utils/StaticValueUtils.h"
25#include "mlir/IR/ImplicitLocOpBuilder.h"
26#include "mlir/IR/Matchers.h"
27#include "mlir/IR/OpDefinition.h"
28#include "mlir/IR/PatternMatch.h"
29#include "mlir/Transforms/DialectConversion.h"
30#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/Sequence.h"
33
34#include <numeric>
35#include <type_traits>
36
37using namespace mlir;
38using namespace mlir::tosa;
39
40// Helper function to materialize the semantically correct compare and select
41// operations given a binary operation with a specific NaN propagation mode.
42//
43// In the case of "PROPAGATE" semantics no compare and selection is required and
44// this function does nothing.
45//
46// In the case of "IGNORE" semantics this function materializes a comparison of
47// the current operands to the op which will return true for any NaN
48// argument and then selects between the non-NaN operation argument and the
49// calculated result based on whether the lhs or rhs is NaN or not. In pseudo
50// code:
51//
52// In the case that the op is operating on non floating point types we ignore
53// the attribute completely, this is consistent with the TOSA spec which has
54// the following wording: "This attribute is ignored by non floating-point
55// types."
56//
57// binary<op>(lhs, rhs):
58// result = op(lhs, rhs)
59// if lhs == NaN return rhs
60// if rhs == NaN return lhs
61// return result
62template <typename OpTy>
63static Value
64materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
65 Value lhs, Value rhs, Value result) {
66 // NaN propagation has no meaning for non floating point types.
67 if (!isa<FloatType>(Val: getElementTypeOrSelf(val: lhs)))
68 return result;
69
70 auto nanMode = op.getNanMode();
71 if (nanMode == "PROPAGATE")
72 return result;
73
74 // Unordered comparison of NaN against itself will always return true.
75 Value lhsIsNaN = rewriter.create<arith::CmpFOp>(
76 op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
77 Value rhsIsNaN = rewriter.create<arith::CmpFOp>(
78 op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
79 Value rhsOrResult =
80 rewriter.create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
81 return rewriter.create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
82 rhsOrResult);
83}
84
85static Value createLinalgBodyCalculationForElementwiseOp(
86 Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
87 ConversionPatternRewriter &rewriter) {
88 Location loc = op->getLoc();
89 auto elementTy =
90 cast<ShapedType>(op->getOperand(0).getType()).getElementType();
91
92 // tosa::AbsOp
93 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
94 return rewriter.create<math::AbsFOp>(loc, resultTypes, args);
95
96 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
97 auto zero = rewriter.create<arith::ConstantOp>(
98 loc, rewriter.getZeroAttr(elementTy));
99 auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
100 return rewriter.create<arith::MaxSIOp>(loc, args[0], neg);
101 }
102
103 // tosa::AddOp
104 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
105 return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
106
107 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
108 return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
109
110 // tosa::SubOp
111 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
112 return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
113
114 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
115 return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
116
117 // tosa::IntDivOp
118 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
119 return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
120
121 // tosa::ReciprocalOp
122 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
123 auto one =
124 rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
125 return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
126 }
127
128 // tosa::MulOp
129 if (isa<tosa::MulOp>(op)) {
130 auto shiftVal = cast<tosa::MulOp>(op).getShift();
131 DenseElementsAttr shiftElem;
132 if (!matchPattern(shiftVal, m_Constant(bind_value: &shiftElem))) {
133 (void)rewriter.notifyMatchFailure(arg&: op, msg: "shift value of mul not found");
134 return nullptr;
135 }
136
137 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
138
139 if (isa<FloatType>(elementTy)) {
140 if (shift != 0) {
141 (void)rewriter.notifyMatchFailure(arg&: op,
142 msg: "Cannot have shift value for float");
143 return nullptr;
144 }
145 return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
146 }
147
148 if (isa<IntegerType>(elementTy)) {
149 Value a = args[0];
150 Value b = args[1];
151
152 if (shift > 0) {
153 auto shiftConst =
154 rewriter.create<arith::ConstantIntOp>(location: loc, args&: shift, /*bitwidth=*/args: 8);
155 if (!a.getType().isInteger(32))
156 a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
157
158 if (!b.getType().isInteger(32))
159 b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
160
161 auto result = rewriter.create<tosa::ApplyScaleOp>(
162 loc, rewriter.getI32Type(), a, b, shiftConst,
163 rewriter.getStringAttr("SINGLE_ROUND"));
164
165 if (elementTy.isInteger(32))
166 return result;
167
168 return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
169 }
170
171 int aWidth = a.getType().getIntOrFloatBitWidth();
172 int bWidth = b.getType().getIntOrFloatBitWidth();
173 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
174
175 if (aWidth < cWidth)
176 a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
177 if (bWidth < cWidth)
178 b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
179
180 return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
181 }
182 }
183
184 // tosa::NegateOp
185 if (isa<tosa::NegateOp>(op)) {
186 auto negate = cast<tosa::NegateOp>(op);
187
188 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
189 if (failed(Result: maybeInZp)) {
190 (void)rewriter.notifyMatchFailure(
191 arg&: op, msg: "input1 zero point cannot be statically determined");
192 return nullptr;
193 }
194
195 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
196 if (failed(Result: maybeOutZp)) {
197 (void)rewriter.notifyMatchFailure(
198 arg&: op, msg: "output zero point cannot be statically determined");
199 return nullptr;
200 }
201
202 int64_t inZp = *maybeInZp;
203 int64_t outZp = *maybeOutZp;
204
205 if (isa<FloatType>(elementTy))
206 return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
207
208 if (isa<IntegerType>(elementTy)) {
209 if (!inZp && !outZp) {
210 auto constant = rewriter.create<arith::ConstantOp>(
211 loc, IntegerAttr::get(elementTy, 0));
212 return rewriter.create<arith::SubIOp>(loc, resultTypes, constant,
213 args[0]);
214 }
215
216 // Compute the maximum value that can occur in the intermediate buffer.
217 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
218 const int64_t zpAdd = inZp + outZp;
219 const int64_t maxValue =
220 APInt::getSignedMaxValue(numBits: inputBitWidth).getSExtValue() +
221 std::abs(i: zpAdd) + 1;
222
223 // Convert that maximum value into the maximum bitwidth needed to
224 // represent it. We assume 48-bit numbers may be supported further in
225 // the pipeline.
226 int intermediateBitWidth = 64;
227 if (maxValue <= APInt::getSignedMaxValue(numBits: 16).getSExtValue()) {
228 intermediateBitWidth = 16;
229 } else if (maxValue <= APInt::getSignedMaxValue(numBits: 32).getSExtValue()) {
230 intermediateBitWidth = 32;
231 } else if (maxValue <= APInt::getSignedMaxValue(numBits: 48).getSExtValue()) {
232 intermediateBitWidth = 48;
233 }
234
235 Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
236 Value zpAddValue = rewriter.create<arith::ConstantOp>(
237 loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
238
239 // The negation can be applied by doing:
240 // outputValue = inZp + outZp - inputValue
241 auto ext =
242 rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
243 auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
244
245 // Clamp to the negation range.
246 Value min = rewriter.create<arith::ConstantIntOp>(
247 location: loc, args: APInt::getSignedMinValue(numBits: inputBitWidth).getSExtValue(),
248 args&: intermediateType);
249 Value max = rewriter.create<arith::ConstantIntOp>(
250 location: loc, args: APInt::getSignedMaxValue(numBits: inputBitWidth).getSExtValue(),
251 args&: intermediateType);
252 auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false);
253
254 // Truncate to the final value.
255 return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
256 }
257 }
258
259 // tosa::BitwiseAndOp
260 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
261 return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
262
263 // tosa::BitwiseOrOp
264 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
265 return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
266
267 // tosa::BitwiseNotOp
268 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
269 auto allOnesAttr = rewriter.getIntegerAttr(
270 elementTy, APInt::getAllOnes(numBits: elementTy.getIntOrFloatBitWidth()));
271 auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
272 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
273 }
274
275 // tosa::BitwiseXOrOp
276 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
277 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
278
279 // tosa::LogicalLeftShiftOp
280 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
281 return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
282
283 // tosa::LogicalRightShiftOp
284 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
285 return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
286
287 // tosa::ArithmeticRightShiftOp
288 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
289 auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
290 auto round = cast<BoolAttr>(Val: op->getAttr(name: "round")).getValue();
291 if (!round) {
292 return result;
293 }
294
295 Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
296 auto one =
297 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
298 auto zero =
299 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
300 auto i1one =
301 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
302
303 // Checking that input2 != 0
304 auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
305 loc, arith::CmpIPredicate::sgt, args[1], zero);
306
307 // Checking for the last bit of input1 to be 1
308 auto subtract =
309 rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
310 auto shifted =
311 rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
312 ->getResults();
313 auto truncated =
314 rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
315 auto isInputOdd =
316 rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
317
318 auto shouldRound = rewriter.create<arith::AndIOp>(
319 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
320 auto extended =
321 rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
322 return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
323 }
324
325 // tosa::ClzOp
326 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
327 return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
328 }
329
330 // tosa::LogicalAnd
331 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
332 return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
333
334 // tosa::LogicalNot
335 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
336 auto one = rewriter.create<arith::ConstantOp>(
337 loc, rewriter.getIntegerAttr(elementTy, 1));
338 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
339 }
340
341 // tosa::LogicalOr
342 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
343 return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
344
345 // tosa::LogicalXor
346 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
347 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
348
349 // tosa::PowOp
350 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
351 return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
352
353 // tosa::RsqrtOp
354 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
355 return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
356
357 // tosa::LogOp
358 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
359 return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
360
361 // tosa::ExpOp
362 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
363 return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
364
365 // tosa::SinOp
366 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
367 return rewriter.create<mlir::math::SinOp>(loc, resultTypes, args);
368
369 // tosa::CosOp
370 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
371 return rewriter.create<mlir::math::CosOp>(loc, resultTypes, args);
372
373 // tosa::TanhOp
374 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
375 return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
376
377 // tosa::ErfOp
378 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
379 return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args);
380
381 // tosa::GreaterOp
382 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
383 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
384 args[0], args[1]);
385
386 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
387 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
388 args[0], args[1]);
389
390 // tosa::GreaterEqualOp
391 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
392 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
393 args[0], args[1]);
394
395 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
396 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
397 args[0], args[1]);
398
399 // tosa::EqualOp
400 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
401 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
402 args[0], args[1]);
403
404 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
405 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
406 args[0], args[1]);
407
408 // tosa::SelectOp
409 if (isa<tosa::SelectOp>(op)) {
410 elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType();
411 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
412 return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
413 }
414
415 // tosa::MaximumOp
416 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
417 auto max = rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
418 return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MaximumOp>(op),
419 rewriter, args[0], args[1], max);
420 }
421
422 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
423 return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
424 }
425
426 // tosa::MinimumOp
427 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
428 auto min = rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
429 return materializeBinaryNanCheckIfRequired(llvm::cast<tosa::MinimumOp>(op),
430 rewriter, args[0], args[1], min);
431 }
432
433 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
434 return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
435 }
436
437 // tosa::CeilOp
438 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
439 return rewriter.create<math::CeilOp>(loc, resultTypes, args);
440
441 // tosa::FloorOp
442 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
443 return rewriter.create<math::FloorOp>(loc, resultTypes, args);
444
445 // tosa::ClampOp
446 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
447 bool losesInfo = false;
448 APFloat minApf = cast<FloatAttr>(op->getAttr(name: "min_val")).getValue();
449 APFloat maxApf = cast<FloatAttr>(op->getAttr(name: "max_val")).getValue();
450 minApf.convert(ToSemantics: cast<FloatType>(elementTy).getFloatSemantics(),
451 RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo);
452 maxApf.convert(ToSemantics: cast<FloatType>(elementTy).getFloatSemantics(),
453 RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo);
454 auto min = rewriter.create<arith::ConstantOp>(
455 loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
456 auto max = rewriter.create<arith::ConstantOp>(
457 loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
458 auto result = clampFloatHelper(loc, args[0], min, max, rewriter);
459
460 auto clampOp = llvm::cast<tosa::ClampOp>(op);
461 const auto nanMode = clampOp.getNanMode();
462
463 // NaN propagation has no meaning for non floating point types.
464 if (!isa<FloatType>(elementTy))
465 return result;
466
467 // In the case of "PROPAGATE" semantics no compare and selection is
468 // required.
469 if (nanMode == "PROPAGATE")
470 return result;
471
472 // In the case of "IGNORE" semantics materialize a comparison
473 // of the current operand to the reduction which will return true for a NaN
474 // argument and then selects between the initial reduction value and the
475 // calculated result based on whether the argument is NaN or not. In pseudo
476 // code:
477 //
478 // reduce<op>(x, init):
479 // result = op(init, x)
480 // return init if x == NaN else result
481
482 // Unordered comparison of NaN against itself will always return true.
483 Value isNaN = rewriter.create<arith::CmpFOp>(
484 op->getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
485 // TOSA specifies that in "ignore" NaN mode the result is "min" if the input
486 // is NaN.
487 return rewriter.create<arith::SelectOp>(op->getLoc(), isNaN, min, result);
488 }
489
490 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
491 auto intTy = cast<IntegerType>(elementTy);
492 int64_t min =
493 cast<IntegerAttr>(op->getAttr(name: "min_val")).getValue().getSExtValue();
494 int64_t max =
495 cast<IntegerAttr>(op->getAttr(name: "max_val")).getValue().getSExtValue();
496
497 int64_t minRepresentable = std::numeric_limits<int64_t>::min();
498 int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
499 if (intTy.isUnsignedInteger()) {
500 minRepresentable = 0;
501 if (intTy.getIntOrFloatBitWidth() <= 63) {
502 maxRepresentable =
503 (int64_t)APInt::getMaxValue(numBits: intTy.getIntOrFloatBitWidth())
504 .getZExtValue();
505 }
506 } else if (intTy.getIntOrFloatBitWidth() <= 64) {
507 // Ensure that min & max fit into signed n-bit constants.
508 minRepresentable = APInt::getSignedMinValue(numBits: intTy.getIntOrFloatBitWidth())
509 .getSExtValue();
510 maxRepresentable = APInt::getSignedMaxValue(numBits: intTy.getIntOrFloatBitWidth())
511 .getSExtValue();
512 }
513 // Ensure that the bounds are representable as n-bit signed/unsigned
514 // integers.
515 min = std::max(a: min, b: minRepresentable);
516 max = std::max(a: max, b: minRepresentable);
517 min = std::min(a: min, b: maxRepresentable);
518 max = std::min(a: max, b: maxRepresentable);
519
520 auto minVal = rewriter.create<arith::ConstantIntOp>(
521 loc, min, intTy.getIntOrFloatBitWidth());
522 auto maxVal = rewriter.create<arith::ConstantIntOp>(
523 loc, max, intTy.getIntOrFloatBitWidth());
524 return clampIntHelper(loc, args[0], minVal, maxVal, rewriter,
525 intTy.isUnsignedInteger());
526 }
527
528 // tosa::SigmoidOp
529 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
530 auto one =
531 rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
532 auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
533 auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
534 auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
535 return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
536 }
537
538 // tosa::CastOp
539 if (isa<tosa::CastOp>(op)) {
540 Type srcTy = elementTy;
541 Type dstTy = resultTypes.front();
542 if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) {
543 (void)rewriter.notifyMatchFailure(arg&: op, msg: "unsupported type");
544 return nullptr;
545 }
546
547 bool bitExtend =
548 srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
549
550 if (srcTy == dstTy)
551 return args.front();
552
553 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
554 return rewriter.create<arith::ExtFOp>(loc, resultTypes, args,
555 std::nullopt);
556
557 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
558 return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
559 std::nullopt);
560
561 // 1-bit integers need to be treated as signless.
562 if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
563 return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
564 std::nullopt);
565
566 if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
567 return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
568 std::nullopt);
569
570 // Unsigned integers need an unrealized cast so that they can be passed
571 // to UIToFP.
572 if (srcTy.isUnsignedInteger() && isa<FloatType>(Val: dstTy)) {
573 auto unrealizedCast =
574 rewriter
575 .create<UnrealizedConversionCastOp>(
576 loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
577 args[0])
578 .getResult(0);
579 return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
580 unrealizedCast);
581 }
582
583 // All other si-to-fp conversions should be handled by SIToFP.
584 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
585 return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
586 std::nullopt);
587
588 // Casting to boolean, floats need to only be checked as not-equal to zero.
589 if (isa<FloatType>(Val: srcTy) && dstTy.isInteger(width: 1)) {
590 Value zero = rewriter.create<arith::ConstantOp>(
591 loc, rewriter.getFloatAttr(srcTy, 0.0));
592 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
593 args.front(), zero);
594 }
595
596 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
597 auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
598
599 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
600 // Check whether neither int min nor int max can be represented in the
601 // input floating-point type due to too short exponent range.
602 if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 >
603 APFloat::semanticsMaxExponent(fltSemantics)) {
604 // Use cmp + select to replace infinites by int min / int max. Other
605 // integral values can be represented in the integer space.
606 auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded);
607 auto posInf = rewriter.create<arith::ConstantOp>(
608 loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy),
609 APFloat::getInf(fltSemantics)));
610 auto negInf = rewriter.create<arith::ConstantOp>(
611 loc, rewriter.getFloatAttr(
612 getElementTypeOrSelf(srcTy),
613 APFloat::getInf(fltSemantics, /*Negative=*/true)));
614 auto overflow = rewriter.create<arith::CmpFOp>(
615 loc, arith::CmpFPredicate::UEQ, rounded, posInf);
616 auto underflow = rewriter.create<arith::CmpFOp>(
617 loc, arith::CmpFPredicate::UEQ, rounded, negInf);
618 auto intMin = rewriter.create<arith::ConstantOp>(
619 loc, rewriter.getIntegerAttr(
620 getElementTypeOrSelf(dstTy),
621 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
622 auto intMax = rewriter.create<arith::ConstantOp>(
623 loc, rewriter.getIntegerAttr(
624 getElementTypeOrSelf(dstTy),
625 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
626 auto maxClamped =
627 rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv);
628 return rewriter.create<arith::SelectOp>(loc, underflow, intMin,
629 maxClamped);
630 }
631
632 auto intMinFP = rewriter.create<arith::ConstantOp>(
633 loc, rewriter.getFloatAttr(
634 getElementTypeOrSelf(srcTy),
635 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
636 .getSExtValue()));
637
638 // Check whether the mantissa has enough bits to represent int max.
639 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
640 dstTy.getIntOrFloatBitWidth() - 1) {
641 // Int min can also be represented since it is a power of two and thus
642 // consists of a single leading bit. Therefore we can clamp the input
643 // in the floating-point domain.
644
645 auto intMaxFP = rewriter.create<arith::ConstantOp>(
646 loc, rewriter.getFloatAttr(
647 getElementTypeOrSelf(srcTy),
648 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
649 .getSExtValue()));
650
651 Value clamped =
652 clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter);
653 return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
654 }
655
656 // Due to earlier check we know exponant range is big enough to represent
657 // int min. We can therefore rely on int max + 1 being representable as
658 // well because it's just int min with a positive sign. So clamp the min
659 // value and compare against that to select the max int value if needed.
660 auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
661 loc, rewriter.getFloatAttr(
662 getElementTypeOrSelf(srcTy),
663 static_cast<double>(
664 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
665 .getSExtValue()) +
666 1.0f));
667
668 auto intMax = rewriter.create<arith::ConstantOp>(
669 loc, rewriter.getIntegerAttr(
670 getElementTypeOrSelf(dstTy),
671 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
672 auto minClampedFP =
673 rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP);
674 auto minClamped =
675 rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
676 auto overflow = rewriter.create<arith::CmpFOp>(
677 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
678 return rewriter.create<arith::SelectOp>(loc, overflow, intMax,
679 minClamped);
680 }
681
682 // Casting to boolean, integers need to only be checked as not-equal to
683 // zero.
684 if (isa<IntegerType>(Val: srcTy) && dstTy.isInteger(width: 1)) {
685 Value zero = rewriter.create<arith::ConstantIntOp>(
686 location: loc, args: 0, args: srcTy.getIntOrFloatBitWidth());
687 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
688 args.front(), zero);
689 }
690
691 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
692 return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
693 std::nullopt);
694
695 if (isa<IntegerType>(Val: srcTy) && isa<IntegerType>(Val: dstTy) && !bitExtend) {
696 return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]);
697 }
698 }
699
700 (void)rewriter.notifyMatchFailure(
701 arg&: op, msg: "unhandled op for linalg body calculation for elementwise op");
702 return nullptr;
703}
704
705using IndexPool = DenseMap<int64_t, Value>;
706
707// Emit an 'arith.constant' op for the given index if it has not been created
708// yet, or return an existing constant. This will prevent an excessive creation
709// of redundant constants, easing readability of emitted code for unit tests.
710static Value createIndex(PatternRewriter &rewriter, Location loc,
711 IndexPool &indexPool, int64_t index) {
712 auto [it, inserted] = indexPool.try_emplace(Key: index);
713 if (inserted)
714 it->second =
715 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index));
716 return it->second;
717}
718
719static Value getTensorDim(PatternRewriter &rewriter, Location loc,
720 IndexPool &indexPool, Value tensor, int64_t index) {
721 auto indexValue = createIndex(rewriter, loc, indexPool, index);
722 return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult();
723}
724
725static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc,
726 IndexPool &indexPool, Value tensor,
727 int64_t index) {
728 auto shapedType = dyn_cast<ShapedType>(tensor.getType());
729 assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type");
730 assert(index >= 0 && index < shapedType.getRank() && "index out of bounds");
731 if (shapedType.isDynamicDim(index))
732 return getTensorDim(rewriter, loc, indexPool, tensor, index);
733 return rewriter.getIndexAttr(value: shapedType.getDimSize(index));
734}
735
736static bool operandsAndResultsRanked(Operation *operation) {
737 auto isRanked = [](Value value) {
738 return isa<RankedTensorType>(Val: value.getType());
739 };
740 return llvm::all_of(Range: operation->getOperands(), P: isRanked) &&
741 llvm::all_of(Range: operation->getResults(), P: isRanked);
742}
743
744// Compute the runtime dimension size for dimension 'dim' of the output by
745// inspecting input 'operands', all of which are expected to have the same rank.
746// This function returns a pair {targetSize, masterOperand}.
747//
748// The runtime size of the output dimension is returned either as a statically
749// computed attribute or as a runtime SSA value.
750//
751// If the target size was inferred directly from one dominating operand, that
752// operand is returned in 'masterOperand'. If the target size is inferred from
753// multiple operands, 'masterOperand' is set to nullptr.
754static std::pair<OpFoldResult, Value>
755computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool,
756 ValueRange operands, int64_t dim) {
757 // If any input operand contains a static size greater than 1 for this
758 // dimension, that is the target size. An occurrence of an additional static
759 // dimension greater than 1 with a different value is undefined behavior.
760 for (auto operand : operands) {
761 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
762 if (!ShapedType::isDynamic(size) && size > 1)
763 return {rewriter.getIndexAttr(value: size), operand};
764 }
765
766 // Filter operands with dynamic dimension
767 auto operandsWithDynamicDim =
768 llvm::filter_to_vector(C&: operands, Pred: [&](Value operand) {
769 return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim);
770 });
771
772 // If no operand has a dynamic dimension, it means all sizes were 1
773 if (operandsWithDynamicDim.empty())
774 return {rewriter.getIndexAttr(1), operands.front()};
775
776 // Emit code that computes the runtime size for this dimension. If there is
777 // only one operand with a dynamic dimension, it is considered the master
778 // operand that determines the runtime size of the output dimension.
779 auto targetSize =
780 getTensorDim(rewriter, loc, indexPool, tensor: operandsWithDynamicDim[0], index: dim);
781 if (operandsWithDynamicDim.size() == 1)
782 return {targetSize, operandsWithDynamicDim[0]};
783
784 // Calculate maximum size among all dynamic dimensions
785 for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
786 auto nextSize =
787 getTensorDim(rewriter, loc, indexPool, tensor: operandsWithDynamicDim[i], index: dim);
788 targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize);
789 }
790 return {targetSize, nullptr};
791}
792
793// Compute the runtime output size for all dimensions. This function returns
794// a pair {targetShape, masterOperands}.
795static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>>
796computeTargetShape(PatternRewriter &rewriter, Location loc,
797 IndexPool &indexPool, ValueRange operands) {
798 assert(!operands.empty());
799 auto rank = cast<RankedTensorType>(operands.front().getType()).getRank();
800 SmallVector<OpFoldResult> targetShape;
801 SmallVector<Value> masterOperands;
802 for (auto dim : llvm::seq<int64_t>(0, rank)) {
803 auto [targetSize, masterOperand] =
804 computeTargetSize(rewriter, loc, indexPool, operands, dim);
805 targetShape.push_back(targetSize);
806 masterOperands.push_back(masterOperand);
807 }
808 return {targetShape, masterOperands};
809}
810
811static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
812 IndexPool &indexPool, Value operand,
813 int64_t dim, OpFoldResult targetSize,
814 Value masterOperand) {
815 // Nothing to do if this is a static dimension
816 auto rankedTensorType = cast<RankedTensorType>(operand.getType());
817 if (!rankedTensorType.isDynamicDim(dim))
818 return operand;
819
820 // If the target size for this dimension was directly inferred by only taking
821 // this operand into account, there is no need to broadcast. This is an
822 // optimization that will prevent redundant control flow, and constitutes the
823 // main motivation for tracking "master operands".
824 if (operand == masterOperand)
825 return operand;
826
827 // Affine maps for 'linalg.generic' op
828 auto rank = rankedTensorType.getRank();
829 SmallVector<AffineExpr> affineExprs;
830 for (auto index : llvm::seq<int64_t>(0, rank)) {
831 auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0)
832 : rewriter.getAffineDimExpr(index);
833 affineExprs.push_back(affineExpr);
834 }
835 auto broadcastAffineMap =
836 AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
837 auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: rank);
838 SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap};
839
840 // Check if broadcast is necessary
841 auto one = createIndex(rewriter, loc, indexPool, index: 1);
842 auto runtimeSize = getTensorDim(rewriter, loc, indexPool, tensor: operand, index: dim);
843 auto broadcastNecessary = rewriter.create<arith::CmpIOp>(
844 loc, arith::CmpIPredicate::eq, runtimeSize, one);
845
846 // Emit 'then' region of 'scf.if'
847 auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
848 // It is not safe to cache constants across regions.
849 // New constants could potentially violate dominance requirements.
850 IndexPool localPool;
851
852 // Emit 'tensor.empty' op
853 SmallVector<OpFoldResult> outputTensorShape;
854 for (auto index : llvm::seq<int64_t>(0, rank)) {
855 auto size = index == dim ? targetSize
856 : getOrFoldTensorDim(rewriter, loc, localPool,
857 operand, index);
858 outputTensorShape.push_back(size);
859 }
860 Value outputTensor = opBuilder.create<tensor::EmptyOp>(
861 loc, outputTensorShape, rankedTensorType.getElementType());
862
863 // Emit 'linalg.generic' op
864 auto resultTensor =
865 opBuilder
866 .create<linalg::GenericOp>(
867 loc, outputTensor.getType(), operand, outputTensor, affineMaps,
868 getNParallelLoopsAttrs(rank),
869 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
870 // Emit 'linalg.yield' op
871 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
872 })
873 .getResult(0);
874
875 // Cast to original operand type if necessary
876 auto castResultTensor = rewriter.createOrFold<tensor::CastOp>(
877 loc, operand.getType(), resultTensor);
878
879 // Emit 'scf.yield' op
880 opBuilder.create<scf::YieldOp>(loc, castResultTensor);
881 };
882
883 // Emit 'else' region of 'scf.if'
884 auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) {
885 opBuilder.create<scf::YieldOp>(loc, operand);
886 };
887
888 // Emit 'scf.if' op
889 auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary,
890 emitThenRegion, emitElseRegion);
891 return ifOp.getResult(0);
892}
893
894static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
895 IndexPool &indexPool, Value operand,
896 ArrayRef<OpFoldResult> targetShape,
897 ArrayRef<Value> masterOperands) {
898 int64_t rank = cast<RankedTensorType>(operand.getType()).getRank();
899 assert((int64_t)targetShape.size() == rank);
900 assert((int64_t)masterOperands.size() == rank);
901 for (auto index : llvm::seq<int64_t>(0, rank))
902 operand =
903 broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
904 targetShape[index], masterOperands[index]);
905 return operand;
906}
907
908static SmallVector<Value>
909broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
910 IndexPool &indexPool, ValueRange operands,
911 ArrayRef<OpFoldResult> targetShape,
912 ArrayRef<Value> masterOperands) {
913 // No need to broadcast for unary operations
914 if (operands.size() == 1)
915 return operands;
916
917 // Broadcast dynamic dimensions operand by operand
918 return llvm::map_to_vector(C&: operands, F: [&](Value operand) {
919 return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
920 targetShape, masterOperands);
921 });
922}
923
924static LogicalResult
925emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
926 Operation *operation, ValueRange operands,
927 ArrayRef<OpFoldResult> targetShape,
928 const TypeConverter &converter) {
929 // Generate output tensor
930 auto resultType = cast_or_null<RankedTensorType>(
931 converter.convertType(t: operation->getResultTypes().front()));
932 if (!resultType) {
933 return rewriter.notifyMatchFailure(arg&: operation, msg: "failed to convert type");
934 }
935 Value outputTensor = rewriter.create<tensor::EmptyOp>(
936 loc, targetShape, resultType.getElementType());
937
938 // Create affine maps. Input affine maps broadcast static dimensions of size
939 // 1. The output affine map is an identity map.
940 //
941 auto rank = resultType.getRank();
942 auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) {
943 auto shape = cast<ShapedType>(operand.getType()).getShape();
944 SmallVector<AffineExpr> affineExprs;
945 for (auto it : llvm::enumerate(shape)) {
946 // Prefer producting identity maps whenever possible (i.e. no broadcasting
947 // needed) because some transforms (like reshape folding)
948 // do not support affine constant exprs.
949 bool requiresBroadcast =
950 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
951 auto affineExpr = requiresBroadcast
952 ? rewriter.getAffineConstantExpr(0)
953 : rewriter.getAffineDimExpr(it.index());
954 affineExprs.push_back(affineExpr);
955 }
956 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
957 });
958 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank: rank));
959
960 // Emit 'linalg.generic' op
961 bool encounteredError = false;
962 auto linalgOp = rewriter.create<linalg::GenericOp>(
963 loc, outputTensor.getType(), operands, outputTensor, affineMaps,
964 getNParallelLoopsAttrs(rank),
965 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
966 Value opResult = createLinalgBodyCalculationForElementwiseOp(
967 operation, blockArgs.take_front(operation->getNumOperands()),
968 {resultType.getElementType()}, rewriter);
969 if (!opResult) {
970 encounteredError = true;
971 return;
972 }
973 opBuilder.create<linalg::YieldOp>(loc, opResult);
974 });
975 if (encounteredError)
976 return rewriter.notifyMatchFailure(
977 arg&: operation, msg: "unable to create linalg.generic body for elementwise op");
978
979 // Cast 'linalg.generic' result into original result type if needed
980 auto castResult = rewriter.createOrFold<tensor::CastOp>(
981 loc, resultType, linalgOp->getResult(0));
982 rewriter.replaceOp(operation, castResult);
983 return success();
984}
985
986static ValueRange getBroadcastableOperands(Operation *operation,
987 ValueRange operands) {
988 // Shift cannot broadcast
989 if (isa<tosa::MulOp>(operation))
990 return operands.take_front(n: 2);
991 // Input1_zp and output_zp cannot broadcast
992 if (isa<tosa::NegateOp>(operation))
993 return operands.take_front(n: 1);
994 return operands;
995}
996
997static LogicalResult
998elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
999 ConversionPatternRewriter &rewriter,
1000 const TypeConverter &converter) {
1001
1002 // Collect op properties
1003 assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
1004 assert(operation->getNumOperands() >= 1 &&
1005 "elementwise op expects at least 1 operand");
1006 if (!operandsAndResultsRanked(operation))
1007 return rewriter.notifyMatchFailure(arg&: operation,
1008 msg: "Unranked tensors not supported");
1009
1010 // Lower operation
1011 IndexPool indexPool;
1012 auto loc = operation->getLoc();
1013 auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
1014 auto [targetShape, masterOperands] =
1015 computeTargetShape(rewriter, loc, indexPool, operands: operandsToBroadcast);
1016 auto broadcastOperands =
1017 broadcastDynamicDimensions(rewriter, loc, indexPool, operands: operandsToBroadcast,
1018 targetShape, masterOperands);
1019 return emitElementwiseComputation(rewriter, loc, operation, operands: broadcastOperands,
1020 targetShape, converter);
1021}
1022
1023// Returns the constant initial value for a given reduction operation. The
1024// attribute type varies depending on the element type required.
1025static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
1026 PatternRewriter &rewriter) {
1027 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1028 return rewriter.getFloatAttr(elementTy, 0.0);
1029
1030 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1031 return rewriter.getIntegerAttr(elementTy, 0);
1032
1033 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1034 return rewriter.getFloatAttr(elementTy, 1.0);
1035
1036 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1037 return rewriter.getIntegerAttr(elementTy, 1);
1038
1039 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1040 return rewriter.getFloatAttr(
1041 elementTy, APFloat::getLargest(
1042 Sem: cast<FloatType>(elementTy).getFloatSemantics(), Negative: false));
1043
1044 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1045 return rewriter.getIntegerAttr(
1046 elementTy, APInt::getSignedMaxValue(numBits: elementTy.getIntOrFloatBitWidth()));
1047
1048 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1049 return rewriter.getFloatAttr(
1050 elementTy, APFloat::getLargest(
1051 Sem: cast<FloatType>(elementTy).getFloatSemantics(), Negative: true));
1052
1053 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1054 return rewriter.getIntegerAttr(
1055 elementTy, APInt::getSignedMinValue(numBits: elementTy.getIntOrFloatBitWidth()));
1056
1057 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1058 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(numBits: 1));
1059
1060 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1061 return rewriter.getIntegerAttr(elementTy, APInt::getZero(numBits: 1));
1062
1063 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1064 return rewriter.getFloatAttr(
1065 elementTy, APFloat::getLargest(
1066 Sem: cast<FloatType>(elementTy).getFloatSemantics(), Negative: true));
1067
1068 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1069 return rewriter.getIntegerAttr(
1070 elementTy, APInt::getSignedMinValue(numBits: elementTy.getIntOrFloatBitWidth()));
1071
1072 return {};
1073}
1074
1075// Creates the body calculation for a reduction. The operations vary depending
1076// on the input type.
1077static Value createLinalgBodyCalculationForReduceOp(Operation *op,
1078 ValueRange args,
1079 Type elementTy,
1080 PatternRewriter &rewriter) {
1081 Location loc = op->getLoc();
1082 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1083 return rewriter.create<arith::AddFOp>(loc, args);
1084 }
1085
1086 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1087 return rewriter.create<arith::AddIOp>(loc, args);
1088 }
1089
1090 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1091 return rewriter.create<arith::MulFOp>(loc, args);
1092 }
1093
1094 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1095 return rewriter.create<arith::MulIOp>(loc, args);
1096 }
1097
1098 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1099 return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]);
1100 }
1101
1102 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1103 return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]);
1104 }
1105
1106 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1107 return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]);
1108 }
1109
1110 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1111 return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]);
1112 }
1113
1114 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
1115 return rewriter.create<arith::AndIOp>(loc, args);
1116
1117 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
1118 return rewriter.create<arith::OrIOp>(loc, args);
1119
1120 return {};
1121}
1122
1123// Performs the match and rewrite for reduction operations. This includes
1124// declaring a correctly sized initial value, and the linalg.generic operation
1125// that reduces across the specified axis.
1126template <typename OpTy>
1127static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
1128 PatternRewriter &rewriter) {
1129 auto loc = op->getLoc();
1130 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1131 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1132 if (!inputTy || !resultTy)
1133 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
1134
1135 auto elementTy = resultTy.getElementType();
1136 Value input = op->getOperand(0);
1137
1138 SmallVector<int64_t> reduceShape;
1139 SmallVector<Value> dynDims;
1140 for (unsigned i = 0; i < inputTy.getRank(); i++) {
1141 if (axis != i) {
1142 reduceShape.push_back(Elt: inputTy.getDimSize(i));
1143 if (inputTy.isDynamicDim(i))
1144 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1145 }
1146 }
1147
1148 SmallVector<Value> inputs, outputs;
1149 inputs.push_back(Elt: input);
1150
1151 // First fill the output buffer with the init value.
1152 auto emptyTensor =
1153 rewriter
1154 .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1155 dynDims)
1156 .getResult();
1157
1158 auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
1159 if (!fillValueAttr)
1160 return rewriter.notifyMatchFailure(
1161 op, "No initial value found for reduction operation");
1162
1163 auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
1164 auto filledTensor = rewriter
1165 .create<linalg::FillOp>(loc, ValueRange{fillValue},
1166 ValueRange{emptyTensor})
1167 .result();
1168 outputs.push_back(Elt: filledTensor);
1169
1170 bool isNanIgnoreMode = false;
1171 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1172 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1173 // NaN propagation has no meaning for non floating point types.
1174 if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") {
1175 isNanIgnoreMode = true;
1176 // Because the TOSA spec requires the result be NaN iff all elements in
1177 // the reduction are NaN we can't simply perform a compare and select.
1178 // Additionally we have to keep track of whether we've seen any non-NaN
1179 // values and then do a final select based on this predicate.
1180 auto trueAttr = rewriter.getBoolAttr(value: true);
1181 auto trueValue = rewriter.create<arith::ConstantOp>(loc, trueAttr);
1182 auto emptyBoolTensor =
1183 rewriter
1184 .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
1185 dynDims)
1186 .getResult();
1187 auto allResultsNaNTensor =
1188 rewriter
1189 .create<linalg::FillOp>(loc, ValueRange{trueValue},
1190 ValueRange{emptyBoolTensor})
1191 .result();
1192 // Note that because the linalg::ReduceOp has two variadic arguments
1193 // (inputs and outputs) and it has the SameVariadicOperandSize trait we
1194 // need to have the same number of inputs and outputs.
1195 //
1196 // The second input isn't actually used anywhere since the value used to
1197 // update the NaN flag is calculated inside the body of the reduction and
1198 // then used to update an out value.
1199 // In order to satisfy type constraints we just pass another copy of the
1200 // input here.
1201 inputs.push_back(Elt: input);
1202 outputs.push_back(Elt: allResultsNaNTensor);
1203 }
1204 }
1205
1206 bool didEncounterError = false;
1207 linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>(
1208 loc, inputs, outputs, axis,
1209 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
1210 std::array<Value, 2> binaryArgs{
1211 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1212 auto result = createLinalgBodyCalculationForReduceOp(
1213 op, binaryArgs, elementTy, rewriter);
1214 if (result)
1215 didEncounterError = true;
1216
1217 SmallVector<Value> resultsToYield;
1218 if (isNanIgnoreMode) {
1219 auto inputValue = blockArgs[0];
1220 auto initialValue = blockArgs[2];
1221 auto oldAllResultsNanFlagValue = blockArgs[3];
1222
1223 // Unordered comparison of NaN against itself will always return true.
1224 Value isNaN = nestedBuilder.create<arith::CmpFOp>(
1225 op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
1226 // If we've encountered a NaN, take the non-NaN value.
1227 auto selectOp = nestedBuilder.create<arith::SelectOp>(
1228 op->getLoc(), isNaN, initialValue, result);
1229 // Update the flag which keeps track of whether we have seen a non-NaN
1230 // value.
1231 auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
1232 op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1233 resultsToYield.push_back(selectOp);
1234 resultsToYield.push_back(newAllResultsNanFlagValue);
1235 } else {
1236 resultsToYield.push_back(result);
1237 }
1238 nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
1239 });
1240
1241 if (!didEncounterError)
1242 return rewriter.notifyMatchFailure(
1243 op, "unable to create linalg.generic body for reduce op");
1244
1245 if (isNanIgnoreMode) {
1246 // Materialize a check to see whether we encountered any non-NaN values, if
1247 // we didn't we need to select a tensor of NaNs since the result will just
1248 // be the initial identity value propagated through all the compares and
1249 // selects inside the reduction.
1250
1251 // Create a tensor full of NaNs.
1252 auto nanValueAttr = rewriter.getFloatAttr(
1253 elementTy,
1254 APFloat::getNaN(Sem: cast<FloatType>(elementTy).getFloatSemantics(), Negative: false));
1255 auto nanValue = rewriter.create<arith::ConstantOp>(loc, nanValueAttr);
1256 auto emptyNanTensor =
1257 rewriter
1258 .create<tensor::EmptyOp>(loc, reduceShape,
1259 resultTy.getElementType(), dynDims)
1260 .getResult();
1261 auto nanFilledTensor =
1262 rewriter
1263 .create<linalg::FillOp>(loc, ValueRange{nanValue},
1264 ValueRange{emptyNanTensor})
1265 .result();
1266
1267 // Create an empty tensor, non need to fill this since it will be
1268 // overwritten by the select.
1269 auto finalEmptyTensor =
1270 rewriter
1271 .create<tensor::EmptyOp>(loc, reduceShape,
1272 resultTy.getElementType(), dynDims)
1273 .getResult();
1274
1275 // Do a selection between the tensors akin to:
1276 // result = NaN if "all results NaN" else result.
1277 SmallVector<Value> ins, outs;
1278 ins.push_back(Elt: linalgOp->getOpResult(1));
1279 ins.push_back(Elt: nanFilledTensor);
1280 ins.push_back(Elt: linalgOp->getResult(0));
1281 outs.push_back(Elt: finalEmptyTensor);
1282 auto linalgSelect =
1283 rewriter.create<linalg::SelectOp>(op->getLoc(), ins, outs);
1284 linalgOp = linalgSelect;
1285 }
1286
1287 SmallVector<ReassociationExprs, 4> reassociationMap;
1288 uint64_t expandInputRank =
1289 cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
1290 reassociationMap.resize(N: expandInputRank);
1291
1292 for (uint64_t i = 0; i < expandInputRank; i++) {
1293 int32_t dimToPush = i > axis ? i + 1 : i;
1294 reassociationMap[i].push_back(Elt: rewriter.getAffineDimExpr(position: dimToPush));
1295 }
1296
1297 if (expandInputRank != 0) {
1298 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1299 reassociationMap[expandedDim].push_back(
1300 Elt: rewriter.getAffineDimExpr(position: expandedDim + 1));
1301 }
1302
1303 // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`,
1304 // since here we know which dimension to expand, and `tosa::ReshapeOp` would
1305 // not have access to such information. This matters when handling dynamically
1306 // sized tensors.
1307 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1308 op, resultTy, linalgOp->getResults()[0], reassociationMap);
1309 return success();
1310}
1311
1312namespace {
1313
1314template <typename SrcOp>
1315class PointwiseConverter : public OpConversionPattern<SrcOp> {
1316public:
1317 using OpConversionPattern<SrcOp>::OpConversionPattern;
1318 using typename OpConversionPattern<SrcOp>::OpAdaptor;
1319
1320 LogicalResult
1321 matchAndRewrite(SrcOp op, OpAdaptor operands,
1322 ConversionPatternRewriter &rewriter) const final {
1323 return elementwiseMatchAndRewriteHelper(
1324 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1325 }
1326};
1327
1328class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1329public:
1330 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1331
1332 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1333 PatternRewriter &rewriter) const final {
1334 auto loc = op.getLoc();
1335 auto input = op.getInput();
1336 auto inputTy = cast<ShapedType>(op.getInput().getType());
1337 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1338 unsigned rank = inputTy.getRank();
1339
1340 // This is an illegal configuration. terminate and log an error
1341 if (op.getRoundingMode() == "INEXACT_ROUND")
1342 return rewriter.notifyMatchFailure(
1343 op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1344 "currently supported");
1345 if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
1346 return rewriter.notifyMatchFailure(
1347 op, "tosa.rescale requires scale32 for double_round to be true");
1348
1349 if (!isa<IntegerType>(inputTy.getElementType()))
1350 return rewriter.notifyMatchFailure(op, "only support integer type");
1351
1352 SmallVector<Value> dynDims;
1353 for (int i = 0; i < outputTy.getRank(); i++) {
1354 if (outputTy.isDynamicDim(i)) {
1355 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1356 }
1357 }
1358
1359 // The shift and multiplier values.
1360 DenseElementsAttr shiftElems;
1361 if (!matchPattern(op.getShift(), m_Constant(bind_value: &shiftElems)))
1362 return rewriter.notifyMatchFailure(
1363 op, "tosa.rescale requires constant shift input values");
1364
1365 DenseElementsAttr multiplierElems;
1366 if (!matchPattern(op.getMultiplier(), m_Constant(bind_value: &multiplierElems)))
1367 return rewriter.notifyMatchFailure(
1368 op, "tosa.rescale requires constant multiplier input values");
1369
1370 llvm::SmallVector<int8_t> shiftValues =
1371 llvm::to_vector(shiftElems.getValues<int8_t>());
1372 // explicit cast is required here
1373 llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
1374 llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1375 [](IntegerAttr attr) -> int32_t {
1376 return static_cast<int32_t>(attr.getInt());
1377 }));
1378
1379 // If we shift by more than the bitwidth, this just sets to 0.
1380 for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1381 if (shiftValues[i] > 63) {
1382 shiftValues[i] = 0;
1383 multiplierValues[i] = 0;
1384 }
1385 }
1386
1387 // Double round only occurs if shift is greater than 31, check that this
1388 // is ever true.
1389
1390 bool doubleRound =
1391 op.getRoundingMode() == "DOUBLE_ROUND" &&
1392 llvm::any_of(Range&: shiftValues, P: [](int32_t v) { return v > 31; });
1393 StringAttr roundingMode = doubleRound
1394 ? rewriter.getStringAttr("DOUBLE_ROUND")
1395 : rewriter.getStringAttr("SINGLE_ROUND");
1396
1397 SmallVector<AffineMap> indexingMaps = {
1398 rewriter.getMultiDimIdentityMap(rank)};
1399 SmallVector<Value, 4> genericInputs = {input};
1400
1401 // If we are rescaling per-channel then we need to store the multiplier
1402 // values in a buffer.
1403 Value multiplierConstant;
1404 int64_t multiplierArg = 0;
1405 if (multiplierValues.size() == 1) {
1406 multiplierConstant = rewriter.create<arith::ConstantOp>(
1407 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1408 } else {
1409 SmallVector<AffineExpr, 2> multiplierExprs{
1410 rewriter.getAffineDimExpr(position: rank - 1)};
1411 auto multiplierType =
1412 RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1413 rewriter.getI32Type());
1414 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1415 loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1416
1417 indexingMaps.push_back(Elt: AffineMap::get(/*dimCount=*/rank,
1418 /*symbolCount=*/0, results: multiplierExprs,
1419 context: rewriter.getContext()));
1420
1421 multiplierArg = indexingMaps.size() - 1;
1422 }
1423
1424 // If we are rescaling per-channel then we need to store the shift
1425 // values in a buffer.
1426 Value shiftConstant;
1427 int64_t shiftArg = 0;
1428 if (shiftValues.size() == 1) {
1429 shiftConstant = rewriter.create<arith::ConstantOp>(
1430 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1431 } else {
1432 SmallVector<AffineExpr, 2> shiftExprs = {
1433 rewriter.getAffineDimExpr(position: rank - 1)};
1434 auto shiftType =
1435 RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1436 rewriter.getIntegerType(8));
1437 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1438 loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1439 indexingMaps.push_back(Elt: AffineMap::get(/*dimCount=*/rank,
1440 /*symbolCount=*/0, results: shiftExprs,
1441 context: rewriter.getContext()));
1442 shiftArg = indexingMaps.size() - 1;
1443 }
1444
1445 // Indexing maps for output values.
1446 indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank));
1447
1448 // Construct the indexing maps needed for linalg.generic ops.
1449 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1450 loc, outputTy.getShape(), outputTy.getElementType(),
1451 ArrayRef<Value>({dynDims}));
1452
1453 auto linalgOp = rewriter.create<linalg::GenericOp>(
1454 loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps,
1455 getNParallelLoopsAttrs(rank),
1456 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1457 ValueRange blockArgs) {
1458 Value value = blockArgs[0];
1459 Type valueTy = value.getType();
1460
1461 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1462 if (failed(maybeIZp)) {
1463 (void)rewriter.notifyMatchFailure(
1464 op, "input zero point cannot be statically determined");
1465 return;
1466 }
1467
1468 const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
1469 // Extend zeropoint for sub-32bits widths.
1470 const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1471 auto inputZp = nestedBuilder.create<arith::ConstantOp>(
1472 loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
1473 *maybeIZp));
1474
1475 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1476 if (failed(maybeOZp)) {
1477 (void)rewriter.notifyMatchFailure(
1478 op, "output zero point cannot be statically determined");
1479 return;
1480 };
1481
1482 IntegerType outIntType =
1483 cast<IntegerType>(blockArgs.back().getType());
1484 unsigned outBitWidth = outIntType.getWidth();
1485 const int32_t outAttrBitwidth = 32;
1486 assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1487 auto outputZp = nestedBuilder.create<arith::ConstantOp>(
1488 loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
1489 *maybeOZp));
1490
1491 Value multiplier = multiplierConstant ? multiplierConstant
1492 : blockArgs[multiplierArg];
1493 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1494
1495 if (valueTy.isUnsignedInteger()) {
1496 value = nestedBuilder
1497 .create<UnrealizedConversionCastOp>(
1498 nestedLoc,
1499 nestedBuilder.getIntegerType(
1500 valueTy.getIntOrFloatBitWidth()),
1501 value)
1502 .getResult(0);
1503 }
1504 if (valueTy.getIntOrFloatBitWidth() < 32) {
1505 if (op.getInputUnsigned()) {
1506 value = nestedBuilder.create<arith::ExtUIOp>(
1507 nestedLoc, nestedBuilder.getI32Type(), value);
1508 } else {
1509 value = nestedBuilder.create<arith::ExtSIOp>(
1510 nestedLoc, nestedBuilder.getI32Type(), value);
1511 }
1512 }
1513
1514 value =
1515 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1516
1517 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1518 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1519 roundingMode);
1520
1521 // Move to the new zero-point.
1522 value =
1523 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1524
1525 // Saturate to the output size.
1526 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1527 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1528
1529 // Unsigned integers have a difference output value.
1530 if (op.getOutputUnsigned()) {
1531 intMin = 0;
1532 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1533 }
1534
1535 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1536 loc, nestedBuilder.getI32IntegerAttr(intMin));
1537 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1538 loc, nestedBuilder.getI32IntegerAttr(intMax));
1539
1540 value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1541 nestedBuilder, /*isUnsigned=*/false);
1542
1543 if (outIntType.getWidth() < 32) {
1544 value = nestedBuilder.create<arith::TruncIOp>(
1545 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1546 value);
1547 }
1548
1549 if (outIntType.isUnsignedInteger()) {
1550 value = nestedBuilder
1551 .create<UnrealizedConversionCastOp>(nestedLoc,
1552 outIntType, value)
1553 .getResult(0);
1554 }
1555 nestedBuilder.create<linalg::YieldOp>(loc, value);
1556 });
1557
1558 rewriter.replaceOp(op, linalgOp->getResults());
1559 return success();
1560 }
1561};
1562
1563// Handle the resize case where the input is a 1x1 image. This case
1564// can entirely avoiding having extract operations which target much
1565// more difficult to optimize away.
1566class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
1567public:
1568 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1569
1570 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1571 PatternRewriter &rewriter) const final {
1572 Location loc = op.getLoc();
1573 ImplicitLocOpBuilder builder(loc, rewriter);
1574 auto input = op.getInput();
1575 auto inputTy = cast<RankedTensorType>(input.getType());
1576 auto resultTy = cast<RankedTensorType>(op.getType());
1577 const bool isBilinear = op.getMode() == "BILINEAR";
1578
1579 auto inputH = inputTy.getDimSize(1);
1580 auto inputW = inputTy.getDimSize(2);
1581 auto outputH = resultTy.getDimSize(1);
1582 auto outputW = resultTy.getDimSize(2);
1583
1584 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1585 return rewriter.notifyMatchFailure(
1586 op, "tosa.resize is not a pure 1x1->1x1 image operation");
1587
1588 // TODO(suderman): These string values should be declared the TOSA dialect.
1589 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1590 return rewriter.notifyMatchFailure(
1591 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1592
1593 if (inputTy == resultTy) {
1594 rewriter.replaceOp(op, input);
1595 return success();
1596 }
1597
1598 SmallVector<int64_t> scale;
1599 if (!tosa::getConstShapeValues(op: op.getScale().getDefiningOp(), result_shape&: scale)) {
1600 return failure();
1601 }
1602
1603 // Collapse the unit width and height away.
1604 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1605 reassociationMap[0].push_back(Elt: builder.getAffineDimExpr(position: 0));
1606 reassociationMap[1].push_back(Elt: builder.getAffineDimExpr(position: 1));
1607 reassociationMap[1].push_back(Elt: builder.getAffineDimExpr(position: 2));
1608 reassociationMap[1].push_back(Elt: builder.getAffineDimExpr(position: 3));
1609
1610 auto collapseTy =
1611 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1612 inputTy.getElementType());
1613 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1614 reassociationMap);
1615
1616 // Get any dynamic shapes that appear in the input format.
1617 llvm::SmallVector<Value> outputDynSize;
1618 if (inputTy.isDynamicDim(0))
1619 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1620 if (inputTy.isDynamicDim(3))
1621 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1622
1623 // Generate the elementwise operation for casting scaling the input value.
1624 auto genericTy = collapseTy.clone(resultTy.getElementType());
1625 Value empty = builder.create<tensor::EmptyOp>(
1626 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1627 auto genericMap = rewriter.getMultiDimIdentityMap(rank: genericTy.getRank());
1628 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1629 utils::IteratorType::parallel);
1630
1631 auto generic = builder.create<linalg::GenericOp>(
1632 genericTy, ValueRange{collapse}, ValueRange{empty},
1633 ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1634 [=](OpBuilder &b, Location loc, ValueRange args) {
1635 Value value = args[0];
1636 // This is the quantized case.
1637 if (inputTy.getElementType() != resultTy.getElementType()) {
1638 value =
1639 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1640
1641 if (isBilinear && scale[0] != 0) {
1642 Value scaleY = b.create<arith::ConstantOp>(
1643 loc, b.getI32IntegerAttr(scale[0]));
1644 value = b.create<arith::MulIOp>(loc, value, scaleY);
1645 }
1646
1647 if (isBilinear && scale[2] != 0) {
1648 Value scaleX = b.create<arith::ConstantOp>(
1649 loc, b.getI32IntegerAttr(scale[2]));
1650 value = b.create<arith::MulIOp>(loc, value, scaleX);
1651 }
1652 }
1653
1654 b.create<linalg::YieldOp>(loc, value);
1655 });
1656
1657 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1658 op, resultTy, generic.getResults()[0], reassociationMap);
1659 return success();
1660 }
1661};
1662
1663// TOSA resize with width or height of 1 may be broadcasted to a wider
1664// dimension. This is done by materializing a new tosa.resize without
1665// the broadcasting behavior, and an explicit broadcast afterwards.
1666class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> {
1667public:
1668 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1669
1670 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1671 PatternRewriter &rewriter) const final {
1672 Location loc = op.getLoc();
1673 ImplicitLocOpBuilder builder(loc, rewriter);
1674 auto input = op.getInput();
1675 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1676 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1677
1678 if (!inputTy || !resultTy)
1679 return rewriter.notifyMatchFailure(op,
1680 "requires ranked input/output types");
1681
1682 auto batch = inputTy.getDimSize(0);
1683 auto channels = inputTy.getDimSize(3);
1684 auto inputH = inputTy.getDimSize(1);
1685 auto inputW = inputTy.getDimSize(2);
1686 auto outputH = resultTy.getDimSize(1);
1687 auto outputW = resultTy.getDimSize(2);
1688
1689 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1690 return rewriter.notifyMatchFailure(
1691 op, "tosa.resize has no broadcasting behavior");
1692
1693 // For any dimension that is broadcastable we generate a width of 1
1694 // on the output.
1695 llvm::SmallVector<int64_t> resizeShape;
1696 resizeShape.push_back(Elt: batch);
1697 resizeShape.push_back(Elt: inputH == 1 ? 1 : outputH);
1698 resizeShape.push_back(Elt: inputW == 1 ? 1 : outputW);
1699 resizeShape.push_back(Elt: channels);
1700
1701 auto resizeTy = resultTy.clone(resizeShape);
1702 auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
1703 op.getOffset(), op.getBorder(),
1704 op.getMode());
1705
1706 // Collapse an unit result dims.
1707 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1708 reassociationMap[0].push_back(Elt: builder.getAffineDimExpr(position: 0));
1709 reassociationMap.back().push_back(Elt: builder.getAffineDimExpr(position: 1));
1710 if (inputH != 1)
1711 reassociationMap.push_back(Elt: {});
1712 reassociationMap.back().push_back(Elt: builder.getAffineDimExpr(position: 2));
1713 if (inputW != 1)
1714 reassociationMap.push_back(Elt: {});
1715 reassociationMap.back().push_back(Elt: builder.getAffineDimExpr(position: 3));
1716
1717 llvm::SmallVector<int64_t> collapseShape = {batch};
1718 if (inputH != 1)
1719 collapseShape.push_back(Elt: outputH);
1720 if (inputW != 1)
1721 collapseShape.push_back(Elt: outputW);
1722 collapseShape.push_back(Elt: channels);
1723
1724 auto collapseTy = resultTy.clone(collapseShape);
1725 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1726 reassociationMap);
1727
1728 // Broadcast the collapsed shape to the output result.
1729 llvm::SmallVector<Value> outputDynSize;
1730 if (inputTy.isDynamicDim(0))
1731 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1732 if (inputTy.isDynamicDim(3))
1733 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1734
1735 SmallVector<utils::IteratorType> iterators(resultTy.getRank(),
1736 utils::IteratorType::parallel);
1737 Value empty = builder.create<tensor::EmptyOp>(
1738 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1739
1740 SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(position: 0)};
1741 if (inputH != 1)
1742 inputExprs.push_back(Elt: rewriter.getAffineDimExpr(position: 1));
1743 if (inputW != 1)
1744 inputExprs.push_back(Elt: rewriter.getAffineDimExpr(position: 2));
1745 inputExprs.push_back(Elt: rewriter.getAffineDimExpr(position: 3));
1746
1747 auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0,
1748 inputExprs, rewriter.getContext());
1749
1750 auto outputMap = rewriter.getMultiDimIdentityMap(rank: resultTy.getRank());
1751 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1752 op, resultTy, ValueRange{collapse}, ValueRange{empty},
1753 ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1754 [=](OpBuilder &b, Location loc, ValueRange args) {
1755 Value value = args[0];
1756 b.create<linalg::YieldOp>(loc, value);
1757 });
1758
1759 return success();
1760 }
1761};
1762
1763class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1764public:
1765 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1766
1767 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1768 PatternRewriter &rewriter) const final {
1769 Location loc = op.getLoc();
1770 ImplicitLocOpBuilder b(loc, rewriter);
1771 auto input = op.getInput();
1772 auto inputTy = cast<ShapedType>(input.getType());
1773 auto resultTy = cast<ShapedType>(op.getType());
1774 auto resultETy = resultTy.getElementType();
1775
1776 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1777 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1778
1779 auto imageH = inputTy.getShape()[1];
1780 auto imageW = inputTy.getShape()[2];
1781
1782 auto dynamicDimsOr =
1783 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1784 if (!dynamicDimsOr.has_value())
1785 return rewriter.notifyMatchFailure(
1786 op, "unable to get dynamic dimensions of tosa.resize");
1787
1788 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1789 return rewriter.notifyMatchFailure(
1790 op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1791
1792 SmallVector<AffineMap, 2> affineMaps = {
1793 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())};
1794 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1795 *dynamicDimsOr);
1796 auto genericOp = b.create<linalg::GenericOp>(
1797 resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps,
1798 getNParallelLoopsAttrs(resultTy.getRank()));
1799 Value resize = genericOp.getResult(0);
1800
1801 {
1802 OpBuilder::InsertionGuard regionGuard(b);
1803 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1804 TypeRange({resultETy}), loc);
1805 Value batch = b.create<linalg::IndexOp>(0);
1806 Value y = b.create<linalg::IndexOp>(1);
1807 Value x = b.create<linalg::IndexOp>(2);
1808 Value channel = b.create<linalg::IndexOp>(3);
1809
1810 Value zeroI32 =
1811 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1812 Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1813 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1814 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1815
1816 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1817 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1818
1819 SmallVector<int64_t> scale, offset, border;
1820 if (!tosa::getConstShapeValues(op: op.getScale().getDefiningOp(), result_shape&: scale) ||
1821 !tosa::getConstShapeValues(op: op.getOffset().getDefiningOp(), result_shape&: offset) ||
1822 !tosa::getConstShapeValues(op: op.getBorder().getDefiningOp(), result_shape&: border)) {
1823 return rewriter.notifyMatchFailure(
1824 op, "tosa.resize scale/offset/border should have compile time "
1825 "constant values.");
1826 }
1827
1828 Value yScaleN, yScaleD, xScaleN, xScaleD;
1829 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1830 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1831 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1832 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1833
1834 Value yOffset, xOffset, yBorder, xBorder;
1835 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1836 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1837 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1838 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1839
1840 // Compute the ix and dx values for both the X and Y dimensions.
1841 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
1842 Value scaleN, Value scaleD, Value offset,
1843 int size, ImplicitLocOpBuilder &b) {
1844 if (size == 1) {
1845 index = zeroI32;
1846 delta = zeroFp;
1847 return;
1848 }
1849 // x = x * scale_d + offset;
1850 // ix = floor(x / scale_n)
1851 Value val = b.create<arith::MulIOp>(in, scaleD);
1852 val = b.create<arith::AddIOp>(val, offset);
1853 index = b.create<arith::FloorDivSIOp>(val, scaleN);
1854
1855 // rx = x % scale_n
1856 // dx = rx / scale_n
1857 Value r = b.create<arith::RemSIOp>(val, scaleN);
1858 Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1859 Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1860 delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1861 };
1862
1863 // Compute the ix and dx values for the X and Y dimensions - int case.
1864 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
1865 Value scaleN, Value scaleD, Value offset,
1866 int size, ImplicitLocOpBuilder &b) {
1867 if (size == 1) {
1868 index = zeroI32;
1869 delta = zeroI32;
1870 return;
1871 }
1872 // x = x * scale_d + offset;
1873 // ix = floor(x / scale_n)
1874 // dx = x - ix * scale_n;
1875 Value val = b.create<arith::MulIOp>(in, scaleD);
1876 val = b.create<arith::AddIOp>(val, offset);
1877 index = b.create<arith::DivSIOp>(val, scaleN);
1878 delta = b.create<arith::MulIOp>(index, scaleN);
1879 delta = b.create<arith::SubIOp>(val, delta);
1880 };
1881
1882 Value ix, iy, dx, dy;
1883 if (floatingPointMode) {
1884 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1885 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1886 } else {
1887 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1888 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1889 }
1890
1891 if (op.getMode() == "NEAREST_NEIGHBOR") {
1892 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1893
1894 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
1895 Value max, int size,
1896 ImplicitLocOpBuilder &b) -> Value {
1897 if (size == 1) {
1898 return b.create<arith::ConstantIndexOp>(0);
1899 }
1900
1901 Value pred;
1902 if (floatingPointMode) {
1903 auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1904 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1905 } else {
1906 Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1907 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1908 dvalDouble, scale);
1909 }
1910
1911 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1912 val = b.create<arith::AddIOp>(val, offset);
1913 val = clampIntHelper(loc, arg: val, min: zeroI32, max, rewriter&: b, /*isUnsigned=*/false);
1914 return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1915 };
1916
1917 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1918 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1919
1920 Value result = b.create<tensor::ExtractOp>(
1921 input, ValueRange{batch, iy, ix, channel});
1922
1923 b.create<linalg::YieldOp>(result);
1924 } else {
1925 // The mode here must be BILINEAR.
1926 assert(op.getMode() == "BILINEAR");
1927
1928 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1929
1930 auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in,
1931 Value max, ImplicitLocOpBuilder &b) {
1932 val0 = in;
1933 val1 = b.create<arith::AddIOp>(val0, oneVal);
1934 val0 =
1935 clampIntHelper(loc, arg: val0, min: zeroI32, max, rewriter&: b, /*isUnsigned=*/false);
1936 val1 =
1937 clampIntHelper(loc, arg: val1, min: zeroI32, max, rewriter&: b, /*isUnsigned=*/false);
1938 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1939 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1940 };
1941
1942 // Linalg equivalent to the section below:
1943 // int16_t iy0 = apply_max(iy, 0);
1944 // int16_t iy1 = apply_min(iy + 1, IH - 1);
1945 // int16_t ix0 = apply_max(ix, 0);
1946 // int16_t ix1 = apply_min(ix + 1, IW - 1);
1947 Value x0, x1, y0, y1;
1948 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1949 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1950
1951 Value y0x0 = b.create<tensor::ExtractOp>(
1952 input, ValueRange{batch, y0, x0, channel});
1953 Value y0x1 = b.create<tensor::ExtractOp>(
1954 input, ValueRange{batch, y0, x1, channel});
1955 Value y1x0 = b.create<tensor::ExtractOp>(
1956 input, ValueRange{batch, y1, x0, channel});
1957 Value y1x1 = b.create<tensor::ExtractOp>(
1958 input, ValueRange{batch, y1, x1, channel});
1959
1960 if (floatingPointMode) {
1961 auto oneVal =
1962 b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1963 auto interpolate = [&](Value val0, Value val1, Value delta,
1964 int inputSize,
1965 ImplicitLocOpBuilder &b) -> Value {
1966 if (inputSize == 1)
1967 return val0;
1968 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1969 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1970 Value mul1 = b.create<arith::MulFOp>(val1, delta);
1971 return b.create<arith::AddFOp>(mul0, mul1);
1972 };
1973
1974 // Linalg equivalent to the section below:
1975 // topAcc = v00 * (unit_x - dx);
1976 // topAcc += v01 * dx;
1977 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1978
1979 // Linalg equivalent to the section below:
1980 // bottomAcc = v10 * (unit_x - dx);
1981 // bottomAcc += v11 * dx;
1982 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1983
1984 // Linalg equivalent to the section below:
1985 // result = topAcc * (unit_y - dy) + bottomAcc * dy
1986 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1987 b.create<linalg::YieldOp>(result);
1988 } else {
1989 // Perform in quantized space.
1990 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1991 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1992 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1993 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1994
1995 const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth();
1996 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1997 dx = b.create<arith::ExtSIOp>(resultETy, dx);
1998 dy = b.create<arith::ExtSIOp>(resultETy, dy);
1999 }
2000
2001 Value yScaleNExt = yScaleN;
2002 Value xScaleNExt = xScaleN;
2003
2004 const int64_t scaleBitwidth =
2005 xScaleN.getType().getIntOrFloatBitWidth();
2006 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2007 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
2008 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
2009 }
2010
2011 auto interpolate = [](Value val0, Value val1, Value weight1,
2012 Value scale, int inputSize,
2013 ImplicitLocOpBuilder &b) -> Value {
2014 if (inputSize == 1)
2015 return b.create<arith::MulIOp>(val0, scale);
2016 Value weight0 = b.create<arith::SubIOp>(scale, weight1);
2017 Value mul0 = b.create<arith::MulIOp>(val0, weight0);
2018 Value mul1 = b.create<arith::MulIOp>(val1, weight1);
2019 return b.create<arith::AddIOp>(mul0, mul1);
2020 };
2021
2022 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2023 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2024 Value result =
2025 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2026 b.create<linalg::YieldOp>(result);
2027 }
2028 }
2029 }
2030
2031 rewriter.replaceOp(op, resize);
2032 return success();
2033 }
2034};
2035
2036// At the codegen level any identity operations should be removed. Any cases
2037// where identity is load-bearing (e.g. cross device computation) should be
2038// handled before lowering to codegen.
2039template <typename SrcOp>
2040class IdentityNConverter : public OpRewritePattern<SrcOp> {
2041public:
2042 using OpRewritePattern<SrcOp>::OpRewritePattern;
2043
2044 LogicalResult matchAndRewrite(SrcOp op,
2045 PatternRewriter &rewriter) const final {
2046 rewriter.replaceOp(op, op.getOperation()->getOperands());
2047 return success();
2048 }
2049};
2050
2051template <typename SrcOp>
2052class ReduceConverter : public OpRewritePattern<SrcOp> {
2053public:
2054 using OpRewritePattern<SrcOp>::OpRewritePattern;
2055
2056 LogicalResult matchAndRewrite(SrcOp reduceOp,
2057 PatternRewriter &rewriter) const final {
2058 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
2059 }
2060};
2061
2062class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
2063public:
2064 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2065
2066 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2067 PatternRewriter &rewriter) const final {
2068 auto loc = op.getLoc();
2069 Value input = op.getInput1();
2070 auto inputTy = cast<ShapedType>(input.getType());
2071 auto resultTy = cast<ShapedType>(op.getType());
2072 auto axis = op.getAxis();
2073
2074 SmallVector<Value> dynDims;
2075 for (int i = 0; i < inputTy.getRank(); i++) {
2076 if (inputTy.isDynamicDim(i)) {
2077 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2078 }
2079 }
2080
2081 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
2082
2083 // First fill the output buffer with the init value.
2084 auto emptyTensor = rewriter
2085 .create<tensor::EmptyOp>(loc, inputTy.getShape(),
2086 inputTy.getElementType(),
2087 ArrayRef<Value>({dynDims}))
2088 .getResult();
2089 SmallVector<AffineMap, 2> affineMaps = {
2090 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())};
2091
2092 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2093 op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
2094 getNParallelLoopsAttrs(resultTy.getRank()),
2095 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2096 llvm::SmallVector<Value> indices;
2097 for (unsigned int i = 0; i < inputTy.getRank(); i++) {
2098 Value index =
2099 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
2100 if (i == axis) {
2101 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
2102 auto sizeMinusOne =
2103 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
2104 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
2105 index);
2106 }
2107
2108 indices.push_back(index);
2109 }
2110
2111 auto extract = nestedBuilder.create<tensor::ExtractOp>(
2112 nestedLoc, input, indices);
2113 nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
2114 extract.getResult());
2115 });
2116 return success();
2117 }
2118};
2119
2120// This converter translate a tile operation to a reshape, broadcast, reshape.
2121// The first reshape minimally expands each tiled dimension to include a
2122// proceding size-1 dim. This dim is then broadcasted to the appropriate
2123// multiple.
2124struct TileConverter : public OpConversionPattern<tosa::TileOp> {
2125 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2126
2127 LogicalResult
2128 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2129 ConversionPatternRewriter &rewriter) const override {
2130 auto loc = op.getLoc();
2131 auto input = op.getInput1();
2132 auto inputTy = cast<ShapedType>(input.getType());
2133 auto inputShape = inputTy.getShape();
2134 auto resultTy = cast<ShapedType>(op.getType());
2135 auto elementTy = inputTy.getElementType();
2136 int64_t rank = inputTy.getRank();
2137
2138 SmallVector<int64_t> multiples;
2139 if (failed(op.getConstantMultiples(multiples)))
2140 return failure();
2141
2142 // Broadcast the newly added dimensions to their appropriate multiple.
2143 SmallVector<int64_t, 2> genericShape;
2144 for (int i = 0; i < rank; i++) {
2145 int64_t dim = multiples[i];
2146 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2147 genericShape.push_back(Elt: inputShape[i]);
2148 }
2149
2150 SmallVector<Value> dynDims;
2151 for (int i = 0; i < inputTy.getRank(); i++) {
2152 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2153 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2154 }
2155 }
2156
2157 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
2158 op.getLoc(), genericShape, elementTy, dynDims);
2159
2160 // We needs to map the input shape to the non-broadcasted dimensions.
2161 SmallVector<AffineExpr, 4> dimExprs;
2162 dimExprs.reserve(N: rank);
2163 for (unsigned i = 0; i < rank; ++i)
2164 dimExprs.push_back(Elt: rewriter.getAffineDimExpr(position: i * 2 + 1));
2165
2166 auto readAffineMap =
2167 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, results: dimExprs,
2168 context: rewriter.getContext());
2169
2170 SmallVector<AffineMap, 2> affineMaps = {
2171 readAffineMap, rewriter.getMultiDimIdentityMap(rank: genericShape.size())};
2172
2173 auto genericOp = rewriter.create<linalg::GenericOp>(
2174 loc, RankedTensorType::get(genericShape, elementTy), input,
2175 ValueRange{emptyTensor}, affineMaps,
2176 getNParallelLoopsAttrs(genericShape.size()),
2177 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
2178 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
2179 });
2180
2181 auto shapeValue = getTosaConstShape(
2182 rewriter, loc, mlir::tosa::convertFromMlirShape(shape: resultTy.getShape()));
2183 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2184 op, resultTy, genericOp.getResult(0), shapeValue);
2185 return success();
2186 }
2187};
2188
2189// Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
2190// op, producing two output buffers.
2191//
2192// The first output buffer contains the index of the found maximum value. It is
2193// initialized to 0 and is resulting integer type.
2194//
2195// The second output buffer contains the maximum value found. It is initialized
2196// to the minimum representable value of the input element type. After being
2197// populated by indexed_generic, this buffer is disgarded as only the index is
2198// requested.
2199//
2200// The indexed_generic op updates both the maximum value and index if the
2201// current value exceeds the running max.
2202class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
2203public:
2204 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2205
2206 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2207 PatternRewriter &rewriter) const final {
2208 auto loc = argmaxOp.getLoc();
2209 Value input = argmaxOp.getInput();
2210 auto inputTy = cast<ShapedType>(input.getType());
2211 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2212 auto inElementTy = inputTy.getElementType();
2213 auto outElementTy = resultTy.getElementType();
2214 int axis = argmaxOp.getAxis();
2215 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2216
2217 if (!isa<IntegerType>(outElementTy))
2218 return rewriter.notifyMatchFailure(
2219 argmaxOp,
2220 "tosa.arg_max to linalg.* requires integer-like result type");
2221
2222 SmallVector<Value> dynDims;
2223 for (int i = 0; i < inputTy.getRank(); i++) {
2224 if (inputTy.isDynamicDim(i) && i != axis) {
2225 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2226 }
2227 }
2228
2229 // First fill the output buffer for the index.
2230 auto emptyTensorIdx = rewriter
2231 .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2232 outElementTy, dynDims)
2233 .getResult();
2234 auto fillValueIdx = rewriter.create<arith::ConstantOp>(
2235 loc, rewriter.getIntegerAttr(outElementTy, 0));
2236 auto filledTensorIdx =
2237 rewriter
2238 .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
2239 ValueRange{emptyTensorIdx})
2240 .result();
2241
2242 // Second fill the output buffer for the running max.
2243 auto emptyTensorMax = rewriter
2244 .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2245 inElementTy, dynDims)
2246 .getResult();
2247 auto fillValueMaxAttr =
2248 createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
2249
2250 if (!fillValueMaxAttr)
2251 return rewriter.notifyMatchFailure(
2252 argmaxOp, "unsupported tosa.argmax element type");
2253
2254 auto fillValueMax =
2255 rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
2256 auto filledTensorMax =
2257 rewriter
2258 .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
2259 ValueRange{emptyTensorMax})
2260 .result();
2261
2262 // We need to reduce along the arg-max axis, with parallel operations along
2263 // the rest.
2264 SmallVector<utils::IteratorType, 4> iteratorTypes;
2265 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2266 iteratorTypes[axis] = utils::IteratorType::reduction;
2267
2268 SmallVector<AffineExpr, 2> srcExprs;
2269 SmallVector<AffineExpr, 2> dstExprs;
2270 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2271 srcExprs.push_back(Elt: mlir::getAffineDimExpr(position: i, context: rewriter.getContext()));
2272 if (axis != i)
2273 dstExprs.push_back(Elt: mlir::getAffineDimExpr(position: i, context: rewriter.getContext()));
2274 }
2275
2276 bool didEncounterError = false;
2277 auto maps = AffineMap::inferFromExprList(exprsList: {srcExprs, dstExprs, dstExprs},
2278 context: rewriter.getContext());
2279 auto linalgOp = rewriter.create<linalg::GenericOp>(
2280 loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2281 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2282 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2283 ValueRange blockArgs) {
2284 auto newValue = blockArgs[0];
2285 auto oldIndex = blockArgs[1];
2286 auto oldValue = blockArgs[2];
2287
2288 Value newIndex = rewriter.create<arith::IndexCastOp>(
2289 nestedLoc, oldIndex.getType(),
2290 rewriter.create<linalg::IndexOp>(loc, axis));
2291
2292 Value predicate;
2293 if (isa<FloatType>(inElementTy)) {
2294 if (argmaxOp.getNanMode() == "IGNORE") {
2295 // Only update index & max value for non NaN values. If all
2296 // values are NaNs, the initial index will be return which is 0.
2297 predicate = rewriter.create<arith::CmpFOp>(
2298 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2299 } else {
2300 // Update max value if either of the following is true:
2301 // - new value is bigger
2302 // - cur max is not NaN and new value is NaN
2303 Value gt = rewriter.create<arith::CmpFOp>(
2304 nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
2305 Value oldNonNaN = rewriter.create<arith::CmpFOp>(
2306 nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
2307 predicate = rewriter.create<arith::AndIOp>(
2308 nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2309 }
2310 } else if (isa<IntegerType>(inElementTy)) {
2311 predicate = rewriter.create<arith::CmpIOp>(
2312 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2313 } else {
2314 didEncounterError = true;
2315 return;
2316 }
2317
2318 auto resultMax = rewriter.create<arith::SelectOp>(
2319 nestedLoc, predicate, newValue, oldValue);
2320 auto resultIndex = rewriter.create<arith::SelectOp>(
2321 nestedLoc, predicate, newIndex, oldIndex);
2322 nestedBuilder.create<linalg::YieldOp>(
2323 nestedLoc, ValueRange({resultIndex, resultMax}));
2324 });
2325
2326 if (didEncounterError)
2327 return rewriter.notifyMatchFailure(
2328 argmaxOp, "unsupported tosa.argmax element type");
2329
2330 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2331 return success();
2332 }
2333};
2334
2335class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2336public:
2337 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2338 LogicalResult
2339 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2340 ConversionPatternRewriter &rewriter) const final {
2341 auto input = adaptor.getOperands()[0];
2342 auto indices = adaptor.getOperands()[1];
2343
2344 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2345 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2346 if (!valuesTy || !resultTy)
2347 return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
2348
2349 auto dynamicDims = inferDynamicDimsForGather(
2350 builder&: rewriter, loc: op.getLoc(), values: adaptor.getValues(), indices: adaptor.getIndices());
2351
2352 auto resultElementTy = resultTy.getElementType();
2353
2354 auto loc = op.getLoc();
2355 auto emptyTensor =
2356 rewriter
2357 .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2358 dynamicDims)
2359 .getResult();
2360
2361 SmallVector<AffineMap, 2> affineMaps = {
2362 AffineMap::get(
2363 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2364 {rewriter.getAffineDimExpr(position: 0), rewriter.getAffineDimExpr(position: 1)},
2365 rewriter.getContext()),
2366 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())};
2367
2368 auto genericOp = rewriter.create<linalg::GenericOp>(
2369 loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2370 ValueRange{emptyTensor}, affineMaps,
2371 getNParallelLoopsAttrs(resultTy.getRank()),
2372 [&](OpBuilder &b, Location loc, ValueRange args) {
2373 auto indexValue = args[0];
2374 auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2375 Value index1 = rewriter.create<arith::IndexCastOp>(
2376 loc, rewriter.getIndexType(), indexValue);
2377 auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2378 Value extract = rewriter.create<tensor::ExtractOp>(
2379 loc, input, ValueRange{index0, index1, index2});
2380 rewriter.create<linalg::YieldOp>(loc, extract);
2381 });
2382 rewriter.replaceOp(op, genericOp.getResult(0));
2383 return success();
2384 }
2385
2386 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2387 Location loc,
2388 Value values,
2389 Value indices) {
2390 llvm::SmallVector<Value> results;
2391
2392 auto addDynamicDimension = [&](Value source, int64_t dim) {
2393 auto sz = tensor::getMixedSize(builder, loc, value: source, dim);
2394 if (auto dimValue = llvm::dyn_cast_if_present<Value>(Val&: sz))
2395 results.push_back(Elt: dimValue);
2396 };
2397
2398 addDynamicDimension(values, 0);
2399 addDynamicDimension(indices, 1);
2400 addDynamicDimension(values, 2);
2401 return results;
2402 }
2403};
2404
2405// Lowerings the TableOp to a series of gathers and numerica operations. This
2406// includes interpolation between the high/low values. For the I8 varient, this
2407// simplifies to a single gather operation.
2408class TableConverter : public OpRewritePattern<tosa::TableOp> {
2409public:
2410 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2411
2412 LogicalResult matchAndRewrite(tosa::TableOp op,
2413 PatternRewriter &rewriter) const final {
2414 auto loc = op.getLoc();
2415 Value input = op.getInput1();
2416 Value table = op.getTable();
2417 auto inputTy = cast<ShapedType>(input.getType());
2418 auto tableTy = cast<ShapedType>(table.getType());
2419 auto resultTy = cast<ShapedType>(op.getType());
2420
2421 auto inputElementTy = inputTy.getElementType();
2422 auto tableElementTy = tableTy.getElementType();
2423 auto resultElementTy = resultTy.getElementType();
2424
2425 SmallVector<Value> dynDims;
2426 for (int i = 0; i < resultTy.getRank(); ++i) {
2427 if (inputTy.isDynamicDim(i)) {
2428 dynDims.push_back(
2429 rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2430 }
2431 }
2432
2433 auto emptyTensor = rewriter
2434 .create<tensor::EmptyOp>(loc, resultTy.getShape(),
2435 resultElementTy, dynDims)
2436 .getResult();
2437
2438 SmallVector<AffineMap, 2> affineMaps = {
2439 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank()),
2440 rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())};
2441
2442 auto genericOp = rewriter.create<linalg::GenericOp>(
2443 loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps,
2444 getNParallelLoopsAttrs(resultTy.getRank()));
2445 rewriter.replaceOp(op, genericOp.getResult(0));
2446
2447 {
2448 OpBuilder::InsertionGuard regionGuard(rewriter);
2449 Block *block = rewriter.createBlock(
2450 &genericOp.getRegion(), genericOp.getRegion().end(),
2451 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2452
2453 auto inputValue = block->getArgument(i: 0);
2454 rewriter.setInsertionPointToStart(block);
2455 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2456 resultElementTy.isInteger(8)) {
2457 Value index = rewriter.create<arith::IndexCastOp>(
2458 loc, rewriter.getIndexType(), inputValue);
2459 Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2460 index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2461 index, offset);
2462 Value extract =
2463 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2464 rewriter.create<linalg::YieldOp>(loc, extract);
2465 return success();
2466 }
2467
2468 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2469 resultElementTy.isInteger(32)) {
2470 Value extend = rewriter.create<arith::ExtSIOp>(
2471 loc, rewriter.getI32Type(), inputValue);
2472
2473 auto offset = rewriter.create<arith::ConstantOp>(
2474 loc, rewriter.getI32IntegerAttr(32768));
2475 auto seven = rewriter.create<arith::ConstantOp>(
2476 loc, rewriter.getI32IntegerAttr(7));
2477 auto one = rewriter.create<arith::ConstantOp>(
2478 loc, rewriter.getI32IntegerAttr(1));
2479 auto b1111111 = rewriter.create<arith::ConstantOp>(
2480 loc, rewriter.getI32IntegerAttr(127));
2481
2482 // Compute the index and fractional part from the input value:
2483 // value = value + 32768
2484 // index = value >> 7;
2485 // fraction = 0x01111111 & value
2486 auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2487 Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2488 Value fraction =
2489 rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2490
2491 // Extract the base and next values from the table.
2492 // base = (int32_t) table[index];
2493 // next = (int32_t) table[index + 1];
2494 Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2495
2496 index = rewriter.create<arith::IndexCastOp>(
2497 loc, rewriter.getIndexType(), index);
2498 indexPlusOne = rewriter.create<arith::IndexCastOp>(
2499 loc, rewriter.getIndexType(), indexPlusOne);
2500
2501 Value base =
2502 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2503 Value next = rewriter.create<tensor::ExtractOp>(
2504 loc, table, ValueRange{indexPlusOne});
2505
2506 base =
2507 rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2508 next =
2509 rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2510
2511 // Use the fractional part to interpolate between the input values:
2512 // result = (base << 7) + (next - base) * fraction
2513 Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2514 Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2515 Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2516 Value result =
2517 rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2518
2519 rewriter.create<linalg::YieldOp>(loc, result);
2520
2521 return success();
2522 }
2523 }
2524
2525 return rewriter.notifyMatchFailure(
2526 op, "unable to create body for tosa.table op");
2527 }
2528};
2529
2530struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
2531 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2532
2533 static bool isRankedTensor(Type type) { return isa<RankedTensorType>(Val: type); }
2534
2535 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2536 OpFoldResult ofr) {
2537 auto one = builder.create<arith::ConstantIndexOp>(location: loc, args: 1);
2538 auto two = builder.create<arith::ConstantIndexOp>(location: loc, args: 2);
2539
2540 auto value = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr);
2541 auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
2542 auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
2543 return getAsOpFoldResult(plusOne);
2544 }
2545
2546 static RankedTensorType
2547 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2548 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2549 // Get [N, H, W]
2550 auto dims = tensor::getMixedSizes(builder, loc, value: input);
2551
2552 // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
2553 // output tensors.
2554 dims[2] = halfPlusOne(builder, loc, ofr: dims[2]);
2555
2556 llvm::SmallVector<int64_t, 3> staticSizes;
2557 dispatchIndexOpFoldResults(ofrs: dims, dynamicVec&: dynamicSizes, staticVec&: staticSizes);
2558
2559 auto elementType = cast<RankedTensorType>(input.getType()).getElementType();
2560 return RankedTensorType::get(staticSizes, elementType);
2561 }
2562
2563 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2564 RankedTensorType type,
2565 llvm::ArrayRef<Value> dynamicSizes) {
2566 auto emptyTensor =
2567 rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
2568 auto fillValueAttr = rewriter.getZeroAttr(type: type.getElementType());
2569 auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
2570 auto filledTensor = rewriter
2571 .create<linalg::FillOp>(loc, ValueRange{fillValue},
2572 ValueRange{emptyTensor})
2573 .result();
2574 return filledTensor;
2575 }
2576
2577 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2578 FloatType type, Value value) {
2579 auto integerVal = builder.create<arith::IndexCastUIOp>(
2580 loc,
2581 type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
2582 : builder.getI32Type(),
2583 value);
2584
2585 return builder.create<arith::UIToFPOp>(loc, type, integerVal);
2586 }
2587
2588 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2589 FloatType type, int64_t index) {
2590 auto indexVal = builder.create<linalg::IndexOp>(loc, index);
2591 return castIndexToFloat(builder, loc, type: type, value: indexVal);
2592 }
2593
2594 template <typename... Args>
2595 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2596 Args... args) {
2597 return {builder.getAffineDimExpr(position: args)...};
2598 }
2599
2600 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2601 PatternRewriter &rewriter) const override {
2602 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2603 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2604 return rewriter.notifyMatchFailure(rfft2d,
2605 "only supports ranked tensors");
2606 }
2607
2608 auto loc = rfft2d.getLoc();
2609 auto input = rfft2d.getInputReal();
2610 auto elementType =
2611 dyn_cast<FloatType>(cast<ShapedType>(input.getType()).getElementType());
2612 if (!elementType)
2613 return rewriter.notifyMatchFailure(rfft2d,
2614 "only supports float element types");
2615
2616 // Compute the output type and set of dynamic sizes
2617 llvm::SmallVector<Value> dynamicSizes;
2618 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2619
2620 // Iterator types for the linalg.generic implementation
2621 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2622 utils::IteratorType::parallel, utils::IteratorType::parallel,
2623 utils::IteratorType::parallel, utils::IteratorType::reduction,
2624 utils::IteratorType::reduction};
2625
2626 // Inputs/outputs to the linalg.generic implementation
2627 llvm::SmallVector<Value> genericOpInputs = {input};
2628 llvm::SmallVector<Value> genericOpOutputs = {
2629 createZeroTensor(rewriter, loc: loc, type: outputType, dynamicSizes),
2630 createZeroTensor(rewriter, loc: loc, type: outputType, dynamicSizes)};
2631
2632 // Indexing maps for input and output tensors
2633 auto indexingMaps = AffineMap::inferFromExprList(
2634 exprsList: llvm::ArrayRef{affineDimsExpr(builder&: rewriter, args: 0, args: 3, args: 4),
2635 affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2),
2636 affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2)},
2637 context: rewriter.getContext());
2638
2639 // Width and height dimensions of the original input.
2640 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1);
2641 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2);
2642
2643 // Constants and dimension sizes
2644 auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
2645 auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2646 auto constH = castIndexToFloat(builder&: rewriter, loc: loc, type: elementType, value: dimH);
2647 auto constW = castIndexToFloat(builder&: rewriter, loc: loc, type: elementType, value: dimW);
2648
2649 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2650 Value valReal = args[0];
2651 Value sumReal = args[1];
2652 Value sumImag = args[2];
2653
2654 // Indices for angle computation
2655 Value oy = builder.create<linalg::IndexOp>(loc, 1);
2656 Value ox = builder.create<linalg::IndexOp>(loc, 2);
2657 Value iy = builder.create<linalg::IndexOp>(loc, 3);
2658 Value ix = builder.create<linalg::IndexOp>(loc, 4);
2659
2660 // Calculating angle without integer parts of components as sin/cos are
2661 // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W )
2662 // / W);
2663 auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2664 auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2665
2666 auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2667 auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2668
2669 auto iyRemFloat = castIndexToFloat(builder, loc, type: elementType, value: iyRem);
2670 auto ixRemFloat = castIndexToFloat(builder, loc, type: elementType, value: ixRem);
2671
2672 auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2673 auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2674 auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2675 auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2676
2677 // realComponent = valReal * cos(angle)
2678 // imagComponent = valReal * sin(angle)
2679 auto cosAngle = builder.create<math::CosOp>(loc, angle);
2680 auto sinAngle = builder.create<math::SinOp>(loc, angle);
2681 auto realComponent =
2682 builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2683 auto imagComponent =
2684 builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2685
2686 // outReal = sumReal + realComponent
2687 // outImag = sumImag - imagComponent
2688 auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2689 auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
2690
2691 builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2692 };
2693
2694 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2695 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2696 indexingMaps, iteratorTypes, buildBody);
2697
2698 return success();
2699 }
2700};
2701
2702struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
2703 using OpRewritePattern::OpRewritePattern;
2704
2705 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2706 PatternRewriter &rewriter) const override {
2707 if (!llvm::all_of(fft2d->getOperandTypes(),
2708 RFFT2dConverter::isRankedTensor) ||
2709 !llvm::all_of(fft2d->getResultTypes(),
2710 RFFT2dConverter::isRankedTensor)) {
2711 return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
2712 }
2713
2714 Location loc = fft2d.getLoc();
2715 Value input_real = fft2d.getInputReal();
2716 Value input_imag = fft2d.getInputImag();
2717 BoolAttr inverse = fft2d.getInverseAttr();
2718
2719 auto real_el_ty = cast<FloatType>(
2720 cast<ShapedType>(input_real.getType()).getElementType());
2721 [[maybe_unused]] auto imag_el_ty = cast<FloatType>(
2722 cast<ShapedType>(input_imag.getType()).getElementType());
2723
2724 assert(real_el_ty == imag_el_ty);
2725
2726 // Compute the output type and set of dynamic sizes
2727 SmallVector<Value> dynamicSizes;
2728
2729 // Get [N, H, W]
2730 auto dims = tensor::getMixedSizes(builder&: rewriter, loc, value: input_real);
2731
2732 SmallVector<int64_t, 3> staticSizes;
2733 dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
2734
2735 auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2736
2737 // Iterator types for the linalg.generic implementation
2738 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2739 utils::IteratorType::parallel, utils::IteratorType::parallel,
2740 utils::IteratorType::parallel, utils::IteratorType::reduction,
2741 utils::IteratorType::reduction};
2742
2743 // Inputs/outputs to the linalg.generic implementation
2744 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2745 SmallVector<Value> genericOpOutputs = {
2746 RFFT2dConverter::createZeroTensor(rewriter, loc, type: outputType,
2747 dynamicSizes),
2748 RFFT2dConverter::createZeroTensor(rewriter, loc, type: outputType,
2749 dynamicSizes)};
2750
2751 // Indexing maps for input and output tensors
2752 auto indexingMaps = AffineMap::inferFromExprList(
2753 exprsList: ArrayRef{RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 3, args: 4),
2754 RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 3, args: 4),
2755 RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2),
2756 RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2)},
2757 context: rewriter.getContext());
2758
2759 // Width and height dimensions of the original input.
2760 auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
2761 auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
2762
2763 // Constants and dimension sizes
2764 auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
2765 auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
2766 Value constH =
2767 RFFT2dConverter::castIndexToFloat(builder&: rewriter, loc, type: real_el_ty, value: dimH);
2768 Value constW =
2769 RFFT2dConverter::castIndexToFloat(builder&: rewriter, loc, type: real_el_ty, value: dimW);
2770
2771 auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
2772 Value valReal = args[0];
2773 Value valImag = args[1];
2774 Value sumReal = args[2];
2775 Value sumImag = args[3];
2776
2777 // Indices for angle computation
2778 Value oy = builder.create<linalg::IndexOp>(loc, 1);
2779 Value ox = builder.create<linalg::IndexOp>(loc, 2);
2780 Value iy = builder.create<linalg::IndexOp>(loc, 3);
2781 Value ix = builder.create<linalg::IndexOp>(loc, 4);
2782
2783 // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
2784 // ox) % W ) / W);
2785 auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
2786 auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
2787
2788 auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
2789 auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
2790
2791 auto iyRemFloat =
2792 RFFT2dConverter::castIndexToFloat(builder, loc, type: real_el_ty, value: iyRem);
2793 auto ixRemFloat =
2794 RFFT2dConverter::castIndexToFloat(builder, loc, type: real_el_ty, value: ixRem);
2795
2796 auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
2797 auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
2798
2799 auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
2800 auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
2801
2802 if (inverse.getValue()) {
2803 angle = builder.create<arith::MulFOp>(
2804 loc, angle,
2805 rewriter.create<arith::ConstantOp>(
2806 loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
2807 }
2808
2809 // realComponent = val_real * cos(a) + val_imag * sin(a);
2810 // imagComponent = -val_real * sin(a) + val_imag * cos(a);
2811 auto cosAngle = builder.create<math::CosOp>(loc, angle);
2812 auto sinAngle = builder.create<math::SinOp>(loc, angle);
2813
2814 auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
2815 auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
2816 auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
2817
2818 auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
2819 auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
2820
2821 auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
2822
2823 // outReal = sumReal + realComponent
2824 // outImag = sumImag - imagComponent
2825 auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
2826 auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
2827
2828 builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
2829 };
2830
2831 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2832 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2833 indexingMaps, iteratorTypes, buildBody);
2834
2835 return success();
2836 }
2837};
2838
2839} // namespace
2840
2841void mlir::tosa::populateTosaToLinalgConversionPatterns(
2842 const TypeConverter &converter, RewritePatternSet *patterns) {
2843
2844 // We have multiple resize coverters to handle degenerate cases.
2845 patterns->add<GenericResizeConverter>(arg: patterns->getContext(),
2846 /*benefit=*/args: 100);
2847 patterns->add<ResizeUnaryConverter>(arg: patterns->getContext(),
2848 /*benefit=*/args: 200);
2849 patterns->add<MaterializeResizeBroadcast>(arg: patterns->getContext(),
2850 /*benefit=*/args: 300);
2851
2852 patterns->add<
2853 // clang-format off
2854 PointwiseConverter<tosa::AddOp>,
2855 PointwiseConverter<tosa::SubOp>,
2856 PointwiseConverter<tosa::MulOp>,
2857 PointwiseConverter<tosa::IntDivOp>,
2858 PointwiseConverter<tosa::NegateOp>,
2859 PointwiseConverter<tosa::PowOp>,
2860 PointwiseConverter<tosa::ReciprocalOp>,
2861 PointwiseConverter<tosa::RsqrtOp>,
2862 PointwiseConverter<tosa::LogOp>,
2863 PointwiseConverter<tosa::ExpOp>,
2864 PointwiseConverter<tosa::AbsOp>,
2865 PointwiseConverter<tosa::SinOp>,
2866 PointwiseConverter<tosa::CosOp>,
2867 PointwiseConverter<tosa::TanhOp>,
2868 PointwiseConverter<tosa::ErfOp>,
2869 PointwiseConverter<tosa::BitwiseAndOp>,
2870 PointwiseConverter<tosa::BitwiseOrOp>,
2871 PointwiseConverter<tosa::BitwiseNotOp>,
2872 PointwiseConverter<tosa::BitwiseXorOp>,
2873 PointwiseConverter<tosa::LogicalAndOp>,
2874 PointwiseConverter<tosa::LogicalNotOp>,
2875 PointwiseConverter<tosa::LogicalOrOp>,
2876 PointwiseConverter<tosa::LogicalXorOp>,
2877 PointwiseConverter<tosa::CastOp>,
2878 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2879 PointwiseConverter<tosa::LogicalRightShiftOp>,
2880 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2881 PointwiseConverter<tosa::ClzOp>,
2882 PointwiseConverter<tosa::SelectOp>,
2883 PointwiseConverter<tosa::GreaterOp>,
2884 PointwiseConverter<tosa::GreaterEqualOp>,
2885 PointwiseConverter<tosa::EqualOp>,
2886 PointwiseConverter<tosa::MaximumOp>,
2887 PointwiseConverter<tosa::MinimumOp>,
2888 PointwiseConverter<tosa::CeilOp>,
2889 PointwiseConverter<tosa::FloorOp>,
2890 PointwiseConverter<tosa::ClampOp>,
2891 PointwiseConverter<tosa::SigmoidOp>
2892 >(converter, patterns->getContext());
2893
2894 patterns->add<
2895 IdentityNConverter<tosa::IdentityOp>,
2896 ReduceConverter<tosa::ReduceAllOp>,
2897 ReduceConverter<tosa::ReduceAnyOp>,
2898 ReduceConverter<tosa::ReduceMinOp>,
2899 ReduceConverter<tosa::ReduceMaxOp>,
2900 ReduceConverter<tosa::ReduceSumOp>,
2901 ReduceConverter<tosa::ReduceProductOp>,
2902 ArgMaxConverter,
2903 GatherConverter,
2904 RescaleConverter,
2905 ReverseConverter,
2906 RFFT2dConverter,
2907 FFT2dConverter,
2908 TableConverter,
2909 TileConverter>(patterns->getContext());
2910 // clang-format on
2911}
2912

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp