1 | //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // These rewriters lower from the Tosa to the Tensor dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
18 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | #include "mlir/Transforms/DialectConversion.h" |
21 | |
22 | #include <numeric> |
23 | |
24 | using namespace mlir; |
25 | using namespace tosa; |
26 | |
27 | namespace { |
28 | |
29 | // Infer the type to which the input of a 'tosa.reshape' op must be cast when |
30 | // lowered. |
31 | TensorType inferReshapeInputType(TypedValue<TensorType> input, |
32 | ArrayRef<int64_t> newShape) { |
33 | // No need to cast input for non-empty target shape |
34 | if (!newShape.empty()) |
35 | return input.getType(); |
36 | |
37 | // The input type must be cast into a tensor with the same rank and all static |
38 | // dimensions set to 1. This prevents the generation of a tensor.collapse_shape |
39 | // op that converts a dynamically shaped tensor into a 0D tensor. While such |
40 | // construct is not incorrect on its own, bufferization cannot properly handle |
41 | // it at the moment, so we avoid it. |
42 | SmallVector<int64_t> shape(input.getType().getRank(), 1); |
43 | return input.getType().clone(shape); |
44 | } |
45 | |
46 | // Infer the result type of 'tensor.expand_shape' in the collapse-expand |
47 | // pair emitted for a 'tosa.reshape' op. |
48 | TensorType inferReshapeExpandedType(TensorType inputType, |
49 | ArrayRef<int64_t> newShape) { |
50 | // Special case for 0D output tensor. Note: Watch out when using Type::clone() |
51 | // with just '{}', as it will invoke the incorrect overload. |
52 | if (newShape.empty()) |
53 | return inputType.clone(ArrayRef<int64_t>{}); |
54 | |
55 | // Check if the input is static, and if so, get its total size |
56 | bool inputIsStatic = inputType.hasStaticShape(); |
57 | int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1; |
58 | |
59 | // Compute result shape |
60 | bool resultIsStatic = true; |
61 | auto resultShape = llvm::map_to_vector(C&: newShape, F: [&](int64_t size) -> int64_t { |
62 | // If this is not a placeholder, do not change it |
63 | if (size >= 0) |
64 | return size; |
65 | |
66 | // If we do not know the total size of the tensor, keep this dimension |
67 | // dynamic in the result shape. |
68 | if (!inputIsStatic) { |
69 | resultIsStatic = false; |
70 | return ShapedType::kDynamic; |
71 | } |
72 | |
73 | // Calculate the product of all elements in 'newShape' except for the -1 |
74 | // placeholder, which we discard by negating the result. |
75 | int64_t totalSizeNoPlaceholder = -std::accumulate( |
76 | first: newShape.begin(), last: newShape.end(), init: 1, binary_op: std::multiplies<int64_t>()); |
77 | |
78 | // If there is a 0 component in 'newShape', resolve the placeholder as 0. |
79 | if (totalSizeNoPlaceholder == 0) |
80 | return 0; |
81 | |
82 | // Resolve the placeholder as the quotient between the total tensor size and |
83 | // the product of all other sizes. |
84 | return totalSize / totalSizeNoPlaceholder; |
85 | }); |
86 | |
87 | // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically |
88 | // shaped input from being reshaped into a statically shaped result. We may |
89 | // simply turn the first result dimension dynamic to address this. |
90 | if (!inputIsStatic && resultIsStatic) |
91 | resultShape[0] = ShapedType::kDynamic; |
92 | |
93 | // The 'tensor.expand_shape' op also forbids a statically shaped input from |
94 | // being reshaped into a dynamically shaped result, but the placeholder |
95 | // inference algorithm above guarantees that this will never be the case. |
96 | assert(!inputIsStatic || resultIsStatic); |
97 | |
98 | // Create result type |
99 | return inputType.clone(resultShape); |
100 | } |
101 | |
102 | // Infer the result type of 'tensor.collapse_shape' in the collapse-expand |
103 | // pair emitted for a 'tosa.reshape' op. |
104 | TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) { |
105 | auto lhsShape = lhsType.getShape(); |
106 | auto rhsShape = rhsType.getShape(); |
107 | |
108 | if (lhsShape.empty() || rhsShape.empty()) |
109 | return lhsType.clone(ArrayRef<int64_t>{}); |
110 | |
111 | if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape)) |
112 | return lhsType.clone({ShapedType::kDynamic}); |
113 | |
114 | SmallVector<int64_t> intermediateShape; |
115 | unsigned currLhsDim = 0, currRhsDim = 0; |
116 | while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { |
117 | int64_t rhsSize = rhsShape[currRhsDim]; |
118 | int64_t lhsSize = lhsShape[currLhsDim]; |
119 | while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && |
120 | currRhsDim < rhsShape.size()) { |
121 | if (lhsSize < rhsSize) { |
122 | currLhsDim++; |
123 | if (currLhsDim < lhsShape.size()) { |
124 | lhsSize *= lhsShape[currLhsDim]; |
125 | } |
126 | } else { |
127 | currRhsDim++; |
128 | if (currRhsDim < rhsShape.size()) { |
129 | rhsSize *= rhsShape[currRhsDim]; |
130 | } |
131 | } |
132 | } |
133 | if (lhsSize == rhsSize) { |
134 | intermediateShape.push_back(Elt: lhsSize); |
135 | } |
136 | currRhsDim++; |
137 | currLhsDim++; |
138 | } |
139 | |
140 | // Static shapes are guaranteed to be compatible by the op verifier, so all |
141 | // leftover dimensions should be 1. |
142 | for (; currLhsDim < lhsShape.size(); currLhsDim++) { |
143 | assert(lhsShape[currLhsDim] == 1); |
144 | } |
145 | for (; currRhsDim < rhsShape.size(); currRhsDim++) { |
146 | assert(rhsShape[currRhsDim] == 1); |
147 | } |
148 | |
149 | return lhsType.clone(intermediateShape); |
150 | } |
151 | |
152 | SmallVector<ReassociationExprs> |
153 | createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) { |
154 | auto srcShape = cast<TensorType>(Val&: srcType).getShape(); |
155 | auto dstShape = cast<TensorType>(Val&: dstType).getShape(); |
156 | |
157 | if (srcShape.empty() || dstShape.empty()) |
158 | return {}; |
159 | |
160 | if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) { |
161 | assert(dstShape.size() == 1); |
162 | SmallVector<AffineExpr, 2> exprs; |
163 | for (auto i : llvm::seq<int64_t>(Size: srcShape.size())) |
164 | exprs.push_back(Elt: builder.getAffineDimExpr(position: i)); |
165 | return {exprs}; |
166 | } |
167 | |
168 | SmallVector<ReassociationExprs> reassociationMap(dstShape.size()); |
169 | unsigned currSrcDim = 0, currDstDim = 0; |
170 | while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { |
171 | int64_t dstSize = dstShape[currDstDim]; |
172 | int64_t srcSize = srcShape[currSrcDim]; |
173 | while (srcSize < dstSize && currSrcDim < srcShape.size()) { |
174 | reassociationMap[currDstDim].push_back( |
175 | Elt: builder.getAffineDimExpr(position: currSrcDim++)); |
176 | srcSize *= srcShape[currSrcDim]; |
177 | } |
178 | if (srcSize == dstSize) { |
179 | reassociationMap[currDstDim].push_back( |
180 | Elt: builder.getAffineDimExpr(position: currSrcDim++)); |
181 | // If the next dim in collapsedShape is not 1, treat subsequent dims in |
182 | // expandedShape which are 1 to be collapsed. |
183 | if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { |
184 | while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { |
185 | reassociationMap[currDstDim].push_back( |
186 | Elt: builder.getAffineDimExpr(position: currSrcDim++)); |
187 | } |
188 | } |
189 | } |
190 | currDstDim++; |
191 | } |
192 | |
193 | // If the source and target shapes are compatible, both iterators must have |
194 | // reached the end. This condition is guaranteed by the op verifier for |
195 | // static shapes. |
196 | assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size()); |
197 | return reassociationMap; |
198 | } |
199 | |
200 | // Create a tensor.collapse_shape op that reshapes the input into the given |
201 | // result type. |
202 | Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType, |
203 | Value input) { |
204 | auto reassociationMap = |
205 | createReassociationMapForCollapse(builder, srcType: input.getType(), dstType: resultType); |
206 | return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input, |
207 | reassociationMap); |
208 | } |
209 | |
210 | // Create a tensor.expand_shape op that reshapes the input into the given result |
211 | // type. |
212 | Value createExpand(OpBuilder &builder, Location loc, TensorType resultType, |
213 | Value input) { |
214 | auto reassociationMap = |
215 | createReassociationMapForCollapse(builder, srcType: resultType, dstType: input.getType()); |
216 | return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input, |
217 | reassociationMap); |
218 | } |
219 | |
220 | class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> { |
221 | public: |
222 | using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern; |
223 | |
224 | LogicalResult |
225 | matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, |
226 | ConversionPatternRewriter &rewriter) const final { |
227 | auto loc = reshape.getLoc(); |
228 | auto resultType = reshape.getResult().getType(); |
229 | auto input = reshape.getInput1(); |
230 | auto newShape = reshape.getNewShape(); |
231 | |
232 | // Infer all intermediate types |
233 | auto inputType = inferReshapeInputType(input, newShape); |
234 | auto expandedType = inferReshapeExpandedType(inputType, newShape); |
235 | auto collapsedType = inferReshapeCollapsedType(inputType, expandedType); |
236 | |
237 | // Cast input if needed |
238 | auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input); |
239 | |
240 | // Emit collaspe-expand pair |
241 | auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput); |
242 | auto expanded = createExpand(rewriter, loc, expandedType, collapsed); |
243 | |
244 | // Cast to final result type if needed |
245 | auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded); |
246 | rewriter.replaceOp(reshape, result); |
247 | return success(); |
248 | } |
249 | }; |
250 | |
251 | class SliceConverter : public OpConversionPattern<tosa::SliceOp> { |
252 | public: |
253 | using OpConversionPattern<tosa::SliceOp>::OpConversionPattern; |
254 | |
255 | LogicalResult |
256 | matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor, |
257 | ConversionPatternRewriter &rewriter) const final { |
258 | Location loc = sliceOp.getLoc(); |
259 | Value input = adaptor.getInput(); |
260 | ShapedType resultType = cast<ShapedType>(sliceOp.getType()); |
261 | if (llvm::isa<UnrankedTensorType>(resultType)) |
262 | return failure(); |
263 | SmallVector<int64_t> strides, sizes; |
264 | ArrayRef<int64_t> starts = sliceOp.getStart(); |
265 | strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1); |
266 | |
267 | SmallVector<Value> dynSizes; |
268 | for (const auto &i : llvm::enumerate(sliceOp.getSize())) { |
269 | int64_t size = i.value(); |
270 | size_t index = i.index(); |
271 | sizes.push_back(size == -1 ? ShapedType::kDynamic : size); |
272 | if (!ShapedType::isDynamic(sizes.back())) |
273 | continue; |
274 | |
275 | auto dim = rewriter.create<tensor::DimOp>(loc, input, index); |
276 | auto offset = rewriter.create<arith::ConstantOp>( |
277 | loc, rewriter.getIndexAttr(starts[index])); |
278 | dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset)); |
279 | } |
280 | |
281 | auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>( |
282 | sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, |
283 | ValueRange({}), rewriter.getDenseI64ArrayAttr(starts), |
284 | rewriter.getDenseI64ArrayAttr(sizes), |
285 | rewriter.getDenseI64ArrayAttr(strides)); |
286 | |
287 | rewriter.replaceOp(sliceOp, newSliceOp.getResult()); |
288 | return success(); |
289 | } |
290 | }; |
291 | |
292 | class PadConverter : public OpRewritePattern<tosa::PadOp> { |
293 | public: |
294 | using OpRewritePattern<tosa::PadOp>::OpRewritePattern; |
295 | |
296 | LogicalResult matchAndRewrite(tosa::PadOp padOp, |
297 | PatternRewriter &rewriter) const final { |
298 | auto loc = padOp.getLoc(); |
299 | auto input = padOp.getInput1(); |
300 | auto padding = padOp.getPadding(); |
301 | |
302 | ShapedType inputTy = cast<ShapedType>(input.getType()); |
303 | Type elementTy = inputTy.getElementType(); |
304 | int64_t rank = inputTy.getRank(); |
305 | |
306 | // Setup the default constantAttr. |
307 | |
308 | Value padConstant; |
309 | |
310 | if (padOp.getPadConst()) { |
311 | padConstant = rewriter.createOrFold<tensor::ExtractOp>( |
312 | loc, padOp.getPadConst(), ValueRange({})); |
313 | } else { |
314 | TypedAttr constantAttr; |
315 | if (isa<FloatType>(Val: elementTy)) { |
316 | constantAttr = rewriter.getFloatAttr(elementTy, 0.0); |
317 | } else if (isa<IntegerType>(Val: elementTy) && !padOp.getQuantizationInfo()) { |
318 | constantAttr = rewriter.getIntegerAttr(elementTy, 0); |
319 | } else if (isa<IntegerType>(Val: elementTy) && padOp.getQuantizationInfo()) { |
320 | int64_t value = padOp.getQuantizationInfo()->getInputZp(); |
321 | constantAttr = rewriter.getIntegerAttr(elementTy, value); |
322 | } |
323 | if (constantAttr) |
324 | padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr); |
325 | } |
326 | |
327 | if (!padConstant) { |
328 | return rewriter.notifyMatchFailure( |
329 | padOp, "tosa.pad was unable to determine the pad constant value." ); |
330 | } |
331 | |
332 | Value lowIndex = |
333 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); |
334 | Value highIndex = |
335 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1)); |
336 | |
337 | SmallVector<OpFoldResult, 3> lowValues; |
338 | SmallVector<OpFoldResult, 3> highValues; |
339 | |
340 | lowValues.reserve(N: rank); |
341 | highValues.reserve(N: rank); |
342 | |
343 | for (int i = 0; i < rank; i++) { |
344 | Value inputIndex = rewriter.create<arith::ConstantIndexOp>(loc, i); |
345 | Value lowVal = rewriter.createOrFold<tensor::ExtractOp>( |
346 | loc, padding, ValueRange({inputIndex, lowIndex})); |
347 | Value highVal = rewriter.createOrFold<tensor::ExtractOp>( |
348 | loc, padding, ValueRange({inputIndex, highIndex})); |
349 | |
350 | lowVal = rewriter.createOrFold<arith::IndexCastOp>( |
351 | loc, rewriter.getIndexType(), lowVal); |
352 | highVal = rewriter.createOrFold<arith::IndexCastOp>( |
353 | loc, rewriter.getIndexType(), highVal); |
354 | |
355 | lowValues.push_back(Elt: lowVal); |
356 | highValues.push_back(Elt: highVal); |
357 | } |
358 | |
359 | auto newPadOp = rewriter.create<tensor::PadOp>( |
360 | loc, padOp.getType(), input, lowValues, highValues, padConstant); |
361 | |
362 | rewriter.replaceOp(padOp, newPadOp.getResult()); |
363 | return success(); |
364 | } |
365 | }; |
366 | |
367 | struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> { |
368 | using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern; |
369 | |
370 | LogicalResult |
371 | matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, |
372 | ConversionPatternRewriter &rewriter) const override { |
373 | auto resultType = dyn_cast<RankedTensorType>(op.getType()); |
374 | |
375 | Location loc = op.getLoc(); |
376 | int axis = op.getAxis(); |
377 | Value axisValue = |
378 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis)); |
379 | int64_t rank = resultType.getRank(); |
380 | |
381 | SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); |
382 | SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); |
383 | SmallVector<OpFoldResult> sizes = |
384 | tensor::getMixedSizes(builder&: rewriter, loc: op.getLoc(), value: adaptor.getOperands()[0]); |
385 | |
386 | // Pre-compute the offsets along the axis dimension. |
387 | // The axisOffsets will be of size rank + 1, where the last value |
388 | // will hold the total size of the tensor along the 'axis' dimension. |
389 | SmallVector<OpFoldResult> axisOffsets; |
390 | axisOffsets.push_back(rewriter.getIndexAttr(0)); |
391 | axisOffsets.push_back(Elt: sizes[axis]); |
392 | |
393 | for (auto arg : adaptor.getOperands().drop_front()) { |
394 | auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue); |
395 | auto currentOffset = |
396 | getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back()); |
397 | auto total = |
398 | rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size); |
399 | axisOffsets.push_back(getAsOpFoldResult(total)); |
400 | } |
401 | sizes[axis] = axisOffsets.back(); |
402 | |
403 | // Compute the dynamic sizes of the tensor.empty operation. |
404 | // This is based off of the specified result type of the tosa.concat |
405 | // operation, since we don't want to change the result type of the operation |
406 | // during the conversion. |
407 | SmallVector<Value> dynDims; |
408 | for (int64_t i = 0; i < rank; ++i) { |
409 | if (resultType.isDynamicDim(i)) { |
410 | dynDims.push_back( |
411 | Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: sizes[i])); |
412 | } |
413 | } |
414 | |
415 | Value result = rewriter.create<tensor::EmptyOp>( |
416 | loc, resultType.getShape(), resultType.getElementType(), dynDims); |
417 | |
418 | for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) { |
419 | auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg); |
420 | offsets[axis] = offset; |
421 | result = rewriter.createOrFold<tensor::InsertSliceOp>( |
422 | loc, arg, result, offsets, sizes, strides); |
423 | } |
424 | rewriter.replaceOp(op, result); |
425 | return success(); |
426 | } |
427 | }; |
428 | |
429 | } // namespace |
430 | |
431 | void mlir::tosa::populateTosaToTensorConversionPatterns( |
432 | RewritePatternSet *patterns) { |
433 | patterns->add< |
434 | ConcatConverter, |
435 | PadConverter, |
436 | ReshapeConverter, |
437 | SliceConverter |
438 | >(arg: patterns->getContext()); |
439 | } |
440 | |