1//===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg named ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Math/IR/Math.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Tensor/Utils/Utils.h"
20#include "mlir/Dialect/Tosa/IR/TosaOps.h"
21#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
22#include "mlir/Dialect/Utils/IndexingUtils.h"
23#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28
29#include "mlir/Interfaces/InferTypeOpInterface.h"
30
31#include <numeric>
32#include <type_traits>
33
34using namespace mlir;
35using namespace mlir::tosa;
36
37static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
38 TypedAttr padAttr, OpBuilder &rewriter) {
39 // Input should be padded only if necessary.
40 if (llvm::all_of(Range&: pad, P: [](int64_t p) { return p == 0; }))
41 return input;
42
43 ShapedType inputTy = cast<ShapedType>(input.getType());
44 Type inputETy = inputTy.getElementType();
45 auto inputShape = inputTy.getShape();
46
47 assert((inputShape.size() * 2) == pad.size());
48
49 SmallVector<int64_t, 4> paddedShape;
50 SmallVector<OpFoldResult, 8> lowIndices;
51 SmallVector<OpFoldResult, 8> highIndices;
52 for (size_t i : llvm::seq(inputShape.size())) {
53 auto lowPad = pad[i * 2];
54 auto highPad = pad[i * 2 + 1];
55 if (ShapedType::isDynamic(inputShape[i]))
56 paddedShape.push_back(inputShape[i]);
57 else
58 paddedShape.push_back(inputShape[i] + highPad + lowPad);
59 lowIndices.push_back(rewriter.getIndexAttr(lowPad));
60 highIndices.push_back(rewriter.getIndexAttr(highPad));
61 }
62
63 Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
64
65 return rewriter.create<tensor::PadOp>(
66 loc, RankedTensorType::get(paddedShape, inputETy), input, lowIndices,
67 highIndices, padValue);
68}
69
70static mlir::Value
71linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
72 Value conv, Value result,
73 ArrayRef<AffineMap> indexingMaps) {
74 ShapedType resultTy = cast<ShapedType>(conv.getType());
75 return rewriter
76 .create<linalg::GenericOp>(
77 loc, resultTy, ValueRange({bias, conv}), result, indexingMaps,
78 getNParallelLoopsAttrs(resultTy.getRank()),
79 [](OpBuilder &builder, Location loc, ValueRange args) {
80 Value biasVal = args[0];
81 Type resType = args[1].getType();
82 if (resType != biasVal.getType()) {
83 biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
84 }
85 Value added = builder.create<arith::AddIOp>(loc, biasVal, args[1]);
86 builder.create<linalg::YieldOp>(loc, added);
87 })
88 .getResult(0);
89}
90
91// Construct the affine map that a linalg generic would use to broadcast the
92// source tensor into the shape of the result tensor.
93static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
94 Value result) {
95 ShapedType resultTy = cast<ShapedType>(result.getType());
96 ShapedType sourceTy = cast<ShapedType>(source.getType());
97 const int64_t resultRank = resultTy.getRank();
98 const int64_t sourceRank = sourceTy.getRank();
99
100 // The source tensor is broadcast to all the outer dimensions of the
101 // result tensor.
102 SmallVector<AffineExpr> sourceDims;
103 // In the case of a rank one source tensor with a single element TOSA
104 // specifies that the value be broadcast meaning we need an edge case for a
105 // constant map.
106 assert(sourceTy.hasStaticShape() &&
107 "Dynamic broadcasting shapes not supported!");
108 if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
109 sourceDims.push_back(Elt: rewriter.getAffineConstantExpr(constant: 0));
110 } else {
111 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
112 auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
113 sourceDims.push_back(expr);
114 }
115 }
116
117 return AffineMap::get(/*dimCount=*/resultRank,
118 /*symbolCount=*/0, results: sourceDims, context: rewriter.getContext());
119}
120
121// Broadcast the source value to all the outer dimensions of the result value.
122// If required, the element type is expanded using an arith.extsi or arith.extf
123// operation as appropriate.
124static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
125 Location loc, Value source,
126 Value result) {
127 ShapedType resultTy = cast<ShapedType>(result.getType());
128 const int64_t resultRank = resultTy.getRank();
129 // Creating maps for the input and output of the broacast-like generic op.
130 SmallVector<AffineMap, 2> indexingMaps;
131 indexingMaps.push_back(Elt: getBroadcastingMap(rewriter, source, result));
132 indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank: resultRank));
133
134 // Build the broadcast-like operation as a linalg.generic.
135 return rewriter
136 .create<linalg::GenericOp>(
137 loc, resultTy, ValueRange({source}), result, indexingMaps,
138 getNParallelLoopsAttrs(resultTy.getRank()),
139 [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
140 Value biasVal = args[0];
141 Type resType = args[1].getType();
142 if (resType != biasVal.getType()) {
143 biasVal =
144 resultTy.getElementType().isFloat()
145 ? builder.create<arith::ExtFOp>(loc, resType, biasVal)
146 .getResult()
147 : builder.create<arith::ExtSIOp>(loc, resType, biasVal)
148 .getResult();
149 }
150 builder.create<linalg::YieldOp>(loc, biasVal);
151 })
152 .getResult(0);
153}
154
155static mlir::Value reifyConstantDim(int64_t attr,
156 ImplicitLocOpBuilder &builder) {
157 return builder.create<arith::ConstantIndexOp>(args&: attr);
158}
159
160// Calculating the output width/height using the formula:
161// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
162// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
163
164static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
165 int64_t padBeforeAttr,
166 int64_t padAfterAttr, Value kernelDim,
167 int64_t strideAttr,
168 int64_t dilationAttr,
169 OpBuilder &rewriter) {
170 ImplicitLocOpBuilder builder(loc, rewriter);
171 auto one = rewriter.create<arith::ConstantOp>(
172 loc, IntegerAttr::get(inputDim.getType(), 1));
173 Value padBefore = reifyConstantDim(attr: padBeforeAttr, builder);
174 Value paddedBefore = builder.create<arith::AddIOp>(inputDim, padBefore);
175 Value padAfter = reifyConstantDim(attr: padAfterAttr, builder);
176 Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
177
178 Value subOne = builder.create<arith::SubIOp>(kernelDim, one);
179 Value dilation = reifyConstantDim(attr: dilationAttr, builder);
180 Value dilated = builder.create<arith::MulIOp>(dilation, subOne);
181 Value addOne = builder.create<arith::AddIOp>(dilated, one);
182
183 Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne);
184 Value stride = reifyConstantDim(attr: strideAttr, builder);
185 Value divide = builder.create<arith::DivUIOp>(subtract, stride);
186 return builder.create<arith::AddIOp>(divide, one);
187}
188
189// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
190static SmallVector<Value> inferDynamicDimsForConv(
191 Location loc, Value input, Value weight, ShapedType resultTy,
192 ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
193 ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
194 ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
195 ShapedType inputTy = cast<ShapedType>(input.getType());
196 int64_t inputRank = inputTy.getRank();
197
198 SmallVector<Value> dynDims;
199 dynDims.resize(resultTy.getRank());
200
201 for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
202 int64_t inputDim = inputSizeDims[i];
203 int64_t kernelDim = kernelSizeDims[i];
204 if (resultTy.isDynamicDim(inputDim)) {
205 auto padTop = padAttr[i * 2];
206 auto padBottom = padAttr[i * 2 + 1];
207 auto stride = strideAttr[i];
208 auto dilation = dilationAttr[i];
209 Value initDynDim = rewriter.create<tensor::DimOp>(loc, input, inputDim);
210 Value kernelDynDim =
211 rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
212 // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
213 dynDims[inputDim] =
214 getConvOrPoolOutputDim(loc, inputDim: initDynDim, padBeforeAttr: padTop, padAfterAttr: padBottom,
215 kernelDim: kernelDynDim, strideAttr: stride, dilationAttr: dilation, rewriter);
216 }
217 }
218
219 // Get the batch/channels dimensions.
220 for (int i = 0; i < inputRank; i++) {
221 if (resultTy.isDynamicDim(i) && !dynDims[i])
222 dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
223 }
224
225 SmallVector<Value> filteredDims = condenseValues(values: dynDims);
226 return filteredDims;
227}
228
229// Creates a map to collapse the last dimension of the Depthwise convolution op
230// due to a shape mismatch
231static void createDepthwiseConvCollapseMap(
232 int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
233 OpBuilder &rewriter) {
234 reassociationMap.resize(N: outputRank);
235 for (int i = 0; i < outputRank; i++) {
236 reassociationMap[i].push_back(Elt: rewriter.getAffineDimExpr(position: i));
237 }
238 reassociationMap[outputRank - 1].push_back(
239 Elt: rewriter.getAffineDimExpr(position: outputRank));
240}
241
242namespace {
243
244template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
245class ConvConverter : public OpConversionPattern<TosaConvOp> {
246public:
247 using OpConversionPattern<TosaConvOp>::OpConversionPattern;
248 LogicalResult
249 matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
250 ConversionPatternRewriter &rewriter) const final {
251 Location loc = op->getLoc();
252 Value input = op->getOperand(0);
253 Value weight = op->getOperand(1);
254 Value bias = op->getOperand(2);
255
256 ShapedType inputTy = cast<ShapedType>(input.getType());
257 ShapedType weightTy = cast<ShapedType>(weight.getType());
258 ShapedType biasTy = cast<ShapedType>(bias.getType());
259 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
260
261 Type inputETy = inputTy.getElementType();
262
263 DenseI64ArrayAttr padAttr = op.getPadAttr();
264 DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
265 DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
266
267 Type accETy = op.getAccType();
268 Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
269
270 // Get and verify zero points.
271 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
272 if (failed(Result: maybeIZp))
273 return rewriter.notifyMatchFailure(
274 op, "input zero point cannot be statically determined");
275
276 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
277 if (failed(Result: maybeWZp))
278 return rewriter.notifyMatchFailure(
279 op, "weight zero point cannot be statically determined");
280
281 const int64_t inputZpVal = *maybeIZp;
282 const int64_t weightZpVal = *maybeWZp;
283
284 if (op.verifyInputZeroPoint(inputZpVal).failed())
285 return rewriter.notifyMatchFailure(
286 op, "input zero point must be zero for non-int8 integer types");
287
288 if (op.verifyWeightZeroPoint(weightZpVal).failed())
289 return rewriter.notifyMatchFailure(
290 op, "weight zero point must be zero for non-int8 integer types");
291
292 bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
293
294 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
295 return rewriter.notifyMatchFailure(
296 op, "tosa.conv ops require static shapes for weight and bias");
297
298 if (inputETy.isUnsignedInteger())
299 return rewriter.notifyMatchFailure(
300 op, "tosa.conv ops does not support unsigned integer input");
301
302 llvm::SmallVector<int64_t> inputSizeDims;
303 llvm::SmallVector<int64_t> kernelSizeDims;
304 for (int i = 1; i < resultTy.getRank() - 1; i++) {
305 inputSizeDims.push_back(Elt: i);
306 kernelSizeDims.push_back(Elt: i);
307 }
308
309 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
310 loc, input, weight, resultTy, padAttr.asArrayRef(),
311 strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
312 inputSizeDims, kernelSizeDims, rewriter);
313
314 auto weightShape = weightTy.getShape();
315
316 // Apply padding as necessary.
317 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
318 if (hasZp) {
319 int64_t intMin =
320 APInt::getSignedMinValue(numBits: inputETy.getIntOrFloatBitWidth())
321 .getSExtValue();
322 int64_t intMax =
323 APInt::getSignedMaxValue(numBits: inputETy.getIntOrFloatBitWidth())
324 .getSExtValue();
325
326 if (inputZpVal < intMin || inputZpVal > intMax)
327 return rewriter.notifyMatchFailure(
328 op, "tosa.conv op quantization has zp outside of input range");
329
330 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
331 }
332
333 llvm::SmallVector<int64_t> pad;
334 pad.resize(N: 2, NV: 0);
335 llvm::append_range(pad, padAttr.asArrayRef());
336 pad.resize(N: pad.size() + 2, NV: 0);
337 input = applyPad(loc, input, pad, zeroAttr, rewriter);
338
339 if (4 == inputTy.getRank()) {
340 // For 2D convolutions, we need to check if the target convolution op
341 // wants a HWCF kernel layout.
342 bool wantHwcf =
343 hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
344 : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
345 if (wantHwcf) {
346 // Transpose the kernel to match dimension ordering of the linalg
347 // convolution operation.
348 // TODO(suderman): See if this can be efficiently folded - check whether
349 // the input is used anywhere else, if not fold the constant.
350 SmallVector<int32_t> weightPerm;
351 for (int i = 1; i < resultTy.getRank(); i++)
352 weightPerm.push_back(Elt: i);
353 weightPerm.push_back(Elt: 0);
354
355 SmallVector<int64_t> newWeightShape;
356 for (auto dim : weightPerm)
357 newWeightShape.push_back(Elt: weightShape[dim]);
358 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
359 Type newWeightTy =
360 RankedTensorType::get(newWeightShape, weightTy.getElementType());
361 weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
362 weightPermAttr);
363 }
364 }
365
366 // For Conv3D transpose the kernel to match dimension ordering of the linalg
367 // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
368 // map directly and then transpose later if desired.
369 if (5 == inputTy.getRank()) {
370 // TODO(suderman): See if this can be efficiently folded - check whether
371 // the input is used anywhere else, if not fold the constant.
372 SmallVector<int32_t> weightPerm;
373 for (int i = 1; i < resultTy.getRank(); i++)
374 weightPerm.push_back(Elt: i);
375 weightPerm.push_back(Elt: 0);
376
377 SmallVector<int64_t> newWeightShape;
378 for (auto dim : weightPerm)
379 newWeightShape.push_back(Elt: weightShape[dim]);
380 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
381 Type newWeightTy =
382 RankedTensorType::get(newWeightShape, weightTy.getElementType());
383 weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
384 weightPermAttr);
385 }
386
387 // Extract the attributes for convolution.
388 ArrayRef<int64_t> stride = strideTosaAttr;
389 ArrayRef<int64_t> dilation = dilationTosaAttr;
390
391 // Create the convolution op.
392 auto strideAttr = rewriter.getI64TensorAttr(values: stride);
393 auto dilationAttr = rewriter.getI64TensorAttr(values: dilation);
394
395 Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
396 loc, resultTy.getShape(), accETy, filteredDims);
397
398 Value broadcastBias =
399 linalgBroadcastAndMaybeExt(rewriter, loc, source: bias, result: biasEmptyTensor);
400
401 if (hasZp) {
402 auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
403 auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
404
405 auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
406 auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
407
408 Value conv =
409 rewriter
410 .create<LinalgConvQOp>(
411 loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
412 ValueRange{broadcastBias}, strideAttr, dilationAttr)
413 ->getResult(0);
414
415 rewriter.replaceOp(op, conv);
416 return success();
417 }
418
419 Value conv = rewriter
420 .create<LinalgConvOp>(
421 loc, accTy, ValueRange{input, weight},
422 ValueRange{broadcastBias}, strideAttr, dilationAttr)
423 ->getResult(0);
424
425 // We may need to truncate back to the result type if the accumulator was
426 // wider than the result.
427 if (resultTy != accTy)
428 conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
429
430 rewriter.replaceOp(op, conv);
431 return success();
432 }
433};
434
435class DepthwiseConvConverter
436 : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
437public:
438 using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
439 LogicalResult
440 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
441 ConversionPatternRewriter &rewriter) const final {
442 Location loc = op->getLoc();
443 Value input = op->getOperand(0);
444 Value weight = op->getOperand(1);
445 Value bias = op->getOperand(2);
446
447 ShapedType inputTy = cast<ShapedType>(input.getType());
448 ShapedType weightTy = cast<ShapedType>(weight.getType());
449 ShapedType biasTy = cast<ShapedType>(bias.getType());
450 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
451 int64_t resultRank = resultTy.getRank();
452
453 Type inputETy = inputTy.getElementType();
454 Type resultETy = resultTy.getElementType();
455
456 auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr("pad"));
457 auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
458 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
459
460 Type accETy = op.getAccType();
461
462 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
463 return rewriter.notifyMatchFailure(
464 op, "tosa.depthwise_conv ops require static shapes");
465
466 // Compute output dynamic dims
467 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
468 loc, input, weight, resultTy, padAttr.asArrayRef(),
469 strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
470 /*inputSizeDims=*/{1, 2},
471 /*kernelSizeDims=*/{0, 1}, rewriter);
472
473 // Get and verify zero points.
474
475 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
476 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
477 if (failed(Result: maybeIZp))
478 return rewriter.notifyMatchFailure(
479 op, "input zero point cannot be statically determined");
480 if (failed(Result: maybeWZp))
481 return rewriter.notifyMatchFailure(
482 op, "weight zero point cannot be statically determined");
483
484 const int64_t inputZpVal = *maybeIZp;
485 const int64_t weightZpVal = *maybeWZp;
486
487 if (op.verifyInputZeroPoint(inputZpVal).failed())
488 return rewriter.notifyMatchFailure(
489 op, "input zero point must be zero for non-int8 integer types");
490
491 if (op.verifyWeightZeroPoint(weightZpVal).failed())
492 return rewriter.notifyMatchFailure(
493 op, "weight zero point must be zero for non-int8 integer types");
494
495 bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
496 auto weightShape = weightTy.getShape();
497 auto resultShape = resultTy.getShape();
498
499 // Apply padding as necessary.
500 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
501 if (!hasNullZps) {
502 int64_t intMin =
503 APInt::getSignedMinValue(numBits: inputETy.getIntOrFloatBitWidth())
504 .getSExtValue();
505 int64_t intMax =
506 APInt::getSignedMaxValue(numBits: inputETy.getIntOrFloatBitWidth())
507 .getSExtValue();
508
509 if (inputZpVal < intMin || inputZpVal > intMax)
510 return rewriter.notifyMatchFailure(
511 op, "tosa.depthwise_conv op quantization has zp outside of input "
512 "range");
513
514 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
515 }
516
517 llvm::SmallVector<int64_t> pad;
518 pad.resize(N: 2, NV: 0);
519 llvm::append_range(pad, padAttr.asArrayRef());
520 pad.resize(N: pad.size() + 2, NV: 0);
521
522 input = applyPad(loc, input, pad, zeroAttr, rewriter);
523
524 // Extract the attributes for convolution.
525 ArrayRef<int64_t> stride = strideTosaAttr;
526 ArrayRef<int64_t> dilation = dilationTosaAttr;
527
528 // Create the convolution op.
529 auto strideAttr = rewriter.getI64TensorAttr(values: stride);
530 auto dilationAttr = rewriter.getI64TensorAttr(values: dilation);
531 ShapedType linalgConvTy =
532 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
533 weightShape[2], weightShape[3]},
534 accETy);
535
536 auto resultZeroAttr = rewriter.getZeroAttr(accETy);
537 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
538 loc, linalgConvTy.getShape(), accETy, filteredDims);
539 Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
540 Value zeroTensor = rewriter
541 .create<linalg::FillOp>(loc, ValueRange{zero},
542 ValueRange{emptyTensor})
543 .result();
544
545 Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
546 loc, resultTy.getShape(), resultETy, filteredDims);
547
548 // Broadcast the initial value to the output tensor before convolving.
549 SmallVector<AffineMap, 4> indexingMaps;
550 indexingMaps.push_back(Elt: getBroadcastingMap(rewriter, source: bias, result: biasEmptyTensor));
551 indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank: resultRank));
552 indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank: resultRank));
553
554 if (hasNullZps) {
555 Value conv = rewriter
556 .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
557 loc, linalgConvTy, ValueRange{input, weight},
558 ValueRange{zeroTensor}, strideAttr, dilationAttr)
559 .getResult(0);
560
561 // We may need to truncate back to the result type if the accumulator was
562 // wider than the result.
563 if (accETy != resultETy)
564 conv = rewriter.create<tosa::CastOp>(
565 loc,
566 RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
567 resultETy),
568 conv);
569
570 SmallVector<ReassociationExprs, 4> reassociationMap;
571 createDepthwiseConvCollapseMap(outputRank: resultRank, reassociationMap, rewriter);
572 Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
573 loc, resultTy, conv, reassociationMap);
574
575 Value result =
576 rewriter
577 .create<linalg::GenericOp>(
578 loc, resultTy, ValueRange({bias, convReshape}),
579 biasEmptyTensor, indexingMaps,
580 getNParallelLoopsAttrs(resultRank),
581 [&](OpBuilder &nestedBuilder, Location nestedLoc,
582 ValueRange args) {
583 Value added;
584 if (llvm::isa<FloatType>(inputETy))
585 added = nestedBuilder.create<arith::AddFOp>(loc, args[0],
586 args[1]);
587 else
588 added = nestedBuilder.create<arith::AddIOp>(loc, args[0],
589 args[1]);
590 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
591 })
592 .getResult(0);
593 rewriter.replaceOp(op, result);
594 } else {
595 IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
596 IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
597 auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
598 auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
599 Value conv =
600 rewriter
601 .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
602 loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
603 ValueRange{zeroTensor}, strideAttr, dilationAttr)
604 .getResult(0);
605 SmallVector<ReassociationExprs, 4> reassociationMap;
606 createDepthwiseConvCollapseMap(outputRank: resultRank, reassociationMap, rewriter);
607 Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
608 loc, resultTy, conv, reassociationMap);
609 Value result = linalgIntBroadcastExtSIAdd(
610 rewriter, loc, bias, conv: convReshape, result: biasEmptyTensor, indexingMaps);
611 rewriter.replaceOp(op, result);
612 }
613 return success();
614 }
615};
616
617class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
618public:
619 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
620 LogicalResult
621 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
622 ConversionPatternRewriter &rewriter) const final {
623 Location loc = op.getLoc();
624
625 auto outputTy = cast<ShapedType>(op.getType());
626 auto outputElementTy = outputTy.getElementType();
627
628 SmallVector<Value> dynDims;
629 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
630
631 if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
632 dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
633 }
634
635 if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
636 dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
637 }
638
639 if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
640 dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
641 }
642
643 SmallVector<Value> filteredDims = condenseValues(values: dynDims);
644
645 auto zeroAttr = rewriter.getZeroAttr(type: outputElementTy);
646 Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
647 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
648 loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
649 Value zeroTensor = rewriter
650 .create<linalg::FillOp>(loc, ValueRange{zero},
651 ValueRange{emptyTensor})
652 .result();
653
654 FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
655 FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
656 if (failed(Result: maybeAZp))
657 return rewriter.notifyMatchFailure(
658 op, "input a zero point cannot be statically determined");
659 if (failed(Result: maybeBZp))
660 return rewriter.notifyMatchFailure(
661 op, "input b zero point cannot be statically determined");
662
663 const int64_t aZpVal = *maybeAZp;
664 const int64_t bZpVal = *maybeBZp;
665
666 if (op.verifyAZeroPoint(aZpVal).failed())
667 return rewriter.notifyMatchFailure(
668 op, "input a zero point must be zero for non-int8 integer types");
669
670 if (op.verifyBZeroPoint(bZpVal).failed())
671 return rewriter.notifyMatchFailure(
672 op, "input b zero point must be zero for non-int8 integer types");
673
674 if (aZpVal == 0 && bZpVal == 0) {
675 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
676 op, TypeRange{op.getType()},
677 ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
678 return success();
679 }
680
681 auto aZp = rewriter.create<arith::ConstantOp>(
682 loc, rewriter.getI32IntegerAttr(aZpVal));
683 auto bZp = rewriter.create<arith::ConstantOp>(
684 loc, rewriter.getI32IntegerAttr(bZpVal));
685 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
686 op, TypeRange{op.getType()},
687 ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
688
689 return success();
690 }
691};
692
693class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
694public:
695 using OpConversionPattern::OpConversionPattern;
696
697 // Compute the dynamic output sizes of the maxpool operation.
698 static SmallVector<Value>
699 computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
700 ConversionPatternRewriter &rewriter) {
701 TensorType resultTy = op.getType();
702 Location loc = op.getLoc();
703
704 Value input = adaptor.getInput();
705 ArrayRef<int64_t> kernel = op.getKernel();
706 ArrayRef<int64_t> pad = op.getPad();
707 ArrayRef<int64_t> stride = op.getStride();
708
709 SmallVector<Value> dynamicDims;
710
711 // Batch dimension
712 if (resultTy.isDynamicDim(0))
713 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
714
715 // Height/width dimensions
716 for (int64_t dim : {1, 2}) {
717 if (!resultTy.isDynamicDim(dim))
718 continue;
719
720 // Index into the attribute arrays
721 int64_t index = dim - 1;
722
723 // Input height/width
724 Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);
725
726 // Kernel height/width
727 Value khw = rewriter.create<arith::ConstantIndexOp>(location: loc, args: kernel[index]);
728
729 // Output height/width
730 Value ohw = getConvOrPoolOutputDim(loc, inputDim: ihw, padBeforeAttr: pad[index * 2],
731 padAfterAttr: pad[index * 2 + 1], kernelDim: khw, strideAttr: stride[index],
732 /*dilationAttr=*/1, rewriter);
733 dynamicDims.push_back(Elt: ohw);
734 }
735
736 // Channel dimension
737 if (resultTy.isDynamicDim(3))
738 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));
739
740 return dynamicDims;
741 }
742
743 LogicalResult
744 matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
745 ConversionPatternRewriter &rewriter) const final {
746 Location loc = op.getLoc();
747 Value input = adaptor.getInput();
748 ShapedType inputTy = cast<ShapedType>(input.getType());
749
750 bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
751 ShapedType resultTy =
752 cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
753 if (!resultTy)
754 return rewriter.notifyMatchFailure(op, "failed to convert type");
755 Type resultETy = inputTy.getElementType();
756
757 SmallVector<Value> dynamicDims =
758 computeDynamicOutputSizes(op, adaptor, rewriter);
759
760 // Determine what the initial value needs to be for the max pool op.
761 TypedAttr initialAttr;
762 if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
763 initialAttr = rewriter.getFloatAttr(
764 resultETy, APFloat::getLargest(
765 Sem: cast<FloatType>(resultETy).getFloatSemantics(), Negative: true));
766
767 else if (isUnsigned)
768 initialAttr = rewriter.getIntegerAttr(
769 resultETy, APInt::getZero(numBits: resultETy.getIntOrFloatBitWidth()));
770 else if (isa<IntegerType>(Val: resultETy))
771 initialAttr = rewriter.getIntegerAttr(
772 resultETy,
773 APInt::getSignedMinValue(numBits: resultETy.getIntOrFloatBitWidth()));
774
775 if (!initialAttr)
776 return rewriter.notifyMatchFailure(
777 op, "Unsupported initial value for tosa.maxpool_2d op");
778
779 // Apply padding as necessary.
780 llvm::SmallVector<int64_t> pad;
781 pad.resize(N: 2, NV: 0);
782 llvm::append_range(pad, op.getPad());
783 pad.resize(N: pad.size() + 2, NV: 0);
784
785 Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
786
787 Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
788
789 ArrayRef<int64_t> kernel = op.getKernel();
790 ArrayRef<int64_t> stride = op.getStride();
791
792 Attribute strideAttr = rewriter.getI64VectorAttr(values: stride);
793 Attribute dilationAttr = rewriter.getI64VectorAttr(values: {1, 1});
794
795 // Create the linalg op that performs pooling.
796 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
797 loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
798
799 Value filledEmptyTensor =
800 rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
801 .result();
802
803 Value fakeWindowDims =
804 rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
805
806 if (isUnsigned) {
807 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
808 op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
809 filledEmptyTensor, strideAttr, dilationAttr);
810 return llvm::success();
811 }
812
813 auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
814 op->getLoc(), ArrayRef<Type>{resultTy},
815 ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
816 dilationAttr);
817
818 rewriter.replaceOp(op, resultOp);
819
820 // NaN propagation has no meaning for non floating point types.
821 if (!isa<FloatType>(getElementTypeOrSelf(inputTy)))
822 return success();
823
824 // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
825 // compare and select materialization is required.
826 //
827 // In the case of "IGNORE" we need to insert a compare and select. Since
828 // we've already produced a named op we will just take its body and modify
829 // it to include the appropriate checks. If the current value is NaN the
830 // old value of pool will be taken otherwise we use the result.
831 if (const auto nanMode = op.getNanMode(); nanMode == "IGNORE") {
832 auto genericOp = rewriter.create<linalg::GenericOp>(
833 op->getLoc(), resultOp.getType(0), resultOp.getInputs(),
834 resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
835 resultOp.getIteratorTypesArray(),
836 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
837 IRMapping map;
838 auto oldBlock = resultOp.getRegion().begin();
839 auto oldArgs = oldBlock->getArguments();
840 auto &oldMaxOp = *resultOp.getBlock()->begin();
841 map.map(oldArgs, blockArgs);
842 auto *newOp = opBuilder.clone(oldMaxOp, map);
843 Value isNaN = opBuilder.create<arith::CmpFOp>(
844 op->getLoc(), arith::CmpFPredicate::UNO, blockArgs.front(),
845 blockArgs.front());
846 auto selectOp = opBuilder.create<arith::SelectOp>(
847 op->getLoc(), isNaN, blockArgs.back(), newOp->getResult(0));
848 opBuilder.create<linalg::YieldOp>(loc, selectOp.getResult());
849 });
850 rewriter.replaceOp(resultOp, genericOp);
851 }
852
853 return success();
854 }
855};
856
857class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
858public:
859 using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
860
861 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
862 PatternRewriter &rewriter) const final {
863 Location loc = op.getLoc();
864 Value input = op.getInput();
865 ShapedType inputTy = cast<ShapedType>(input.getType());
866 Type inElementTy = inputTy.getElementType();
867
868 ShapedType resultTy = cast<ShapedType>(op.getType());
869 Type resultETy = cast<ShapedType>(op.getType()).getElementType();
870
871 Type accETy = op.getAccType();
872 ShapedType accTy = resultTy.clone(accETy);
873
874 auto dynamicDimsOr =
875 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
876 if (!dynamicDimsOr.has_value())
877 return failure();
878 SmallVector<Value> dynamicDims = *dynamicDimsOr;
879
880 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
881 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
882 if (failed(Result: maybeIZp))
883 return rewriter.notifyMatchFailure(
884 op, "input zero point could not be statically determined");
885 if (failed(Result: maybeOZp))
886 return rewriter.notifyMatchFailure(
887 op, "output zero point could not be statically determined");
888
889 const int64_t inputZpVal = *maybeIZp;
890 const int64_t outputZpVal = *maybeOZp;
891
892 // Apply padding as necessary.
893 llvm::SmallVector<int64_t> pad;
894 pad.resize(N: 2, NV: 0);
895 llvm::append_range(pad, op.getPad());
896 pad.resize(N: pad.size() + 2, NV: 0);
897 TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
898 // Unsupported element type
899 if (!padAttr)
900 return failure();
901 Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
902
903 auto initialAttr = rewriter.getZeroAttr(accETy);
904 Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
905
906 ArrayRef<int64_t> kernel = op.getKernel();
907 ArrayRef<int64_t> stride = op.getStride();
908
909 Attribute strideAttr = rewriter.getI64VectorAttr(values: stride);
910 Attribute dilationAttr = rewriter.getI64VectorAttr(values: {1, 1});
911
912 // Create the linalg op that performs pooling.
913 Value poolEmptyTensor = rewriter.create<tensor::EmptyOp>(
914 loc, accTy.getShape(), accETy, dynamicDims);
915
916 Value filledEmptyTensor =
917 rewriter
918 .create<linalg::FillOp>(loc, ValueRange{initialValue},
919 ValueRange{poolEmptyTensor})
920 .result();
921
922 Value fakeWindowDims =
923 rewriter.create<tensor::EmptyOp>(loc, kernel, accETy);
924
925 // Sum across the pooled region.
926 Value poolingOp = rewriter
927 .create<linalg::PoolingNhwcSumOp>(
928 loc, ArrayRef<Type>{accTy},
929 ValueRange{paddedInput, fakeWindowDims},
930 filledEmptyTensor, strideAttr, dilationAttr)
931 .getResult(0);
932
933 // Normalize the summed value by the number of elements grouped in each
934 // pool.
935 Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1);
936 Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2);
937
938 auto one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
939 iH = rewriter.create<arith::SubIOp>(loc, iH, one);
940 iW = rewriter.create<arith::SubIOp>(loc, iW, one);
941
942 Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>(
943 loc, resultTy.getShape(), resultETy, dynamicDims);
944
945 auto affineMap = rewriter.getMultiDimIdentityMap(rank: resultTy.getRank());
946 auto genericOp = rewriter.create<linalg::GenericOp>(
947 loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
948 ValueRange{genericEmptyTensor},
949 ArrayRef<AffineMap>({affineMap, affineMap}),
950 getNParallelLoopsAttrs(resultTy.getRank()),
951 [&](OpBuilder &b, Location loc, ValueRange args) {
952 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
953
954 // Determines what the portion of valid input is covered by the
955 // kernel.
956 auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
957 if (pad == 0)
958 return valid;
959
960 auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
961 Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
962
963 Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
964 return rewriter.create<arith::AddIOp>(loc, valid, offset)
965 ->getResult(0);
966 };
967
968 auto coverageFn = [&](int64_t i, Value isize) -> Value {
969 Value strideVal =
970 rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]);
971 Value val =
972 rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]);
973
974 // Find the position relative to the input tensor's ends.
975 Value left = rewriter.create<linalg::IndexOp>(loc, i);
976 Value right = rewriter.create<arith::SubIOp>(loc, isize, left);
977 left = rewriter.create<arith::MulIOp>(loc, left, strideVal);
978 right = rewriter.create<arith::MulIOp>(loc, right, strideVal);
979
980 // Determine how much padding was included.
981 val = padFn(val, left, pad[i * 2]);
982 val = padFn(val, right, pad[i * 2 + 1]);
983 return rewriter.create<arith::MaxSIOp>(loc, one, val);
984 };
985
986 // Compute the indices from either end.
987 Value kH3 = coverageFn(1, iH);
988 Value kW3 = coverageFn(2, iW);
989
990 // Compute the total number of elements and normalize.
991 auto count = rewriter.create<arith::IndexCastOp>(
992 loc, rewriter.getI32Type(),
993 rewriter.create<arith::MulIOp>(loc, kH3, kW3));
994
995 // Divide by the number of summed values. For floats this is just
996 // a div however for quantized values input normalization had
997 // to be applied.
998 Value poolVal = args[0];
999 if (isa<FloatType>(accETy)) {
1000 auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
1001 poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
1002 ->getResult(0);
1003 if (accETy.getIntOrFloatBitWidth() >
1004 resultETy.getIntOrFloatBitWidth())
1005 poolVal =
1006 rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal);
1007 } else {
1008
1009 // If we have quantization information we need to apply an offset
1010 // for the input zp value.
1011 if (inputZpVal != 0) {
1012 auto inputZp = rewriter.create<arith::ConstantOp>(
1013 loc, b.getIntegerAttr(accETy, inputZpVal));
1014 Value offset =
1015 rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
1016 poolVal =
1017 rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
1018 }
1019
1020 // Compute: k = 32 - count_leading_zeros(value - 1)
1021 Value one32 = rewriter.create<arith::ConstantOp>(
1022 loc, rewriter.getI32IntegerAttr(1));
1023 Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
1024 loc, rewriter.getI32IntegerAttr(32));
1025
1026 Value countSubOne =
1027 rewriter.create<arith::SubIOp>(loc, count, one32);
1028 Value leadingZeros =
1029 rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne);
1030 Value k =
1031 rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros);
1032
1033 // Compute: numerator = ((1 << 30) + 1) << k
1034 Value k64 =
1035 rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k);
1036 Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>(
1037 loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
1038 Value numerator =
1039 rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64);
1040
1041 // Compute: scale.multiplier = numerator / value;
1042 Value count64 = rewriter.create<arith::ExtUIOp>(
1043 loc, rewriter.getI64Type(), count);
1044 Value multiplier =
1045 rewriter.create<arith::DivUIOp>(loc, numerator, count64);
1046 multiplier = rewriter.create<arith::TruncIOp>(
1047 loc, rewriter.getI32Type(), multiplier);
1048
1049 // Compute: scale.shift = 30 + k
1050 Value k8 =
1051 rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k);
1052 Value thirty8 = rewriter.create<arith::ConstantOp>(
1053 loc, rewriter.getI8IntegerAttr(30));
1054 Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8);
1055
1056 auto scaled =
1057 rewriter
1058 .create<tosa::ApplyScaleOp>(
1059 loc, rewriter.getI32Type(), poolVal, multiplier, shift,
1060 rewriter.getStringAttr("SINGLE_ROUND"))
1061 .getResult();
1062
1063 // If we have quantization information we need to apply output
1064 // zeropoint.
1065 if (outputZpVal != 0) {
1066 auto outputZp = rewriter.create<arith::ConstantOp>(
1067 loc, b.getIntegerAttr(scaled.getType(), outputZpVal));
1068 scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
1069 .getResult();
1070 }
1071
1072 // Apply Clip.
1073 int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
1074
1075 auto min = rewriter.create<arith::ConstantIntOp>(
1076 loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
1077 accETy);
1078 auto max = rewriter.create<arith::ConstantIntOp>(
1079 loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
1080 accETy);
1081 auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
1082 /*isUnsigned=*/false);
1083
1084 poolVal = clamp;
1085 // Convert type.
1086 if (resultETy != clamp.getType()) {
1087 poolVal =
1088 rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal);
1089 }
1090 }
1091
1092 rewriter.create<linalg::YieldOp>(loc, poolVal);
1093 });
1094
1095 rewriter.replaceOp(op, genericOp.getResult(0));
1096 return success();
1097 }
1098};
1099
1100class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1101public:
1102 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1103
1104 LogicalResult matchAndRewrite(tosa::TransposeOp op,
1105 PatternRewriter &rewriter) const final {
1106 const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
1107
1108 Location loc = op.getLoc();
1109 // The verifier should have made sure we have a valid TOSA permutation
1110 // tensor. isPermutationVector doesn't actually check the TOSA perms we
1111 // expect.
1112 SmallVector<OpFoldResult> inputSizes =
1113 tensor::getMixedSizes(builder&: rewriter, loc, value: op.getInput1());
1114 auto permutedSizes =
1115 applyTOSAPermutation<OpFoldResult>(input: inputSizes, perms: constantPerms);
1116
1117 auto permutedInit = rewriter.create<tensor::EmptyOp>(
1118 loc, permutedSizes, op.getInput1().getType().getElementType());
1119 rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1120 op, op.getInput1(), permutedInit,
1121 llvm::to_vector(llvm::map_range(
1122 constantPerms, [](int32_t v) -> int64_t { return v; })));
1123 return success();
1124 }
1125};
1126} // namespace
1127
1128void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
1129 const TypeConverter &converter, RewritePatternSet *patterns,
1130 const TosaToLinalgNamedOptions &options) {
1131 if (options.preferConv2DKernelLayoutHWCF) {
1132 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1133 linalg::Conv2DNhwcHwcfQOp>>(
1134 patterns->getContext());
1135 } else {
1136 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1137 linalg::Conv2DNhwcFhwcQOp>>(
1138 patterns->getContext());
1139 }
1140 patterns->add<
1141 // clang-format off
1142 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1143 DepthwiseConvConverter,
1144 MatMulConverter,
1145 AvgPool2dConverter,
1146 TransposeConverter
1147 >(patterns->getContext());
1148
1149 patterns->add<
1150 MaxPool2dConverter
1151 >(arg: converter, args: patterns->getContext());
1152 // clang-format on
1153}
1154

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp