1 | //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // These rewriters lower from the Tosa to the Linalg dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
16 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
17 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
18 | #include "mlir/Dialect/Math/IR/Math.h" |
19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
20 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
21 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
22 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
23 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
24 | #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
25 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
26 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
27 | #include "mlir/IR/Matchers.h" |
28 | #include "mlir/IR/OpDefinition.h" |
29 | #include "mlir/IR/PatternMatch.h" |
30 | #include "mlir/Transforms/DialectConversion.h" |
31 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
32 | #include "llvm/ADT/STLExtras.h" |
33 | #include "llvm/ADT/Sequence.h" |
34 | |
35 | #include <numeric> |
36 | |
37 | using namespace mlir; |
38 | using namespace mlir::tosa; |
39 | |
40 | template <typename T> |
41 | static arith::ConstantOp |
42 | createConstFromIntAttribute(Operation *op, const std::string &attrName, |
43 | Type requiredAttrType, OpBuilder &rewriter) { |
44 | auto castedN = static_cast<T>( |
45 | cast<IntegerAttr>(op->getAttr(name: attrName)).getValue().getSExtValue()); |
46 | return rewriter.create<arith::ConstantOp>( |
47 | op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); |
48 | } |
49 | |
50 | static Value |
51 | createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, |
52 | ArrayRef<Type> resultTypes, |
53 | PatternRewriter &rewriter) { |
54 | Location loc = op->getLoc(); |
55 | auto elementTy = |
56 | cast<ShapedType>(op->getOperand(0).getType()).getElementType(); |
57 | |
58 | // tosa::AbsOp |
59 | if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy)) |
60 | return rewriter.create<math::AbsFOp>(loc, resultTypes, args); |
61 | |
62 | if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) { |
63 | auto zero = rewriter.create<arith::ConstantOp>( |
64 | loc, rewriter.getZeroAttr(elementTy)); |
65 | auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]); |
66 | return rewriter.create<arith::MaxSIOp>(loc, args[0], neg); |
67 | } |
68 | |
69 | // tosa::AddOp |
70 | if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy)) |
71 | return rewriter.create<arith::AddFOp>(loc, resultTypes, args); |
72 | |
73 | if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy)) |
74 | return rewriter.create<arith::AddIOp>(loc, resultTypes, args); |
75 | |
76 | // tosa::SubOp |
77 | if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy)) |
78 | return rewriter.create<arith::SubFOp>(loc, resultTypes, args); |
79 | |
80 | if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy)) |
81 | return rewriter.create<arith::SubIOp>(loc, resultTypes, args); |
82 | |
83 | // tosa::MulOp |
84 | if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) { |
85 | if (dyn_cast<tosa::MulOp>(op).getShift() != 0) { |
86 | (void)rewriter.notifyMatchFailure(arg&: op, |
87 | msg: "Cannot have shift value for float" ); |
88 | return nullptr; |
89 | } |
90 | return rewriter.create<arith::MulFOp>(loc, resultTypes, args); |
91 | } |
92 | |
93 | // tosa::DivOp |
94 | if (isa<tosa::DivOp>(op) && isa<IntegerType>(elementTy)) |
95 | return rewriter.create<arith::DivSIOp>(loc, resultTypes, args); |
96 | |
97 | // tosa::ReciprocalOp |
98 | if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) { |
99 | auto one = |
100 | rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1)); |
101 | return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]); |
102 | } |
103 | |
104 | if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) { |
105 | Value a = args[0]; |
106 | Value b = args[1]; |
107 | auto shift = |
108 | cast<IntegerAttr>(op->getAttr(name: "shift" )).getValue().getSExtValue(); |
109 | if (shift > 0) { |
110 | auto shiftConst = |
111 | rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8); |
112 | if (!a.getType().isInteger(32)) |
113 | a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a); |
114 | |
115 | if (!b.getType().isInteger(32)) |
116 | b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b); |
117 | |
118 | auto result = rewriter.create<tosa::ApplyScaleOp>( |
119 | loc, rewriter.getI32Type(), a, b, shiftConst, |
120 | rewriter.getBoolAttr(false)); |
121 | |
122 | if (elementTy.isInteger(32)) |
123 | return result; |
124 | |
125 | return rewriter.create<arith::TruncIOp>(loc, elementTy, result); |
126 | } |
127 | |
128 | int aWidth = a.getType().getIntOrFloatBitWidth(); |
129 | int bWidth = b.getType().getIntOrFloatBitWidth(); |
130 | int cWidth = resultTypes[0].getIntOrFloatBitWidth(); |
131 | |
132 | if (aWidth < cWidth) |
133 | a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a); |
134 | if (bWidth < cWidth) |
135 | b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b); |
136 | |
137 | return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b); |
138 | } |
139 | |
140 | // tosa::NegateOp |
141 | if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy)) |
142 | return rewriter.create<arith::NegFOp>(loc, resultTypes, args); |
143 | |
144 | if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) && |
145 | !cast<tosa::NegateOp>(op).getQuantizationInfo()) { |
146 | auto constant = |
147 | rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); |
148 | return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]); |
149 | } |
150 | |
151 | if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) && |
152 | cast<tosa::NegateOp>(op).getQuantizationInfo()) { |
153 | auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo(); |
154 | int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); |
155 | int64_t inZp = quantizationInfo.value().getInputZp(); |
156 | int64_t outZp = quantizationInfo.value().getOutputZp(); |
157 | |
158 | // Compute the maximum value that can occur in the intermediate buffer. |
159 | int64_t zpAdd = inZp + outZp; |
160 | int64_t maxValue = APInt::getSignedMaxValue(numBits: inputBitWidth).getSExtValue() + |
161 | std::abs(i: zpAdd) + 1; |
162 | |
163 | // Convert that maximum value into the maximum bitwidth needed to represent |
164 | // it. We assume 48-bit numbers may be supported further in the pipeline. |
165 | int intermediateBitWidth = 64; |
166 | if (maxValue <= APInt::getSignedMaxValue(numBits: 16).getSExtValue()) { |
167 | intermediateBitWidth = 16; |
168 | } else if (maxValue <= APInt::getSignedMaxValue(numBits: 32).getSExtValue()) { |
169 | intermediateBitWidth = 32; |
170 | } else if (maxValue <= APInt::getSignedMaxValue(numBits: 48).getSExtValue()) { |
171 | intermediateBitWidth = 48; |
172 | } |
173 | |
174 | Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); |
175 | Value zpAddValue = rewriter.create<arith::ConstantOp>( |
176 | loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); |
177 | |
178 | // The negation can be applied by doing: |
179 | // outputValue = inZp + outZp - inputValue |
180 | auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]); |
181 | auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext); |
182 | |
183 | // Clamp to the negation range. |
184 | Value min = rewriter.create<arith::ConstantIntOp>( |
185 | location: loc, args: APInt::getSignedMinValue(numBits: inputBitWidth).getSExtValue(), |
186 | args&: intermediateType); |
187 | Value max = rewriter.create<arith::ConstantIntOp>( |
188 | location: loc, args: APInt::getSignedMaxValue(numBits: inputBitWidth).getSExtValue(), |
189 | args&: intermediateType); |
190 | auto clamp = clampIntHelper(loc, sub, min, max, rewriter); |
191 | |
192 | // Truncate to the final value. |
193 | return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp); |
194 | } |
195 | |
196 | // tosa::BitwiseAndOp |
197 | if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy)) |
198 | return rewriter.create<arith::AndIOp>(loc, resultTypes, args); |
199 | |
200 | // tosa::BitwiseOrOp |
201 | if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy)) |
202 | return rewriter.create<arith::OrIOp>(loc, resultTypes, args); |
203 | |
204 | // tosa::BitwiseNotOp |
205 | if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) { |
206 | auto allOnesAttr = rewriter.getIntegerAttr( |
207 | elementTy, APInt::getAllOnes(numBits: elementTy.getIntOrFloatBitWidth())); |
208 | auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr); |
209 | return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes); |
210 | } |
211 | |
212 | // tosa::BitwiseXOrOp |
213 | if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy)) |
214 | return rewriter.create<arith::XOrIOp>(loc, resultTypes, args); |
215 | |
216 | // tosa::LogicalLeftShiftOp |
217 | if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy)) |
218 | return rewriter.create<arith::ShLIOp>(loc, resultTypes, args); |
219 | |
220 | // tosa::LogicalRightShiftOp |
221 | if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy)) |
222 | return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args); |
223 | |
224 | // tosa::ArithmeticRightShiftOp |
225 | if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) { |
226 | auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args); |
227 | auto round = cast<BoolAttr>(Val: op->getAttr(name: "round" )).getValue(); |
228 | if (!round) { |
229 | return result; |
230 | } |
231 | |
232 | Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); |
233 | auto one = |
234 | rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1)); |
235 | auto zero = |
236 | rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); |
237 | auto i1one = |
238 | rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1)); |
239 | |
240 | // Checking that input2 != 0 |
241 | auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>( |
242 | loc, arith::CmpIPredicate::sgt, args[1], zero); |
243 | |
244 | // Checking for the last bit of input1 to be 1 |
245 | auto subtract = |
246 | rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one); |
247 | auto shifted = |
248 | rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract) |
249 | ->getResults(); |
250 | auto truncated = |
251 | rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt); |
252 | auto isInputOdd = |
253 | rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one); |
254 | |
255 | auto shouldRound = rewriter.create<arith::AndIOp>( |
256 | loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); |
257 | auto extended = |
258 | rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound); |
259 | return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended); |
260 | } |
261 | |
262 | // tosa::ClzOp |
263 | if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) { |
264 | return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]); |
265 | } |
266 | |
267 | // tosa::LogicalAnd |
268 | if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1)) |
269 | return rewriter.create<arith::AndIOp>(loc, resultTypes, args); |
270 | |
271 | // tosa::LogicalNot |
272 | if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) { |
273 | auto one = rewriter.create<arith::ConstantOp>( |
274 | loc, rewriter.getIntegerAttr(elementTy, 1)); |
275 | return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one); |
276 | } |
277 | |
278 | // tosa::LogicalOr |
279 | if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1)) |
280 | return rewriter.create<arith::OrIOp>(loc, resultTypes, args); |
281 | |
282 | // tosa::LogicalXor |
283 | if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1)) |
284 | return rewriter.create<arith::XOrIOp>(loc, resultTypes, args); |
285 | |
286 | // tosa::PowOp |
287 | if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy)) |
288 | return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args); |
289 | |
290 | // tosa::RsqrtOp |
291 | if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy)) |
292 | return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args); |
293 | |
294 | // tosa::LogOp |
295 | if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy)) |
296 | return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args); |
297 | |
298 | // tosa::ExpOp |
299 | if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy)) |
300 | return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args); |
301 | |
302 | // tosa::TanhOp |
303 | if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy)) |
304 | return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args); |
305 | |
306 | // tosa::ErfOp |
307 | if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy)) |
308 | return rewriter.create<mlir::math::ErfOp>(loc, resultTypes, args); |
309 | |
310 | // tosa::GreaterOp |
311 | if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy)) |
312 | return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, |
313 | args[0], args[1]); |
314 | |
315 | if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger()) |
316 | return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, |
317 | args[0], args[1]); |
318 | |
319 | // tosa::GreaterEqualOp |
320 | if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy)) |
321 | return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE, |
322 | args[0], args[1]); |
323 | |
324 | if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger()) |
325 | return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, |
326 | args[0], args[1]); |
327 | |
328 | // tosa::EqualOp |
329 | if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy)) |
330 | return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ, |
331 | args[0], args[1]); |
332 | |
333 | if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger()) |
334 | return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
335 | args[0], args[1]); |
336 | |
337 | // tosa::SelectOp |
338 | if (isa<tosa::SelectOp>(op)) { |
339 | elementTy = cast<ShapedType>(op->getOperand(1).getType()).getElementType(); |
340 | if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy)) |
341 | return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]); |
342 | } |
343 | |
344 | // tosa::MaximumOp |
345 | if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) { |
346 | return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]); |
347 | } |
348 | |
349 | if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) { |
350 | return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]); |
351 | } |
352 | |
353 | // tosa::MinimumOp |
354 | if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) { |
355 | return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]); |
356 | } |
357 | |
358 | if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) { |
359 | return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]); |
360 | } |
361 | |
362 | // tosa::CeilOp |
363 | if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy)) |
364 | return rewriter.create<math::CeilOp>(loc, resultTypes, args); |
365 | |
366 | // tosa::FloorOp |
367 | if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy)) |
368 | return rewriter.create<math::FloorOp>(loc, resultTypes, args); |
369 | |
370 | // tosa::ClampOp |
371 | if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) { |
372 | bool losesInfo = false; |
373 | APFloat minApf = cast<FloatAttr>(op->getAttr(name: "min_fp" )).getValue(); |
374 | APFloat maxApf = cast<FloatAttr>(op->getAttr(name: "max_fp" )).getValue(); |
375 | minApf.convert(ToSemantics: cast<FloatType>(elementTy).getFloatSemantics(), |
376 | RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo); |
377 | maxApf.convert(ToSemantics: cast<FloatType>(elementTy).getFloatSemantics(), |
378 | RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo); |
379 | auto min = rewriter.create<arith::ConstantOp>( |
380 | loc, elementTy, rewriter.getFloatAttr(elementTy, minApf)); |
381 | auto max = rewriter.create<arith::ConstantOp>( |
382 | loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf)); |
383 | return clampFloatHelper(loc, args[0], min, max, rewriter); |
384 | } |
385 | |
386 | if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) { |
387 | auto intTy = cast<IntegerType>(elementTy); |
388 | int64_t min = |
389 | cast<IntegerAttr>(op->getAttr(name: "min_int" )).getValue().getSExtValue(); |
390 | int64_t max = |
391 | cast<IntegerAttr>(op->getAttr(name: "max_int" )).getValue().getSExtValue(); |
392 | |
393 | if (intTy.isUnsignedInteger()) { |
394 | min = std::max(a: min, b: (int64_t)0); |
395 | max = std::min( |
396 | max, |
397 | APInt::getMaxValue(numBits: intTy.getIntOrFloatBitWidth()).getSExtValue()); |
398 | } else { |
399 | min = |
400 | std::max(min, APInt::getSignedMinValue(numBits: intTy.getIntOrFloatBitWidth()) |
401 | .getSExtValue()); |
402 | max = |
403 | std::min(max, APInt::getSignedMaxValue(numBits: intTy.getIntOrFloatBitWidth()) |
404 | .getSExtValue()); |
405 | } |
406 | |
407 | auto minVal = rewriter.create<arith::ConstantIntOp>( |
408 | loc, min, intTy.getIntOrFloatBitWidth()); |
409 | auto maxVal = rewriter.create<arith::ConstantIntOp>( |
410 | loc, max, intTy.getIntOrFloatBitWidth()); |
411 | return clampIntHelper(loc, args[0], minVal, maxVal, rewriter); |
412 | } |
413 | |
414 | // tosa::SigmoidOp |
415 | if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) { |
416 | auto one = |
417 | rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1)); |
418 | auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]); |
419 | auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate); |
420 | auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one); |
421 | return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added); |
422 | } |
423 | |
424 | // tosa::CastOp |
425 | if (isa<tosa::CastOp>(op)) { |
426 | Type srcTy = elementTy; |
427 | Type dstTy = resultTypes.front(); |
428 | bool bitExtend = |
429 | srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); |
430 | |
431 | if (srcTy == dstTy) |
432 | return args.front(); |
433 | |
434 | if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend) |
435 | return rewriter.create<arith::ExtFOp>(loc, resultTypes, args, |
436 | std::nullopt); |
437 | |
438 | if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend) |
439 | return rewriter.create<arith::TruncFOp>(loc, resultTypes, args, |
440 | std::nullopt); |
441 | |
442 | // 1-bit integers need to be treated as signless. |
443 | if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) |
444 | return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args, |
445 | std::nullopt); |
446 | |
447 | if (srcTy.isInteger(1) && isa<IntegerType>(dstTy) && bitExtend) |
448 | return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args, |
449 | std::nullopt); |
450 | |
451 | // Unsigned integers need an unrealized cast so that they can be passed |
452 | // to UIToFP. |
453 | if (srcTy.isUnsignedInteger() && isa<FloatType>(Val: dstTy)) { |
454 | auto unrealizedCast = |
455 | rewriter |
456 | .create<UnrealizedConversionCastOp>( |
457 | loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), |
458 | args[0]) |
459 | .getResult(0); |
460 | return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0], |
461 | unrealizedCast); |
462 | } |
463 | |
464 | // All other si-to-fp conversions should be handled by SIToFP. |
465 | if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy)) |
466 | return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args, |
467 | std::nullopt); |
468 | |
469 | // Casting to boolean, floats need to only be checked as not-equal to zero. |
470 | if (isa<FloatType>(Val: srcTy) && dstTy.isInteger(width: 1)) { |
471 | Value zero = rewriter.create<arith::ConstantOp>( |
472 | loc, rewriter.getFloatAttr(srcTy, 0.0)); |
473 | return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, |
474 | args.front(), zero); |
475 | } |
476 | |
477 | if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { |
478 | auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]); |
479 | |
480 | const auto &fltSemantics = cast<FloatType>(Val&: srcTy).getFloatSemantics(); |
481 | // Check whether neither int min nor int max can be represented in the |
482 | // input floating-point type due to too short exponent range. |
483 | if (static_cast<int>(dstTy.getIntOrFloatBitWidth()) - 1 > |
484 | APFloat::semanticsMaxExponent(fltSemantics)) { |
485 | // Use cmp + select to replace infinites by int min / int max. Other |
486 | // integral values can be represented in the integer space. |
487 | auto conv = rewriter.create<arith::FPToSIOp>(loc, dstTy, rounded); |
488 | auto posInf = rewriter.create<arith::ConstantOp>( |
489 | loc, rewriter.getFloatAttr(getElementTypeOrSelf(srcTy), |
490 | APFloat::getInf(fltSemantics))); |
491 | auto negInf = rewriter.create<arith::ConstantOp>( |
492 | loc, rewriter.getFloatAttr( |
493 | getElementTypeOrSelf(srcTy), |
494 | APFloat::getInf(fltSemantics, /*Negative=*/true))); |
495 | auto overflow = rewriter.create<arith::CmpFOp>( |
496 | loc, arith::CmpFPredicate::UEQ, rounded, posInf); |
497 | auto underflow = rewriter.create<arith::CmpFOp>( |
498 | loc, arith::CmpFPredicate::UEQ, rounded, negInf); |
499 | auto intMin = rewriter.create<arith::ConstantOp>( |
500 | loc, rewriter.getIntegerAttr( |
501 | getElementTypeOrSelf(dstTy), |
502 | APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()))); |
503 | auto intMax = rewriter.create<arith::ConstantOp>( |
504 | loc, rewriter.getIntegerAttr( |
505 | getElementTypeOrSelf(dstTy), |
506 | APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); |
507 | auto maxClamped = |
508 | rewriter.create<arith::SelectOp>(loc, overflow, intMax, conv); |
509 | return rewriter.create<arith::SelectOp>(loc, underflow, intMin, |
510 | maxClamped); |
511 | } |
512 | |
513 | auto intMinFP = rewriter.create<arith::ConstantOp>( |
514 | loc, rewriter.getFloatAttr( |
515 | getElementTypeOrSelf(srcTy), |
516 | APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) |
517 | .getSExtValue())); |
518 | |
519 | // Check whether the mantissa has enough bits to represent int max. |
520 | if (cast<FloatType>(Val&: srcTy).getFPMantissaWidth() >= |
521 | dstTy.getIntOrFloatBitWidth() - 1) { |
522 | // Int min can also be represented since it is a power of two and thus |
523 | // consists of a single leading bit. Therefore we can clamp the input |
524 | // in the floating-point domain. |
525 | |
526 | auto intMaxFP = rewriter.create<arith::ConstantOp>( |
527 | loc, rewriter.getFloatAttr( |
528 | getElementTypeOrSelf(srcTy), |
529 | APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) |
530 | .getSExtValue())); |
531 | |
532 | Value clamped = |
533 | clampFloatHelper(loc, rounded, intMinFP, intMaxFP, rewriter); |
534 | return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped); |
535 | } |
536 | |
537 | // Due to earlier check we know exponant range is big enough to represent |
538 | // int min. We can therefore rely on int max + 1 being representable as |
539 | // well because it's just int min with a positive sign. So clamp the min |
540 | // value and compare against that to select the max int value if needed. |
541 | auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>( |
542 | loc, rewriter.getFloatAttr( |
543 | getElementTypeOrSelf(srcTy), |
544 | APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) |
545 | .getSExtValue() + |
546 | 1)); |
547 | |
548 | auto intMax = rewriter.create<arith::ConstantOp>( |
549 | loc, rewriter.getIntegerAttr( |
550 | getElementTypeOrSelf(dstTy), |
551 | APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()))); |
552 | auto minClampedFP = |
553 | rewriter.create<arith::MaximumFOp>(loc, rounded, intMinFP); |
554 | auto minClamped = |
555 | rewriter.create<arith::FPToSIOp>(loc, dstTy, minClampedFP); |
556 | auto overflow = rewriter.create<arith::CmpFOp>( |
557 | loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP); |
558 | return rewriter.create<arith::SelectOp>(loc, overflow, intMax, |
559 | minClamped); |
560 | } |
561 | |
562 | // Casting to boolean, integers need to only be checked as not-equal to |
563 | // zero. |
564 | if (isa<IntegerType>(Val: srcTy) && dstTy.isInteger(width: 1)) { |
565 | Value zero = rewriter.create<arith::ConstantIntOp>( |
566 | location: loc, args: 0, args: srcTy.getIntOrFloatBitWidth()); |
567 | return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, |
568 | args.front(), zero); |
569 | } |
570 | |
571 | if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend) |
572 | return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args, |
573 | std::nullopt); |
574 | |
575 | if (isa<IntegerType>(Val: srcTy) && isa<IntegerType>(Val: dstTy) && !bitExtend) { |
576 | return rewriter.create<arith::TruncIOp>(loc, dstTy, args[0]); |
577 | } |
578 | } |
579 | |
580 | (void)rewriter.notifyMatchFailure( |
581 | arg&: op, msg: "unhandled op for linalg body calculation for elementwise op" ); |
582 | return nullptr; |
583 | } |
584 | |
585 | static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, |
586 | int64_t rank) { |
587 | // No need to expand if we are already at the desired rank |
588 | auto shapedType = dyn_cast<ShapedType>(tensor.getType()); |
589 | assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type" ); |
590 | int64_t = rank - shapedType.getRank(); |
591 | assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank" ); |
592 | if (!numExtraDims) |
593 | return tensor; |
594 | |
595 | // Compute reassociation indices |
596 | SmallVector<SmallVector<int64_t, 2>> reassociationIndices( |
597 | shapedType.getRank()); |
598 | int64_t index = 0; |
599 | for (index = 0; index <= numExtraDims; index++) |
600 | reassociationIndices[0].push_back(Elt: index); |
601 | for (size_t position = 1; position < reassociationIndices.size(); position++) |
602 | reassociationIndices[position].push_back(Elt: index++); |
603 | |
604 | // Compute result type |
605 | SmallVector<int64_t> resultShape; |
606 | for (index = 0; index < numExtraDims; index++) |
607 | resultShape.push_back(Elt: 1); |
608 | for (auto size : shapedType.getShape()) |
609 | resultShape.push_back(size); |
610 | auto resultType = |
611 | RankedTensorType::get(resultShape, shapedType.getElementType()); |
612 | |
613 | // Emit 'tensor.expand_shape' op |
614 | return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor, |
615 | reassociationIndices); |
616 | } |
617 | |
618 | static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter, |
619 | Location loc, Operation *operation) { |
620 | auto rank = |
621 | cast<RankedTensorType>(operation->getResultTypes().front()).getRank(); |
622 | return llvm::map_to_vector(operation->getOperands(), [&](Value operand) { |
623 | return expandRank(rewriter, loc, operand, rank); |
624 | }); |
625 | } |
626 | |
627 | using IndexPool = DenseMap<int64_t, Value>; |
628 | |
629 | // Emit an 'arith.constant' op for the given index if it has not been created |
630 | // yet, or return an existing constant. This will prevent an excessive creation |
631 | // of redundant constants, easing readability of emitted code for unit tests. |
632 | static Value createIndex(PatternRewriter &rewriter, Location loc, |
633 | IndexPool &indexPool, int64_t index) { |
634 | auto [it, inserted] = indexPool.try_emplace(Key: index); |
635 | if (inserted) |
636 | it->second = |
637 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(index)); |
638 | return it->second; |
639 | } |
640 | |
641 | static Value getTensorDim(PatternRewriter &rewriter, Location loc, |
642 | IndexPool &indexPool, Value tensor, int64_t index) { |
643 | auto indexValue = createIndex(rewriter, loc, indexPool, index); |
644 | return rewriter.create<tensor::DimOp>(loc, tensor, indexValue).getResult(); |
645 | } |
646 | |
647 | static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, |
648 | IndexPool &indexPool, Value tensor, |
649 | int64_t index) { |
650 | auto shapedType = dyn_cast<ShapedType>(tensor.getType()); |
651 | assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type" ); |
652 | assert(index >= 0 && index < shapedType.getRank() && "index out of bounds" ); |
653 | if (shapedType.isDynamicDim(index)) |
654 | return getTensorDim(rewriter, loc, indexPool, tensor, index); |
655 | return rewriter.getIndexAttr(value: shapedType.getDimSize(index)); |
656 | } |
657 | |
658 | static bool operandsAndResultsRanked(Operation *operation) { |
659 | auto isRanked = [](Value value) { |
660 | return isa<RankedTensorType>(Val: value.getType()); |
661 | }; |
662 | return llvm::all_of(Range: operation->getOperands(), P: isRanked) && |
663 | llvm::all_of(Range: operation->getResults(), P: isRanked); |
664 | } |
665 | |
666 | // Compute the runtime dimension size for dimension 'dim' of the output by |
667 | // inspecting input 'operands', all of which are expected to have the same rank. |
668 | // This function returns a pair {targetSize, masterOperand}. |
669 | // |
670 | // The runtime size of the output dimension is returned either as a statically |
671 | // computed attribute or as a runtime SSA value. |
672 | // |
673 | // If the target size was inferred directly from one dominating operand, that |
674 | // operand is returned in 'masterOperand'. If the target size is inferred from |
675 | // multiple operands, 'masterOperand' is set to nullptr. |
676 | static std::pair<OpFoldResult, Value> |
677 | computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, |
678 | ValueRange operands, int64_t dim) { |
679 | // If any input operand contains a static size greater than 1 for this |
680 | // dimension, that is the target size. An occurrence of an additional static |
681 | // dimension greater than 1 with a different value is undefined behavior. |
682 | for (auto operand : operands) { |
683 | auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim); |
684 | if (!ShapedType::isDynamic(size) && size > 1) |
685 | return {rewriter.getIndexAttr(value: size), operand}; |
686 | } |
687 | |
688 | // Filter operands with dynamic dimension |
689 | auto operandsWithDynamicDim = |
690 | llvm::to_vector(Range: llvm::make_filter_range(Range&: operands, Pred: [&](Value operand) { |
691 | return cast<RankedTensorType>(operand.getType()).isDynamicDim(dim); |
692 | })); |
693 | |
694 | // If no operand has a dynamic dimension, it means all sizes were 1 |
695 | if (operandsWithDynamicDim.empty()) |
696 | return {rewriter.getIndexAttr(1), operands.front()}; |
697 | |
698 | // Emit code that computes the runtime size for this dimension. If there is |
699 | // only one operand with a dynamic dimension, it is considered the master |
700 | // operand that determines the runtime size of the output dimension. |
701 | auto targetSize = |
702 | getTensorDim(rewriter, loc, indexPool, tensor: operandsWithDynamicDim[0], index: dim); |
703 | if (operandsWithDynamicDim.size() == 1) |
704 | return {targetSize, operandsWithDynamicDim[0]}; |
705 | |
706 | // Calculate maximum size among all dynamic dimensions |
707 | for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) { |
708 | auto nextSize = |
709 | getTensorDim(rewriter, loc, indexPool, tensor: operandsWithDynamicDim[i], index: dim); |
710 | targetSize = rewriter.create<arith::MaxUIOp>(loc, targetSize, nextSize); |
711 | } |
712 | return {targetSize, nullptr}; |
713 | } |
714 | |
715 | // Compute the runtime output size for all dimensions. This function returns |
716 | // a pair {targetShape, masterOperands}. |
717 | static std::pair<SmallVector<OpFoldResult>, SmallVector<Value>> |
718 | computeTargetShape(PatternRewriter &rewriter, Location loc, |
719 | IndexPool &indexPool, ValueRange operands) { |
720 | assert(!operands.empty()); |
721 | auto rank = cast<RankedTensorType>(operands.front().getType()).getRank(); |
722 | SmallVector<OpFoldResult> targetShape; |
723 | SmallVector<Value> masterOperands; |
724 | for (auto dim : llvm::seq<int64_t>(0, rank)) { |
725 | auto [targetSize, masterOperand] = |
726 | computeTargetSize(rewriter, loc, indexPool, operands, dim); |
727 | targetShape.push_back(targetSize); |
728 | masterOperands.push_back(masterOperand); |
729 | } |
730 | return {targetShape, masterOperands}; |
731 | } |
732 | |
733 | static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, |
734 | IndexPool &indexPool, Value operand, |
735 | int64_t dim, OpFoldResult targetSize, |
736 | Value masterOperand) { |
737 | // Nothing to do if this is a static dimension |
738 | auto rankedTensorType = cast<RankedTensorType>(operand.getType()); |
739 | if (!rankedTensorType.isDynamicDim(dim)) |
740 | return operand; |
741 | |
742 | // If the target size for this dimension was directly inferred by only taking |
743 | // this operand into account, there is no need to broadcast. This is an |
744 | // optimization that will prevent redundant control flow, and constitutes the |
745 | // main motivation for tracking "master operands". |
746 | if (operand == masterOperand) |
747 | return operand; |
748 | |
749 | // Affine maps for 'linalg.generic' op |
750 | auto rank = rankedTensorType.getRank(); |
751 | SmallVector<AffineExpr> affineExprs; |
752 | for (auto index : llvm::seq<int64_t>(0, rank)) { |
753 | auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0) |
754 | : rewriter.getAffineDimExpr(index); |
755 | affineExprs.push_back(affineExpr); |
756 | } |
757 | auto broadcastAffineMap = |
758 | AffineMap::get(rank, 0, affineExprs, rewriter.getContext()); |
759 | auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank: rank); |
760 | SmallVector<AffineMap> affineMaps = {broadcastAffineMap, identityAffineMap}; |
761 | |
762 | // Check if broadcast is necessary |
763 | auto one = createIndex(rewriter, loc, indexPool, index: 1); |
764 | auto runtimeSize = getTensorDim(rewriter, loc, indexPool, tensor: operand, index: dim); |
765 | auto broadcastNecessary = rewriter.create<arith::CmpIOp>( |
766 | loc, arith::CmpIPredicate::eq, runtimeSize, one); |
767 | |
768 | // Emit 'then' region of 'scf.if' |
769 | auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) { |
770 | // It is not safe to cache constants across regions. |
771 | // New constants could potentially violate dominance requirements. |
772 | IndexPool localPool; |
773 | |
774 | // Emit 'tensor.empty' op |
775 | SmallVector<OpFoldResult> outputTensorShape; |
776 | for (auto index : llvm::seq<int64_t>(0, rank)) { |
777 | auto size = index == dim ? targetSize |
778 | : getOrFoldTensorDim(rewriter, loc, localPool, |
779 | operand, index); |
780 | outputTensorShape.push_back(size); |
781 | } |
782 | Value outputTensor = opBuilder.create<tensor::EmptyOp>( |
783 | loc, outputTensorShape, rankedTensorType.getElementType()); |
784 | |
785 | // Emit 'linalg.generic' op |
786 | auto resultTensor = |
787 | opBuilder |
788 | .create<linalg::GenericOp>( |
789 | loc, outputTensor.getType(), operand, outputTensor, affineMaps, |
790 | getNParallelLoopsAttrs(rank), |
791 | [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { |
792 | // Emit 'linalg.yield' op |
793 | opBuilder.create<linalg::YieldOp>(loc, blockArgs.front()); |
794 | }) |
795 | .getResult(0); |
796 | |
797 | // Cast to original operand type if necessary |
798 | auto castResultTensor = rewriter.createOrFold<tensor::CastOp>( |
799 | loc, operand.getType(), resultTensor); |
800 | |
801 | // Emit 'scf.yield' op |
802 | opBuilder.create<scf::YieldOp>(loc, castResultTensor); |
803 | }; |
804 | |
805 | // Emit 'else' region of 'scf.if' |
806 | auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) { |
807 | opBuilder.create<scf::YieldOp>(loc, operand); |
808 | }; |
809 | |
810 | // Emit 'scf.if' op |
811 | auto ifOp = rewriter.create<scf::IfOp>(loc, broadcastNecessary, |
812 | emitThenRegion, emitElseRegion); |
813 | return ifOp.getResult(0); |
814 | } |
815 | |
816 | static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, |
817 | IndexPool &indexPool, Value operand, |
818 | ArrayRef<OpFoldResult> targetShape, |
819 | ArrayRef<Value> masterOperands) { |
820 | int64_t rank = cast<RankedTensorType>(operand.getType()).getRank(); |
821 | assert((int64_t)targetShape.size() == rank); |
822 | assert((int64_t)masterOperands.size() == rank); |
823 | for (auto index : llvm::seq<int64_t>(0, rank)) |
824 | operand = |
825 | broadcastDynamicDimension(rewriter, loc, indexPool, operand, index, |
826 | targetShape[index], masterOperands[index]); |
827 | return operand; |
828 | } |
829 | |
830 | static SmallVector<Value> |
831 | broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, |
832 | IndexPool &indexPool, ValueRange operands, |
833 | ArrayRef<OpFoldResult> targetShape, |
834 | ArrayRef<Value> masterOperands) { |
835 | // No need to broadcast for unary operations |
836 | if (operands.size() == 1) |
837 | return operands; |
838 | |
839 | // Broadcast dynamic dimensions operand by operand |
840 | return llvm::map_to_vector(C&: operands, F: [&](Value operand) { |
841 | return broadcastDynamicDimensions(rewriter, loc, indexPool, operand, |
842 | targetShape, masterOperands); |
843 | }); |
844 | } |
845 | |
846 | static LogicalResult |
847 | emitElementwiseComputation(PatternRewriter &rewriter, Location loc, |
848 | Operation *operation, ValueRange operands, |
849 | ArrayRef<OpFoldResult> targetShape) { |
850 | // Generate output tensor |
851 | auto resultType = cast<RankedTensorType>(operation->getResultTypes().front()); |
852 | Value outputTensor = rewriter.create<tensor::EmptyOp>( |
853 | loc, targetShape, resultType.getElementType()); |
854 | |
855 | // Create affine maps. Input affine maps broadcast static dimensions of size |
856 | // 1. The output affine map is an identity map. |
857 | // |
858 | auto rank = resultType.getRank(); |
859 | auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) { |
860 | auto shape = cast<ShapedType>(operand.getType()).getShape(); |
861 | SmallVector<AffineExpr> affineExprs; |
862 | for (auto it : llvm::enumerate(shape)) { |
863 | auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0) |
864 | : rewriter.getAffineDimExpr(it.index()); |
865 | affineExprs.push_back(affineExpr); |
866 | } |
867 | return AffineMap::get(rank, 0, affineExprs, rewriter.getContext()); |
868 | }); |
869 | affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank: rank)); |
870 | |
871 | // Emit 'linalg.generic' op |
872 | bool encounteredError = false; |
873 | auto linalgOp = rewriter.create<linalg::GenericOp>( |
874 | loc, outputTensor.getType(), operands, outputTensor, affineMaps, |
875 | getNParallelLoopsAttrs(rank), |
876 | [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { |
877 | Value opResult = createLinalgBodyCalculationForElementwiseOp( |
878 | operation, blockArgs.take_front(operation->getNumOperands()), |
879 | {resultType.getElementType()}, rewriter); |
880 | if (!opResult) { |
881 | encounteredError = true; |
882 | return; |
883 | } |
884 | opBuilder.create<linalg::YieldOp>(loc, opResult); |
885 | }); |
886 | if (encounteredError) |
887 | return rewriter.notifyMatchFailure( |
888 | arg&: operation, msg: "unable to create linalg.generic body for elementwise op" ); |
889 | |
890 | // Cast 'linalg.generic' result into original result type if needed |
891 | auto castResult = rewriter.createOrFold<tensor::CastOp>( |
892 | loc, resultType, linalgOp->getResult(0)); |
893 | rewriter.replaceOp(operation, castResult); |
894 | return success(); |
895 | } |
896 | |
897 | static LogicalResult |
898 | elementwiseMatchAndRewriteHelper(Operation *operation, |
899 | PatternRewriter &rewriter) { |
900 | |
901 | // Collect op properties |
902 | assert(operation->getNumResults() == 1 && "elementwise op expects 1 result" ); |
903 | assert(operation->getNumOperands() >= 1 && |
904 | "elementwise op expects at least 1 operand" ); |
905 | if (!operandsAndResultsRanked(operation)) |
906 | return rewriter.notifyMatchFailure(arg&: operation, |
907 | msg: "Unranked tensors not supported" ); |
908 | |
909 | // Lower operation |
910 | IndexPool indexPool; |
911 | auto loc = operation->getLoc(); |
912 | auto expandedOperands = expandInputRanks(rewriter, loc, operation); |
913 | auto [targetShape, masterOperands] = |
914 | computeTargetShape(rewriter, loc, indexPool, operands: expandedOperands); |
915 | auto broadcastOperands = broadcastDynamicDimensions( |
916 | rewriter, loc, indexPool, operands: expandedOperands, targetShape, masterOperands); |
917 | return emitElementwiseComputation(rewriter, loc, operation, operands: broadcastOperands, |
918 | targetShape); |
919 | } |
920 | |
921 | // Returns the constant initial value for a given reduction operation. The |
922 | // attribute type varies depending on the element type required. |
923 | static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, |
924 | PatternRewriter &rewriter) { |
925 | if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) |
926 | return rewriter.getFloatAttr(elementTy, 0.0); |
927 | |
928 | if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) |
929 | return rewriter.getIntegerAttr(elementTy, 0); |
930 | |
931 | if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) |
932 | return rewriter.getFloatAttr(elementTy, 1.0); |
933 | |
934 | if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) |
935 | return rewriter.getIntegerAttr(elementTy, 1); |
936 | |
937 | if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) |
938 | return rewriter.getFloatAttr( |
939 | elementTy, APFloat::getLargest( |
940 | Sem: cast<FloatType>(Val&: elementTy).getFloatSemantics(), Negative: false)); |
941 | |
942 | if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) |
943 | return rewriter.getIntegerAttr( |
944 | elementTy, APInt::getSignedMaxValue(numBits: elementTy.getIntOrFloatBitWidth())); |
945 | |
946 | if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) |
947 | return rewriter.getFloatAttr( |
948 | elementTy, APFloat::getLargest( |
949 | Sem: cast<FloatType>(Val&: elementTy).getFloatSemantics(), Negative: true)); |
950 | |
951 | if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) |
952 | return rewriter.getIntegerAttr( |
953 | elementTy, APInt::getSignedMinValue(numBits: elementTy.getIntOrFloatBitWidth())); |
954 | |
955 | if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1)) |
956 | return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(numBits: 1)); |
957 | |
958 | if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1)) |
959 | return rewriter.getIntegerAttr(elementTy, APInt::getZero(numBits: 1)); |
960 | |
961 | if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy)) |
962 | return rewriter.getFloatAttr( |
963 | elementTy, APFloat::getLargest( |
964 | Sem: cast<FloatType>(Val&: elementTy).getFloatSemantics(), Negative: true)); |
965 | |
966 | if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy)) |
967 | return rewriter.getIntegerAttr( |
968 | elementTy, APInt::getSignedMinValue(numBits: elementTy.getIntOrFloatBitWidth())); |
969 | |
970 | return {}; |
971 | } |
972 | |
973 | // Creates the body calculation for a reduction. The operations vary depending |
974 | // on the input type. |
975 | static Value createLinalgBodyCalculationForReduceOp(Operation *op, |
976 | ValueRange args, |
977 | Type elementTy, |
978 | PatternRewriter &rewriter) { |
979 | Location loc = op->getLoc(); |
980 | if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) { |
981 | return rewriter.create<arith::AddFOp>(loc, args); |
982 | } |
983 | |
984 | if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) { |
985 | return rewriter.create<arith::AddIOp>(loc, args); |
986 | } |
987 | |
988 | if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) { |
989 | return rewriter.create<arith::MulFOp>(loc, args); |
990 | } |
991 | |
992 | if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) { |
993 | return rewriter.create<arith::MulIOp>(loc, args); |
994 | } |
995 | |
996 | if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) { |
997 | return rewriter.create<arith::MinimumFOp>(loc, args[0], args[1]); |
998 | } |
999 | |
1000 | if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) { |
1001 | return rewriter.create<arith::MinSIOp>(loc, args[0], args[1]); |
1002 | } |
1003 | |
1004 | if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) { |
1005 | return rewriter.create<arith::MaximumFOp>(loc, args[0], args[1]); |
1006 | } |
1007 | |
1008 | if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) { |
1009 | return rewriter.create<arith::MaxSIOp>(loc, args[0], args[1]); |
1010 | } |
1011 | |
1012 | if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1)) |
1013 | return rewriter.create<arith::AndIOp>(loc, args); |
1014 | |
1015 | if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1)) |
1016 | return rewriter.create<arith::OrIOp>(loc, args); |
1017 | |
1018 | return {}; |
1019 | } |
1020 | |
1021 | // Performs the match and rewrite for reduction operations. This includes |
1022 | // declaring a correctly sized initial value, and the linalg.generic operation |
1023 | // that reduces across the specified axis. |
1024 | static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, |
1025 | PatternRewriter &rewriter) { |
1026 | auto loc = op->getLoc(); |
1027 | auto inputTy = cast<ShapedType>(op->getOperand(0).getType()); |
1028 | auto resultTy = cast<ShapedType>(op->getResult(0).getType()); |
1029 | auto elementTy = resultTy.getElementType(); |
1030 | Value input = op->getOperand(idx: 0); |
1031 | |
1032 | SmallVector<int64_t> reduceShape; |
1033 | SmallVector<Value> dynDims; |
1034 | for (unsigned i = 0; i < inputTy.getRank(); i++) { |
1035 | if (axis != i) { |
1036 | reduceShape.push_back(Elt: inputTy.getDimSize(i)); |
1037 | if (inputTy.isDynamicDim(i)) |
1038 | dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); |
1039 | } |
1040 | } |
1041 | |
1042 | // First fill the output buffer with the init value. |
1043 | auto emptyTensor = |
1044 | rewriter |
1045 | .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(), |
1046 | dynDims) |
1047 | .getResult(); |
1048 | |
1049 | auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); |
1050 | if (!fillValueAttr) |
1051 | return rewriter.notifyMatchFailure( |
1052 | arg&: op, msg: "No initial value found for reduction operation" ); |
1053 | |
1054 | auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr); |
1055 | auto filledTensor = rewriter |
1056 | .create<linalg::FillOp>(loc, ValueRange{fillValue}, |
1057 | ValueRange{emptyTensor}) |
1058 | .result(); |
1059 | |
1060 | bool didEncounterError = false; |
1061 | auto linalgOp = rewriter.create<linalg::ReduceOp>( |
1062 | loc, input, filledTensor, axis, |
1063 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { |
1064 | auto result = createLinalgBodyCalculationForReduceOp( |
1065 | op, blockArgs, elementTy, rewriter); |
1066 | if (result) |
1067 | didEncounterError = true; |
1068 | |
1069 | nestedBuilder.create<linalg::YieldOp>(loc, result); |
1070 | }); |
1071 | |
1072 | if (!didEncounterError) |
1073 | return rewriter.notifyMatchFailure( |
1074 | arg&: op, msg: "unable to create linalg.generic body for reduce op" ); |
1075 | |
1076 | SmallVector<ReassociationExprs, 4> reassociationMap; |
1077 | uint64_t expandInputRank = |
1078 | cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank(); |
1079 | reassociationMap.resize(N: expandInputRank); |
1080 | |
1081 | for (uint64_t i = 0; i < expandInputRank; i++) { |
1082 | int32_t dimToPush = i > axis ? i + 1 : i; |
1083 | reassociationMap[i].push_back(Elt: rewriter.getAffineDimExpr(position: dimToPush)); |
1084 | } |
1085 | |
1086 | if (expandInputRank != 0) { |
1087 | int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1; |
1088 | reassociationMap[expandedDim].push_back( |
1089 | Elt: rewriter.getAffineDimExpr(position: expandedDim + 1)); |
1090 | } |
1091 | |
1092 | // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`, |
1093 | // since here we know which dimension to expand, and `tosa::ReshapeOp` would |
1094 | // not have access to such information. This matters when handling dynamically |
1095 | // sized tensors. |
1096 | rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( |
1097 | op, resultTy, linalgOp.getResults()[0], reassociationMap); |
1098 | return success(); |
1099 | } |
1100 | |
1101 | namespace { |
1102 | |
1103 | template <typename SrcOp> |
1104 | class PointwiseConverter : public OpRewritePattern<SrcOp> { |
1105 | public: |
1106 | using OpRewritePattern<SrcOp>::OpRewritePattern; |
1107 | |
1108 | LogicalResult matchAndRewrite(SrcOp op, |
1109 | PatternRewriter &rewriter) const final { |
1110 | return elementwiseMatchAndRewriteHelper(op, rewriter); |
1111 | } |
1112 | }; |
1113 | |
1114 | class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> { |
1115 | public: |
1116 | using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern; |
1117 | |
1118 | LogicalResult matchAndRewrite(tosa::RescaleOp op, |
1119 | PatternRewriter &rewriter) const final { |
1120 | auto loc = op.getLoc(); |
1121 | auto input = op.getInput(); |
1122 | auto inputTy = cast<ShapedType>(op.getInput().getType()); |
1123 | auto outputTy = cast<ShapedType>(op.getOutput().getType()); |
1124 | unsigned rank = inputTy.getRank(); |
1125 | |
1126 | // This is an illegal configuration. terminate and log an error |
1127 | if (op.getDoubleRound() && !op.getScale32()) |
1128 | return rewriter.notifyMatchFailure( |
1129 | op, "tosa.rescale requires scale32 for double_round to be true" ); |
1130 | |
1131 | SmallVector<Value> dynDims; |
1132 | for (int i = 0; i < outputTy.getRank(); i++) { |
1133 | if (outputTy.isDynamicDim(i)) { |
1134 | dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); |
1135 | } |
1136 | } |
1137 | |
1138 | // The shift and multiplier values. |
1139 | SmallVector<int32_t> multiplierValues(op.getMultiplier()); |
1140 | SmallVector<int8_t> shiftValues(op.getShift()); |
1141 | |
1142 | // If we shift by more than the bitwidth, this just sets to 0. |
1143 | for (int i = 0, s = multiplierValues.size(); i < s; i++) { |
1144 | if (shiftValues[i] > 63) { |
1145 | shiftValues[i] = 0; |
1146 | multiplierValues[i] = 0; |
1147 | } |
1148 | } |
1149 | |
1150 | // Double round only occurs if shift is greater than 31, check that this |
1151 | // is ever true. |
1152 | bool doubleRound = |
1153 | op.getDoubleRound() && |
1154 | llvm::any_of(Range&: shiftValues, P: [](int32_t v) { return v > 31; }); |
1155 | |
1156 | SmallVector<AffineMap> indexingMaps = { |
1157 | rewriter.getMultiDimIdentityMap(rank)}; |
1158 | SmallVector<Value, 4> genericInputs = {input}; |
1159 | |
1160 | // If we are rescaling per-channel then we need to store the multiplier |
1161 | // values in a buffer. |
1162 | Value multiplierConstant; |
1163 | int64_t multiplierArg = 0; |
1164 | if (multiplierValues.size() == 1) { |
1165 | multiplierConstant = rewriter.create<arith::ConstantOp>( |
1166 | loc, rewriter.getI32IntegerAttr(multiplierValues.front())); |
1167 | } else { |
1168 | SmallVector<AffineExpr, 2> multiplierExprs{ |
1169 | rewriter.getAffineDimExpr(position: rank - 1)}; |
1170 | auto multiplierType = |
1171 | RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())}, |
1172 | rewriter.getI32Type()); |
1173 | genericInputs.push_back(rewriter.create<arith::ConstantOp>( |
1174 | loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); |
1175 | |
1176 | indexingMaps.push_back(Elt: AffineMap::get(/*dimCount=*/rank, |
1177 | /*symbolCount=*/0, results: multiplierExprs, |
1178 | context: rewriter.getContext())); |
1179 | |
1180 | multiplierArg = indexingMaps.size() - 1; |
1181 | } |
1182 | |
1183 | // If we are rescaling per-channel then we need to store the shift |
1184 | // values in a buffer. |
1185 | Value shiftConstant; |
1186 | int64_t shiftArg = 0; |
1187 | if (shiftValues.size() == 1) { |
1188 | shiftConstant = rewriter.create<arith::ConstantOp>( |
1189 | loc, rewriter.getI8IntegerAttr(shiftValues.front())); |
1190 | } else { |
1191 | SmallVector<AffineExpr, 2> shiftExprs = { |
1192 | rewriter.getAffineDimExpr(position: rank - 1)}; |
1193 | auto shiftType = |
1194 | RankedTensorType::get({static_cast<int64_t>(shiftValues.size())}, |
1195 | rewriter.getIntegerType(8)); |
1196 | genericInputs.push_back(rewriter.create<arith::ConstantOp>( |
1197 | loc, DenseIntElementsAttr::get(shiftType, shiftValues))); |
1198 | indexingMaps.push_back(Elt: AffineMap::get(/*dimCount=*/rank, |
1199 | /*symbolCount=*/0, results: shiftExprs, |
1200 | context: rewriter.getContext())); |
1201 | shiftArg = indexingMaps.size() - 1; |
1202 | } |
1203 | |
1204 | // Indexing maps for output values. |
1205 | indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank)); |
1206 | |
1207 | // Construct the indexing maps needed for linalg.generic ops. |
1208 | Value emptyTensor = rewriter.create<tensor::EmptyOp>( |
1209 | loc, outputTy.getShape(), outputTy.getElementType(), |
1210 | ArrayRef<Value>({dynDims})); |
1211 | |
1212 | auto linalgOp = rewriter.create<linalg::GenericOp>( |
1213 | loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps, |
1214 | getNParallelLoopsAttrs(rank), |
1215 | [&](OpBuilder &nestedBuilder, Location nestedLoc, |
1216 | ValueRange blockArgs) { |
1217 | Value value = blockArgs[0]; |
1218 | Type valueTy = value.getType(); |
1219 | |
1220 | // For now we do all of our math in 64-bit. This is not optimal but |
1221 | // should be correct for now, consider computing correct bit depth |
1222 | // later. |
1223 | int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32; |
1224 | |
1225 | auto inputZp = createConstFromIntAttribute<int32_t>( |
1226 | op, "input_zp" , nestedBuilder.getIntegerType(inBitwidth), |
1227 | nestedBuilder); |
1228 | auto outputZp = createConstFromIntAttribute<int32_t>( |
1229 | op, "output_zp" , nestedBuilder.getI32Type(), nestedBuilder); |
1230 | |
1231 | Value multiplier = multiplierConstant ? multiplierConstant |
1232 | : blockArgs[multiplierArg]; |
1233 | Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; |
1234 | |
1235 | if (valueTy.getIntOrFloatBitWidth() < 32) { |
1236 | if (valueTy.isUnsignedInteger()) { |
1237 | value = nestedBuilder |
1238 | .create<UnrealizedConversionCastOp>( |
1239 | nestedLoc, |
1240 | nestedBuilder.getIntegerType( |
1241 | valueTy.getIntOrFloatBitWidth()), |
1242 | value) |
1243 | .getResult(0); |
1244 | value = nestedBuilder.create<arith::ExtUIOp>( |
1245 | nestedLoc, nestedBuilder.getI32Type(), value); |
1246 | } else { |
1247 | value = nestedBuilder.create<arith::ExtSIOp>( |
1248 | nestedLoc, nestedBuilder.getI32Type(), value); |
1249 | } |
1250 | } |
1251 | |
1252 | value = |
1253 | nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp); |
1254 | |
1255 | value = nestedBuilder.create<tosa::ApplyScaleOp>( |
1256 | loc, nestedBuilder.getI32Type(), value, multiplier, shift, |
1257 | nestedBuilder.getBoolAttr(doubleRound)); |
1258 | |
1259 | // Move to the new zero-point. |
1260 | value = |
1261 | nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp); |
1262 | |
1263 | // Saturate to the output size. |
1264 | IntegerType outIntType = |
1265 | cast<IntegerType>(blockArgs.back().getType()); |
1266 | unsigned outBitWidth = outIntType.getWidth(); |
1267 | |
1268 | int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue(); |
1269 | int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue(); |
1270 | |
1271 | // Unsigned integers have a difference output value. |
1272 | if (outIntType.isUnsignedInteger()) { |
1273 | intMin = 0; |
1274 | intMax = APInt::getMaxValue(outBitWidth).getZExtValue(); |
1275 | } |
1276 | |
1277 | auto intMinVal = nestedBuilder.create<arith::ConstantOp>( |
1278 | loc, nestedBuilder.getI32IntegerAttr(intMin)); |
1279 | auto intMaxVal = nestedBuilder.create<arith::ConstantOp>( |
1280 | loc, nestedBuilder.getI32IntegerAttr(intMax)); |
1281 | |
1282 | value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal, |
1283 | nestedBuilder); |
1284 | |
1285 | if (outIntType.getWidth() < 32) { |
1286 | value = nestedBuilder.create<arith::TruncIOp>( |
1287 | nestedLoc, rewriter.getIntegerType(outIntType.getWidth()), |
1288 | value); |
1289 | |
1290 | if (outIntType.isUnsignedInteger()) { |
1291 | value = nestedBuilder |
1292 | .create<UnrealizedConversionCastOp>(nestedLoc, |
1293 | outIntType, value) |
1294 | .getResult(0); |
1295 | } |
1296 | } |
1297 | |
1298 | nestedBuilder.create<linalg::YieldOp>(loc, value); |
1299 | }); |
1300 | |
1301 | rewriter.replaceOp(op, linalgOp->getResults()); |
1302 | return success(); |
1303 | } |
1304 | }; |
1305 | |
1306 | // Handle the resize case where the input is a 1x1 image. This case |
1307 | // can entirely avoiding having extract operations which target much |
1308 | // more difficult to optimize away. |
1309 | class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> { |
1310 | public: |
1311 | using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern; |
1312 | |
1313 | LogicalResult matchAndRewrite(tosa::ResizeOp op, |
1314 | PatternRewriter &rewriter) const final { |
1315 | Location loc = op.getLoc(); |
1316 | ImplicitLocOpBuilder builder(loc, rewriter); |
1317 | auto input = op.getInput(); |
1318 | auto inputTy = cast<RankedTensorType>(input.getType()); |
1319 | auto resultTy = cast<RankedTensorType>(op.getType()); |
1320 | const bool isBilinear = op.getMode() == "BILINEAR" ; |
1321 | |
1322 | auto inputH = inputTy.getDimSize(1); |
1323 | auto inputW = inputTy.getDimSize(2); |
1324 | auto outputH = resultTy.getDimSize(1); |
1325 | auto outputW = resultTy.getDimSize(2); |
1326 | |
1327 | if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1) |
1328 | return rewriter.notifyMatchFailure( |
1329 | op, "tosa.resize is not a pure 1x1->1x1 image operation" ); |
1330 | |
1331 | // TODO(suderman): These string values should be declared the TOSA dialect. |
1332 | if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR" ) |
1333 | return rewriter.notifyMatchFailure( |
1334 | op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR" ); |
1335 | |
1336 | if (inputTy == resultTy) { |
1337 | rewriter.replaceOp(op, input); |
1338 | return success(); |
1339 | } |
1340 | |
1341 | ArrayRef<int64_t> scale = op.getScale(); |
1342 | |
1343 | // Collapse the unit width and height away. |
1344 | SmallVector<ReassociationExprs, 4> reassociationMap(2); |
1345 | reassociationMap[0].push_back(Elt: builder.getAffineDimExpr(position: 0)); |
1346 | reassociationMap[1].push_back(Elt: builder.getAffineDimExpr(position: 1)); |
1347 | reassociationMap[1].push_back(Elt: builder.getAffineDimExpr(position: 2)); |
1348 | reassociationMap[1].push_back(Elt: builder.getAffineDimExpr(position: 3)); |
1349 | |
1350 | auto collapseTy = |
1351 | RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)}, |
1352 | inputTy.getElementType()); |
1353 | Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input, |
1354 | reassociationMap); |
1355 | |
1356 | // Get any dynamic shapes that appear in the input format. |
1357 | llvm::SmallVector<Value> outputDynSize; |
1358 | if (inputTy.isDynamicDim(0)) |
1359 | outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0)); |
1360 | if (inputTy.isDynamicDim(3)) |
1361 | outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3)); |
1362 | |
1363 | // Generate the elementwise operation for casting scaling the input value. |
1364 | auto genericTy = collapseTy.clone(resultTy.getElementType()); |
1365 | Value empty = builder.create<tensor::EmptyOp>( |
1366 | genericTy.getShape(), resultTy.getElementType(), outputDynSize); |
1367 | auto genericMap = rewriter.getMultiDimIdentityMap(rank: genericTy.getRank()); |
1368 | SmallVector<utils::IteratorType> iterators(genericTy.getRank(), |
1369 | utils::IteratorType::parallel); |
1370 | |
1371 | auto generic = builder.create<linalg::GenericOp>( |
1372 | genericTy, ValueRange{collapse}, ValueRange{empty}, |
1373 | ArrayRef<AffineMap>{genericMap, genericMap}, iterators, |
1374 | [=](OpBuilder &b, Location loc, ValueRange args) { |
1375 | Value value = args[0]; |
1376 | // This is the quantized case. |
1377 | if (inputTy.getElementType() != resultTy.getElementType()) { |
1378 | value = |
1379 | b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value); |
1380 | |
1381 | if (isBilinear && scale[0] != 0) { |
1382 | Value scaleY = b.create<arith::ConstantOp>( |
1383 | loc, b.getI32IntegerAttr(scale[0])); |
1384 | value = b.create<arith::MulIOp>(loc, value, scaleY); |
1385 | } |
1386 | |
1387 | if (isBilinear && scale[2] != 0) { |
1388 | Value scaleX = b.create<arith::ConstantOp>( |
1389 | loc, b.getI32IntegerAttr(scale[2])); |
1390 | value = b.create<arith::MulIOp>(loc, value, scaleX); |
1391 | } |
1392 | } |
1393 | |
1394 | b.create<linalg::YieldOp>(loc, value); |
1395 | }); |
1396 | |
1397 | rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( |
1398 | op, resultTy, generic.getResults()[0], reassociationMap); |
1399 | return success(); |
1400 | } |
1401 | }; |
1402 | |
1403 | // TOSA resize with width or height of 1 may be broadcasted to a wider |
1404 | // dimension. This is done by materializing a new tosa.resize without |
1405 | // the broadcasting behavior, and an explicit broadcast afterwards. |
1406 | class MaterializeResizeBroadcast : public OpRewritePattern<tosa::ResizeOp> { |
1407 | public: |
1408 | using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern; |
1409 | |
1410 | LogicalResult matchAndRewrite(tosa::ResizeOp op, |
1411 | PatternRewriter &rewriter) const final { |
1412 | Location loc = op.getLoc(); |
1413 | ImplicitLocOpBuilder builder(loc, rewriter); |
1414 | auto input = op.getInput(); |
1415 | auto inputTy = dyn_cast<RankedTensorType>(input.getType()); |
1416 | auto resultTy = dyn_cast<RankedTensorType>(op.getType()); |
1417 | |
1418 | if (!inputTy || !resultTy) |
1419 | return rewriter.notifyMatchFailure(op, |
1420 | "requires ranked input/output types" ); |
1421 | |
1422 | auto batch = inputTy.getDimSize(0); |
1423 | auto channels = inputTy.getDimSize(3); |
1424 | auto inputH = inputTy.getDimSize(1); |
1425 | auto inputW = inputTy.getDimSize(2); |
1426 | auto outputH = resultTy.getDimSize(1); |
1427 | auto outputW = resultTy.getDimSize(2); |
1428 | |
1429 | if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1)) |
1430 | return rewriter.notifyMatchFailure( |
1431 | op, "tosa.resize has no broadcasting behavior" ); |
1432 | |
1433 | // For any dimension that is broadcastable we generate a width of 1 |
1434 | // on the output. |
1435 | llvm::SmallVector<int64_t> resizeShape; |
1436 | resizeShape.push_back(Elt: batch); |
1437 | resizeShape.push_back(Elt: inputH == 1 ? 1 : outputH); |
1438 | resizeShape.push_back(Elt: inputW == 1 ? 1 : outputW); |
1439 | resizeShape.push_back(Elt: channels); |
1440 | |
1441 | auto resizeTy = resultTy.clone(resizeShape); |
1442 | auto resize = |
1443 | builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs()); |
1444 | |
1445 | // Collapse an unit result dims. |
1446 | SmallVector<ReassociationExprs, 4> reassociationMap(2); |
1447 | reassociationMap[0].push_back(Elt: builder.getAffineDimExpr(position: 0)); |
1448 | reassociationMap.back().push_back(Elt: builder.getAffineDimExpr(position: 1)); |
1449 | if (inputH != 1) |
1450 | reassociationMap.push_back(Elt: {}); |
1451 | reassociationMap.back().push_back(Elt: builder.getAffineDimExpr(position: 2)); |
1452 | if (inputW != 1) |
1453 | reassociationMap.push_back(Elt: {}); |
1454 | reassociationMap.back().push_back(Elt: builder.getAffineDimExpr(position: 3)); |
1455 | |
1456 | llvm::SmallVector<int64_t> collapseShape{batch}; |
1457 | if (inputH != 1) |
1458 | collapseShape.push_back(Elt: outputH); |
1459 | if (inputW != 1) |
1460 | collapseShape.push_back(Elt: outputW); |
1461 | collapseShape.push_back(Elt: channels); |
1462 | |
1463 | auto collapseTy = resultTy.clone(collapseShape); |
1464 | Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize, |
1465 | reassociationMap); |
1466 | |
1467 | // Broadcast the collapsed shape to the output result. |
1468 | llvm::SmallVector<Value> outputDynSize; |
1469 | if (inputTy.isDynamicDim(0)) |
1470 | outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0)); |
1471 | if (inputTy.isDynamicDim(3)) |
1472 | outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3)); |
1473 | |
1474 | SmallVector<utils::IteratorType> iterators(resultTy.getRank(), |
1475 | utils::IteratorType::parallel); |
1476 | Value empty = builder.create<tensor::EmptyOp>( |
1477 | resultTy.getShape(), resultTy.getElementType(), outputDynSize); |
1478 | |
1479 | SmallVector<AffineExpr, 4> inputExprs{rewriter.getAffineDimExpr(position: 0)}; |
1480 | if (inputH != 1) |
1481 | inputExprs.push_back(Elt: rewriter.getAffineDimExpr(position: 1)); |
1482 | if (inputW != 1) |
1483 | inputExprs.push_back(Elt: rewriter.getAffineDimExpr(position: 2)); |
1484 | inputExprs.push_back(Elt: rewriter.getAffineDimExpr(position: 3)); |
1485 | |
1486 | auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, |
1487 | inputExprs, rewriter.getContext()); |
1488 | |
1489 | auto outputMap = rewriter.getMultiDimIdentityMap(rank: resultTy.getRank()); |
1490 | rewriter.replaceOpWithNewOp<linalg::GenericOp>( |
1491 | op, resultTy, ValueRange{collapse}, ValueRange{empty}, |
1492 | ArrayRef<AffineMap>{inputMap, outputMap}, iterators, |
1493 | [=](OpBuilder &b, Location loc, ValueRange args) { |
1494 | Value value = args[0]; |
1495 | b.create<linalg::YieldOp>(loc, value); |
1496 | }); |
1497 | |
1498 | return success(); |
1499 | } |
1500 | }; |
1501 | |
1502 | class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> { |
1503 | public: |
1504 | using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern; |
1505 | |
1506 | LogicalResult matchAndRewrite(tosa::ResizeOp op, |
1507 | PatternRewriter &rewriter) const final { |
1508 | Location loc = op.getLoc(); |
1509 | ImplicitLocOpBuilder b(loc, rewriter); |
1510 | auto input = op.getInput(); |
1511 | auto inputTy = cast<ShapedType>(input.getType()); |
1512 | auto resultTy = cast<ShapedType>(op.getType()); |
1513 | auto resultETy = resultTy.getElementType(); |
1514 | |
1515 | bool floatingPointMode = resultETy.isF16() || resultETy.isF32(); |
1516 | auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type(); |
1517 | |
1518 | auto imageH = inputTy.getShape()[1]; |
1519 | auto imageW = inputTy.getShape()[2]; |
1520 | |
1521 | auto dynamicDimsOr = |
1522 | checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()}); |
1523 | if (!dynamicDimsOr.has_value()) |
1524 | return rewriter.notifyMatchFailure( |
1525 | op, "unable to get dynamic dimensions of tosa.resize" ); |
1526 | |
1527 | if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR" ) |
1528 | return rewriter.notifyMatchFailure( |
1529 | op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR" ); |
1530 | |
1531 | SmallVector<AffineMap, 2> affineMaps = { |
1532 | rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())}; |
1533 | auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy, |
1534 | *dynamicDimsOr); |
1535 | auto genericOp = b.create<linalg::GenericOp>( |
1536 | resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, |
1537 | getNParallelLoopsAttrs(resultTy.getRank())); |
1538 | Value resize = genericOp.getResult(0); |
1539 | |
1540 | { |
1541 | OpBuilder::InsertionGuard regionGuard(b); |
1542 | b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(), |
1543 | TypeRange({resultETy}), loc); |
1544 | Value batch = b.create<linalg::IndexOp>(0); |
1545 | Value y = b.create<linalg::IndexOp>(1); |
1546 | Value x = b.create<linalg::IndexOp>(2); |
1547 | Value channel = b.create<linalg::IndexOp>(3); |
1548 | |
1549 | Value zeroI32 = |
1550 | b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type())); |
1551 | Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy)); |
1552 | Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1)); |
1553 | Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1)); |
1554 | |
1555 | Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y); |
1556 | Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x); |
1557 | |
1558 | ArrayRef<int64_t> offset = op.getOffset(); |
1559 | ArrayRef<int64_t> border = op.getBorder(); |
1560 | ArrayRef<int64_t> scale = op.getScale(); |
1561 | |
1562 | Value yScaleN, yScaleD, xScaleN, xScaleD; |
1563 | yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0])); |
1564 | yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1])); |
1565 | xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2])); |
1566 | xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3])); |
1567 | |
1568 | Value yOffset, xOffset, yBorder, xBorder; |
1569 | yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0])); |
1570 | xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1])); |
1571 | yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0])); |
1572 | xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1])); |
1573 | |
1574 | // Compute the ix and dx values for both the X and Y dimensions. |
1575 | auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in, |
1576 | Value scaleN, Value scaleD, Value offset, |
1577 | int size, ImplicitLocOpBuilder &b) { |
1578 | if (size == 1) { |
1579 | index = zeroI32; |
1580 | delta = zeroFp; |
1581 | return; |
1582 | } |
1583 | // x = x * scale_d + offset; |
1584 | // ix = floor(x / scale_n) |
1585 | Value val = b.create<arith::MulIOp>(in, scaleD); |
1586 | val = b.create<arith::AddIOp>(val, offset); |
1587 | index = b.create<arith::FloorDivSIOp>(val, scaleN); |
1588 | |
1589 | // rx = x % scale_n |
1590 | // dx = rx / scale_n |
1591 | Value r = b.create<arith::RemSIOp>(val, scaleN); |
1592 | Value rFp = b.create<arith::SIToFPOp>(floatTy, r); |
1593 | Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN); |
1594 | delta = b.create<arith::DivFOp>(rFp, scaleNfp); |
1595 | }; |
1596 | |
1597 | // Compute the ix and dx values for the X and Y dimensions - int case. |
1598 | auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in, |
1599 | Value scaleN, Value scaleD, Value offset, |
1600 | int size, ImplicitLocOpBuilder &b) { |
1601 | if (size == 1) { |
1602 | index = zeroI32; |
1603 | delta = zeroI32; |
1604 | return; |
1605 | } |
1606 | // x = x * scale_d + offset; |
1607 | // ix = floor(x / scale_n) |
1608 | // dx = x - ix * scale_n; |
1609 | Value val = b.create<arith::MulIOp>(in, scaleD); |
1610 | val = b.create<arith::AddIOp>(val, offset); |
1611 | index = b.create<arith::DivSIOp>(val, scaleN); |
1612 | delta = b.create<arith::MulIOp>(index, scaleN); |
1613 | delta = b.create<arith::SubIOp>(val, delta); |
1614 | }; |
1615 | |
1616 | Value ix, iy, dx, dy; |
1617 | if (floatingPointMode) { |
1618 | getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b); |
1619 | getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); |
1620 | } else { |
1621 | getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b); |
1622 | getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); |
1623 | } |
1624 | |
1625 | if (op.getMode() == "NEAREST_NEIGHBOR" ) { |
1626 | auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1)); |
1627 | |
1628 | auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale, |
1629 | Value max, int size, |
1630 | ImplicitLocOpBuilder &b) -> Value { |
1631 | if (size == 1) { |
1632 | return b.create<arith::ConstantIndexOp>(0); |
1633 | } |
1634 | |
1635 | Value pred; |
1636 | if (floatingPointMode) { |
1637 | auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f)); |
1638 | pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h); |
1639 | } else { |
1640 | Value dvalDouble = b.create<arith::ShLIOp>(dval, one); |
1641 | pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, |
1642 | dvalDouble, scale); |
1643 | } |
1644 | |
1645 | auto offset = b.create<arith::SelectOp>(pred, one, zeroI32); |
1646 | val = b.create<arith::AddIOp>(val, offset); |
1647 | val = clampIntHelper(loc, arg: val, min: zeroI32, max, rewriter&: b); |
1648 | return b.create<arith::IndexCastOp>(b.getIndexType(), val); |
1649 | }; |
1650 | |
1651 | iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b); |
1652 | ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b); |
1653 | |
1654 | Value result = b.create<tensor::ExtractOp>( |
1655 | input, ValueRange{batch, iy, ix, channel}); |
1656 | |
1657 | b.create<linalg::YieldOp>(result); |
1658 | } else { |
1659 | // The mode here must be BILINEAR. |
1660 | assert(op.getMode() == "BILINEAR" ); |
1661 | |
1662 | auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1)); |
1663 | |
1664 | auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in, |
1665 | Value max, ImplicitLocOpBuilder &b) { |
1666 | val0 = in; |
1667 | val1 = b.create<arith::AddIOp>(val0, oneVal); |
1668 | val0 = clampIntHelper(loc, arg: val0, min: zeroI32, max, rewriter&: b); |
1669 | val1 = clampIntHelper(loc, arg: val1, min: zeroI32, max, rewriter&: b); |
1670 | val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0); |
1671 | val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1); |
1672 | }; |
1673 | |
1674 | // Linalg equivalent to the section below: |
1675 | // int16_t iy0 = apply_max(iy, 0); |
1676 | // int16_t iy1 = apply_min(iy + 1, IH - 1); |
1677 | // int16_t ix0 = apply_max(ix, 0); |
1678 | // int16_t ix1 = apply_min(ix + 1, IW - 1); |
1679 | Value x0, x1, y0, y1; |
1680 | getClampedIdxs(y0, y1, imageH, iy, hMax, b); |
1681 | getClampedIdxs(x0, x1, imageW, ix, wMax, b); |
1682 | |
1683 | Value y0x0 = b.create<tensor::ExtractOp>( |
1684 | input, ValueRange{batch, y0, x0, channel}); |
1685 | Value y0x1 = b.create<tensor::ExtractOp>( |
1686 | input, ValueRange{batch, y0, x1, channel}); |
1687 | Value y1x0 = b.create<tensor::ExtractOp>( |
1688 | input, ValueRange{batch, y1, x0, channel}); |
1689 | Value y1x1 = b.create<tensor::ExtractOp>( |
1690 | input, ValueRange{batch, y1, x1, channel}); |
1691 | |
1692 | if (floatingPointMode) { |
1693 | auto oneVal = |
1694 | b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f)); |
1695 | auto interpolate = [&](Value val0, Value val1, Value delta, |
1696 | int inputSize, |
1697 | ImplicitLocOpBuilder &b) -> Value { |
1698 | if (inputSize == 1) |
1699 | return val0; |
1700 | Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta); |
1701 | Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta); |
1702 | Value mul1 = b.create<arith::MulFOp>(val1, delta); |
1703 | return b.create<arith::AddFOp>(mul0, mul1); |
1704 | }; |
1705 | |
1706 | // Linalg equivalent to the section below: |
1707 | // topAcc = v00 * (unit_x - dx); |
1708 | // topAcc += v01 * dx; |
1709 | Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b); |
1710 | |
1711 | // Linalg equivalent to the section below: |
1712 | // bottomAcc = v10 * (unit_x - dx); |
1713 | // bottomAcc += v11 * dx; |
1714 | Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b); |
1715 | |
1716 | // Linalg equivalent to the section below: |
1717 | // result = topAcc * (unit_y - dy) + bottomAcc * dy |
1718 | Value result = interpolate(topAcc, bottomAcc, dy, imageH, b); |
1719 | b.create<linalg::YieldOp>(result); |
1720 | } else { |
1721 | // Perform in quantized space. |
1722 | y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0); |
1723 | y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1); |
1724 | y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0); |
1725 | y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1); |
1726 | |
1727 | const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth(); |
1728 | if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) { |
1729 | dx = b.create<arith::ExtSIOp>(resultETy, dx); |
1730 | dy = b.create<arith::ExtSIOp>(resultETy, dy); |
1731 | } |
1732 | |
1733 | Value yScaleNExt = yScaleN; |
1734 | Value xScaleNExt = xScaleN; |
1735 | |
1736 | const int64_t scaleBitwidth = |
1737 | xScaleN.getType().getIntOrFloatBitWidth(); |
1738 | if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) { |
1739 | yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN); |
1740 | xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN); |
1741 | } |
1742 | |
1743 | auto interpolate = [](Value val0, Value val1, Value weight1, |
1744 | Value scale, int inputSize, |
1745 | ImplicitLocOpBuilder &b) -> Value { |
1746 | if (inputSize == 1) |
1747 | return b.create<arith::MulIOp>(val0, scale); |
1748 | Value weight0 = b.create<arith::SubIOp>(scale, weight1); |
1749 | Value mul0 = b.create<arith::MulIOp>(val0, weight0); |
1750 | Value mul1 = b.create<arith::MulIOp>(val1, weight1); |
1751 | return b.create<arith::AddIOp>(mul0, mul1); |
1752 | }; |
1753 | |
1754 | Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b); |
1755 | Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b); |
1756 | Value result = |
1757 | interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b); |
1758 | b.create<linalg::YieldOp>(result); |
1759 | } |
1760 | } |
1761 | } |
1762 | |
1763 | rewriter.replaceOp(op, resize); |
1764 | return success(); |
1765 | } |
1766 | }; |
1767 | |
1768 | // At the codegen level any identity operations should be removed. Any cases |
1769 | // where identity is load-bearing (e.g. cross device computation) should be |
1770 | // handled before lowering to codegen. |
1771 | template <typename SrcOp> |
1772 | class IdentityNConverter : public OpRewritePattern<SrcOp> { |
1773 | public: |
1774 | using OpRewritePattern<SrcOp>::OpRewritePattern; |
1775 | |
1776 | LogicalResult matchAndRewrite(SrcOp op, |
1777 | PatternRewriter &rewriter) const final { |
1778 | rewriter.replaceOp(op, op.getOperation()->getOperands()); |
1779 | return success(); |
1780 | } |
1781 | }; |
1782 | |
1783 | template <typename SrcOp> |
1784 | class ReduceConverter : public OpRewritePattern<SrcOp> { |
1785 | public: |
1786 | using OpRewritePattern<SrcOp>::OpRewritePattern; |
1787 | |
1788 | LogicalResult matchAndRewrite(SrcOp reduceOp, |
1789 | PatternRewriter &rewriter) const final { |
1790 | return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter); |
1791 | } |
1792 | }; |
1793 | |
1794 | class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> { |
1795 | public: |
1796 | using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern; |
1797 | |
1798 | LogicalResult matchAndRewrite(tosa::ReverseOp op, |
1799 | PatternRewriter &rewriter) const final { |
1800 | auto loc = op.getLoc(); |
1801 | Value input = op.getInput(); |
1802 | auto inputTy = cast<ShapedType>(input.getType()); |
1803 | auto resultTy = cast<ShapedType>(op.getType()); |
1804 | auto axis = op.getAxis(); |
1805 | |
1806 | SmallVector<Value> dynDims; |
1807 | for (int i = 0; i < inputTy.getRank(); i++) { |
1808 | if (inputTy.isDynamicDim(i)) { |
1809 | dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); |
1810 | } |
1811 | } |
1812 | |
1813 | Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis); |
1814 | |
1815 | // First fill the output buffer with the init value. |
1816 | auto emptyTensor = rewriter |
1817 | .create<tensor::EmptyOp>(loc, inputTy.getShape(), |
1818 | inputTy.getElementType(), |
1819 | ArrayRef<Value>({dynDims})) |
1820 | .getResult(); |
1821 | SmallVector<AffineMap, 2> affineMaps = { |
1822 | rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())}; |
1823 | |
1824 | rewriter.replaceOpWithNewOp<linalg::GenericOp>( |
1825 | op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps, |
1826 | getNParallelLoopsAttrs(resultTy.getRank()), |
1827 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
1828 | llvm::SmallVector<Value> indices; |
1829 | for (unsigned int i = 0; i < inputTy.getRank(); i++) { |
1830 | Value index = |
1831 | rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult(); |
1832 | if (i == axis) { |
1833 | auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1); |
1834 | auto sizeMinusOne = |
1835 | rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one); |
1836 | index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne, |
1837 | index); |
1838 | } |
1839 | |
1840 | indices.push_back(index); |
1841 | } |
1842 | |
1843 | auto extract = nestedBuilder.create<tensor::ExtractOp>( |
1844 | nestedLoc, input, indices); |
1845 | nestedBuilder.create<linalg::YieldOp>(op.getLoc(), |
1846 | extract.getResult()); |
1847 | }); |
1848 | return success(); |
1849 | } |
1850 | }; |
1851 | |
1852 | // This converter translate a tile operation to a reshape, broadcast, reshape. |
1853 | // The first reshape minimally expands each tiled dimension to include a |
1854 | // proceding size-1 dim. This dim is then broadcasted to the appropriate |
1855 | // multiple. |
1856 | struct TileConverter : public OpConversionPattern<tosa::TileOp> { |
1857 | using OpConversionPattern<tosa::TileOp>::OpConversionPattern; |
1858 | |
1859 | LogicalResult |
1860 | matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor, |
1861 | ConversionPatternRewriter &rewriter) const override { |
1862 | auto loc = op.getLoc(); |
1863 | auto input = op.getInput1(); |
1864 | auto inputTy = cast<ShapedType>(input.getType()); |
1865 | auto inputShape = inputTy.getShape(); |
1866 | auto resultTy = cast<ShapedType>(op.getType()); |
1867 | auto elementTy = inputTy.getElementType(); |
1868 | int64_t rank = inputTy.getRank(); |
1869 | |
1870 | ArrayRef<int64_t> multiples = op.getMultiples(); |
1871 | |
1872 | // Broadcast the newly added dimensions to their appropriate multiple. |
1873 | SmallVector<int64_t, 2> genericShape; |
1874 | for (int i = 0; i < rank; i++) { |
1875 | int64_t dim = multiples[i]; |
1876 | genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim); |
1877 | genericShape.push_back(Elt: inputShape[i]); |
1878 | } |
1879 | |
1880 | SmallVector<Value> dynDims; |
1881 | for (int i = 0; i < inputTy.getRank(); i++) { |
1882 | if (inputTy.isDynamicDim(i) || multiples[i] == -1) { |
1883 | dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); |
1884 | } |
1885 | } |
1886 | |
1887 | auto emptyTensor = rewriter.create<tensor::EmptyOp>( |
1888 | op.getLoc(), genericShape, elementTy, dynDims); |
1889 | |
1890 | // We needs to map the input shape to the non-broadcasted dimensions. |
1891 | SmallVector<AffineExpr, 4> dimExprs; |
1892 | dimExprs.reserve(N: rank); |
1893 | for (unsigned i = 0; i < rank; ++i) |
1894 | dimExprs.push_back(Elt: rewriter.getAffineDimExpr(position: i * 2 + 1)); |
1895 | |
1896 | auto readAffineMap = |
1897 | AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, results: dimExprs, |
1898 | context: rewriter.getContext()); |
1899 | |
1900 | SmallVector<AffineMap, 2> affineMaps = { |
1901 | readAffineMap, rewriter.getMultiDimIdentityMap(rank: genericShape.size())}; |
1902 | |
1903 | auto genericOp = rewriter.create<linalg::GenericOp>( |
1904 | loc, RankedTensorType::get(genericShape, elementTy), input, |
1905 | ValueRange{emptyTensor}, affineMaps, |
1906 | getNParallelLoopsAttrs(genericShape.size()), |
1907 | [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
1908 | nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin()); |
1909 | }); |
1910 | |
1911 | rewriter.replaceOpWithNewOp<tosa::ReshapeOp>( |
1912 | op, resultTy, genericOp.getResult(0), |
1913 | rewriter.getDenseI64ArrayAttr(resultTy.getShape())); |
1914 | return success(); |
1915 | } |
1916 | }; |
1917 | |
1918 | // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic |
1919 | // op, producing two output buffers. |
1920 | // |
1921 | // The first output buffer contains the index of the found maximum value. It is |
1922 | // initialized to 0 and is resulting integer type. |
1923 | // |
1924 | // The second output buffer contains the maximum value found. It is initialized |
1925 | // to the minimum representable value of the input element type. After being |
1926 | // populated by indexed_generic, this buffer is disgarded as only the index is |
1927 | // requested. |
1928 | // |
1929 | // The indexed_generic op updates both the maximum value and index if the |
1930 | // current value exceeds the running max. |
1931 | class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> { |
1932 | public: |
1933 | using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern; |
1934 | |
1935 | LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp, |
1936 | PatternRewriter &rewriter) const final { |
1937 | auto loc = argmaxOp.getLoc(); |
1938 | Value input = argmaxOp.getInput(); |
1939 | auto inputTy = cast<ShapedType>(input.getType()); |
1940 | auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType()); |
1941 | auto inElementTy = inputTy.getElementType(); |
1942 | auto outElementTy = resultTy.getElementType(); |
1943 | int axis = argmaxOp.getAxis(); |
1944 | auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy); |
1945 | |
1946 | if (!isa<IntegerType>(outElementTy)) |
1947 | return rewriter.notifyMatchFailure( |
1948 | argmaxOp, |
1949 | "tosa.arg_max to linalg.* requires integer-like result type" ); |
1950 | |
1951 | SmallVector<Value> dynDims; |
1952 | for (int i = 0; i < inputTy.getRank(); i++) { |
1953 | if (inputTy.isDynamicDim(i) && i != axis) { |
1954 | dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); |
1955 | } |
1956 | } |
1957 | |
1958 | // First fill the output buffer for the index. |
1959 | auto emptyTensorIdx = rewriter |
1960 | .create<tensor::EmptyOp>(loc, resultTy.getShape(), |
1961 | outElementTy, dynDims) |
1962 | .getResult(); |
1963 | auto fillValueIdx = rewriter.create<arith::ConstantOp>( |
1964 | loc, rewriter.getIntegerAttr(outElementTy, 0)); |
1965 | auto filledTensorIdx = |
1966 | rewriter |
1967 | .create<linalg::FillOp>(loc, ValueRange{fillValueIdx}, |
1968 | ValueRange{emptyTensorIdx}) |
1969 | .result(); |
1970 | |
1971 | // Second fill the output buffer for the running max. |
1972 | auto emptyTensorMax = rewriter |
1973 | .create<tensor::EmptyOp>(loc, resultTy.getShape(), |
1974 | inElementTy, dynDims) |
1975 | .getResult(); |
1976 | auto fillValueMaxAttr = |
1977 | createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); |
1978 | |
1979 | if (!fillValueMaxAttr) |
1980 | return rewriter.notifyMatchFailure( |
1981 | argmaxOp, "unsupported tosa.argmax element type" ); |
1982 | |
1983 | auto fillValueMax = |
1984 | rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr); |
1985 | auto filledTensorMax = |
1986 | rewriter |
1987 | .create<linalg::FillOp>(loc, ValueRange{fillValueMax}, |
1988 | ValueRange{emptyTensorMax}) |
1989 | .result(); |
1990 | |
1991 | // We need to reduce along the arg-max axis, with parallel operations along |
1992 | // the rest. |
1993 | SmallVector<utils::IteratorType, 4> iteratorTypes; |
1994 | iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel); |
1995 | iteratorTypes[axis] = utils::IteratorType::reduction; |
1996 | |
1997 | SmallVector<AffineExpr, 2> srcExprs; |
1998 | SmallVector<AffineExpr, 2> dstExprs; |
1999 | for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) { |
2000 | srcExprs.push_back(Elt: mlir::getAffineDimExpr(position: i, context: rewriter.getContext())); |
2001 | if (axis != i) |
2002 | dstExprs.push_back(Elt: mlir::getAffineDimExpr(position: i, context: rewriter.getContext())); |
2003 | } |
2004 | |
2005 | bool didEncounterError = false; |
2006 | auto maps = AffineMap::inferFromExprList(exprsList: {srcExprs, dstExprs, dstExprs}, |
2007 | context: rewriter.getContext()); |
2008 | auto linalgOp = rewriter.create<linalg::GenericOp>( |
2009 | loc, ArrayRef<Type>({resultTy, resultMaxTy}), input, |
2010 | ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes, |
2011 | [&](OpBuilder &nestedBuilder, Location nestedLoc, |
2012 | ValueRange blockArgs) { |
2013 | auto newValue = blockArgs[0]; |
2014 | auto oldIndex = blockArgs[1]; |
2015 | auto oldValue = blockArgs[2]; |
2016 | |
2017 | Value newIndex = rewriter.create<arith::IndexCastOp>( |
2018 | nestedLoc, oldIndex.getType(), |
2019 | rewriter.create<linalg::IndexOp>(loc, axis)); |
2020 | |
2021 | Value predicate; |
2022 | if (isa<FloatType>(inElementTy)) { |
2023 | predicate = rewriter.create<arith::CmpFOp>( |
2024 | nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); |
2025 | } else if (isa<IntegerType>(inElementTy)) { |
2026 | predicate = rewriter.create<arith::CmpIOp>( |
2027 | nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); |
2028 | } else { |
2029 | didEncounterError = true; |
2030 | return; |
2031 | } |
2032 | |
2033 | auto resultMax = rewriter.create<arith::SelectOp>( |
2034 | nestedLoc, predicate, newValue, oldValue); |
2035 | auto resultIndex = rewriter.create<arith::SelectOp>( |
2036 | nestedLoc, predicate, newIndex, oldIndex); |
2037 | nestedBuilder.create<linalg::YieldOp>( |
2038 | nestedLoc, ValueRange({resultIndex, resultMax})); |
2039 | }); |
2040 | |
2041 | if (didEncounterError) |
2042 | return rewriter.notifyMatchFailure( |
2043 | argmaxOp, "unsupported tosa.argmax element type" ); |
2044 | |
2045 | rewriter.replaceOp(argmaxOp, linalgOp.getResult(0)); |
2046 | return success(); |
2047 | } |
2048 | }; |
2049 | |
2050 | class GatherConverter : public OpConversionPattern<tosa::GatherOp> { |
2051 | public: |
2052 | using OpConversionPattern<tosa::GatherOp>::OpConversionPattern; |
2053 | LogicalResult |
2054 | matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor, |
2055 | ConversionPatternRewriter &rewriter) const final { |
2056 | auto input = adaptor.getOperands()[0]; |
2057 | auto indices = adaptor.getOperands()[1]; |
2058 | |
2059 | auto valuesTy = |
2060 | dyn_cast_or_null<RankedTensorType>(op.getValues().getType()); |
2061 | auto resultTy = cast<ShapedType>(op.getType()); |
2062 | |
2063 | if (!valuesTy) |
2064 | return rewriter.notifyMatchFailure(op, "unranked tensors not supported" ); |
2065 | |
2066 | auto dynamicDims = inferDynamicDimsForGather( |
2067 | builder&: rewriter, loc: op.getLoc(), values: adaptor.getValues(), indices: adaptor.getIndices()); |
2068 | |
2069 | auto resultElementTy = resultTy.getElementType(); |
2070 | |
2071 | auto loc = op.getLoc(); |
2072 | auto emptyTensor = |
2073 | rewriter |
2074 | .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy, |
2075 | dynamicDims) |
2076 | .getResult(); |
2077 | |
2078 | SmallVector<AffineMap, 2> affineMaps = { |
2079 | AffineMap::get( |
2080 | /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, |
2081 | {rewriter.getAffineDimExpr(position: 0), rewriter.getAffineDimExpr(position: 1)}, |
2082 | rewriter.getContext()), |
2083 | rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())}; |
2084 | |
2085 | auto genericOp = rewriter.create<linalg::GenericOp>( |
2086 | loc, ArrayRef<Type>({resultTy}), ValueRange{indices}, |
2087 | ValueRange{emptyTensor}, affineMaps, |
2088 | getNParallelLoopsAttrs(resultTy.getRank()), |
2089 | [&](OpBuilder &b, Location loc, ValueRange args) { |
2090 | auto indexValue = args[0]; |
2091 | auto index0 = rewriter.create<linalg::IndexOp>(loc, 0); |
2092 | Value index1 = rewriter.create<arith::IndexCastOp>( |
2093 | loc, rewriter.getIndexType(), indexValue); |
2094 | auto index2 = rewriter.create<linalg::IndexOp>(loc, 2); |
2095 | Value extract = rewriter.create<tensor::ExtractOp>( |
2096 | loc, input, ValueRange{index0, index1, index2}); |
2097 | rewriter.create<linalg::YieldOp>(loc, extract); |
2098 | }); |
2099 | rewriter.replaceOp(op, genericOp.getResult(0)); |
2100 | return success(); |
2101 | } |
2102 | |
2103 | static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder, |
2104 | Location loc, |
2105 | Value values, |
2106 | Value indices) { |
2107 | llvm::SmallVector<Value> results; |
2108 | |
2109 | auto addDynamicDimension = [&](Value source, int64_t dim) { |
2110 | auto sz = tensor::getMixedSize(builder, loc, value: source, dim); |
2111 | if (auto dimValue = llvm::dyn_cast_if_present<Value>(Val&: sz)) |
2112 | results.push_back(Elt: dimValue); |
2113 | }; |
2114 | |
2115 | addDynamicDimension(values, 0); |
2116 | addDynamicDimension(indices, 1); |
2117 | addDynamicDimension(values, 2); |
2118 | return results; |
2119 | } |
2120 | }; |
2121 | |
2122 | // Lowerings the TableOp to a series of gathers and numerica operations. This |
2123 | // includes interpolation between the high/low values. For the I8 varient, this |
2124 | // simplifies to a single gather operation. |
2125 | class TableConverter : public OpRewritePattern<tosa::TableOp> { |
2126 | public: |
2127 | using OpRewritePattern<tosa::TableOp>::OpRewritePattern; |
2128 | |
2129 | LogicalResult matchAndRewrite(tosa::TableOp op, |
2130 | PatternRewriter &rewriter) const final { |
2131 | auto loc = op.getLoc(); |
2132 | Value input = op.getInput(); |
2133 | Value table = op.getTable(); |
2134 | auto inputTy = cast<ShapedType>(input.getType()); |
2135 | auto tableTy = cast<ShapedType>(table.getType()); |
2136 | auto resultTy = cast<ShapedType>(op.getType()); |
2137 | |
2138 | auto inputElementTy = inputTy.getElementType(); |
2139 | auto tableElementTy = tableTy.getElementType(); |
2140 | auto resultElementTy = resultTy.getElementType(); |
2141 | |
2142 | SmallVector<Value> dynDims; |
2143 | for (int i = 0; i < resultTy.getRank(); ++i) { |
2144 | if (inputTy.isDynamicDim(i)) { |
2145 | dynDims.push_back( |
2146 | rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i)); |
2147 | } |
2148 | } |
2149 | |
2150 | auto emptyTensor = rewriter |
2151 | .create<tensor::EmptyOp>(loc, resultTy.getShape(), |
2152 | resultElementTy, dynDims) |
2153 | .getResult(); |
2154 | |
2155 | SmallVector<AffineMap, 2> affineMaps = { |
2156 | rewriter.getMultiDimIdentityMap(rank: resultTy.getRank()), |
2157 | rewriter.getMultiDimIdentityMap(rank: resultTy.getRank())}; |
2158 | |
2159 | auto genericOp = rewriter.create<linalg::GenericOp>( |
2160 | loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps, |
2161 | getNParallelLoopsAttrs(resultTy.getRank())); |
2162 | rewriter.replaceOp(op, genericOp.getResult(0)); |
2163 | |
2164 | { |
2165 | OpBuilder::InsertionGuard regionGuard(rewriter); |
2166 | Block *block = rewriter.createBlock( |
2167 | &genericOp.getRegion(), genericOp.getRegion().end(), |
2168 | TypeRange({inputElementTy, resultElementTy}), {loc, loc}); |
2169 | |
2170 | auto inputValue = block->getArgument(i: 0); |
2171 | rewriter.setInsertionPointToStart(block); |
2172 | if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && |
2173 | resultElementTy.isInteger(8)) { |
2174 | Value index = rewriter.create<arith::IndexCastOp>( |
2175 | loc, rewriter.getIndexType(), inputValue); |
2176 | Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128); |
2177 | index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(), |
2178 | index, offset); |
2179 | Value = |
2180 | rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index}); |
2181 | rewriter.create<linalg::YieldOp>(loc, extract); |
2182 | return success(); |
2183 | } |
2184 | |
2185 | if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && |
2186 | resultElementTy.isInteger(32)) { |
2187 | Value extend = rewriter.create<arith::ExtSIOp>( |
2188 | loc, rewriter.getI32Type(), inputValue); |
2189 | |
2190 | auto offset = rewriter.create<arith::ConstantOp>( |
2191 | loc, rewriter.getI32IntegerAttr(32768)); |
2192 | auto seven = rewriter.create<arith::ConstantOp>( |
2193 | loc, rewriter.getI32IntegerAttr(7)); |
2194 | auto one = rewriter.create<arith::ConstantOp>( |
2195 | loc, rewriter.getI32IntegerAttr(1)); |
2196 | auto b1111111 = rewriter.create<arith::ConstantOp>( |
2197 | loc, rewriter.getI32IntegerAttr(127)); |
2198 | |
2199 | // Compute the index and fractional part from the input value: |
2200 | // value = value + 32768 |
2201 | // index = value >> 7; |
2202 | // fraction = 0x01111111 & value |
2203 | auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset); |
2204 | Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven); |
2205 | Value fraction = |
2206 | rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111); |
2207 | |
2208 | // Extract the base and next values from the table. |
2209 | // base = (int32_t) table[index]; |
2210 | // next = (int32_t) table[index + 1]; |
2211 | Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one); |
2212 | |
2213 | index = rewriter.create<arith::IndexCastOp>( |
2214 | loc, rewriter.getIndexType(), index); |
2215 | indexPlusOne = rewriter.create<arith::IndexCastOp>( |
2216 | loc, rewriter.getIndexType(), indexPlusOne); |
2217 | |
2218 | Value base = |
2219 | rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index}); |
2220 | Value next = rewriter.create<tensor::ExtractOp>( |
2221 | loc, table, ValueRange{indexPlusOne}); |
2222 | |
2223 | base = |
2224 | rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base); |
2225 | next = |
2226 | rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next); |
2227 | |
2228 | // Use the fractional part to interpolate between the input values: |
2229 | // result = (base << 7) + (next - base) * fraction |
2230 | Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven); |
2231 | Value diff = rewriter.create<arith::SubIOp>(loc, next, base); |
2232 | Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction); |
2233 | Value result = |
2234 | rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled); |
2235 | |
2236 | rewriter.create<linalg::YieldOp>(loc, result); |
2237 | |
2238 | return success(); |
2239 | } |
2240 | } |
2241 | |
2242 | return rewriter.notifyMatchFailure( |
2243 | op, "unable to create body for tosa.table op" ); |
2244 | } |
2245 | }; |
2246 | |
2247 | struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> { |
2248 | using OpRewritePattern<RFFT2dOp>::OpRewritePattern; |
2249 | |
2250 | static bool isRankedTensor(Type type) { return isa<RankedTensorType>(Val: type); } |
2251 | |
2252 | static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc, |
2253 | OpFoldResult ofr) { |
2254 | auto one = builder.create<arith::ConstantIndexOp>(location: loc, args: 1); |
2255 | auto two = builder.create<arith::ConstantIndexOp>(location: loc, args: 2); |
2256 | |
2257 | auto value = getValueOrCreateConstantIndexOp(b&: builder, loc, ofr); |
2258 | auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two); |
2259 | auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one); |
2260 | return getAsOpFoldResult(plusOne); |
2261 | } |
2262 | |
2263 | static RankedTensorType |
2264 | computeOutputShape(OpBuilder &builder, Location loc, Value input, |
2265 | llvm::SmallVectorImpl<Value> &dynamicSizes) { |
2266 | // Get [N, H, W] |
2267 | auto dims = tensor::getMixedSizes(builder, loc, value: input); |
2268 | |
2269 | // Set W = (W / 2) + 1 to account for the half-sized W dimension of the |
2270 | // output tensors. |
2271 | dims[2] = halfPlusOne(builder, loc, ofr: dims[2]); |
2272 | |
2273 | llvm::SmallVector<int64_t, 3> staticSizes; |
2274 | dispatchIndexOpFoldResults(ofrs: dims, dynamicVec&: dynamicSizes, staticVec&: staticSizes); |
2275 | |
2276 | auto elementType = cast<RankedTensorType>(input.getType()).getElementType(); |
2277 | return RankedTensorType::get(staticSizes, elementType); |
2278 | } |
2279 | |
2280 | static Value createZeroTensor(PatternRewriter &rewriter, Location loc, |
2281 | RankedTensorType type, |
2282 | llvm::ArrayRef<Value> dynamicSizes) { |
2283 | auto emptyTensor = |
2284 | rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes); |
2285 | auto fillValueAttr = rewriter.getZeroAttr(type: type.getElementType()); |
2286 | auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr); |
2287 | auto filledTensor = rewriter |
2288 | .create<linalg::FillOp>(loc, ValueRange{fillValue}, |
2289 | ValueRange{emptyTensor}) |
2290 | .result(); |
2291 | return filledTensor; |
2292 | } |
2293 | |
2294 | static Value castIndexToFloat(OpBuilder &builder, Location loc, |
2295 | FloatType type, Value value) { |
2296 | auto integerVal = builder.create<arith::IndexCastUIOp>( |
2297 | loc, |
2298 | type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type() |
2299 | : builder.getI32Type(), |
2300 | value); |
2301 | |
2302 | return builder.create<arith::UIToFPOp>(loc, type, integerVal); |
2303 | } |
2304 | |
2305 | static Value createLinalgIndex(OpBuilder &builder, Location loc, |
2306 | FloatType type, int64_t index) { |
2307 | auto indexVal = builder.create<linalg::IndexOp>(loc, index); |
2308 | return castIndexToFloat(builder, loc, type, value: indexVal); |
2309 | } |
2310 | |
2311 | template <typename... Args> |
2312 | static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder, |
2313 | Args... args) { |
2314 | return {builder.getAffineDimExpr(position: args)...}; |
2315 | } |
2316 | |
2317 | LogicalResult matchAndRewrite(RFFT2dOp rfft2d, |
2318 | PatternRewriter &rewriter) const override { |
2319 | if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) || |
2320 | !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) { |
2321 | return rewriter.notifyMatchFailure(rfft2d, |
2322 | "only supports ranked tensors" ); |
2323 | } |
2324 | |
2325 | auto loc = rfft2d.getLoc(); |
2326 | auto input = rfft2d.getInput(); |
2327 | auto elementType = |
2328 | cast<FloatType>(cast<ShapedType>(input.getType()).getElementType()); |
2329 | |
2330 | // Compute the output type and set of dynamic sizes |
2331 | llvm::SmallVector<Value> dynamicSizes; |
2332 | auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes); |
2333 | |
2334 | // Iterator types for the linalg.generic implementation |
2335 | llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = { |
2336 | utils::IteratorType::parallel, utils::IteratorType::parallel, |
2337 | utils::IteratorType::parallel, utils::IteratorType::reduction, |
2338 | utils::IteratorType::reduction}; |
2339 | |
2340 | // Inputs/outputs to the linalg.generic implementation |
2341 | llvm::SmallVector<Value> genericOpInputs = {input}; |
2342 | llvm::SmallVector<Value> genericOpOutputs = { |
2343 | createZeroTensor(rewriter, loc: loc, type: outputType, dynamicSizes), |
2344 | createZeroTensor(rewriter, loc: loc, type: outputType, dynamicSizes)}; |
2345 | |
2346 | // Indexing maps for input and output tensors |
2347 | auto indexingMaps = AffineMap::inferFromExprList( |
2348 | exprsList: llvm::ArrayRef{affineDimsExpr(builder&: rewriter, args: 0, args: 3, args: 4), |
2349 | affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2), |
2350 | affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2)}, |
2351 | context: rewriter.getContext()); |
2352 | |
2353 | // Width and height dimensions of the original input. |
2354 | auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input, 1); |
2355 | auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input, 2); |
2356 | |
2357 | // Constants and dimension sizes |
2358 | auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586); |
2359 | auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr); |
2360 | auto constH = castIndexToFloat(builder&: rewriter, loc: loc, type: elementType, value: dimH); |
2361 | auto constW = castIndexToFloat(builder&: rewriter, loc: loc, type: elementType, value: dimW); |
2362 | |
2363 | auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) { |
2364 | Value valReal = args[0]; |
2365 | Value sumReal = args[1]; |
2366 | Value sumImag = args[2]; |
2367 | |
2368 | // Indices for angle computation |
2369 | Value oy = builder.create<linalg::IndexOp>(loc, 1); |
2370 | Value ox = builder.create<linalg::IndexOp>(loc, 2); |
2371 | Value iy = builder.create<linalg::IndexOp>(loc, 3); |
2372 | Value ix = builder.create<linalg::IndexOp>(loc, 4); |
2373 | |
2374 | // Calculating angle without integer parts of components as sin/cos are |
2375 | // periodic: angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W ) |
2376 | // / W); |
2377 | auto iyXoy = builder.create<index::MulOp>(loc, iy, oy); |
2378 | auto ixXox = builder.create<index::MulOp>(loc, ix, ox); |
2379 | |
2380 | auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH); |
2381 | auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW); |
2382 | |
2383 | auto iyRemFloat = castIndexToFloat(builder, loc, type: elementType, value: iyRem); |
2384 | auto ixRemFloat = castIndexToFloat(builder, loc, type: elementType, value: ixRem); |
2385 | |
2386 | auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH); |
2387 | auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW); |
2388 | auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent); |
2389 | auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY); |
2390 | |
2391 | // realComponent = valReal * cos(angle) |
2392 | // imagComponent = valReal * sin(angle) |
2393 | auto cosAngle = builder.create<math::CosOp>(loc, angle); |
2394 | auto sinAngle = builder.create<math::SinOp>(loc, angle); |
2395 | auto realComponent = |
2396 | builder.create<arith::MulFOp>(loc, valReal, cosAngle); |
2397 | auto imagComponent = |
2398 | builder.create<arith::MulFOp>(loc, valReal, sinAngle); |
2399 | |
2400 | // outReal = sumReal + realComponent |
2401 | // outImag = sumImag - imagComponent |
2402 | auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent); |
2403 | auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent); |
2404 | |
2405 | builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag}); |
2406 | }; |
2407 | |
2408 | rewriter.replaceOpWithNewOp<linalg::GenericOp>( |
2409 | rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs, |
2410 | indexingMaps, iteratorTypes, buildBody); |
2411 | |
2412 | return success(); |
2413 | } |
2414 | }; |
2415 | |
2416 | struct FFT2dConverter final : OpRewritePattern<FFT2dOp> { |
2417 | using OpRewritePattern::OpRewritePattern; |
2418 | |
2419 | LogicalResult matchAndRewrite(FFT2dOp fft2d, |
2420 | PatternRewriter &rewriter) const override { |
2421 | if (!llvm::all_of(fft2d->getOperandTypes(), |
2422 | RFFT2dConverter::isRankedTensor) || |
2423 | !llvm::all_of(fft2d->getResultTypes(), |
2424 | RFFT2dConverter::isRankedTensor)) { |
2425 | return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors" ); |
2426 | } |
2427 | |
2428 | Location loc = fft2d.getLoc(); |
2429 | Value input_real = fft2d.getInputReal(); |
2430 | Value input_imag = fft2d.getInputImag(); |
2431 | BoolAttr inverse = fft2d.getInverseAttr(); |
2432 | |
2433 | auto real_el_ty = cast<FloatType>( |
2434 | cast<ShapedType>(input_real.getType()).getElementType()); |
2435 | [[maybe_unused]] auto imag_el_ty = cast<FloatType>( |
2436 | cast<ShapedType>(input_imag.getType()).getElementType()); |
2437 | |
2438 | assert(real_el_ty == imag_el_ty); |
2439 | |
2440 | // Compute the output type and set of dynamic sizes |
2441 | SmallVector<Value> dynamicSizes; |
2442 | |
2443 | // Get [N, H, W] |
2444 | auto dims = tensor::getMixedSizes(builder&: rewriter, loc, value: input_real); |
2445 | |
2446 | SmallVector<int64_t, 3> staticSizes; |
2447 | dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes); |
2448 | |
2449 | auto outputType = RankedTensorType::get(staticSizes, real_el_ty); |
2450 | |
2451 | // Iterator types for the linalg.generic implementation |
2452 | SmallVector<utils::IteratorType, 5> iteratorTypes = { |
2453 | utils::IteratorType::parallel, utils::IteratorType::parallel, |
2454 | utils::IteratorType::parallel, utils::IteratorType::reduction, |
2455 | utils::IteratorType::reduction}; |
2456 | |
2457 | // Inputs/outputs to the linalg.generic implementation |
2458 | SmallVector<Value> genericOpInputs = {input_real, input_imag}; |
2459 | SmallVector<Value> genericOpOutputs = { |
2460 | RFFT2dConverter::createZeroTensor(rewriter, loc, type: outputType, |
2461 | dynamicSizes), |
2462 | RFFT2dConverter::createZeroTensor(rewriter, loc, type: outputType, |
2463 | dynamicSizes)}; |
2464 | |
2465 | // Indexing maps for input and output tensors |
2466 | auto indexingMaps = AffineMap::inferFromExprList( |
2467 | exprsList: ArrayRef{RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 3, args: 4), |
2468 | RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 3, args: 4), |
2469 | RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2), |
2470 | RFFT2dConverter::affineDimsExpr(builder&: rewriter, args: 0, args: 1, args: 2)}, |
2471 | context: rewriter.getContext()); |
2472 | |
2473 | // Width and height dimensions of the original input. |
2474 | auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1); |
2475 | auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2); |
2476 | |
2477 | // Constants and dimension sizes |
2478 | auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586); |
2479 | auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr); |
2480 | Value constH = |
2481 | RFFT2dConverter::castIndexToFloat(builder&: rewriter, loc, type: real_el_ty, value: dimH); |
2482 | Value constW = |
2483 | RFFT2dConverter::castIndexToFloat(builder&: rewriter, loc, type: real_el_ty, value: dimW); |
2484 | |
2485 | auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) { |
2486 | Value valReal = args[0]; |
2487 | Value valImag = args[1]; |
2488 | Value sumReal = args[2]; |
2489 | Value sumImag = args[3]; |
2490 | |
2491 | // Indices for angle computation |
2492 | Value oy = builder.create<linalg::IndexOp>(loc, 1); |
2493 | Value ox = builder.create<linalg::IndexOp>(loc, 2); |
2494 | Value iy = builder.create<linalg::IndexOp>(loc, 3); |
2495 | Value ix = builder.create<linalg::IndexOp>(loc, 4); |
2496 | |
2497 | // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * |
2498 | // ox) % W ) / W); |
2499 | auto iyXoy = builder.create<index::MulOp>(loc, iy, oy); |
2500 | auto ixXox = builder.create<index::MulOp>(loc, ix, ox); |
2501 | |
2502 | auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH); |
2503 | auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW); |
2504 | |
2505 | auto iyRemFloat = |
2506 | RFFT2dConverter::castIndexToFloat(builder, loc, type: real_el_ty, value: iyRem); |
2507 | auto ixRemFloat = |
2508 | RFFT2dConverter::castIndexToFloat(builder, loc, type: real_el_ty, value: ixRem); |
2509 | |
2510 | auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH); |
2511 | auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW); |
2512 | |
2513 | auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent); |
2514 | auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY); |
2515 | |
2516 | if (inverse.getValue()) { |
2517 | angle = builder.create<arith::MulFOp>( |
2518 | loc, angle, |
2519 | rewriter.create<arith::ConstantOp>( |
2520 | loc, rewriter.getFloatAttr(real_el_ty, -1.0))); |
2521 | } |
2522 | |
2523 | // realComponent = val_real * cos(a) + val_imag * sin(a); |
2524 | // imagComponent = -val_real * sin(a) + val_imag * cos(a); |
2525 | auto cosAngle = builder.create<math::CosOp>(loc, angle); |
2526 | auto sinAngle = builder.create<math::SinOp>(loc, angle); |
2527 | |
2528 | auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle); |
2529 | auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle); |
2530 | auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin); |
2531 | |
2532 | auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle); |
2533 | auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle); |
2534 | |
2535 | auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin); |
2536 | |
2537 | // outReal = sumReal + realComponent |
2538 | // outImag = sumImag - imagComponent |
2539 | auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent); |
2540 | auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent); |
2541 | |
2542 | builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag}); |
2543 | }; |
2544 | |
2545 | rewriter.replaceOpWithNewOp<linalg::GenericOp>( |
2546 | fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs, |
2547 | indexingMaps, iteratorTypes, buildBody); |
2548 | |
2549 | return success(); |
2550 | } |
2551 | }; |
2552 | |
2553 | } // namespace |
2554 | |
2555 | void mlir::tosa::populateTosaToLinalgConversionPatterns( |
2556 | RewritePatternSet *patterns) { |
2557 | |
2558 | // We have multiple resize coverters to handle degenerate cases. |
2559 | patterns->add<GenericResizeConverter>(arg: patterns->getContext(), |
2560 | /*benefit=*/args: 100); |
2561 | patterns->add<ResizeUnaryConverter>(arg: patterns->getContext(), |
2562 | /*benefit=*/args: 200); |
2563 | patterns->add<MaterializeResizeBroadcast>(arg: patterns->getContext(), |
2564 | /*benefit=*/args: 300); |
2565 | |
2566 | patterns->add< |
2567 | // clang-format off |
2568 | PointwiseConverter<tosa::AddOp>, |
2569 | PointwiseConverter<tosa::SubOp>, |
2570 | PointwiseConverter<tosa::MulOp>, |
2571 | PointwiseConverter<tosa::DivOp>, |
2572 | PointwiseConverter<tosa::NegateOp>, |
2573 | PointwiseConverter<tosa::PowOp>, |
2574 | PointwiseConverter<tosa::ReciprocalOp>, |
2575 | PointwiseConverter<tosa::RsqrtOp>, |
2576 | PointwiseConverter<tosa::LogOp>, |
2577 | PointwiseConverter<tosa::ExpOp>, |
2578 | PointwiseConverter<tosa::AbsOp>, |
2579 | PointwiseConverter<tosa::TanhOp>, |
2580 | PointwiseConverter<tosa::ErfOp>, |
2581 | PointwiseConverter<tosa::BitwiseAndOp>, |
2582 | PointwiseConverter<tosa::BitwiseOrOp>, |
2583 | PointwiseConverter<tosa::BitwiseNotOp>, |
2584 | PointwiseConverter<tosa::BitwiseXorOp>, |
2585 | PointwiseConverter<tosa::LogicalAndOp>, |
2586 | PointwiseConverter<tosa::LogicalNotOp>, |
2587 | PointwiseConverter<tosa::LogicalOrOp>, |
2588 | PointwiseConverter<tosa::LogicalXorOp>, |
2589 | PointwiseConverter<tosa::CastOp>, |
2590 | PointwiseConverter<tosa::LogicalLeftShiftOp>, |
2591 | PointwiseConverter<tosa::LogicalRightShiftOp>, |
2592 | PointwiseConverter<tosa::ArithmeticRightShiftOp>, |
2593 | PointwiseConverter<tosa::ClzOp>, |
2594 | PointwiseConverter<tosa::SelectOp>, |
2595 | PointwiseConverter<tosa::GreaterOp>, |
2596 | PointwiseConverter<tosa::GreaterEqualOp>, |
2597 | PointwiseConverter<tosa::EqualOp>, |
2598 | PointwiseConverter<tosa::MaximumOp>, |
2599 | PointwiseConverter<tosa::MinimumOp>, |
2600 | PointwiseConverter<tosa::CeilOp>, |
2601 | PointwiseConverter<tosa::FloorOp>, |
2602 | PointwiseConverter<tosa::ClampOp>, |
2603 | PointwiseConverter<tosa::SigmoidOp>, |
2604 | IdentityNConverter<tosa::IdentityOp>, |
2605 | ReduceConverter<tosa::ReduceAllOp>, |
2606 | ReduceConverter<tosa::ReduceAnyOp>, |
2607 | ReduceConverter<tosa::ReduceMinOp>, |
2608 | ReduceConverter<tosa::ReduceMaxOp>, |
2609 | ReduceConverter<tosa::ReduceSumOp>, |
2610 | ReduceConverter<tosa::ReduceProdOp>, |
2611 | ArgMaxConverter, |
2612 | GatherConverter, |
2613 | RescaleConverter, |
2614 | ReverseConverter, |
2615 | RFFT2dConverter, |
2616 | FFT2dConverter, |
2617 | TableConverter, |
2618 | TileConverter>(patterns->getContext()); |
2619 | // clang-format on |
2620 | } |
2621 | |