1//===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg named ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Math/IR/Math.h"
17#include "mlir/Dialect/Tensor/IR/Tensor.h"
18#include "mlir/Dialect/Tosa/IR/TosaOps.h"
19#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
20#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Transforms/DialectConversion.h"
23
24#include <type_traits>
25
26using namespace mlir;
27using namespace mlir::tosa;
28
29static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
30 TypedAttr padAttr, OpBuilder &rewriter) {
31 // Input should be padded only if necessary.
32 if (llvm::all_of(Range&: pad, P: [](int64_t p) { return p == 0; }))
33 return input;
34
35 ShapedType inputTy = cast<ShapedType>(Val: input.getType());
36 Type inputETy = inputTy.getElementType();
37 auto inputShape = inputTy.getShape();
38
39 assert((inputShape.size() * 2) == pad.size());
40
41 SmallVector<int64_t, 4> paddedShape;
42 SmallVector<OpFoldResult, 8> lowIndices;
43 SmallVector<OpFoldResult, 8> highIndices;
44 for (size_t i : llvm::seq(Size: inputShape.size())) {
45 auto lowPad = pad[i * 2];
46 auto highPad = pad[i * 2 + 1];
47 if (ShapedType::isDynamic(dValue: inputShape[i]))
48 paddedShape.push_back(Elt: inputShape[i]);
49 else
50 paddedShape.push_back(Elt: inputShape[i] + highPad + lowPad);
51 lowIndices.push_back(Elt: rewriter.getIndexAttr(value: lowPad));
52 highIndices.push_back(Elt: rewriter.getIndexAttr(value: highPad));
53 }
54
55 Value padValue = rewriter.create<arith::ConstantOp>(location: loc, args&: padAttr);
56
57 return rewriter.create<tensor::PadOp>(
58 location: loc, args: RankedTensorType::get(shape: paddedShape, elementType: inputETy), args&: input, args&: lowIndices,
59 args&: highIndices, args&: padValue);
60}
61
62static mlir::Value
63linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
64 Value conv, Value result,
65 ArrayRef<AffineMap> indexingMaps) {
66 ShapedType resultTy = cast<ShapedType>(Val: conv.getType());
67 return rewriter
68 .create<linalg::GenericOp>(
69 location: loc, args&: resultTy, args: ValueRange({bias, conv}), args&: result, args&: indexingMaps,
70 args: getNParallelLoopsAttrs(nParallelLoops: resultTy.getRank()),
71 args: [](OpBuilder &builder, Location loc, ValueRange args) {
72 Value biasVal = args[0];
73 Type resType = args[1].getType();
74 if (resType != biasVal.getType()) {
75 biasVal = builder.create<arith::ExtSIOp>(location: loc, args&: resType, args&: biasVal);
76 }
77 Value added = builder.create<arith::AddIOp>(location: loc, args&: biasVal, args: args[1]);
78 builder.create<linalg::YieldOp>(location: loc, args&: added);
79 })
80 .getResult(i: 0);
81}
82
83// Construct the affine map that a linalg generic would use to broadcast the
84// source tensor into the shape of the result tensor.
85static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
86 Value result) {
87 ShapedType resultTy = cast<ShapedType>(Val: result.getType());
88 ShapedType sourceTy = cast<ShapedType>(Val: source.getType());
89 const int64_t resultRank = resultTy.getRank();
90 const int64_t sourceRank = sourceTy.getRank();
91
92 // The source tensor is broadcast to all the outer dimensions of the
93 // result tensor.
94 SmallVector<AffineExpr> sourceDims;
95 // In the case of a rank one source tensor with a single element TOSA
96 // specifies that the value be broadcast meaning we need an edge case for a
97 // constant map.
98 assert(sourceTy.hasStaticShape() &&
99 "Dynamic broadcasting shapes not supported!");
100 if (sourceRank == 1 && sourceTy.getDimSize(idx: 0) == 1) {
101 sourceDims.push_back(Elt: rewriter.getAffineConstantExpr(constant: 0));
102 } else {
103 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: sourceRank)) {
104 auto expr = rewriter.getAffineDimExpr(position: dim + resultRank - sourceRank);
105 sourceDims.push_back(Elt: expr);
106 }
107 }
108
109 return AffineMap::get(/*dimCount=*/resultRank,
110 /*symbolCount=*/0, results: sourceDims, context: rewriter.getContext());
111}
112
113// Broadcast the source value to all the outer dimensions of the result value.
114// If required, the element type is expanded using an arith.extsi or arith.extf
115// operation as appropriate.
116static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
117 Location loc, Value source,
118 Value result) {
119 ShapedType resultTy = cast<ShapedType>(Val: result.getType());
120 const int64_t resultRank = resultTy.getRank();
121 // Creating maps for the input and output of the broacast-like generic op.
122 SmallVector<AffineMap, 2> indexingMaps;
123 indexingMaps.push_back(Elt: getBroadcastingMap(rewriter, source, result));
124 indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank: resultRank));
125
126 // Build the broadcast-like operation as a linalg.generic.
127 return rewriter
128 .create<linalg::GenericOp>(
129 location: loc, args&: resultTy, args: ValueRange({source}), args&: result, args&: indexingMaps,
130 args: getNParallelLoopsAttrs(nParallelLoops: resultTy.getRank()),
131 args: [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
132 Value biasVal = args[0];
133 Type resType = args[1].getType();
134 if (resType != biasVal.getType()) {
135 biasVal =
136 resultTy.getElementType().isFloat()
137 ? builder.create<arith::ExtFOp>(location: loc, args&: resType, args&: biasVal)
138 .getResult()
139 : builder.create<arith::ExtSIOp>(location: loc, args&: resType, args&: biasVal)
140 .getResult();
141 }
142 builder.create<linalg::YieldOp>(location: loc, args&: biasVal);
143 })
144 .getResult(i: 0);
145}
146
147static mlir::Value reifyConstantDim(int64_t attr,
148 ImplicitLocOpBuilder &builder) {
149 return builder.create<arith::ConstantIndexOp>(args&: attr);
150}
151
152// Calculating the output width/height using the formula:
153// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
154// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
155
156static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
157 int64_t padBeforeAttr,
158 int64_t padAfterAttr, Value kernelDim,
159 int64_t strideAttr,
160 int64_t dilationAttr,
161 OpBuilder &rewriter) {
162 ImplicitLocOpBuilder builder(loc, rewriter);
163 auto one = rewriter.create<arith::ConstantOp>(
164 location: loc, args: IntegerAttr::get(type: inputDim.getType(), value: 1));
165 Value padBefore = reifyConstantDim(attr: padBeforeAttr, builder);
166 Value paddedBefore = builder.create<arith::AddIOp>(args&: inputDim, args&: padBefore);
167 Value padAfter = reifyConstantDim(attr: padAfterAttr, builder);
168 Value paddedAfter = builder.create<arith::AddIOp>(args&: paddedBefore, args&: padAfter);
169
170 Value subOne = builder.create<arith::SubIOp>(args&: kernelDim, args&: one);
171 Value dilation = reifyConstantDim(attr: dilationAttr, builder);
172 Value dilated = builder.create<arith::MulIOp>(args&: dilation, args&: subOne);
173 Value addOne = builder.create<arith::AddIOp>(args&: dilated, args&: one);
174
175 Value subtract = builder.create<arith::SubIOp>(args&: paddedAfter, args&: addOne);
176 Value stride = reifyConstantDim(attr: strideAttr, builder);
177 Value divide = builder.create<arith::DivUIOp>(args&: subtract, args&: stride);
178 return builder.create<arith::AddIOp>(args&: divide, args&: one);
179}
180
181// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
182static SmallVector<Value> inferDynamicDimsForConv(
183 Location loc, Value input, Value weight, ShapedType resultTy,
184 ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
185 ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
186 ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
187 ShapedType inputTy = cast<ShapedType>(Val: input.getType());
188 int64_t inputRank = inputTy.getRank();
189
190 SmallVector<Value> dynDims;
191 dynDims.resize(N: resultTy.getRank());
192
193 for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
194 int64_t inputDim = inputSizeDims[i];
195 int64_t kernelDim = kernelSizeDims[i];
196 if (resultTy.isDynamicDim(idx: inputDim)) {
197 auto padTop = padAttr[i * 2];
198 auto padBottom = padAttr[i * 2 + 1];
199 auto stride = strideAttr[i];
200 auto dilation = dilationAttr[i];
201 Value initDynDim = rewriter.create<tensor::DimOp>(location: loc, args&: input, args&: inputDim);
202 Value kernelDynDim =
203 rewriter.create<tensor::DimOp>(location: loc, args&: weight, args&: kernelDim);
204 // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
205 dynDims[inputDim] =
206 getConvOrPoolOutputDim(loc, inputDim: initDynDim, padBeforeAttr: padTop, padAfterAttr: padBottom,
207 kernelDim: kernelDynDim, strideAttr: stride, dilationAttr: dilation, rewriter);
208 }
209 }
210
211 // Get the batch/channels dimensions.
212 for (int i = 0; i < inputRank; i++) {
213 if (resultTy.isDynamicDim(idx: i) && !dynDims[i])
214 dynDims[i] = rewriter.create<tensor::DimOp>(location: loc, args&: input, args&: i);
215 }
216
217 SmallVector<Value> filteredDims = condenseValues(values: dynDims);
218 return filteredDims;
219}
220
221// Creates a map to collapse the last dimension of the Depthwise convolution op
222// due to a shape mismatch
223static void createDepthwiseConvCollapseMap(
224 int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
225 OpBuilder &rewriter) {
226 reassociationMap.resize(N: outputRank);
227 for (int i = 0; i < outputRank; i++) {
228 reassociationMap[i].push_back(Elt: rewriter.getAffineDimExpr(position: i));
229 }
230 reassociationMap[outputRank - 1].push_back(
231 Elt: rewriter.getAffineDimExpr(position: outputRank));
232}
233
234namespace {
235
236template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
237class ConvConverter : public OpConversionPattern<TosaConvOp> {
238public:
239 using OpConversionPattern<TosaConvOp>::OpConversionPattern;
240 LogicalResult
241 matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
242 ConversionPatternRewriter &rewriter) const final {
243 Location loc = op->getLoc();
244 Value input = op->getOperand(0);
245 Value weight = op->getOperand(1);
246 Value bias = op->getOperand(2);
247
248 ShapedType inputTy = cast<ShapedType>(Val: input.getType());
249 ShapedType weightTy = cast<ShapedType>(Val: weight.getType());
250 ShapedType biasTy = cast<ShapedType>(Val: bias.getType());
251 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
252
253 Type inputETy = inputTy.getElementType();
254
255 DenseI64ArrayAttr padAttr = op.getPadAttr();
256 DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
257 DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
258
259 Type accETy = op.getAccType();
260 Type accTy = RankedTensorType::get(shape: resultTy.getShape(), elementType: accETy);
261
262 // Get and verify zero points.
263 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
264 if (failed(Result: maybeIZp))
265 return rewriter.notifyMatchFailure(
266 op, "input zero point cannot be statically determined");
267
268 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
269 if (failed(Result: maybeWZp))
270 return rewriter.notifyMatchFailure(
271 op, "weight zero point cannot be statically determined");
272
273 const int64_t inputZpVal = *maybeIZp;
274 const int64_t weightZpVal = *maybeWZp;
275
276 if (op.verifyInputZeroPoint(inputZpVal).failed())
277 return rewriter.notifyMatchFailure(
278 op, "input zero point must be zero for non-int8 integer types");
279
280 if (op.verifyWeightZeroPoint(weightZpVal).failed())
281 return rewriter.notifyMatchFailure(
282 op, "weight zero point must be zero for non-int8 integer types");
283
284 bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
285
286 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
287 return rewriter.notifyMatchFailure(
288 op, "tosa.conv ops require static shapes for weight and bias");
289
290 if (inputETy.isUnsignedInteger())
291 return rewriter.notifyMatchFailure(
292 op, "tosa.conv ops does not support unsigned integer input");
293
294 llvm::SmallVector<int64_t> inputSizeDims;
295 llvm::SmallVector<int64_t> kernelSizeDims;
296 for (int i = 1; i < resultTy.getRank() - 1; i++) {
297 inputSizeDims.push_back(Elt: i);
298 kernelSizeDims.push_back(Elt: i);
299 }
300
301 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
302 loc, input, weight, resultTy, padAttr: padAttr.asArrayRef(),
303 strideAttr: strideTosaAttr.asArrayRef(), dilationAttr: dilationTosaAttr.asArrayRef(),
304 inputSizeDims, kernelSizeDims, rewriter);
305
306 auto weightShape = weightTy.getShape();
307
308 // Apply padding as necessary.
309 TypedAttr zeroAttr = rewriter.getZeroAttr(type: inputETy);
310 if (hasZp) {
311 int64_t intMin =
312 APInt::getSignedMinValue(numBits: inputETy.getIntOrFloatBitWidth())
313 .getSExtValue();
314 int64_t intMax =
315 APInt::getSignedMaxValue(numBits: inputETy.getIntOrFloatBitWidth())
316 .getSExtValue();
317
318 if (inputZpVal < intMin || inputZpVal > intMax)
319 return rewriter.notifyMatchFailure(
320 op, "tosa.conv op quantization has zp outside of input range");
321
322 zeroAttr = rewriter.getIntegerAttr(type: inputETy, value: inputZpVal);
323 }
324
325 llvm::SmallVector<int64_t> pad;
326 pad.resize(N: 2, NV: 0);
327 llvm::append_range(C&: pad, R: padAttr.asArrayRef());
328 pad.resize(N: pad.size() + 2, NV: 0);
329 input = applyPad(loc, input, pad, padAttr: zeroAttr, rewriter);
330
331 if (4 == inputTy.getRank()) {
332 // For 2D convolutions, we need to check if the target convolution op
333 // wants a HWCF kernel layout.
334 bool wantHwcf =
335 hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
336 : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
337 if (wantHwcf) {
338 // Transpose the kernel to match dimension ordering of the linalg
339 // convolution operation.
340 // TODO(suderman): See if this can be efficiently folded - check whether
341 // the input is used anywhere else, if not fold the constant.
342 SmallVector<int32_t> weightPerm;
343 for (int i = 1; i < resultTy.getRank(); i++)
344 weightPerm.push_back(Elt: i);
345 weightPerm.push_back(Elt: 0);
346
347 SmallVector<int64_t> newWeightShape;
348 for (auto dim : weightPerm)
349 newWeightShape.push_back(Elt: weightShape[dim]);
350 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(values: weightPerm);
351 Type newWeightTy =
352 RankedTensorType::get(shape: newWeightShape, elementType: weightTy.getElementType());
353 weight = rewriter.create<tosa::TransposeOp>(location: loc, args&: newWeightTy, args&: weight,
354 args&: weightPermAttr);
355 }
356 }
357
358 // For Conv3D transpose the kernel to match dimension ordering of the linalg
359 // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
360 // map directly and then transpose later if desired.
361 if (5 == inputTy.getRank()) {
362 // TODO(suderman): See if this can be efficiently folded - check whether
363 // the input is used anywhere else, if not fold the constant.
364 SmallVector<int32_t> weightPerm;
365 for (int i = 1; i < resultTy.getRank(); i++)
366 weightPerm.push_back(Elt: i);
367 weightPerm.push_back(Elt: 0);
368
369 SmallVector<int64_t> newWeightShape;
370 for (auto dim : weightPerm)
371 newWeightShape.push_back(Elt: weightShape[dim]);
372 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(values: weightPerm);
373 Type newWeightTy =
374 RankedTensorType::get(shape: newWeightShape, elementType: weightTy.getElementType());
375 weight = rewriter.create<tosa::TransposeOp>(location: loc, args&: newWeightTy, args&: weight,
376 args&: weightPermAttr);
377 }
378
379 // Extract the attributes for convolution.
380 ArrayRef<int64_t> stride = strideTosaAttr;
381 ArrayRef<int64_t> dilation = dilationTosaAttr;
382
383 // Create the convolution op.
384 auto strideAttr = rewriter.getI64TensorAttr(values: stride);
385 auto dilationAttr = rewriter.getI64TensorAttr(values: dilation);
386
387 Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
388 location: loc, args: resultTy.getShape(), args&: accETy, args&: filteredDims);
389
390 Value broadcastBias =
391 linalgBroadcastAndMaybeExt(rewriter, loc, source: bias, result: biasEmptyTensor);
392
393 if (hasZp) {
394 auto iZp = rewriter.getI32IntegerAttr(value: inputZpVal);
395 auto kZp = rewriter.getI32IntegerAttr(value: weightZpVal);
396
397 auto iZpVal = rewriter.create<arith::ConstantOp>(location: loc, args&: iZp);
398 auto kZpVal = rewriter.create<arith::ConstantOp>(location: loc, args&: kZp);
399
400 Value conv =
401 rewriter
402 .create<LinalgConvQOp>(
403 loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
404 ValueRange{broadcastBias}, strideAttr, dilationAttr)
405 ->getResult(0);
406
407 rewriter.replaceOp(op, conv);
408 return success();
409 }
410
411 Value conv = rewriter
412 .create<LinalgConvOp>(
413 loc, accTy, ValueRange{input, weight},
414 ValueRange{broadcastBias}, strideAttr, dilationAttr)
415 ->getResult(0);
416
417 // We may need to truncate back to the result type if the accumulator was
418 // wider than the result.
419 if (resultTy != accTy)
420 conv = rewriter.create<tosa::CastOp>(location: loc, args&: resultTy, args&: conv);
421
422 rewriter.replaceOp(op, conv);
423 return success();
424 }
425};
426
427class DepthwiseConvConverter
428 : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
429public:
430 using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
431 LogicalResult
432 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
433 ConversionPatternRewriter &rewriter) const final {
434 Location loc = op->getLoc();
435 Value input = op->getOperand(idx: 0);
436 Value weight = op->getOperand(idx: 1);
437 Value bias = op->getOperand(idx: 2);
438
439 ShapedType inputTy = cast<ShapedType>(Val: input.getType());
440 ShapedType weightTy = cast<ShapedType>(Val: weight.getType());
441 ShapedType biasTy = cast<ShapedType>(Val: bias.getType());
442 ShapedType resultTy = cast<ShapedType>(Val: op->getResult(idx: 0).getType());
443 int64_t resultRank = resultTy.getRank();
444
445 Type inputETy = inputTy.getElementType();
446 Type resultETy = resultTy.getElementType();
447
448 auto padAttr = cast<DenseI64ArrayAttr>(Val: op->getAttr(name: "pad"));
449 auto strideTosaAttr = cast<DenseI64ArrayAttr>(Val: op->getAttr(name: "stride"));
450 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(Val: op->getAttr(name: "dilation"));
451
452 Type accETy = op.getAccType();
453
454 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
455 return rewriter.notifyMatchFailure(
456 arg&: op, msg: "tosa.depthwise_conv ops require static shapes");
457
458 // Compute output dynamic dims
459 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
460 loc, input, weight, resultTy, padAttr: padAttr.asArrayRef(),
461 strideAttr: strideTosaAttr.asArrayRef(), dilationAttr: dilationTosaAttr.asArrayRef(),
462 /*inputSizeDims=*/{1, 2},
463 /*kernelSizeDims=*/{0, 1}, rewriter);
464
465 // Get and verify zero points.
466
467 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
468 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
469 if (failed(Result: maybeIZp))
470 return rewriter.notifyMatchFailure(
471 arg&: op, msg: "input zero point cannot be statically determined");
472 if (failed(Result: maybeWZp))
473 return rewriter.notifyMatchFailure(
474 arg&: op, msg: "weight zero point cannot be statically determined");
475
476 const int64_t inputZpVal = *maybeIZp;
477 const int64_t weightZpVal = *maybeWZp;
478
479 if (op.verifyInputZeroPoint(zp: inputZpVal).failed())
480 return rewriter.notifyMatchFailure(
481 arg&: op, msg: "input zero point must be zero for non-int8 integer types");
482
483 if (op.verifyWeightZeroPoint(zp: weightZpVal).failed())
484 return rewriter.notifyMatchFailure(
485 arg&: op, msg: "weight zero point must be zero for non-int8 integer types");
486
487 bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
488 auto weightShape = weightTy.getShape();
489 auto resultShape = resultTy.getShape();
490
491 // Apply padding as necessary.
492 TypedAttr zeroAttr = rewriter.getZeroAttr(type: inputETy);
493 if (!hasNullZps) {
494 int64_t intMin =
495 APInt::getSignedMinValue(numBits: inputETy.getIntOrFloatBitWidth())
496 .getSExtValue();
497 int64_t intMax =
498 APInt::getSignedMaxValue(numBits: inputETy.getIntOrFloatBitWidth())
499 .getSExtValue();
500
501 if (inputZpVal < intMin || inputZpVal > intMax)
502 return rewriter.notifyMatchFailure(
503 arg&: op, msg: "tosa.depthwise_conv op quantization has zp outside of input "
504 "range");
505
506 zeroAttr = rewriter.getIntegerAttr(type: inputETy, value: inputZpVal);
507 }
508
509 llvm::SmallVector<int64_t> pad;
510 pad.resize(N: 2, NV: 0);
511 llvm::append_range(C&: pad, R: padAttr.asArrayRef());
512 pad.resize(N: pad.size() + 2, NV: 0);
513
514 input = applyPad(loc, input, pad, padAttr: zeroAttr, rewriter);
515
516 // Extract the attributes for convolution.
517 ArrayRef<int64_t> stride = strideTosaAttr;
518 ArrayRef<int64_t> dilation = dilationTosaAttr;
519
520 // Create the convolution op.
521 auto strideAttr = rewriter.getI64TensorAttr(values: stride);
522 auto dilationAttr = rewriter.getI64TensorAttr(values: dilation);
523 ShapedType linalgConvTy =
524 RankedTensorType::get(shape: {resultShape[0], resultShape[1], resultShape[2],
525 weightShape[2], weightShape[3]},
526 elementType: accETy);
527
528 auto resultZeroAttr = rewriter.getZeroAttr(type: accETy);
529 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
530 location: loc, args: linalgConvTy.getShape(), args&: accETy, args&: filteredDims);
531 Value zero = rewriter.create<arith::ConstantOp>(location: loc, args&: resultZeroAttr);
532 Value zeroTensor = rewriter
533 .create<linalg::FillOp>(location: loc, args: ValueRange{zero},
534 args: ValueRange{emptyTensor})
535 .result();
536
537 Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
538 location: loc, args: resultTy.getShape(), args&: resultETy, args&: filteredDims);
539
540 // Broadcast the initial value to the output tensor before convolving.
541 SmallVector<AffineMap, 4> indexingMaps;
542 indexingMaps.push_back(Elt: getBroadcastingMap(rewriter, source: bias, result: biasEmptyTensor));
543 indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank: resultRank));
544 indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank: resultRank));
545
546 if (hasNullZps) {
547 Value conv = rewriter
548 .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
549 location: loc, args&: linalgConvTy, args: ValueRange{input, weight},
550 args: ValueRange{zeroTensor}, args&: strideAttr, args&: dilationAttr)
551 .getResult(i: 0);
552
553 // We may need to truncate back to the result type if the accumulator was
554 // wider than the result.
555 if (accETy != resultETy)
556 conv = rewriter.create<tosa::CastOp>(
557 location: loc,
558 args: RankedTensorType::get(shape: cast<ShapedType>(Val: conv.getType()).getShape(),
559 elementType: resultETy),
560 args&: conv);
561
562 SmallVector<ReassociationExprs, 4> reassociationMap;
563 createDepthwiseConvCollapseMap(outputRank: resultRank, reassociationMap, rewriter);
564 Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
565 location: loc, args&: resultTy, args&: conv, args&: reassociationMap);
566
567 Value result =
568 rewriter
569 .create<linalg::GenericOp>(
570 location: loc, args&: resultTy, args: ValueRange({bias, convReshape}),
571 args&: biasEmptyTensor, args&: indexingMaps,
572 args: getNParallelLoopsAttrs(nParallelLoops: resultRank),
573 args: [&](OpBuilder &nestedBuilder, Location nestedLoc,
574 ValueRange args) {
575 Value added;
576 if (llvm::isa<FloatType>(Val: inputETy))
577 added = nestedBuilder.create<arith::AddFOp>(location: loc, args: args[0],
578 args: args[1]);
579 else
580 added = nestedBuilder.create<arith::AddIOp>(location: loc, args: args[0],
581 args: args[1]);
582 nestedBuilder.create<linalg::YieldOp>(location: nestedLoc, args&: added);
583 })
584 .getResult(i: 0);
585 rewriter.replaceOp(op, newValues: result);
586 } else {
587 IntegerAttr iZp = rewriter.getI32IntegerAttr(value: inputZpVal);
588 IntegerAttr wZp = rewriter.getI32IntegerAttr(value: weightZpVal);
589 auto iZpVal = rewriter.create<arith::ConstantOp>(location: loc, args&: iZp);
590 auto kZpVal = rewriter.create<arith::ConstantOp>(location: loc, args&: wZp);
591 Value conv =
592 rewriter
593 .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
594 location: loc, args&: linalgConvTy, args: ValueRange{input, weight, iZpVal, kZpVal},
595 args: ValueRange{zeroTensor}, args&: strideAttr, args&: dilationAttr)
596 .getResult(i: 0);
597 SmallVector<ReassociationExprs, 4> reassociationMap;
598 createDepthwiseConvCollapseMap(outputRank: resultRank, reassociationMap, rewriter);
599 Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
600 location: loc, args&: resultTy, args&: conv, args&: reassociationMap);
601 Value result = linalgIntBroadcastExtSIAdd(
602 rewriter, loc, bias, conv: convReshape, result: biasEmptyTensor, indexingMaps);
603 rewriter.replaceOp(op, newValues: result);
604 }
605 return success();
606 }
607};
608
609class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
610public:
611 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
612 LogicalResult
613 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
614 ConversionPatternRewriter &rewriter) const final {
615 Location loc = op.getLoc();
616
617 auto outputTy = cast<ShapedType>(Val: op.getType());
618 auto outputElementTy = outputTy.getElementType();
619
620 SmallVector<Value> dynDims;
621 dynDims.resize(N: cast<ShapedType>(Val: op->getResult(idx: 0).getType()).getRank());
622
623 if (!outputTy.hasRank() || outputTy.isDynamicDim(idx: 0)) {
624 dynDims[0] = rewriter.create<tensor::DimOp>(location: loc, args: op->getOperand(idx: 0), args: 0);
625 }
626
627 if (!outputTy.hasRank() || outputTy.isDynamicDim(idx: 1)) {
628 dynDims[1] = rewriter.create<tensor::DimOp>(location: loc, args: op->getOperand(idx: 0), args: 1);
629 }
630
631 if (!outputTy.hasRank() || outputTy.isDynamicDim(idx: 2)) {
632 dynDims[2] = rewriter.create<tensor::DimOp>(location: loc, args: op->getOperand(idx: 1), args: 2);
633 }
634
635 SmallVector<Value> filteredDims = condenseValues(values: dynDims);
636
637 auto zeroAttr = rewriter.getZeroAttr(type: outputElementTy);
638 Value zero = rewriter.create<arith::ConstantOp>(location: loc, args&: zeroAttr);
639 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
640 location: loc, args: outputTy.getShape(), args: outputTy.getElementType(), args&: filteredDims);
641 Value zeroTensor = rewriter
642 .create<linalg::FillOp>(location: loc, args: ValueRange{zero},
643 args: ValueRange{emptyTensor})
644 .result();
645
646 FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
647 FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
648 if (failed(Result: maybeAZp))
649 return rewriter.notifyMatchFailure(
650 arg&: op, msg: "input a zero point cannot be statically determined");
651 if (failed(Result: maybeBZp))
652 return rewriter.notifyMatchFailure(
653 arg&: op, msg: "input b zero point cannot be statically determined");
654
655 const int64_t aZpVal = *maybeAZp;
656 const int64_t bZpVal = *maybeBZp;
657
658 if (op.verifyAZeroPoint(zp: aZpVal).failed())
659 return rewriter.notifyMatchFailure(
660 arg&: op, msg: "input a zero point must be zero for non-int8 integer types");
661
662 if (op.verifyBZeroPoint(zp: bZpVal).failed())
663 return rewriter.notifyMatchFailure(
664 arg&: op, msg: "input b zero point must be zero for non-int8 integer types");
665
666 if (aZpVal == 0 && bZpVal == 0) {
667 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
668 op, args: TypeRange{op.getType()},
669 args: ValueRange{adaptor.getA(), adaptor.getB()}, args: ValueRange{zeroTensor});
670 return success();
671 }
672
673 auto aZp = rewriter.create<arith::ConstantOp>(
674 location: loc, args: rewriter.getI32IntegerAttr(value: aZpVal));
675 auto bZp = rewriter.create<arith::ConstantOp>(
676 location: loc, args: rewriter.getI32IntegerAttr(value: bZpVal));
677 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
678 op, args: TypeRange{op.getType()},
679 args: ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, args&: zeroTensor);
680
681 return success();
682 }
683};
684
685class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
686public:
687 using OpConversionPattern::OpConversionPattern;
688
689 // Compute the dynamic output sizes of the maxpool operation.
690 static SmallVector<Value>
691 computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
692 ConversionPatternRewriter &rewriter) {
693 TensorType resultTy = op.getType();
694 Location loc = op.getLoc();
695
696 Value input = adaptor.getInput();
697 ArrayRef<int64_t> kernel = op.getKernel();
698 ArrayRef<int64_t> pad = op.getPad();
699 ArrayRef<int64_t> stride = op.getStride();
700
701 SmallVector<Value> dynamicDims;
702
703 // Batch dimension
704 if (resultTy.isDynamicDim(idx: 0))
705 dynamicDims.push_back(Elt: rewriter.create<tensor::DimOp>(location: loc, args&: input, args: 0));
706
707 // Height/width dimensions
708 for (int64_t dim : {1, 2}) {
709 if (!resultTy.isDynamicDim(idx: dim))
710 continue;
711
712 // Index into the attribute arrays
713 int64_t index = dim - 1;
714
715 // Input height/width
716 Value ihw = rewriter.create<tensor::DimOp>(location: loc, args&: input, args&: dim);
717
718 // Kernel height/width
719 Value khw = rewriter.create<arith::ConstantIndexOp>(location: loc, args: kernel[index]);
720
721 // Output height/width
722 Value ohw = getConvOrPoolOutputDim(loc, inputDim: ihw, padBeforeAttr: pad[index * 2],
723 padAfterAttr: pad[index * 2 + 1], kernelDim: khw, strideAttr: stride[index],
724 /*dilationAttr=*/1, rewriter);
725 dynamicDims.push_back(Elt: ohw);
726 }
727
728 // Channel dimension
729 if (resultTy.isDynamicDim(idx: 3))
730 dynamicDims.push_back(Elt: rewriter.create<tensor::DimOp>(location: loc, args&: input, args: 3));
731
732 return dynamicDims;
733 }
734
735 LogicalResult
736 matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
737 ConversionPatternRewriter &rewriter) const final {
738 Location loc = op.getLoc();
739 Value input = adaptor.getInput();
740 ShapedType inputTy = cast<ShapedType>(Val: input.getType());
741
742 bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
743 ShapedType resultTy =
744 cast<ShapedType>(Val: getTypeConverter()->convertType(t: op.getType()));
745 if (!resultTy)
746 return rewriter.notifyMatchFailure(arg&: op, msg: "failed to convert type");
747 Type resultETy = inputTy.getElementType();
748
749 SmallVector<Value> dynamicDims =
750 computeDynamicOutputSizes(op, adaptor, rewriter);
751
752 // Determine what the initial value needs to be for the max pool op.
753 TypedAttr initialAttr;
754 if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
755 initialAttr = rewriter.getFloatAttr(
756 type: resultETy, value: APFloat::getLargest(
757 Sem: cast<FloatType>(Val&: resultETy).getFloatSemantics(), Negative: true));
758
759 else if (isUnsigned)
760 initialAttr = rewriter.getIntegerAttr(
761 type: resultETy, value: APInt::getZero(numBits: resultETy.getIntOrFloatBitWidth()));
762 else if (isa<IntegerType>(Val: resultETy))
763 initialAttr = rewriter.getIntegerAttr(
764 type: resultETy,
765 value: APInt::getSignedMinValue(numBits: resultETy.getIntOrFloatBitWidth()));
766
767 if (!initialAttr)
768 return rewriter.notifyMatchFailure(
769 arg&: op, msg: "Unsupported initial value for tosa.maxpool_2d op");
770
771 // Apply padding as necessary.
772 llvm::SmallVector<int64_t> pad;
773 pad.resize(N: 2, NV: 0);
774 llvm::append_range(C&: pad, R: op.getPad());
775 pad.resize(N: pad.size() + 2, NV: 0);
776
777 Value paddedInput = applyPad(loc, input, pad, padAttr: initialAttr, rewriter);
778
779 Value initialValue = rewriter.create<arith::ConstantOp>(location: loc, args&: initialAttr);
780
781 ArrayRef<int64_t> kernel = op.getKernel();
782 ArrayRef<int64_t> stride = op.getStride();
783
784 Attribute strideAttr = rewriter.getI64VectorAttr(values: stride);
785 Attribute dilationAttr = rewriter.getI64VectorAttr(values: {1, 1});
786
787 // Create the linalg op that performs pooling.
788 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
789 location: loc, args: resultTy.getShape(), args: resultTy.getElementType(), args&: dynamicDims);
790
791 Value filledEmptyTensor =
792 rewriter.create<linalg::FillOp>(location: loc, args&: initialValue, args&: emptyTensor)
793 .result();
794
795 Value fakeWindowDims =
796 rewriter.create<tensor::EmptyOp>(location: loc, args&: kernel, args&: resultETy);
797
798 if (isUnsigned) {
799 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
800 op, args: ArrayRef<Type>{resultTy}, args: ValueRange{paddedInput, fakeWindowDims},
801 args&: filledEmptyTensor, args&: strideAttr, args&: dilationAttr);
802 return llvm::success();
803 }
804
805 auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
806 location: op->getLoc(), args: ArrayRef<Type>{resultTy},
807 args: ValueRange{paddedInput, fakeWindowDims}, args&: filledEmptyTensor, args&: strideAttr,
808 args&: dilationAttr);
809
810 rewriter.setInsertionPointAfter(op);
811 StringRef nanMode = op.getNanMode();
812 rewriter.replaceOp(op, newOp: resultOp);
813
814 // NaN propagation has no meaning for non floating point types.
815 if (!isa<FloatType>(Val: getElementTypeOrSelf(type: inputTy)))
816 return success();
817
818 // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
819 // compare and select materialization is required.
820 //
821 // In the case of "IGNORE" we need to insert a compare and select. Since
822 // we've already produced a named op we will just take its body and modify
823 // it to include the appropriate checks. If the current value is NaN the
824 // old value of pool will be taken otherwise we use the result.
825 if (nanMode == "IGNORE") {
826 auto genericOp = rewriter.create<linalg::GenericOp>(
827 location: loc, args: resultOp.getType(i: 0), args: resultOp.getInputs(), args: resultOp.getOutputs(),
828 args: resultOp.getIndexingMapsArray(), args: resultOp.getIteratorTypesArray(),
829 args: [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
830 IRMapping map;
831 auto oldBlock = resultOp.getRegion().begin();
832 auto oldArgs = oldBlock->getArguments();
833 auto &oldMaxOp = *resultOp.getBlock()->begin();
834 map.map(from&: oldArgs, to&: blockArgs);
835 auto *newOp = opBuilder.clone(op&: oldMaxOp, mapper&: map);
836 Value isNaN = opBuilder.create<arith::CmpFOp>(
837 location: loc, args: arith::CmpFPredicate::UNO, args: blockArgs.front(),
838 args: blockArgs.front());
839 auto selectOp = opBuilder.create<arith::SelectOp>(
840 location: loc, args&: isNaN, args: blockArgs.back(), args: newOp->getResult(idx: 0));
841 opBuilder.create<linalg::YieldOp>(location: loc, args: selectOp.getResult());
842 });
843 rewriter.replaceOp(op: resultOp, newOp: genericOp);
844 }
845
846 return success();
847 }
848};
849
850class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
851public:
852 using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
853
854 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
855 PatternRewriter &rewriter) const final {
856 Location loc = op.getLoc();
857 Value input = op.getInput();
858 ShapedType inputTy = cast<ShapedType>(Val: input.getType());
859 Type inElementTy = inputTy.getElementType();
860
861 ShapedType resultTy = cast<ShapedType>(Val: op.getType());
862 Type resultETy = cast<ShapedType>(Val: op.getType()).getElementType();
863
864 Type accETy = op.getAccType();
865 ShapedType accTy = resultTy.clone(elementType: accETy);
866
867 auto dynamicDimsOr =
868 checkHasDynamicBatchDims(rewriter, op, params: {input, op.getOutput()});
869 if (!dynamicDimsOr.has_value())
870 return failure();
871 SmallVector<Value> dynamicDims = *dynamicDimsOr;
872
873 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
874 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
875 if (failed(Result: maybeIZp))
876 return rewriter.notifyMatchFailure(
877 arg&: op, msg: "input zero point could not be statically determined");
878 if (failed(Result: maybeOZp))
879 return rewriter.notifyMatchFailure(
880 arg&: op, msg: "output zero point could not be statically determined");
881
882 const int64_t inputZpVal = *maybeIZp;
883 const int64_t outputZpVal = *maybeOZp;
884
885 // Apply padding as necessary.
886 llvm::SmallVector<int64_t> pad;
887 pad.resize(N: 2, NV: 0);
888 llvm::append_range(C&: pad, R: op.getPad());
889 pad.resize(N: pad.size() + 2, NV: 0);
890 TypedAttr padAttr = rewriter.getZeroAttr(type: inElementTy);
891 // Unsupported element type
892 if (!padAttr)
893 return failure();
894 Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
895
896 auto initialAttr = rewriter.getZeroAttr(type: accETy);
897 Value initialValue = rewriter.create<arith::ConstantOp>(location: loc, args&: initialAttr);
898
899 ArrayRef<int64_t> kernel = op.getKernel();
900 ArrayRef<int64_t> stride = op.getStride();
901
902 Attribute strideAttr = rewriter.getI64VectorAttr(values: stride);
903 Attribute dilationAttr = rewriter.getI64VectorAttr(values: {1, 1});
904
905 // Create the linalg op that performs pooling.
906 Value poolEmptyTensor = rewriter.create<tensor::EmptyOp>(
907 location: loc, args: accTy.getShape(), args&: accETy, args&: dynamicDims);
908
909 Value filledEmptyTensor =
910 rewriter
911 .create<linalg::FillOp>(location: loc, args: ValueRange{initialValue},
912 args: ValueRange{poolEmptyTensor})
913 .result();
914
915 Value fakeWindowDims =
916 rewriter.create<tensor::EmptyOp>(location: loc, args&: kernel, args&: accETy);
917
918 // Sum across the pooled region.
919 Value poolingOp = rewriter
920 .create<linalg::PoolingNhwcSumOp>(
921 location: loc, args: ArrayRef<Type>{accTy},
922 args: ValueRange{paddedInput, fakeWindowDims},
923 args&: filledEmptyTensor, args&: strideAttr, args&: dilationAttr)
924 .getResult(i: 0);
925
926 // Normalize the summed value by the number of elements grouped in each
927 // pool.
928 Value iH = rewriter.create<tensor::DimOp>(location: loc, args&: poolingOp, args: 1);
929 Value iW = rewriter.create<tensor::DimOp>(location: loc, args&: poolingOp, args: 2);
930
931 auto one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
932 iH = rewriter.create<arith::SubIOp>(location: loc, args&: iH, args&: one);
933 iW = rewriter.create<arith::SubIOp>(location: loc, args&: iW, args&: one);
934
935 Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>(
936 location: loc, args: resultTy.getShape(), args&: resultETy, args&: dynamicDims);
937
938 auto affineMap = rewriter.getMultiDimIdentityMap(rank: resultTy.getRank());
939 auto genericOp = rewriter.create<linalg::GenericOp>(
940 location: loc, args: ArrayRef<Type>({resultTy}), args: ValueRange{poolingOp},
941 args: ValueRange{genericEmptyTensor},
942 args: ArrayRef<AffineMap>({affineMap, affineMap}),
943 args: getNParallelLoopsAttrs(nParallelLoops: resultTy.getRank()),
944 args: [&](OpBuilder &b, Location loc, ValueRange args) {
945 auto zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0);
946
947 // Determines what the portion of valid input is covered by the
948 // kernel.
949 auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
950 if (pad == 0)
951 return valid;
952
953 auto padVal = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: pad);
954 Value dpos = rewriter.create<arith::SubIOp>(location: loc, args&: pos, args&: padVal);
955
956 Value offset = rewriter.create<arith::MinSIOp>(location: loc, args&: dpos, args&: zero);
957 return rewriter.create<arith::AddIOp>(location: loc, args&: valid, args&: offset)
958 ->getResult(idx: 0);
959 };
960
961 auto coverageFn = [&](int64_t i, Value isize) -> Value {
962 Value strideVal =
963 rewriter.create<arith::ConstantIndexOp>(location: loc, args: stride[i - 1]);
964 Value val =
965 rewriter.create<arith::ConstantIndexOp>(location: loc, args: kernel[i - 1]);
966
967 // Find the position relative to the input tensor's ends.
968 Value left = rewriter.create<linalg::IndexOp>(location: loc, args&: i);
969 Value right = rewriter.create<arith::SubIOp>(location: loc, args&: isize, args&: left);
970 left = rewriter.create<arith::MulIOp>(location: loc, args&: left, args&: strideVal);
971 right = rewriter.create<arith::MulIOp>(location: loc, args&: right, args&: strideVal);
972
973 // Determine how much padding was included.
974 val = padFn(val, left, pad[i * 2]);
975 val = padFn(val, right, pad[i * 2 + 1]);
976 return rewriter.create<arith::MaxSIOp>(location: loc, args&: one, args&: val);
977 };
978
979 // Compute the indices from either end.
980 Value kH3 = coverageFn(1, iH);
981 Value kW3 = coverageFn(2, iW);
982
983 // Compute the total number of elements and normalize.
984 auto count = rewriter.create<arith::IndexCastOp>(
985 location: loc, args: rewriter.getI32Type(),
986 args: rewriter.create<arith::MulIOp>(location: loc, args&: kH3, args&: kW3));
987
988 // Divide by the number of summed values. For floats this is just
989 // a div however for quantized values input normalization had
990 // to be applied.
991 Value poolVal = args[0];
992 if (isa<FloatType>(Val: accETy)) {
993 auto countF = rewriter.create<arith::SIToFPOp>(location: loc, args&: accETy, args&: count);
994 poolVal = rewriter.create<arith::DivFOp>(location: loc, args&: poolVal, args&: countF)
995 ->getResult(idx: 0);
996 if (accETy.getIntOrFloatBitWidth() >
997 resultETy.getIntOrFloatBitWidth())
998 poolVal =
999 rewriter.create<arith::TruncFOp>(location: loc, args&: resultETy, args&: poolVal);
1000 } else {
1001
1002 // If we have quantization information we need to apply an offset
1003 // for the input zp value.
1004 if (inputZpVal != 0) {
1005 auto inputZp = rewriter.create<arith::ConstantOp>(
1006 location: loc, args: b.getIntegerAttr(type: accETy, value: inputZpVal));
1007 Value offset =
1008 rewriter.create<arith::MulIOp>(location: loc, args&: accETy, args&: count, args&: inputZp);
1009 poolVal =
1010 rewriter.create<arith::SubIOp>(location: loc, args&: accETy, args&: poolVal, args&: offset);
1011 }
1012
1013 // Compute: k = 32 - count_leading_zeros(value - 1)
1014 Value one32 = rewriter.create<arith::ConstantOp>(
1015 location: loc, args: rewriter.getI32IntegerAttr(value: 1));
1016 Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
1017 location: loc, args: rewriter.getI32IntegerAttr(value: 32));
1018
1019 Value countSubOne =
1020 rewriter.create<arith::SubIOp>(location: loc, args&: count, args&: one32);
1021 Value leadingZeros =
1022 rewriter.create<math::CountLeadingZerosOp>(location: loc, args&: countSubOne);
1023 Value k =
1024 rewriter.create<arith::SubIOp>(location: loc, args&: thirtyTwo32, args&: leadingZeros);
1025
1026 // Compute: numerator = ((1 << 30) + 1) << k
1027 Value k64 =
1028 rewriter.create<arith::ExtUIOp>(location: loc, args: rewriter.getI64Type(), args&: k);
1029 Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>(
1030 location: loc, args: rewriter.getI64IntegerAttr(value: (1 << 30) + 1));
1031 Value numerator =
1032 rewriter.create<arith::ShLIOp>(location: loc, args&: thirtyShiftPlusOne, args&: k64);
1033
1034 // Compute: scale.multiplier = numerator / value;
1035 Value count64 = rewriter.create<arith::ExtUIOp>(
1036 location: loc, args: rewriter.getI64Type(), args&: count);
1037 Value multiplier =
1038 rewriter.create<arith::DivUIOp>(location: loc, args&: numerator, args&: count64);
1039 multiplier = rewriter.create<arith::TruncIOp>(
1040 location: loc, args: rewriter.getI32Type(), args&: multiplier);
1041
1042 // Compute: scale.shift = 30 + k
1043 Value k8 =
1044 rewriter.create<arith::TruncIOp>(location: loc, args: rewriter.getI8Type(), args&: k);
1045 Value thirty8 = rewriter.create<arith::ConstantOp>(
1046 location: loc, args: rewriter.getI8IntegerAttr(value: 30));
1047 Value shift = rewriter.create<arith::AddIOp>(location: loc, args&: k8, args&: thirty8);
1048
1049 auto scaled =
1050 rewriter
1051 .create<tosa::ApplyScaleOp>(
1052 location: loc, args: rewriter.getI32Type(), args&: poolVal, args&: multiplier, args&: shift,
1053 args: rewriter.getStringAttr(bytes: "SINGLE_ROUND"))
1054 .getResult();
1055
1056 // If we have quantization information we need to apply output
1057 // zeropoint.
1058 if (outputZpVal != 0) {
1059 auto outputZp = rewriter.create<arith::ConstantOp>(
1060 location: loc, args: b.getIntegerAttr(type: scaled.getType(), value: outputZpVal));
1061 scaled = rewriter.create<arith::AddIOp>(location: loc, args&: scaled, args&: outputZp)
1062 .getResult();
1063 }
1064
1065 // Apply Clip.
1066 int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
1067
1068 auto min = rewriter.create<arith::ConstantIntOp>(
1069 location: loc, args&: accETy,
1070 args: APInt::getSignedMinValue(numBits: outBitwidth).getSExtValue());
1071 auto max = rewriter.create<arith::ConstantIntOp>(
1072 location: loc, args&: accETy,
1073 args: APInt::getSignedMaxValue(numBits: outBitwidth).getSExtValue());
1074 auto clamp = clampIntHelper(loc, arg: scaled, min, max, rewriter,
1075 /*isUnsigned=*/false);
1076
1077 poolVal = clamp;
1078 // Convert type.
1079 if (resultETy != clamp.getType()) {
1080 poolVal =
1081 rewriter.create<arith::TruncIOp>(location: loc, args&: resultETy, args&: poolVal);
1082 }
1083 }
1084
1085 rewriter.create<linalg::YieldOp>(location: loc, args&: poolVal);
1086 });
1087
1088 rewriter.replaceOp(op, newValues: genericOp.getResult(i: 0));
1089 return success();
1090 }
1091};
1092
1093class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1094public:
1095 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1096
1097 LogicalResult matchAndRewrite(tosa::TransposeOp op,
1098 PatternRewriter &rewriter) const final {
1099 const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
1100
1101 Location loc = op.getLoc();
1102 // The verifier should have made sure we have a valid TOSA permutation
1103 // tensor. isPermutationVector doesn't actually check the TOSA perms we
1104 // expect.
1105 SmallVector<OpFoldResult> inputSizes =
1106 tensor::getMixedSizes(builder&: rewriter, loc, value: op.getInput1());
1107 auto permutedSizes =
1108 applyTOSAPermutation<OpFoldResult>(input: inputSizes, perms: constantPerms);
1109
1110 auto permutedInit = rewriter.create<tensor::EmptyOp>(
1111 location: loc, args&: permutedSizes, args: op.getInput1().getType().getElementType());
1112 rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1113 op, args: op.getInput1(), args&: permutedInit,
1114 args: llvm::to_vector(Range: llvm::map_range(
1115 C: constantPerms, F: [](int32_t v) -> int64_t { return v; })));
1116 return success();
1117 }
1118};
1119} // namespace
1120
1121void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
1122 const TypeConverter &converter, RewritePatternSet *patterns,
1123 const TosaToLinalgNamedOptions &options) {
1124 if (options.preferConv2DKernelLayoutHWCF) {
1125 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1126 linalg::Conv2DNhwcHwcfQOp>>(
1127 arg: patterns->getContext());
1128 } else {
1129 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1130 linalg::Conv2DNhwcFhwcQOp>>(
1131 arg: patterns->getContext());
1132 }
1133 patterns->add<
1134 // clang-format off
1135 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1136 DepthwiseConvConverter,
1137 MatMulConverter,
1138 AvgPool2dConverter,
1139 TransposeConverter
1140 >(arg: patterns->getContext());
1141
1142 patterns->add<
1143 MaxPool2dConverter
1144 >(arg: converter, args: patterns->getContext());
1145 // clang-format on
1146}
1147

source code of mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp