1//===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg named ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Math/IR/Math.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/Tensor/IR/Tensor.h"
19#include "mlir/Dialect/Tensor/Utils/Utils.h"
20#include "mlir/Dialect/Tosa/IR/TosaOps.h"
21#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
22#include "mlir/Dialect/Utils/IndexingUtils.h"
23#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28
29#include "mlir/Interfaces/InferTypeOpInterface.h"
30
31#include <numeric>
32#include <type_traits>
33
34using namespace mlir;
35using namespace mlir::tosa;
36
37static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
38 TypedAttr padAttr, OpBuilder &rewriter) {
39 // Input should be padded only if necessary.
40 if (llvm::all_of(Range&: pad, P: [](int64_t p) { return p == 0; }))
41 return input;
42
43 ShapedType inputTy = cast<ShapedType>(input.getType());
44 Type inputETy = inputTy.getElementType();
45 auto inputShape = inputTy.getShape();
46
47 assert((inputShape.size() * 2) == pad.size());
48
49 SmallVector<int64_t, 4> paddedShape;
50 SmallVector<OpFoldResult, 8> lowIndices;
51 SmallVector<OpFoldResult, 8> highIndices;
52 for (size_t i : llvm::seq(inputShape.size())) {
53 auto lowPad = pad[i * 2];
54 auto highPad = pad[i * 2 + 1];
55 if (ShapedType::isDynamic(inputShape[i]))
56 paddedShape.push_back(inputShape[i]);
57 else
58 paddedShape.push_back(inputShape[i] + highPad + lowPad);
59 lowIndices.push_back(rewriter.getIndexAttr(lowPad));
60 highIndices.push_back(rewriter.getIndexAttr(highPad));
61 }
62
63 Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
64
65 return rewriter.create<tensor::PadOp>(
66 loc, RankedTensorType::get(paddedShape, inputETy), input, lowIndices,
67 highIndices, padValue);
68}
69
70static mlir::Value
71linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
72 Value conv, Value result,
73 ArrayRef<AffineMap> indexingMaps) {
74 ShapedType resultTy = cast<ShapedType>(conv.getType());
75 return rewriter
76 .create<linalg::GenericOp>(
77 loc, resultTy, ValueRange({bias, conv}), result, indexingMaps,
78 getNParallelLoopsAttrs(resultTy.getRank()),
79 [](OpBuilder &builder, Location loc, ValueRange args) {
80 Value biasVal = args[0];
81 Type resType = args[1].getType();
82 if (resType != biasVal.getType()) {
83 biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
84 }
85 Value added = builder.create<arith::AddIOp>(loc, biasVal, args[1]);
86 builder.create<linalg::YieldOp>(loc, added);
87 })
88 .getResult(0);
89}
90
91// Broadcast the source value to all the outer dimensions of the result value.
92// If required, the element type is expanded using an arith.extsi operation.
93static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
94 Location loc, Value source,
95 Value result) {
96 ShapedType resultTy = cast<ShapedType>(result.getType());
97 ShapedType sourceTy = cast<ShapedType>(source.getType());
98 int64_t resultRank = resultTy.getRank();
99 int64_t sourceRank = sourceTy.getRank();
100
101 // The source tensor is broadcast to all the outer dimensions of the
102 // result tensor.
103 SmallVector<AffineExpr> sourceDims;
104 // In the case of a rank one source tensor with a single element TOSA
105 // specifies that the value be broadcast meaning we need an edge case for a
106 // constant map.
107 assert(sourceTy.hasStaticShape() &&
108 "Dynamic broadcasting shapes not supported!");
109 if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
110 sourceDims.push_back(Elt: rewriter.getAffineConstantExpr(constant: 0));
111 } else {
112 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
113 auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
114 sourceDims.push_back(expr);
115 }
116 }
117
118 // Creating maps for the input and output of the broacast-like generic op.
119 SmallVector<AffineMap, 2> indexingMaps = {
120 // Broadcast the last dimension of the bias to all output dimensions.
121 AffineMap::get(/*dimCount=*/resultRank,
122 /*symbolCount=*/0, results: sourceDims, context: rewriter.getContext()),
123
124 // Output indexing map.
125 rewriter.getMultiDimIdentityMap(rank: resultRank)};
126
127 // Build the broadcast-like operation as a linalg.generic.
128 return rewriter
129 .create<linalg::GenericOp>(
130 loc, resultTy, ValueRange({source}), result, indexingMaps,
131 getNParallelLoopsAttrs(resultTy.getRank()),
132 [](OpBuilder &builder, Location loc, ValueRange args) {
133 Value biasVal = args[0];
134 Type resType = args[1].getType();
135 if (resType != biasVal.getType()) {
136 biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
137 }
138 builder.create<linalg::YieldOp>(loc, biasVal);
139 })
140 .getResult(0);
141}
142
143static mlir::Value reifyConstantDim(int64_t attr,
144 ImplicitLocOpBuilder &builder) {
145 return builder.create<arith::ConstantIndexOp>(args&: attr);
146}
147
148// Calculating the output width/height using the formula:
149// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
150// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
151
152static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
153 int64_t padBeforeAttr,
154 int64_t padAfterAttr, Value kernelDim,
155 int64_t strideAttr,
156 int64_t dilationAttr,
157 OpBuilder &rewriter) {
158 ImplicitLocOpBuilder builder(loc, rewriter);
159 auto one = rewriter.create<arith::ConstantOp>(
160 loc, IntegerAttr::get(inputDim.getType(), 1));
161 Value padBefore = reifyConstantDim(attr: padBeforeAttr, builder);
162 Value paddedBefore = builder.create<arith::AddIOp>(inputDim, padBefore);
163 Value padAfter = reifyConstantDim(attr: padAfterAttr, builder);
164 Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
165
166 Value subOne = builder.create<arith::SubIOp>(kernelDim, one);
167 Value dilation = reifyConstantDim(attr: dilationAttr, builder);
168 Value dilated = builder.create<arith::MulIOp>(dilation, subOne);
169 Value addOne = builder.create<arith::AddIOp>(dilated, one);
170
171 Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne);
172 Value stride = reifyConstantDim(attr: strideAttr, builder);
173 Value divide = builder.create<arith::DivUIOp>(subtract, stride);
174 return builder.create<arith::AddIOp>(divide, one);
175}
176
177// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
178static SmallVector<Value> inferDynamicDimsForConv(
179 Location loc, Value input, Value weight, ShapedType resultTy,
180 ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
181 ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
182 ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
183 ShapedType inputTy = cast<ShapedType>(input.getType());
184 int64_t inputRank = inputTy.getRank();
185
186 SmallVector<Value> dynDims;
187 dynDims.resize(resultTy.getRank());
188
189 for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
190 int64_t inputDim = inputSizeDims[i];
191 int64_t kernelDim = kernelSizeDims[i];
192 if (resultTy.isDynamicDim(inputDim)) {
193 auto padTop = padAttr[i * 2];
194 auto padBottom = padAttr[i * 2 + 1];
195 auto stride = strideAttr[i];
196 auto dilation = dilationAttr[i];
197 Value initDynDim = rewriter.create<tensor::DimOp>(loc, input, inputDim);
198 Value kernelDynDim =
199 rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
200 // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
201 dynDims[inputDim] =
202 getConvOrPoolOutputDim(loc, inputDim: initDynDim, padBeforeAttr: padTop, padAfterAttr: padBottom,
203 kernelDim: kernelDynDim, strideAttr: stride, dilationAttr: dilation, rewriter);
204 }
205 }
206
207 // Get the batch/channels dimensions.
208 for (int i = 0; i < inputRank; i++) {
209 if (resultTy.isDynamicDim(i) && !dynDims[i])
210 dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
211 }
212
213 SmallVector<Value> filteredDims = condenseValues(values: dynDims);
214 return filteredDims;
215}
216
217// Creates a map to collapse the last dimension of the Depthwise convolution op
218// due to a shape mismatch
219static void createDepthwiseConvCollapseMap(
220 int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
221 OpBuilder &rewriter) {
222 reassociationMap.resize(N: outputRank);
223 for (int i = 0; i < outputRank; i++) {
224 reassociationMap[i].push_back(Elt: rewriter.getAffineDimExpr(position: i));
225 }
226 reassociationMap[outputRank - 1].push_back(
227 Elt: rewriter.getAffineDimExpr(position: outputRank));
228}
229
230namespace {
231
232template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
233class ConvConverter : public OpConversionPattern<TosaConvOp> {
234public:
235 using OpConversionPattern<TosaConvOp>::OpConversionPattern;
236 LogicalResult
237 matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
238 ConversionPatternRewriter &rewriter) const final {
239 Location loc = op->getLoc();
240 Value input = op->getOperand(0);
241 Value weight = op->getOperand(1);
242 Value bias = op->getOperand(2);
243
244 ShapedType inputTy = cast<ShapedType>(input.getType());
245 ShapedType weightTy = cast<ShapedType>(weight.getType());
246 ShapedType biasTy = cast<ShapedType>(bias.getType());
247 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
248
249 Type inputETy = inputTy.getElementType();
250 Type resultETy = resultTy.getElementType();
251
252 DenseI64ArrayAttr padAttr = op.getPadAttr();
253 DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
254 DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
255 bool isQuantized = op.getQuantizationInfo().has_value();
256
257 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
258 return rewriter.notifyMatchFailure(
259 op, "tosa.conv ops require static shapes for weight and bias");
260
261 if (inputETy.isUnsignedInteger())
262 return rewriter.notifyMatchFailure(
263 op, "tosa.conv ops does not support unsigned integer input");
264
265 llvm::SmallVector<int64_t> inputSizeDims;
266 llvm::SmallVector<int64_t> kernelSizeDims;
267 for (int i = 1; i < resultTy.getRank() - 1; i++) {
268 inputSizeDims.push_back(Elt: i);
269 kernelSizeDims.push_back(Elt: i);
270 }
271
272 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
273 loc, input, weight, resultTy, padAttr.asArrayRef(),
274 strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
275 inputSizeDims, kernelSizeDims, rewriter);
276
277 auto weightShape = weightTy.getShape();
278
279 // Apply padding as necessary.
280 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
281 if (isQuantized) {
282 auto quantizationInfo = *op.getQuantizationInfo();
283 int64_t iZp = quantizationInfo.getInputZp();
284
285 int64_t intMin =
286 APInt::getSignedMinValue(numBits: inputETy.getIntOrFloatBitWidth())
287 .getSExtValue();
288 int64_t intMax =
289 APInt::getSignedMaxValue(numBits: inputETy.getIntOrFloatBitWidth())
290 .getSExtValue();
291
292 if (iZp < intMin || iZp > intMax)
293 return rewriter.notifyMatchFailure(
294 op, "tosa.conv op quantization has zp outside of input range");
295
296 zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
297 }
298
299 llvm::SmallVector<int64_t> pad;
300 pad.resize(N: 2, NV: 0);
301 llvm::append_range(pad, padAttr.asArrayRef());
302 pad.resize(N: pad.size() + 2, NV: 0);
303 input = applyPad(loc, input, pad, zeroAttr, rewriter);
304
305 if (4 == inputTy.getRank()) {
306 // For 2D convolutions, we need to check if the target convolution op
307 // wants a HWCF kernel layout.
308 bool wantHwcf =
309 isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
310 : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
311 if (wantHwcf) {
312 // Transpose the kernel to match dimension ordering of the linalg
313 // convolution operation.
314 // TODO(suderman): See if this can be efficiently folded - check whether
315 // the input is used anywhere else, if not fold the constant.
316 SmallVector<int64_t> weightPerm;
317 for (int i = 1; i < resultTy.getRank(); i++)
318 weightPerm.push_back(Elt: i);
319 weightPerm.push_back(Elt: 0);
320
321 SmallVector<int64_t> newWeightShape;
322 for (auto dim : weightPerm)
323 newWeightShape.push_back(Elt: weightShape[dim]);
324 auto weightPermAttr = rewriter.getI64TensorAttr(values: weightPerm);
325 Value weightPermValue =
326 rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
327 Type newWeightTy =
328 RankedTensorType::get(newWeightShape, weightTy.getElementType());
329 weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
330 weightPermValue);
331 }
332 }
333
334 // For Conv3D transpose the kernel to match dimension ordering of the linalg
335 // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
336 // map directly and then transpose later if desired.
337 if (5 == inputTy.getRank()) {
338 // TODO(suderman): See if this can be efficiently folded - check whether
339 // the input is used anywhere else, if not fold the constant.
340 SmallVector<int64_t> weightPerm;
341 for (int i = 1; i < resultTy.getRank(); i++)
342 weightPerm.push_back(Elt: i);
343 weightPerm.push_back(Elt: 0);
344
345 SmallVector<int64_t> newWeightShape;
346 for (auto dim : weightPerm)
347 newWeightShape.push_back(Elt: weightShape[dim]);
348 auto weightPermAttr = rewriter.getI64TensorAttr(values: weightPerm);
349 Value weightPermValue =
350 rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
351 Type newWeightTy =
352 RankedTensorType::get(newWeightShape, weightTy.getElementType());
353 weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
354 weightPermValue);
355 }
356
357 // Extract the attributes for convolution.
358 ArrayRef<int64_t> stride = strideTosaAttr;
359 ArrayRef<int64_t> dilation = dilationTosaAttr;
360
361 // Create the convolution op.
362 auto strideAttr = rewriter.getI64TensorAttr(values: stride);
363 auto dilationAttr = rewriter.getI64TensorAttr(values: dilation);
364
365 Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
366 loc, resultTy.getShape(), resultETy, filteredDims);
367
368 Value broadcastBias =
369 linalgBroadcastAndMaybeExtSI(rewriter, loc, source: bias, result: biasEmptyTensor);
370
371 if (isQuantized) {
372 auto quantizationInfo = *op.getQuantizationInfo();
373 auto iZp = rewriter.getI32IntegerAttr(value: quantizationInfo.getInputZp());
374 auto kZp = rewriter.getI32IntegerAttr(value: quantizationInfo.getWeightZp());
375
376 auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
377 auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
378
379 Value conv =
380 rewriter
381 .create<LinalgConvQOp>(
382 loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
383 ValueRange{broadcastBias}, strideAttr, dilationAttr)
384 ->getResult(0);
385
386 rewriter.replaceOp(op, conv);
387 return success();
388 }
389
390 Value conv = rewriter
391 .create<LinalgConvOp>(
392 loc, resultTy, ValueRange{input, weight},
393 ValueRange{broadcastBias}, strideAttr, dilationAttr)
394 ->getResult(0);
395
396 rewriter.replaceOp(op, conv);
397 return success();
398 }
399};
400
401class DepthwiseConvConverter
402 : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
403public:
404 using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
405 LogicalResult
406 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter) const final {
408 Location loc = op->getLoc();
409 Value input = op->getOperand(0);
410 Value weight = op->getOperand(1);
411 Value bias = op->getOperand(2);
412
413 ShapedType inputTy = cast<ShapedType>(input.getType());
414 ShapedType weightTy = cast<ShapedType>(weight.getType());
415 ShapedType biasTy = cast<ShapedType>(bias.getType());
416 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
417 int64_t resultRank = resultTy.getRank();
418
419 Type inputETy = inputTy.getElementType();
420 Type resultETy = resultTy.getElementType();
421
422 auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr("pad"));
423 auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
424 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
425
426 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
427 return rewriter.notifyMatchFailure(
428 op, "tosa.depthwise_conv ops require static shapes");
429
430 // Compute output dynamic dims
431 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
432 loc, input, weight, resultTy, padAttr.asArrayRef(),
433 strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
434 /*inputSizeDims=*/{1, 2},
435 /*kernelSizeDims=*/{0, 1}, rewriter);
436
437 bool isQuantized = op->hasAttr("quantization_info");
438 IntegerAttr iZp;
439 IntegerAttr kZp;
440 if (isQuantized) {
441 auto quantizationInfo =
442 cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
443 iZp = rewriter.getI32IntegerAttr(value: quantizationInfo.getInputZp());
444 kZp = rewriter.getI32IntegerAttr(value: quantizationInfo.getWeightZp());
445 }
446
447 auto weightShape = weightTy.getShape();
448 auto resultShape = resultTy.getShape();
449
450 // Apply padding as necessary.
451 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
452 if (isQuantized) {
453 auto quantizationInfo =
454 cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
455 int64_t iZp = quantizationInfo.getInputZp();
456
457 int64_t intMin =
458 APInt::getSignedMinValue(numBits: inputETy.getIntOrFloatBitWidth())
459 .getSExtValue();
460 int64_t intMax =
461 APInt::getSignedMaxValue(numBits: inputETy.getIntOrFloatBitWidth())
462 .getSExtValue();
463
464 if (iZp < intMin || iZp > intMax)
465 return rewriter.notifyMatchFailure(
466 op, "tosa.depthwise_conv op quantization has zp outside of input "
467 "range");
468
469 zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
470 }
471
472 llvm::SmallVector<int64_t> pad;
473 pad.resize(N: 2, NV: 0);
474 llvm::append_range(pad, padAttr.asArrayRef());
475 pad.resize(N: pad.size() + 2, NV: 0);
476
477 input = applyPad(loc, input, pad, zeroAttr, rewriter);
478
479 // Extract the attributes for convolution.
480 ArrayRef<int64_t> stride = strideTosaAttr;
481 ArrayRef<int64_t> dilation = dilationTosaAttr;
482
483 // Create the convolution op.
484 auto strideAttr = rewriter.getI64TensorAttr(values: stride);
485 auto dilationAttr = rewriter.getI64TensorAttr(values: dilation);
486 ShapedType linalgConvTy =
487 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
488 weightShape[2], weightShape[3]},
489 resultETy);
490
491 // Broadcast the initial value to the output tensor before convolving.
492 SmallVector<AffineMap, 4> indexingMaps;
493 indexingMaps.push_back(Elt: AffineMap::get(
494 /*dimCount=*/resultRank, /*symbolCount=*/0,
495 results: {rewriter.getAffineDimExpr(position: 3)}, context: rewriter.getContext()));
496 indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank: resultRank));
497 indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank: resultRank));
498
499 auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
500 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
501 loc, linalgConvTy.getShape(), resultETy, filteredDims);
502 Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
503 Value zeroTensor = rewriter
504 .create<linalg::FillOp>(loc, ValueRange{zero},
505 ValueRange{emptyTensor})
506 .result();
507
508 Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
509 loc, resultTy.getShape(), resultETy, filteredDims);
510 if (!isQuantized) {
511 Value conv = rewriter
512 .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
513 loc, linalgConvTy, ValueRange{input, weight},
514 ValueRange{zeroTensor}, strideAttr, dilationAttr)
515 .getResult(0);
516
517 SmallVector<ReassociationExprs, 4> reassociationMap;
518 createDepthwiseConvCollapseMap(outputRank: resultRank, reassociationMap, rewriter);
519 Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
520 loc, resultTy, conv, reassociationMap);
521
522 Value result =
523 rewriter
524 .create<linalg::GenericOp>(
525 loc, resultTy, ValueRange({bias, convReshape}),
526 biasEmptyTensor, indexingMaps,
527 getNParallelLoopsAttrs(resultRank),
528 [&](OpBuilder &nestedBuilder, Location nestedLoc,
529 ValueRange args) {
530 Value added = nestedBuilder.create<arith::AddFOp>(
531 loc, args[0], args[1]);
532 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
533 })
534 .getResult(0);
535 rewriter.replaceOp(op, result);
536 } else {
537 auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
538 auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
539 Value conv =
540 rewriter
541 .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
542 loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
543 ValueRange{zeroTensor}, strideAttr, dilationAttr)
544 .getResult(0);
545 SmallVector<ReassociationExprs, 4> reassociationMap;
546 createDepthwiseConvCollapseMap(outputRank: resultRank, reassociationMap, rewriter);
547 Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
548 loc, resultTy, conv, reassociationMap);
549 Value result = linalgIntBroadcastExtSIAdd(
550 rewriter, loc, bias, conv: convReshape, result: biasEmptyTensor, indexingMaps);
551 rewriter.replaceOp(op, result);
552 }
553 return success();
554 }
555};
556
557class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
558public:
559 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
560 LogicalResult
561 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
562 ConversionPatternRewriter &rewriter) const final {
563 Location loc = op.getLoc();
564
565 auto outputTy = cast<ShapedType>(op.getType());
566 auto outputElementTy = outputTy.getElementType();
567
568 SmallVector<Value> dynDims;
569 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
570
571 if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
572 dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
573 }
574
575 if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
576 dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
577 }
578
579 if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
580 dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
581 }
582
583 SmallVector<Value> filteredDims = condenseValues(values: dynDims);
584
585 auto zeroAttr = rewriter.getZeroAttr(type: outputElementTy);
586 Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
587 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
588 loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
589 Value zeroTensor = rewriter
590 .create<linalg::FillOp>(loc, ValueRange{zero},
591 ValueRange{emptyTensor})
592 .result();
593 if (!op.getQuantizationInfo()) {
594 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
595 op, TypeRange{op.getType()},
596 ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
597 return success();
598 }
599
600 auto quantizationInfo = *op.getQuantizationInfo();
601 auto aZp = rewriter.create<arith::ConstantOp>(
602 loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
603 auto bZp = rewriter.create<arith::ConstantOp>(
604 loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
605 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
606 op, TypeRange{op.getType()},
607 ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
608
609 return success();
610 }
611};
612
613class FullyConnectedConverter
614 : public OpConversionPattern<tosa::FullyConnectedOp> {
615public:
616 using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
617 LogicalResult
618 matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
619 ConversionPatternRewriter &rewriter) const final {
620 Location loc = op.getLoc();
621 auto outputTy = cast<ShapedType>(op.getType());
622 auto input = op.getInput();
623 auto inputTy = cast<ShapedType>(input.getType());
624
625 auto bias = op.getBias();
626
627 auto weight = op.getWeight();
628 auto weightTy = cast<ShapedType>(weight.getType());
629 auto weightShape = weightTy.getShape();
630
631 auto outputETy = outputTy.getElementType();
632
633 SmallVector<Value> dynDims;
634 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
635
636 if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
637 dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
638 }
639
640 if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
641 dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
642 }
643
644 SmallVector<Value> filteredDims = condenseValues(values: dynDims);
645
646 SmallVector<int64_t> permutation{1, 0};
647 auto permutationAttr = rewriter.getI64TensorAttr(values: permutation);
648 Value permutationValue =
649 rewriter.create<arith::ConstantOp>(loc, permutationAttr);
650
651 SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
652 Type newWeightTy =
653 RankedTensorType::get(newWeightShape, weightTy.getElementType());
654
655 Value transposedWeight = rewriter.create<tosa::TransposeOp>(
656 loc, newWeightTy, weight, permutationValue);
657
658 Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
659 loc, outputTy.getShape(), outputETy, filteredDims);
660
661 Value broadcastBias =
662 linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
663
664 if (!op.getQuantizationInfo()) {
665 Value matmul = rewriter
666 .create<linalg::MatmulOp>(
667 loc, TypeRange{op.getType()},
668 ValueRange{input, transposedWeight}, broadcastBias)
669 ->getResult(0);
670
671 rewriter.replaceOp(op, matmul);
672 return success();
673 }
674
675 auto quantizationInfo = *op.getQuantizationInfo();
676 auto inputZp = rewriter.create<arith::ConstantOp>(
677 loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
678 auto outputZp = rewriter.create<arith::ConstantOp>(
679 loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
680 Value matmul =
681 rewriter
682 .create<linalg::QuantizedMatmulOp>(
683 loc, TypeRange{op.getType()},
684 ValueRange{input, transposedWeight, inputZp, outputZp},
685 broadcastBias)
686 ->getResult(0);
687
688 rewriter.replaceOp(op, matmul);
689 return success();
690 }
691};
692
693class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
694public:
695 using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
696
697 // Compute the dynamic output sizes of the maxpool operation.
698 static SmallVector<Value>
699 computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
700 TensorType resultTy = op.getType();
701 Location loc = op.getLoc();
702
703 TypedValue<TensorType> input = op.getInput();
704 ArrayRef<int64_t> kernel = op.getKernel();
705 ArrayRef<int64_t> pad = op.getPad();
706 ArrayRef<int64_t> stride = op.getStride();
707
708 SmallVector<Value> dynamicDims;
709
710 // Batch dimension
711 if (resultTy.isDynamicDim(0))
712 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
713
714 // Height/width dimensions
715 for (int64_t dim : {1, 2}) {
716 if (!resultTy.isDynamicDim(dim))
717 continue;
718
719 // Index into the attribute arrays
720 int64_t index = dim - 1;
721
722 // Input height/width
723 Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);
724
725 // Kernel height/width
726 Value khw = rewriter.create<arith::ConstantIndexOp>(location: loc, args: kernel[index]);
727
728 // Output height/width
729 Value ohw = getConvOrPoolOutputDim(loc, inputDim: ihw, padBeforeAttr: pad[index * 2],
730 padAfterAttr: pad[index * 2 + 1], kernelDim: khw, strideAttr: stride[index],
731 /*dilationAttr=*/1, rewriter);
732 dynamicDims.push_back(Elt: ohw);
733 }
734
735 // Channel dimension
736 if (resultTy.isDynamicDim(3))
737 dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));
738
739 return dynamicDims;
740 }
741
742 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
743 PatternRewriter &rewriter) const final {
744 Location loc = op.getLoc();
745 TypedValue<TensorType> input = op.getInput();
746 ShapedType inputTy = input.getType();
747
748 ShapedType resultTy = op.getType();
749 Type resultETy = inputTy.getElementType();
750
751 SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);
752
753 // Determine what the initial value needs to be for the max pool op.
754 TypedAttr initialAttr;
755 if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
756 initialAttr = rewriter.getFloatAttr(
757 resultETy, APFloat::getLargest(
758 Sem: cast<FloatType>(Val&: resultETy).getFloatSemantics(), Negative: true));
759
760 if (isa<IntegerType>(Val: resultETy))
761 initialAttr = rewriter.getIntegerAttr(
762 resultETy,
763 APInt::getSignedMinValue(numBits: resultETy.getIntOrFloatBitWidth()));
764
765 if (!initialAttr)
766 return rewriter.notifyMatchFailure(
767 op, "Unsupported initial value for tosa.maxpool_2d op");
768
769 // Apply padding as necessary.
770 llvm::SmallVector<int64_t> pad;
771 pad.resize(N: 2, NV: 0);
772 llvm::append_range(pad, op.getPad());
773 pad.resize(N: pad.size() + 2, NV: 0);
774
775 Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
776
777 Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
778
779 ArrayRef<int64_t> kernel = op.getKernel();
780 ArrayRef<int64_t> stride = op.getStride();
781
782 Attribute strideAttr = rewriter.getI64VectorAttr(values: stride);
783 Attribute dilationAttr = rewriter.getI64VectorAttr(values: {1, 1});
784
785 // Create the linalg op that performs pooling.
786 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
787 loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
788
789 Value filledEmptyTensor =
790 rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
791 .result();
792
793 Value fakeWindowDims =
794 rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
795
796 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
797 op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
798 filledEmptyTensor, strideAttr, dilationAttr);
799 return success();
800 }
801};
802
803class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
804public:
805 using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
806
807 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
808 PatternRewriter &rewriter) const final {
809 Location loc = op.getLoc();
810 Value input = op.getInput();
811 ShapedType inputTy = cast<ShapedType>(input.getType());
812 Type inElementTy = inputTy.getElementType();
813
814 ShapedType resultTy = cast<ShapedType>(op.getType());
815 Type resultETy = cast<ShapedType>(op.getType()).getElementType();
816
817 Type accETy = op.getAccType();
818 ShapedType accTy = resultTy.clone(accETy);
819
820 auto dynamicDimsOr =
821 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
822 if (!dynamicDimsOr.has_value())
823 return failure();
824 SmallVector<Value> dynamicDims = *dynamicDimsOr;
825
826 // Apply padding as necessary.
827 llvm::SmallVector<int64_t> pad;
828 pad.resize(N: 2, NV: 0);
829 llvm::append_range(pad, op.getPad());
830 pad.resize(N: pad.size() + 2, NV: 0);
831 TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
832 // Unsupported element type
833 if (!padAttr)
834 return failure();
835 Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
836
837 auto initialAttr = rewriter.getZeroAttr(accETy);
838 Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
839
840 ArrayRef<int64_t> kernel = op.getKernel();
841 ArrayRef<int64_t> stride = op.getStride();
842
843 Attribute strideAttr = rewriter.getI64VectorAttr(values: stride);
844 Attribute dilationAttr = rewriter.getI64VectorAttr(values: {1, 1});
845
846 // Create the linalg op that performs pooling.
847 Value poolEmptyTensor = rewriter.create<tensor::EmptyOp>(
848 loc, accTy.getShape(), accETy, dynamicDims);
849
850 Value filledEmptyTensor =
851 rewriter
852 .create<linalg::FillOp>(loc, ValueRange{initialValue},
853 ValueRange{poolEmptyTensor})
854 .result();
855
856 Value fakeWindowDims =
857 rewriter.create<tensor::EmptyOp>(loc, kernel, accETy);
858
859 // Sum across the pooled region.
860 Value poolingOp = rewriter
861 .create<linalg::PoolingNhwcSumOp>(
862 loc, ArrayRef<Type>{accTy},
863 ValueRange{paddedInput, fakeWindowDims},
864 filledEmptyTensor, strideAttr, dilationAttr)
865 .getResult(0);
866
867 // Normalize the summed value by the number of elements grouped in each
868 // pool.
869 Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1);
870 Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2);
871
872 auto one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
873 iH = rewriter.create<arith::SubIOp>(loc, iH, one);
874 iW = rewriter.create<arith::SubIOp>(loc, iW, one);
875
876 Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>(
877 loc, resultTy.getShape(), resultETy, dynamicDims);
878
879 auto affineMap = rewriter.getMultiDimIdentityMap(rank: resultTy.getRank());
880 auto genericOp = rewriter.create<linalg::GenericOp>(
881 loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
882 ValueRange{genericEmptyTensor},
883 ArrayRef<AffineMap>({affineMap, affineMap}),
884 getNParallelLoopsAttrs(resultTy.getRank()),
885 [&](OpBuilder &b, Location loc, ValueRange args) {
886 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
887
888 // Determines what the portion of valid input is covered by the
889 // kernel.
890 auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
891 if (pad == 0)
892 return valid;
893
894 auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
895 Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
896
897 Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
898 return rewriter.create<arith::AddIOp>(loc, valid, offset)
899 ->getResult(0);
900 };
901
902 auto coverageFn = [&](int64_t i, Value isize) -> Value {
903 Value strideVal =
904 rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]);
905 Value val =
906 rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]);
907
908 // Find the position relative to the input tensor's ends.
909 Value left = rewriter.create<linalg::IndexOp>(loc, i);
910 Value right = rewriter.create<arith::SubIOp>(loc, isize, left);
911 left = rewriter.create<arith::MulIOp>(loc, left, strideVal);
912 right = rewriter.create<arith::MulIOp>(loc, right, strideVal);
913
914 // Determine how much padding was included.
915 val = padFn(val, left, pad[i * 2]);
916 val = padFn(val, right, pad[i * 2 + 1]);
917 return rewriter.create<arith::MaxSIOp>(loc, one, val);
918 };
919
920 // Compute the indices from either end.
921 Value kH3 = coverageFn(1, iH);
922 Value kW3 = coverageFn(2, iW);
923
924 // Compute the total number of elements and normalize.
925 auto count = rewriter.create<arith::IndexCastOp>(
926 loc, rewriter.getI32Type(),
927 rewriter.create<arith::MulIOp>(loc, kH3, kW3));
928
929 // Divide by the number of summed values. For floats this is just
930 // a div however for quantized values input normalization had
931 // to be applied.
932 Value poolVal = args[0];
933 if (isa<FloatType>(accETy)) {
934 auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
935 poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
936 ->getResult(0);
937 if (accETy.getIntOrFloatBitWidth() >
938 resultETy.getIntOrFloatBitWidth())
939 poolVal =
940 rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal);
941 } else {
942
943 // If we have quantization information we need to apply an offset
944 // for the input zp value.
945 if (op.getQuantizationInfo()) {
946 auto quantizationInfo = *op.getQuantizationInfo();
947 auto inputZp = rewriter.create<arith::ConstantOp>(
948 loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
949 Value offset =
950 rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
951 poolVal =
952 rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
953 }
954
955 // Compute: k = 32 - count_leading_zeros(value - 1)
956 Value one32 = rewriter.create<arith::ConstantOp>(
957 loc, rewriter.getI32IntegerAttr(1));
958 Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
959 loc, rewriter.getI32IntegerAttr(32));
960
961 Value countSubOne =
962 rewriter.create<arith::SubIOp>(loc, count, one32);
963 Value leadingZeros =
964 rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne);
965 Value k =
966 rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros);
967
968 // Compute: numerator = ((1 << 30) + 1) << k
969 Value k64 =
970 rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k);
971 Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>(
972 loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
973 Value numerator =
974 rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64);
975
976 // Compute: scale.multiplier = numerator / value;
977 Value count64 = rewriter.create<arith::ExtUIOp>(
978 loc, rewriter.getI64Type(), count);
979 Value multiplier =
980 rewriter.create<arith::DivUIOp>(loc, numerator, count64);
981 multiplier = rewriter.create<arith::TruncIOp>(
982 loc, rewriter.getI32Type(), multiplier);
983
984 // Compute: scale.shift = 30 + k
985 Value k8 =
986 rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k);
987 Value thirty8 = rewriter.create<arith::ConstantOp>(
988 loc, rewriter.getI8IntegerAttr(30));
989 Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8);
990
991 auto scaled =
992 rewriter
993 .create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
994 poolVal, multiplier, shift,
995 rewriter.getBoolAttr(false))
996 .getResult();
997
998 // If we have quantization information we need to apply output
999 // zeropoint.
1000 if (op.getQuantizationInfo()) {
1001 auto quantizationInfo = *op.getQuantizationInfo();
1002 auto outputZp = rewriter.create<arith::ConstantOp>(
1003 loc, b.getIntegerAttr(scaled.getType(),
1004 quantizationInfo.getOutputZp()));
1005 scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
1006 .getResult();
1007 }
1008
1009 // Apply Clip.
1010 int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
1011
1012 auto min = rewriter.create<arith::ConstantIntOp>(
1013 loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
1014 accETy);
1015 auto max = rewriter.create<arith::ConstantIntOp>(
1016 loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
1017 accETy);
1018 auto clamp = clampIntHelper(loc, scaled, min, max, rewriter);
1019
1020 poolVal = clamp;
1021 // Convert type.
1022 if (resultETy != clamp.getType()) {
1023 poolVal =
1024 rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal);
1025 }
1026 }
1027
1028 rewriter.create<linalg::YieldOp>(loc, poolVal);
1029 });
1030
1031 rewriter.replaceOp(op, genericOp.getResult(0));
1032 return success();
1033 }
1034};
1035
1036class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1037public:
1038 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1039
1040 LogicalResult matchAndRewrite(tosa::TransposeOp op,
1041 PatternRewriter &rewriter) const final {
1042 SmallVector<int64_t> constantPerms;
1043 if (failed(op.getConstantPerms(constantPerms)))
1044 return failure();
1045
1046 Location loc = op.getLoc();
1047 // The verifier should have made sure we have a valid permutation tensor.
1048 assert(isPermutationVector(constantPerms) && "Expected valid permutation");
1049 SmallVector<OpFoldResult> inputSizes =
1050 tensor::getMixedSizes(builder&: rewriter, loc, value: op.getInput1());
1051 auto permutedSizes =
1052 applyPermutation<OpFoldResult>(input: inputSizes, permutation: constantPerms);
1053
1054 auto permutedInit = rewriter.create<tensor::EmptyOp>(
1055 loc, permutedSizes, op.getInput1().getType().getElementType());
1056 rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1057 op, op.getInput1(), permutedInit, constantPerms);
1058 return success();
1059 }
1060};
1061} // namespace
1062
1063void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
1064 RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) {
1065 if (options.preferConv2DKernelLayoutHWCF) {
1066 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1067 linalg::Conv2DNhwcHwcfQOp>>(
1068 patterns->getContext());
1069 } else {
1070 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1071 linalg::Conv2DNhwcFhwcQOp>>(
1072 patterns->getContext());
1073 }
1074 patterns->add<
1075 // clang-format off
1076 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1077 DepthwiseConvConverter,
1078 MatMulConverter,
1079 MaxPool2dConverter,
1080 AvgPool2dConverter,
1081 FullyConnectedConverter,
1082 TransposeConverter
1083 >(patterns->getContext());
1084 // clang-format on
1085}
1086

source code of mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp