1 | //===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // These rewriters lower from the Tosa to the Linalg named ops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
16 | #include "mlir/Dialect/Math/IR/Math.h" |
17 | #include "mlir/Dialect/SCF/IR/SCF.h" |
18 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
19 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
20 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
21 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
22 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
23 | #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
24 | #include "mlir/IR/Matchers.h" |
25 | #include "mlir/IR/PatternMatch.h" |
26 | #include "mlir/Transforms/DialectConversion.h" |
27 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
28 | |
29 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
30 | |
31 | #include <numeric> |
32 | #include <type_traits> |
33 | |
34 | using namespace mlir; |
35 | using namespace mlir::tosa; |
36 | |
37 | static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad, |
38 | TypedAttr padAttr, OpBuilder &rewriter) { |
39 | // Input should be padded only if necessary. |
40 | if (llvm::all_of(Range&: pad, P: [](int64_t p) { return p == 0; })) |
41 | return input; |
42 | |
43 | ShapedType inputTy = cast<ShapedType>(input.getType()); |
44 | Type inputETy = inputTy.getElementType(); |
45 | auto inputShape = inputTy.getShape(); |
46 | |
47 | assert((inputShape.size() * 2) == pad.size()); |
48 | |
49 | SmallVector<int64_t, 4> paddedShape; |
50 | SmallVector<OpFoldResult, 8> lowIndices; |
51 | SmallVector<OpFoldResult, 8> highIndices; |
52 | for (size_t i : llvm::seq(inputShape.size())) { |
53 | auto lowPad = pad[i * 2]; |
54 | auto highPad = pad[i * 2 + 1]; |
55 | if (ShapedType::isDynamic(inputShape[i])) |
56 | paddedShape.push_back(inputShape[i]); |
57 | else |
58 | paddedShape.push_back(inputShape[i] + highPad + lowPad); |
59 | lowIndices.push_back(rewriter.getIndexAttr(lowPad)); |
60 | highIndices.push_back(rewriter.getIndexAttr(highPad)); |
61 | } |
62 | |
63 | Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr); |
64 | |
65 | return rewriter.create<tensor::PadOp>( |
66 | loc, RankedTensorType::get(paddedShape, inputETy), input, lowIndices, |
67 | highIndices, padValue); |
68 | } |
69 | |
70 | static mlir::Value |
71 | linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, |
72 | Value conv, Value result, |
73 | ArrayRef<AffineMap> indexingMaps) { |
74 | ShapedType resultTy = cast<ShapedType>(conv.getType()); |
75 | return rewriter |
76 | .create<linalg::GenericOp>( |
77 | loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, |
78 | getNParallelLoopsAttrs(resultTy.getRank()), |
79 | [](OpBuilder &builder, Location loc, ValueRange args) { |
80 | Value biasVal = args[0]; |
81 | Type resType = args[1].getType(); |
82 | if (resType != biasVal.getType()) { |
83 | biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal); |
84 | } |
85 | Value added = builder.create<arith::AddIOp>(loc, biasVal, args[1]); |
86 | builder.create<linalg::YieldOp>(loc, added); |
87 | }) |
88 | .getResult(0); |
89 | } |
90 | |
91 | // Broadcast the source value to all the outer dimensions of the result value. |
92 | // If required, the element type is expanded using an arith.extsi operation. |
93 | static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter, |
94 | Location loc, Value source, |
95 | Value result) { |
96 | ShapedType resultTy = cast<ShapedType>(result.getType()); |
97 | ShapedType sourceTy = cast<ShapedType>(source.getType()); |
98 | int64_t resultRank = resultTy.getRank(); |
99 | int64_t sourceRank = sourceTy.getRank(); |
100 | |
101 | // The source tensor is broadcast to all the outer dimensions of the |
102 | // result tensor. |
103 | SmallVector<AffineExpr> sourceDims; |
104 | // In the case of a rank one source tensor with a single element TOSA |
105 | // specifies that the value be broadcast meaning we need an edge case for a |
106 | // constant map. |
107 | assert(sourceTy.hasStaticShape() && |
108 | "Dynamic broadcasting shapes not supported!" ); |
109 | if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) { |
110 | sourceDims.push_back(Elt: rewriter.getAffineConstantExpr(constant: 0)); |
111 | } else { |
112 | for (auto dim : llvm::seq<int64_t>(0, sourceRank)) { |
113 | auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank); |
114 | sourceDims.push_back(expr); |
115 | } |
116 | } |
117 | |
118 | // Creating maps for the input and output of the broacast-like generic op. |
119 | SmallVector<AffineMap, 2> indexingMaps = { |
120 | // Broadcast the last dimension of the bias to all output dimensions. |
121 | AffineMap::get(/*dimCount=*/resultRank, |
122 | /*symbolCount=*/0, results: sourceDims, context: rewriter.getContext()), |
123 | |
124 | // Output indexing map. |
125 | rewriter.getMultiDimIdentityMap(rank: resultRank)}; |
126 | |
127 | // Build the broadcast-like operation as a linalg.generic. |
128 | return rewriter |
129 | .create<linalg::GenericOp>( |
130 | loc, resultTy, ValueRange({source}), result, indexingMaps, |
131 | getNParallelLoopsAttrs(resultTy.getRank()), |
132 | [](OpBuilder &builder, Location loc, ValueRange args) { |
133 | Value biasVal = args[0]; |
134 | Type resType = args[1].getType(); |
135 | if (resType != biasVal.getType()) { |
136 | biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal); |
137 | } |
138 | builder.create<linalg::YieldOp>(loc, biasVal); |
139 | }) |
140 | .getResult(0); |
141 | } |
142 | |
143 | static mlir::Value reifyConstantDim(int64_t attr, |
144 | ImplicitLocOpBuilder &builder) { |
145 | return builder.create<arith::ConstantIndexOp>(args&: attr); |
146 | } |
147 | |
148 | // Calculating the output width/height using the formula: |
149 | // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1 |
150 | // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1 |
151 | |
152 | static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, |
153 | int64_t padBeforeAttr, |
154 | int64_t padAfterAttr, Value kernelDim, |
155 | int64_t strideAttr, |
156 | int64_t dilationAttr, |
157 | OpBuilder &rewriter) { |
158 | ImplicitLocOpBuilder builder(loc, rewriter); |
159 | auto one = rewriter.create<arith::ConstantOp>( |
160 | loc, IntegerAttr::get(inputDim.getType(), 1)); |
161 | Value padBefore = reifyConstantDim(attr: padBeforeAttr, builder); |
162 | Value paddedBefore = builder.create<arith::AddIOp>(inputDim, padBefore); |
163 | Value padAfter = reifyConstantDim(attr: padAfterAttr, builder); |
164 | Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter); |
165 | |
166 | Value subOne = builder.create<arith::SubIOp>(kernelDim, one); |
167 | Value dilation = reifyConstantDim(attr: dilationAttr, builder); |
168 | Value dilated = builder.create<arith::MulIOp>(dilation, subOne); |
169 | Value addOne = builder.create<arith::AddIOp>(dilated, one); |
170 | |
171 | Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne); |
172 | Value stride = reifyConstantDim(attr: strideAttr, builder); |
173 | Value divide = builder.create<arith::DivUIOp>(subtract, stride); |
174 | return builder.create<arith::AddIOp>(divide, one); |
175 | } |
176 | |
177 | // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D |
178 | static SmallVector<Value> inferDynamicDimsForConv( |
179 | Location loc, Value input, Value weight, ShapedType resultTy, |
180 | ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr, |
181 | ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims, |
182 | ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) { |
183 | ShapedType inputTy = cast<ShapedType>(input.getType()); |
184 | int64_t inputRank = inputTy.getRank(); |
185 | |
186 | SmallVector<Value> dynDims; |
187 | dynDims.resize(resultTy.getRank()); |
188 | |
189 | for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) { |
190 | int64_t inputDim = inputSizeDims[i]; |
191 | int64_t kernelDim = kernelSizeDims[i]; |
192 | if (resultTy.isDynamicDim(inputDim)) { |
193 | auto padTop = padAttr[i * 2]; |
194 | auto padBottom = padAttr[i * 2 + 1]; |
195 | auto stride = strideAttr[i]; |
196 | auto dilation = dilationAttr[i]; |
197 | Value initDynDim = rewriter.create<tensor::DimOp>(loc, input, inputDim); |
198 | Value kernelDynDim = |
199 | rewriter.create<tensor::DimOp>(loc, weight, kernelDim); |
200 | // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y) |
201 | dynDims[inputDim] = |
202 | getConvOrPoolOutputDim(loc, inputDim: initDynDim, padBeforeAttr: padTop, padAfterAttr: padBottom, |
203 | kernelDim: kernelDynDim, strideAttr: stride, dilationAttr: dilation, rewriter); |
204 | } |
205 | } |
206 | |
207 | // Get the batch/channels dimensions. |
208 | for (int i = 0; i < inputRank; i++) { |
209 | if (resultTy.isDynamicDim(i) && !dynDims[i]) |
210 | dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i); |
211 | } |
212 | |
213 | SmallVector<Value> filteredDims = condenseValues(values: dynDims); |
214 | return filteredDims; |
215 | } |
216 | |
217 | // Creates a map to collapse the last dimension of the Depthwise convolution op |
218 | // due to a shape mismatch |
219 | static void createDepthwiseConvCollapseMap( |
220 | int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap, |
221 | OpBuilder &rewriter) { |
222 | reassociationMap.resize(N: outputRank); |
223 | for (int i = 0; i < outputRank; i++) { |
224 | reassociationMap[i].push_back(Elt: rewriter.getAffineDimExpr(position: i)); |
225 | } |
226 | reassociationMap[outputRank - 1].push_back( |
227 | Elt: rewriter.getAffineDimExpr(position: outputRank)); |
228 | } |
229 | |
230 | namespace { |
231 | |
232 | template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp> |
233 | class ConvConverter : public OpConversionPattern<TosaConvOp> { |
234 | public: |
235 | using OpConversionPattern<TosaConvOp>::OpConversionPattern; |
236 | LogicalResult |
237 | matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor, |
238 | ConversionPatternRewriter &rewriter) const final { |
239 | Location loc = op->getLoc(); |
240 | Value input = op->getOperand(0); |
241 | Value weight = op->getOperand(1); |
242 | Value bias = op->getOperand(2); |
243 | |
244 | ShapedType inputTy = cast<ShapedType>(input.getType()); |
245 | ShapedType weightTy = cast<ShapedType>(weight.getType()); |
246 | ShapedType biasTy = cast<ShapedType>(bias.getType()); |
247 | ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType()); |
248 | |
249 | Type inputETy = inputTy.getElementType(); |
250 | Type resultETy = resultTy.getElementType(); |
251 | |
252 | DenseI64ArrayAttr padAttr = op.getPadAttr(); |
253 | DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr(); |
254 | DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr(); |
255 | bool isQuantized = op.getQuantizationInfo().has_value(); |
256 | |
257 | if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) |
258 | return rewriter.notifyMatchFailure( |
259 | op, "tosa.conv ops require static shapes for weight and bias" ); |
260 | |
261 | if (inputETy.isUnsignedInteger()) |
262 | return rewriter.notifyMatchFailure( |
263 | op, "tosa.conv ops does not support unsigned integer input" ); |
264 | |
265 | llvm::SmallVector<int64_t> inputSizeDims; |
266 | llvm::SmallVector<int64_t> kernelSizeDims; |
267 | for (int i = 1; i < resultTy.getRank() - 1; i++) { |
268 | inputSizeDims.push_back(Elt: i); |
269 | kernelSizeDims.push_back(Elt: i); |
270 | } |
271 | |
272 | SmallVector<Value> filteredDims = inferDynamicDimsForConv( |
273 | loc, input, weight, resultTy, padAttr.asArrayRef(), |
274 | strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(), |
275 | inputSizeDims, kernelSizeDims, rewriter); |
276 | |
277 | auto weightShape = weightTy.getShape(); |
278 | |
279 | // Apply padding as necessary. |
280 | TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy); |
281 | if (isQuantized) { |
282 | auto quantizationInfo = *op.getQuantizationInfo(); |
283 | int64_t iZp = quantizationInfo.getInputZp(); |
284 | |
285 | int64_t intMin = |
286 | APInt::getSignedMinValue(numBits: inputETy.getIntOrFloatBitWidth()) |
287 | .getSExtValue(); |
288 | int64_t intMax = |
289 | APInt::getSignedMaxValue(numBits: inputETy.getIntOrFloatBitWidth()) |
290 | .getSExtValue(); |
291 | |
292 | if (iZp < intMin || iZp > intMax) |
293 | return rewriter.notifyMatchFailure( |
294 | op, "tosa.conv op quantization has zp outside of input range" ); |
295 | |
296 | zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); |
297 | } |
298 | |
299 | llvm::SmallVector<int64_t> pad; |
300 | pad.resize(N: 2, NV: 0); |
301 | llvm::append_range(pad, padAttr.asArrayRef()); |
302 | pad.resize(N: pad.size() + 2, NV: 0); |
303 | input = applyPad(loc, input, pad, zeroAttr, rewriter); |
304 | |
305 | if (4 == inputTy.getRank()) { |
306 | // For 2D convolutions, we need to check if the target convolution op |
307 | // wants a HWCF kernel layout. |
308 | bool wantHwcf = |
309 | isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp> |
310 | : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>; |
311 | if (wantHwcf) { |
312 | // Transpose the kernel to match dimension ordering of the linalg |
313 | // convolution operation. |
314 | // TODO(suderman): See if this can be efficiently folded - check whether |
315 | // the input is used anywhere else, if not fold the constant. |
316 | SmallVector<int64_t> weightPerm; |
317 | for (int i = 1; i < resultTy.getRank(); i++) |
318 | weightPerm.push_back(Elt: i); |
319 | weightPerm.push_back(Elt: 0); |
320 | |
321 | SmallVector<int64_t> newWeightShape; |
322 | for (auto dim : weightPerm) |
323 | newWeightShape.push_back(Elt: weightShape[dim]); |
324 | auto weightPermAttr = rewriter.getI64TensorAttr(values: weightPerm); |
325 | Value weightPermValue = |
326 | rewriter.create<arith::ConstantOp>(loc, weightPermAttr); |
327 | Type newWeightTy = |
328 | RankedTensorType::get(newWeightShape, weightTy.getElementType()); |
329 | weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight, |
330 | weightPermValue); |
331 | } |
332 | } |
333 | |
334 | // For Conv3D transpose the kernel to match dimension ordering of the linalg |
335 | // convolution operation. Conv2D has a 1-1 mapping in linalg so better to |
336 | // map directly and then transpose later if desired. |
337 | if (5 == inputTy.getRank()) { |
338 | // TODO(suderman): See if this can be efficiently folded - check whether |
339 | // the input is used anywhere else, if not fold the constant. |
340 | SmallVector<int64_t> weightPerm; |
341 | for (int i = 1; i < resultTy.getRank(); i++) |
342 | weightPerm.push_back(Elt: i); |
343 | weightPerm.push_back(Elt: 0); |
344 | |
345 | SmallVector<int64_t> newWeightShape; |
346 | for (auto dim : weightPerm) |
347 | newWeightShape.push_back(Elt: weightShape[dim]); |
348 | auto weightPermAttr = rewriter.getI64TensorAttr(values: weightPerm); |
349 | Value weightPermValue = |
350 | rewriter.create<arith::ConstantOp>(loc, weightPermAttr); |
351 | Type newWeightTy = |
352 | RankedTensorType::get(newWeightShape, weightTy.getElementType()); |
353 | weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight, |
354 | weightPermValue); |
355 | } |
356 | |
357 | // Extract the attributes for convolution. |
358 | ArrayRef<int64_t> stride = strideTosaAttr; |
359 | ArrayRef<int64_t> dilation = dilationTosaAttr; |
360 | |
361 | // Create the convolution op. |
362 | auto strideAttr = rewriter.getI64TensorAttr(values: stride); |
363 | auto dilationAttr = rewriter.getI64TensorAttr(values: dilation); |
364 | |
365 | Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>( |
366 | loc, resultTy.getShape(), resultETy, filteredDims); |
367 | |
368 | Value broadcastBias = |
369 | linalgBroadcastAndMaybeExtSI(rewriter, loc, source: bias, result: biasEmptyTensor); |
370 | |
371 | if (isQuantized) { |
372 | auto quantizationInfo = *op.getQuantizationInfo(); |
373 | auto iZp = rewriter.getI32IntegerAttr(value: quantizationInfo.getInputZp()); |
374 | auto kZp = rewriter.getI32IntegerAttr(value: quantizationInfo.getWeightZp()); |
375 | |
376 | auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp); |
377 | auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp); |
378 | |
379 | Value conv = |
380 | rewriter |
381 | .create<LinalgConvQOp>( |
382 | loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, |
383 | ValueRange{broadcastBias}, strideAttr, dilationAttr) |
384 | ->getResult(0); |
385 | |
386 | rewriter.replaceOp(op, conv); |
387 | return success(); |
388 | } |
389 | |
390 | Value conv = rewriter |
391 | .create<LinalgConvOp>( |
392 | loc, resultTy, ValueRange{input, weight}, |
393 | ValueRange{broadcastBias}, strideAttr, dilationAttr) |
394 | ->getResult(0); |
395 | |
396 | rewriter.replaceOp(op, conv); |
397 | return success(); |
398 | } |
399 | }; |
400 | |
401 | class DepthwiseConvConverter |
402 | : public OpConversionPattern<tosa::DepthwiseConv2DOp> { |
403 | public: |
404 | using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern; |
405 | LogicalResult |
406 | matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor, |
407 | ConversionPatternRewriter &rewriter) const final { |
408 | Location loc = op->getLoc(); |
409 | Value input = op->getOperand(0); |
410 | Value weight = op->getOperand(1); |
411 | Value bias = op->getOperand(2); |
412 | |
413 | ShapedType inputTy = cast<ShapedType>(input.getType()); |
414 | ShapedType weightTy = cast<ShapedType>(weight.getType()); |
415 | ShapedType biasTy = cast<ShapedType>(bias.getType()); |
416 | ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType()); |
417 | int64_t resultRank = resultTy.getRank(); |
418 | |
419 | Type inputETy = inputTy.getElementType(); |
420 | Type resultETy = resultTy.getElementType(); |
421 | |
422 | auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr("pad" )); |
423 | auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride" )); |
424 | auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation" )); |
425 | |
426 | if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) |
427 | return rewriter.notifyMatchFailure( |
428 | op, "tosa.depthwise_conv ops require static shapes" ); |
429 | |
430 | // Compute output dynamic dims |
431 | SmallVector<Value> filteredDims = inferDynamicDimsForConv( |
432 | loc, input, weight, resultTy, padAttr.asArrayRef(), |
433 | strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(), |
434 | /*inputSizeDims=*/{1, 2}, |
435 | /*kernelSizeDims=*/{0, 1}, rewriter); |
436 | |
437 | bool isQuantized = op->hasAttr("quantization_info" ); |
438 | IntegerAttr iZp; |
439 | IntegerAttr kZp; |
440 | if (isQuantized) { |
441 | auto quantizationInfo = |
442 | cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info" )); |
443 | iZp = rewriter.getI32IntegerAttr(value: quantizationInfo.getInputZp()); |
444 | kZp = rewriter.getI32IntegerAttr(value: quantizationInfo.getWeightZp()); |
445 | } |
446 | |
447 | auto weightShape = weightTy.getShape(); |
448 | auto resultShape = resultTy.getShape(); |
449 | |
450 | // Apply padding as necessary. |
451 | TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy); |
452 | if (isQuantized) { |
453 | auto quantizationInfo = |
454 | cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info" )); |
455 | int64_t iZp = quantizationInfo.getInputZp(); |
456 | |
457 | int64_t intMin = |
458 | APInt::getSignedMinValue(numBits: inputETy.getIntOrFloatBitWidth()) |
459 | .getSExtValue(); |
460 | int64_t intMax = |
461 | APInt::getSignedMaxValue(numBits: inputETy.getIntOrFloatBitWidth()) |
462 | .getSExtValue(); |
463 | |
464 | if (iZp < intMin || iZp > intMax) |
465 | return rewriter.notifyMatchFailure( |
466 | op, "tosa.depthwise_conv op quantization has zp outside of input " |
467 | "range" ); |
468 | |
469 | zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); |
470 | } |
471 | |
472 | llvm::SmallVector<int64_t> pad; |
473 | pad.resize(N: 2, NV: 0); |
474 | llvm::append_range(pad, padAttr.asArrayRef()); |
475 | pad.resize(N: pad.size() + 2, NV: 0); |
476 | |
477 | input = applyPad(loc, input, pad, zeroAttr, rewriter); |
478 | |
479 | // Extract the attributes for convolution. |
480 | ArrayRef<int64_t> stride = strideTosaAttr; |
481 | ArrayRef<int64_t> dilation = dilationTosaAttr; |
482 | |
483 | // Create the convolution op. |
484 | auto strideAttr = rewriter.getI64TensorAttr(values: stride); |
485 | auto dilationAttr = rewriter.getI64TensorAttr(values: dilation); |
486 | ShapedType linalgConvTy = |
487 | RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2], |
488 | weightShape[2], weightShape[3]}, |
489 | resultETy); |
490 | |
491 | // Broadcast the initial value to the output tensor before convolving. |
492 | SmallVector<AffineMap, 4> indexingMaps; |
493 | indexingMaps.push_back(Elt: AffineMap::get( |
494 | /*dimCount=*/resultRank, /*symbolCount=*/0, |
495 | results: {rewriter.getAffineDimExpr(position: 3)}, context: rewriter.getContext())); |
496 | indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank: resultRank)); |
497 | indexingMaps.push_back(Elt: rewriter.getMultiDimIdentityMap(rank: resultRank)); |
498 | |
499 | auto resultZeroAttr = rewriter.getZeroAttr(resultETy); |
500 | Value emptyTensor = rewriter.create<tensor::EmptyOp>( |
501 | loc, linalgConvTy.getShape(), resultETy, filteredDims); |
502 | Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr); |
503 | Value zeroTensor = rewriter |
504 | .create<linalg::FillOp>(loc, ValueRange{zero}, |
505 | ValueRange{emptyTensor}) |
506 | .result(); |
507 | |
508 | Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>( |
509 | loc, resultTy.getShape(), resultETy, filteredDims); |
510 | if (!isQuantized) { |
511 | Value conv = rewriter |
512 | .create<linalg::DepthwiseConv2DNhwcHwcmOp>( |
513 | loc, linalgConvTy, ValueRange{input, weight}, |
514 | ValueRange{zeroTensor}, strideAttr, dilationAttr) |
515 | .getResult(0); |
516 | |
517 | SmallVector<ReassociationExprs, 4> reassociationMap; |
518 | createDepthwiseConvCollapseMap(outputRank: resultRank, reassociationMap, rewriter); |
519 | Value convReshape = rewriter.create<tensor::CollapseShapeOp>( |
520 | loc, resultTy, conv, reassociationMap); |
521 | |
522 | Value result = |
523 | rewriter |
524 | .create<linalg::GenericOp>( |
525 | loc, resultTy, ValueRange({bias, convReshape}), |
526 | biasEmptyTensor, indexingMaps, |
527 | getNParallelLoopsAttrs(resultRank), |
528 | [&](OpBuilder &nestedBuilder, Location nestedLoc, |
529 | ValueRange args) { |
530 | Value added = nestedBuilder.create<arith::AddFOp>( |
531 | loc, args[0], args[1]); |
532 | nestedBuilder.create<linalg::YieldOp>(nestedLoc, added); |
533 | }) |
534 | .getResult(0); |
535 | rewriter.replaceOp(op, result); |
536 | } else { |
537 | auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp); |
538 | auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp); |
539 | Value conv = |
540 | rewriter |
541 | .create<linalg::DepthwiseConv2DNhwcHwcmQOp>( |
542 | loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, |
543 | ValueRange{zeroTensor}, strideAttr, dilationAttr) |
544 | .getResult(0); |
545 | SmallVector<ReassociationExprs, 4> reassociationMap; |
546 | createDepthwiseConvCollapseMap(outputRank: resultRank, reassociationMap, rewriter); |
547 | Value convReshape = rewriter.create<tensor::CollapseShapeOp>( |
548 | loc, resultTy, conv, reassociationMap); |
549 | Value result = linalgIntBroadcastExtSIAdd( |
550 | rewriter, loc, bias, conv: convReshape, result: biasEmptyTensor, indexingMaps); |
551 | rewriter.replaceOp(op, result); |
552 | } |
553 | return success(); |
554 | } |
555 | }; |
556 | |
557 | class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> { |
558 | public: |
559 | using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern; |
560 | LogicalResult |
561 | matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor, |
562 | ConversionPatternRewriter &rewriter) const final { |
563 | Location loc = op.getLoc(); |
564 | |
565 | auto outputTy = cast<ShapedType>(op.getType()); |
566 | auto outputElementTy = outputTy.getElementType(); |
567 | |
568 | SmallVector<Value> dynDims; |
569 | dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank()); |
570 | |
571 | if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) { |
572 | dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0); |
573 | } |
574 | |
575 | if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) { |
576 | dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1); |
577 | } |
578 | |
579 | if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) { |
580 | dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2); |
581 | } |
582 | |
583 | SmallVector<Value> filteredDims = condenseValues(values: dynDims); |
584 | |
585 | auto zeroAttr = rewriter.getZeroAttr(type: outputElementTy); |
586 | Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); |
587 | auto emptyTensor = rewriter.create<tensor::EmptyOp>( |
588 | loc, outputTy.getShape(), outputTy.getElementType(), filteredDims); |
589 | Value zeroTensor = rewriter |
590 | .create<linalg::FillOp>(loc, ValueRange{zero}, |
591 | ValueRange{emptyTensor}) |
592 | .result(); |
593 | if (!op.getQuantizationInfo()) { |
594 | rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>( |
595 | op, TypeRange{op.getType()}, |
596 | ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor}); |
597 | return success(); |
598 | } |
599 | |
600 | auto quantizationInfo = *op.getQuantizationInfo(); |
601 | auto aZp = rewriter.create<arith::ConstantOp>( |
602 | loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp())); |
603 | auto bZp = rewriter.create<arith::ConstantOp>( |
604 | loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp())); |
605 | rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>( |
606 | op, TypeRange{op.getType()}, |
607 | ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor); |
608 | |
609 | return success(); |
610 | } |
611 | }; |
612 | |
613 | class FullyConnectedConverter |
614 | : public OpConversionPattern<tosa::FullyConnectedOp> { |
615 | public: |
616 | using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern; |
617 | LogicalResult |
618 | matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor, |
619 | ConversionPatternRewriter &rewriter) const final { |
620 | Location loc = op.getLoc(); |
621 | auto outputTy = cast<ShapedType>(op.getType()); |
622 | auto input = op.getInput(); |
623 | auto inputTy = cast<ShapedType>(input.getType()); |
624 | |
625 | auto bias = op.getBias(); |
626 | |
627 | auto weight = op.getWeight(); |
628 | auto weightTy = cast<ShapedType>(weight.getType()); |
629 | auto weightShape = weightTy.getShape(); |
630 | |
631 | auto outputETy = outputTy.getElementType(); |
632 | |
633 | SmallVector<Value> dynDims; |
634 | dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank()); |
635 | |
636 | if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) { |
637 | dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0); |
638 | } |
639 | |
640 | if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) { |
641 | dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0); |
642 | } |
643 | |
644 | SmallVector<Value> filteredDims = condenseValues(values: dynDims); |
645 | |
646 | SmallVector<int64_t> permutation{1, 0}; |
647 | auto permutationAttr = rewriter.getI64TensorAttr(values: permutation); |
648 | Value permutationValue = |
649 | rewriter.create<arith::ConstantOp>(loc, permutationAttr); |
650 | |
651 | SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]}; |
652 | Type newWeightTy = |
653 | RankedTensorType::get(newWeightShape, weightTy.getElementType()); |
654 | |
655 | Value transposedWeight = rewriter.create<tosa::TransposeOp>( |
656 | loc, newWeightTy, weight, permutationValue); |
657 | |
658 | Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>( |
659 | loc, outputTy.getShape(), outputETy, filteredDims); |
660 | |
661 | Value broadcastBias = |
662 | linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor); |
663 | |
664 | if (!op.getQuantizationInfo()) { |
665 | Value matmul = rewriter |
666 | .create<linalg::MatmulOp>( |
667 | loc, TypeRange{op.getType()}, |
668 | ValueRange{input, transposedWeight}, broadcastBias) |
669 | ->getResult(0); |
670 | |
671 | rewriter.replaceOp(op, matmul); |
672 | return success(); |
673 | } |
674 | |
675 | auto quantizationInfo = *op.getQuantizationInfo(); |
676 | auto inputZp = rewriter.create<arith::ConstantOp>( |
677 | loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp())); |
678 | auto outputZp = rewriter.create<arith::ConstantOp>( |
679 | loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp())); |
680 | Value matmul = |
681 | rewriter |
682 | .create<linalg::QuantizedMatmulOp>( |
683 | loc, TypeRange{op.getType()}, |
684 | ValueRange{input, transposedWeight, inputZp, outputZp}, |
685 | broadcastBias) |
686 | ->getResult(0); |
687 | |
688 | rewriter.replaceOp(op, matmul); |
689 | return success(); |
690 | } |
691 | }; |
692 | |
693 | class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> { |
694 | public: |
695 | using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern; |
696 | |
697 | // Compute the dynamic output sizes of the maxpool operation. |
698 | static SmallVector<Value> |
699 | computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) { |
700 | TensorType resultTy = op.getType(); |
701 | Location loc = op.getLoc(); |
702 | |
703 | TypedValue<TensorType> input = op.getInput(); |
704 | ArrayRef<int64_t> kernel = op.getKernel(); |
705 | ArrayRef<int64_t> pad = op.getPad(); |
706 | ArrayRef<int64_t> stride = op.getStride(); |
707 | |
708 | SmallVector<Value> dynamicDims; |
709 | |
710 | // Batch dimension |
711 | if (resultTy.isDynamicDim(0)) |
712 | dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0)); |
713 | |
714 | // Height/width dimensions |
715 | for (int64_t dim : {1, 2}) { |
716 | if (!resultTy.isDynamicDim(dim)) |
717 | continue; |
718 | |
719 | // Index into the attribute arrays |
720 | int64_t index = dim - 1; |
721 | |
722 | // Input height/width |
723 | Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim); |
724 | |
725 | // Kernel height/width |
726 | Value khw = rewriter.create<arith::ConstantIndexOp>(location: loc, args: kernel[index]); |
727 | |
728 | // Output height/width |
729 | Value ohw = getConvOrPoolOutputDim(loc, inputDim: ihw, padBeforeAttr: pad[index * 2], |
730 | padAfterAttr: pad[index * 2 + 1], kernelDim: khw, strideAttr: stride[index], |
731 | /*dilationAttr=*/1, rewriter); |
732 | dynamicDims.push_back(Elt: ohw); |
733 | } |
734 | |
735 | // Channel dimension |
736 | if (resultTy.isDynamicDim(3)) |
737 | dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3)); |
738 | |
739 | return dynamicDims; |
740 | } |
741 | |
742 | LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, |
743 | PatternRewriter &rewriter) const final { |
744 | Location loc = op.getLoc(); |
745 | TypedValue<TensorType> input = op.getInput(); |
746 | ShapedType inputTy = input.getType(); |
747 | |
748 | ShapedType resultTy = op.getType(); |
749 | Type resultETy = inputTy.getElementType(); |
750 | |
751 | SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter); |
752 | |
753 | // Determine what the initial value needs to be for the max pool op. |
754 | TypedAttr initialAttr; |
755 | if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16()) |
756 | initialAttr = rewriter.getFloatAttr( |
757 | resultETy, APFloat::getLargest( |
758 | Sem: cast<FloatType>(Val&: resultETy).getFloatSemantics(), Negative: true)); |
759 | |
760 | if (isa<IntegerType>(Val: resultETy)) |
761 | initialAttr = rewriter.getIntegerAttr( |
762 | resultETy, |
763 | APInt::getSignedMinValue(numBits: resultETy.getIntOrFloatBitWidth())); |
764 | |
765 | if (!initialAttr) |
766 | return rewriter.notifyMatchFailure( |
767 | op, "Unsupported initial value for tosa.maxpool_2d op" ); |
768 | |
769 | // Apply padding as necessary. |
770 | llvm::SmallVector<int64_t> pad; |
771 | pad.resize(N: 2, NV: 0); |
772 | llvm::append_range(pad, op.getPad()); |
773 | pad.resize(N: pad.size() + 2, NV: 0); |
774 | |
775 | Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); |
776 | |
777 | Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr); |
778 | |
779 | ArrayRef<int64_t> kernel = op.getKernel(); |
780 | ArrayRef<int64_t> stride = op.getStride(); |
781 | |
782 | Attribute strideAttr = rewriter.getI64VectorAttr(values: stride); |
783 | Attribute dilationAttr = rewriter.getI64VectorAttr(values: {1, 1}); |
784 | |
785 | // Create the linalg op that performs pooling. |
786 | Value emptyTensor = rewriter.create<tensor::EmptyOp>( |
787 | loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims); |
788 | |
789 | Value filledEmptyTensor = |
790 | rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor) |
791 | .result(); |
792 | |
793 | Value fakeWindowDims = |
794 | rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy); |
795 | |
796 | rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>( |
797 | op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims}, |
798 | filledEmptyTensor, strideAttr, dilationAttr); |
799 | return success(); |
800 | } |
801 | }; |
802 | |
803 | class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> { |
804 | public: |
805 | using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern; |
806 | |
807 | LogicalResult matchAndRewrite(tosa::AvgPool2dOp op, |
808 | PatternRewriter &rewriter) const final { |
809 | Location loc = op.getLoc(); |
810 | Value input = op.getInput(); |
811 | ShapedType inputTy = cast<ShapedType>(input.getType()); |
812 | Type inElementTy = inputTy.getElementType(); |
813 | |
814 | ShapedType resultTy = cast<ShapedType>(op.getType()); |
815 | Type resultETy = cast<ShapedType>(op.getType()).getElementType(); |
816 | |
817 | Type accETy = op.getAccType(); |
818 | ShapedType accTy = resultTy.clone(accETy); |
819 | |
820 | auto dynamicDimsOr = |
821 | checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()}); |
822 | if (!dynamicDimsOr.has_value()) |
823 | return failure(); |
824 | SmallVector<Value> dynamicDims = *dynamicDimsOr; |
825 | |
826 | // Apply padding as necessary. |
827 | llvm::SmallVector<int64_t> pad; |
828 | pad.resize(N: 2, NV: 0); |
829 | llvm::append_range(pad, op.getPad()); |
830 | pad.resize(N: pad.size() + 2, NV: 0); |
831 | TypedAttr padAttr = rewriter.getZeroAttr(inElementTy); |
832 | // Unsupported element type |
833 | if (!padAttr) |
834 | return failure(); |
835 | Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter); |
836 | |
837 | auto initialAttr = rewriter.getZeroAttr(accETy); |
838 | Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr); |
839 | |
840 | ArrayRef<int64_t> kernel = op.getKernel(); |
841 | ArrayRef<int64_t> stride = op.getStride(); |
842 | |
843 | Attribute strideAttr = rewriter.getI64VectorAttr(values: stride); |
844 | Attribute dilationAttr = rewriter.getI64VectorAttr(values: {1, 1}); |
845 | |
846 | // Create the linalg op that performs pooling. |
847 | Value poolEmptyTensor = rewriter.create<tensor::EmptyOp>( |
848 | loc, accTy.getShape(), accETy, dynamicDims); |
849 | |
850 | Value filledEmptyTensor = |
851 | rewriter |
852 | .create<linalg::FillOp>(loc, ValueRange{initialValue}, |
853 | ValueRange{poolEmptyTensor}) |
854 | .result(); |
855 | |
856 | Value fakeWindowDims = |
857 | rewriter.create<tensor::EmptyOp>(loc, kernel, accETy); |
858 | |
859 | // Sum across the pooled region. |
860 | Value poolingOp = rewriter |
861 | .create<linalg::PoolingNhwcSumOp>( |
862 | loc, ArrayRef<Type>{accTy}, |
863 | ValueRange{paddedInput, fakeWindowDims}, |
864 | filledEmptyTensor, strideAttr, dilationAttr) |
865 | .getResult(0); |
866 | |
867 | // Normalize the summed value by the number of elements grouped in each |
868 | // pool. |
869 | Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1); |
870 | Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2); |
871 | |
872 | auto one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
873 | iH = rewriter.create<arith::SubIOp>(loc, iH, one); |
874 | iW = rewriter.create<arith::SubIOp>(loc, iW, one); |
875 | |
876 | Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>( |
877 | loc, resultTy.getShape(), resultETy, dynamicDims); |
878 | |
879 | auto affineMap = rewriter.getMultiDimIdentityMap(rank: resultTy.getRank()); |
880 | auto genericOp = rewriter.create<linalg::GenericOp>( |
881 | loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp}, |
882 | ValueRange{genericEmptyTensor}, |
883 | ArrayRef<AffineMap>({affineMap, affineMap}), |
884 | getNParallelLoopsAttrs(resultTy.getRank()), |
885 | [&](OpBuilder &b, Location loc, ValueRange args) { |
886 | auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
887 | |
888 | // Determines what the portion of valid input is covered by the |
889 | // kernel. |
890 | auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value { |
891 | if (pad == 0) |
892 | return valid; |
893 | |
894 | auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad); |
895 | Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal); |
896 | |
897 | Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero); |
898 | return rewriter.create<arith::AddIOp>(loc, valid, offset) |
899 | ->getResult(0); |
900 | }; |
901 | |
902 | auto coverageFn = [&](int64_t i, Value isize) -> Value { |
903 | Value strideVal = |
904 | rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]); |
905 | Value val = |
906 | rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]); |
907 | |
908 | // Find the position relative to the input tensor's ends. |
909 | Value left = rewriter.create<linalg::IndexOp>(loc, i); |
910 | Value right = rewriter.create<arith::SubIOp>(loc, isize, left); |
911 | left = rewriter.create<arith::MulIOp>(loc, left, strideVal); |
912 | right = rewriter.create<arith::MulIOp>(loc, right, strideVal); |
913 | |
914 | // Determine how much padding was included. |
915 | val = padFn(val, left, pad[i * 2]); |
916 | val = padFn(val, right, pad[i * 2 + 1]); |
917 | return rewriter.create<arith::MaxSIOp>(loc, one, val); |
918 | }; |
919 | |
920 | // Compute the indices from either end. |
921 | Value kH3 = coverageFn(1, iH); |
922 | Value kW3 = coverageFn(2, iW); |
923 | |
924 | // Compute the total number of elements and normalize. |
925 | auto count = rewriter.create<arith::IndexCastOp>( |
926 | loc, rewriter.getI32Type(), |
927 | rewriter.create<arith::MulIOp>(loc, kH3, kW3)); |
928 | |
929 | // Divide by the number of summed values. For floats this is just |
930 | // a div however for quantized values input normalization had |
931 | // to be applied. |
932 | Value poolVal = args[0]; |
933 | if (isa<FloatType>(accETy)) { |
934 | auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count); |
935 | poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF) |
936 | ->getResult(0); |
937 | if (accETy.getIntOrFloatBitWidth() > |
938 | resultETy.getIntOrFloatBitWidth()) |
939 | poolVal = |
940 | rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal); |
941 | } else { |
942 | |
943 | // If we have quantization information we need to apply an offset |
944 | // for the input zp value. |
945 | if (op.getQuantizationInfo()) { |
946 | auto quantizationInfo = *op.getQuantizationInfo(); |
947 | auto inputZp = rewriter.create<arith::ConstantOp>( |
948 | loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp())); |
949 | Value offset = |
950 | rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp); |
951 | poolVal = |
952 | rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset); |
953 | } |
954 | |
955 | // Compute: k = 32 - count_leading_zeros(value - 1) |
956 | Value one32 = rewriter.create<arith::ConstantOp>( |
957 | loc, rewriter.getI32IntegerAttr(1)); |
958 | Value thirtyTwo32 = rewriter.create<arith::ConstantOp>( |
959 | loc, rewriter.getI32IntegerAttr(32)); |
960 | |
961 | Value countSubOne = |
962 | rewriter.create<arith::SubIOp>(loc, count, one32); |
963 | Value leadingZeros = |
964 | rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne); |
965 | Value k = |
966 | rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros); |
967 | |
968 | // Compute: numerator = ((1 << 30) + 1) << k |
969 | Value k64 = |
970 | rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k); |
971 | Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>( |
972 | loc, rewriter.getI64IntegerAttr((1 << 30) + 1)); |
973 | Value numerator = |
974 | rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64); |
975 | |
976 | // Compute: scale.multiplier = numerator / value; |
977 | Value count64 = rewriter.create<arith::ExtUIOp>( |
978 | loc, rewriter.getI64Type(), count); |
979 | Value multiplier = |
980 | rewriter.create<arith::DivUIOp>(loc, numerator, count64); |
981 | multiplier = rewriter.create<arith::TruncIOp>( |
982 | loc, rewriter.getI32Type(), multiplier); |
983 | |
984 | // Compute: scale.shift = 30 + k |
985 | Value k8 = |
986 | rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k); |
987 | Value thirty8 = rewriter.create<arith::ConstantOp>( |
988 | loc, rewriter.getI8IntegerAttr(30)); |
989 | Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8); |
990 | |
991 | auto scaled = |
992 | rewriter |
993 | .create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(), |
994 | poolVal, multiplier, shift, |
995 | rewriter.getBoolAttr(false)) |
996 | .getResult(); |
997 | |
998 | // If we have quantization information we need to apply output |
999 | // zeropoint. |
1000 | if (op.getQuantizationInfo()) { |
1001 | auto quantizationInfo = *op.getQuantizationInfo(); |
1002 | auto outputZp = rewriter.create<arith::ConstantOp>( |
1003 | loc, b.getIntegerAttr(scaled.getType(), |
1004 | quantizationInfo.getOutputZp())); |
1005 | scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp) |
1006 | .getResult(); |
1007 | } |
1008 | |
1009 | // Apply Clip. |
1010 | int64_t outBitwidth = resultETy.getIntOrFloatBitWidth(); |
1011 | |
1012 | auto min = rewriter.create<arith::ConstantIntOp>( |
1013 | loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(), |
1014 | accETy); |
1015 | auto max = rewriter.create<arith::ConstantIntOp>( |
1016 | loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(), |
1017 | accETy); |
1018 | auto clamp = clampIntHelper(loc, scaled, min, max, rewriter); |
1019 | |
1020 | poolVal = clamp; |
1021 | // Convert type. |
1022 | if (resultETy != clamp.getType()) { |
1023 | poolVal = |
1024 | rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal); |
1025 | } |
1026 | } |
1027 | |
1028 | rewriter.create<linalg::YieldOp>(loc, poolVal); |
1029 | }); |
1030 | |
1031 | rewriter.replaceOp(op, genericOp.getResult(0)); |
1032 | return success(); |
1033 | } |
1034 | }; |
1035 | |
1036 | class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> { |
1037 | public: |
1038 | using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern; |
1039 | |
1040 | LogicalResult matchAndRewrite(tosa::TransposeOp op, |
1041 | PatternRewriter &rewriter) const final { |
1042 | SmallVector<int64_t> constantPerms; |
1043 | if (failed(op.getConstantPerms(constantPerms))) |
1044 | return failure(); |
1045 | |
1046 | Location loc = op.getLoc(); |
1047 | // The verifier should have made sure we have a valid permutation tensor. |
1048 | assert(isPermutationVector(constantPerms) && "Expected valid permutation" ); |
1049 | SmallVector<OpFoldResult> inputSizes = |
1050 | tensor::getMixedSizes(builder&: rewriter, loc, value: op.getInput1()); |
1051 | auto permutedSizes = |
1052 | applyPermutation<OpFoldResult>(input: inputSizes, permutation: constantPerms); |
1053 | |
1054 | auto permutedInit = rewriter.create<tensor::EmptyOp>( |
1055 | loc, permutedSizes, op.getInput1().getType().getElementType()); |
1056 | rewriter.replaceOpWithNewOp<linalg::TransposeOp>( |
1057 | op, op.getInput1(), permutedInit, constantPerms); |
1058 | return success(); |
1059 | } |
1060 | }; |
1061 | } // namespace |
1062 | |
1063 | void mlir::tosa::populateTosaToLinalgNamedConversionPatterns( |
1064 | RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) { |
1065 | if (options.preferConv2DKernelLayoutHWCF) { |
1066 | patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp, |
1067 | linalg::Conv2DNhwcHwcfQOp>>( |
1068 | patterns->getContext()); |
1069 | } else { |
1070 | patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, |
1071 | linalg::Conv2DNhwcFhwcQOp>>( |
1072 | patterns->getContext()); |
1073 | } |
1074 | patterns->add< |
1075 | // clang-format off |
1076 | ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>, |
1077 | DepthwiseConvConverter, |
1078 | MatMulConverter, |
1079 | MaxPool2dConverter, |
1080 | AvgPool2dConverter, |
1081 | FullyConnectedConverter, |
1082 | TransposeConverter |
1083 | >(patterns->getContext()); |
1084 | // clang-format on |
1085 | } |
1086 | |