1 | //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // These rewriters lower from the Tosa to the Tensor dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | #include "mlir/Dialect/Tensor/Utils/Utils.h" |
18 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
19 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
20 | #include "mlir/IR/PatternMatch.h" |
21 | #include "mlir/Transforms/DialectConversion.h" |
22 | |
23 | #include <numeric> |
24 | |
25 | using namespace mlir; |
26 | using namespace tosa; |
27 | |
28 | namespace { |
29 | |
30 | // Infer the type to which the input of a 'tosa.reshape' op must be cast when |
31 | // lowered. |
32 | TensorType inferReshapeInputType(TypedValue<TensorType> input, |
33 | ArrayRef<int64_t> newShape) { |
34 | // No need to cast input for non-empty target shape |
35 | if (!newShape.empty()) |
36 | return input.getType(); |
37 | |
38 | // The input type must be cast into a tensor with the same rank and all static |
39 | // dimensions set to 1. This prevents the generation of a |
40 | // tensor.collapse_shape op that converts a dynamically shaped tensor into a |
41 | // 0D tensor. While such construct is not incorrect on its own, bufferization |
42 | // cannot properly handle it at the moment, so we avoid it. |
43 | SmallVector<int64_t> shape(input.getType().getRank(), 1); |
44 | return input.getType().clone(shape); |
45 | } |
46 | |
47 | // Infer the result type of 'tensor.expand_shape' in the collapse-expand |
48 | // pair emitted for a 'tosa.reshape' op. |
49 | TensorType inferReshapeExpandedType(TensorType inputType, |
50 | ArrayRef<int64_t> newShape) { |
51 | // Special case for 0D output tensor. Note: Watch out when using Type::clone() |
52 | // with just '{}', as it will invoke the incorrect overload. |
53 | if (newShape.empty()) |
54 | return inputType.clone(ArrayRef<int64_t>{}); |
55 | |
56 | // Check if the input is static, and if so, get its total size |
57 | bool inputIsStatic = inputType.hasStaticShape(); |
58 | int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1; |
59 | |
60 | // Compute result shape |
61 | auto resultShape = |
62 | llvm::map_to_vector(C&: newShape, F: [&](int64_t size) -> int64_t { |
63 | // If this is not a placeholder, do not change it. |
64 | if (size >= 0) |
65 | return size; |
66 | |
67 | // If we do not know the total size of the tensor, keep this dimension |
68 | // dynamic in the result shape. |
69 | if (!inputIsStatic) |
70 | return ShapedType::kDynamic; |
71 | |
72 | // Calculate the product of all elements in 'newShape' except for the -1 |
73 | // placeholder, which we discard by negating the result. |
74 | int64_t totalSizeNoPlaceholder = -std::accumulate( |
75 | first: newShape.begin(), last: newShape.end(), init: 1, binary_op: std::multiplies<int64_t>()); |
76 | |
77 | // If there is a 0 component in 'newShape', resolve the placeholder as |
78 | // 0. |
79 | if (totalSizeNoPlaceholder == 0) |
80 | return 0; |
81 | |
82 | // Resolve the placeholder as the quotient between the total tensor size |
83 | // and the product of all other sizes. |
84 | return totalSize / totalSizeNoPlaceholder; |
85 | }); |
86 | |
87 | bool resultIsStatic = !ShapedType::isDynamicShape(resultShape); |
88 | |
89 | // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically |
90 | // shaped input from being reshaped into a statically shaped result. We may |
91 | // simply turn the first result dimension dynamic to address this. |
92 | if (!inputIsStatic && resultIsStatic) |
93 | resultShape[0] = ShapedType::kDynamic; |
94 | |
95 | // The 'tensor.expand_shape' op also forbids a statically shaped input from |
96 | // being reshaped into a dynamically shaped result, but the placeholder |
97 | // inference algorithm above guarantees that this will never be the case. |
98 | assert(!inputIsStatic || resultIsStatic); |
99 | |
100 | // Create result type |
101 | return inputType.clone(resultShape); |
102 | } |
103 | |
104 | // Infer the result type of 'tensor.collapse_shape' in the collapse-expand |
105 | // pair emitted for a 'tosa.reshape' op. |
106 | TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) { |
107 | auto lhsShape = lhsType.getShape(); |
108 | auto rhsShape = rhsType.getShape(); |
109 | |
110 | if (lhsShape.empty() || rhsShape.empty()) |
111 | return lhsType.clone(ArrayRef<int64_t>{}); |
112 | |
113 | if (ShapedType::isDynamicShape(lhsShape) || |
114 | ShapedType::isDynamicShape(rhsShape)) |
115 | return lhsType.clone({ShapedType::kDynamic}); |
116 | |
117 | SmallVector<int64_t> intermediateShape; |
118 | unsigned currLhsDim = 0, currRhsDim = 0; |
119 | while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { |
120 | int64_t rhsSize = rhsShape[currRhsDim]; |
121 | int64_t lhsSize = lhsShape[currLhsDim]; |
122 | while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && |
123 | currRhsDim < rhsShape.size()) { |
124 | if (lhsSize < rhsSize) { |
125 | currLhsDim++; |
126 | if (currLhsDim < lhsShape.size()) { |
127 | lhsSize *= lhsShape[currLhsDim]; |
128 | } |
129 | } else { |
130 | currRhsDim++; |
131 | if (currRhsDim < rhsShape.size()) { |
132 | rhsSize *= rhsShape[currRhsDim]; |
133 | } |
134 | } |
135 | } |
136 | if (lhsSize == rhsSize) { |
137 | intermediateShape.push_back(Elt: lhsSize); |
138 | } |
139 | currRhsDim++; |
140 | currLhsDim++; |
141 | } |
142 | |
143 | // Static shapes are guaranteed to be compatible by the op verifier, so all |
144 | // leftover dimensions should be 1. |
145 | for (; currLhsDim < lhsShape.size(); currLhsDim++) { |
146 | assert(lhsShape[currLhsDim] == 1); |
147 | } |
148 | for (; currRhsDim < rhsShape.size(); currRhsDim++) { |
149 | assert(rhsShape[currRhsDim] == 1); |
150 | } |
151 | |
152 | return lhsType.clone(intermediateShape); |
153 | } |
154 | |
155 | SmallVector<ReassociationExprs> |
156 | createReassociationMapForCollapse(OpBuilder &builder, Type srcType, |
157 | Type dstType) { |
158 | auto srcShape = cast<TensorType>(Val&: srcType).getShape(); |
159 | auto dstShape = cast<TensorType>(Val&: dstType).getShape(); |
160 | |
161 | if (srcShape.empty() || dstShape.empty()) |
162 | return {}; |
163 | |
164 | if (ShapedType::isDynamicShape(srcShape) || |
165 | ShapedType::isDynamicShape(dstShape)) { |
166 | assert(dstShape.size() == 1); |
167 | SmallVector<AffineExpr, 2> exprs; |
168 | for (auto i : llvm::seq<int64_t>(Size: srcShape.size())) |
169 | exprs.push_back(Elt: builder.getAffineDimExpr(position: i)); |
170 | return {exprs}; |
171 | } |
172 | |
173 | SmallVector<ReassociationExprs> reassociationMap(dstShape.size()); |
174 | unsigned currSrcDim = 0, currDstDim = 0; |
175 | while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { |
176 | int64_t dstSize = dstShape[currDstDim]; |
177 | int64_t srcSize = srcShape[currSrcDim]; |
178 | while (srcSize < dstSize && currSrcDim < srcShape.size()) { |
179 | reassociationMap[currDstDim].push_back( |
180 | Elt: builder.getAffineDimExpr(position: currSrcDim++)); |
181 | srcSize *= srcShape[currSrcDim]; |
182 | } |
183 | if (srcSize == dstSize) { |
184 | reassociationMap[currDstDim].push_back( |
185 | Elt: builder.getAffineDimExpr(position: currSrcDim++)); |
186 | // If the next dim in collapsedShape is not 1, treat subsequent dims in |
187 | // expandedShape which are 1 to be collapsed. |
188 | if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { |
189 | while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { |
190 | reassociationMap[currDstDim].push_back( |
191 | Elt: builder.getAffineDimExpr(position: currSrcDim++)); |
192 | } |
193 | } |
194 | } |
195 | currDstDim++; |
196 | } |
197 | |
198 | // If the source and target shapes are compatible, both iterators must have |
199 | // reached the end. This condition is guaranteed by the op verifier for |
200 | // static shapes. |
201 | assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size()); |
202 | return reassociationMap; |
203 | } |
204 | |
205 | // Create a tensor.collapse_shape op that reshapes the input into the given |
206 | // result type. |
207 | Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType, |
208 | Value input) { |
209 | auto reassociationMap = |
210 | createReassociationMapForCollapse(builder, srcType: input.getType(), dstType: resultType); |
211 | return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input, |
212 | reassociationMap); |
213 | } |
214 | |
215 | // Create a tensor.expand_shape op that reshapes the input into the given result |
216 | // type. |
217 | Value createExpand(OpBuilder &builder, Location loc, TensorType resultType, |
218 | Value input) { |
219 | auto reassociationMap = |
220 | createReassociationMapForCollapse(builder, srcType: resultType, dstType: input.getType()); |
221 | return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input, |
222 | reassociationMap); |
223 | } |
224 | |
225 | class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> { |
226 | public: |
227 | using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern; |
228 | |
229 | LogicalResult |
230 | matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, |
231 | ConversionPatternRewriter &rewriter) const final { |
232 | auto loc = reshape.getLoc(); |
233 | auto resultType = cast_if_present<ShapedType>( |
234 | getTypeConverter()->convertType(reshape.getType())); |
235 | if (!resultType) { |
236 | return rewriter.notifyMatchFailure(reshape.getLoc(), |
237 | "could not convert result type" ); |
238 | } |
239 | auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1()); |
240 | if (!input) { |
241 | return rewriter.notifyMatchFailure(reshape.getLoc(), |
242 | "expected input type to be tensor" ); |
243 | } |
244 | |
245 | llvm::SmallVector<int64_t> newShape; |
246 | if (!tosa::getConstShapeValues(op: reshape.getShape().getDefiningOp(), |
247 | result_shape&: newShape)) { |
248 | return failure(); |
249 | } |
250 | |
251 | // Infer all intermediate types |
252 | auto inputType = inferReshapeInputType(input, newShape); |
253 | auto expandedType = inferReshapeExpandedType(inputType, newShape); |
254 | auto collapsedType = inferReshapeCollapsedType(inputType, expandedType); |
255 | |
256 | // Cast input if needed |
257 | auto castInput = |
258 | rewriter.createOrFold<tensor::CastOp>(loc, inputType, input); |
259 | |
260 | // Emit collaspe-expand pair |
261 | auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput); |
262 | auto expanded = createExpand(rewriter, loc, expandedType, collapsed); |
263 | |
264 | // Cast to final result type if needed |
265 | auto result = |
266 | rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded); |
267 | rewriter.replaceOp(reshape, result); |
268 | return success(); |
269 | } |
270 | }; |
271 | |
272 | class SliceConverter : public OpConversionPattern<tosa::SliceOp> { |
273 | public: |
274 | using OpConversionPattern<tosa::SliceOp>::OpConversionPattern; |
275 | |
276 | LogicalResult |
277 | matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor, |
278 | ConversionPatternRewriter &rewriter) const final { |
279 | Location loc = sliceOp.getLoc(); |
280 | Value input = adaptor.getInput1(); |
281 | ShapedType resultType = cast<ShapedType>(sliceOp.getType()); |
282 | if (llvm::isa<UnrankedTensorType>(resultType)) |
283 | return failure(); |
284 | |
285 | ElementsAttr startElems; |
286 | ElementsAttr sizeElems; |
287 | |
288 | if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems))) |
289 | return rewriter.notifyMatchFailure( |
290 | sliceOp, "start of slice must be a static ranked shape" ); |
291 | |
292 | if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) |
293 | return rewriter.notifyMatchFailure( |
294 | sliceOp, "size of slice must be a static ranked shape" ); |
295 | |
296 | llvm::SmallVector<int64_t> sliceStarts = |
297 | llvm::to_vector(startElems.getValues<int64_t>()); |
298 | llvm::SmallVector<int64_t> sliceSizes = |
299 | llvm::to_vector(sizeElems.getValues<int64_t>()); |
300 | |
301 | SmallVector<int64_t> strides, sizes; |
302 | strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1); |
303 | |
304 | SmallVector<Value> dynSizes; |
305 | for (const auto &i : llvm::enumerate(sliceSizes)) { |
306 | int64_t size = i.value(); |
307 | size_t index = i.index(); |
308 | sizes.push_back(size == -1 ? ShapedType::kDynamic : size); |
309 | if (!ShapedType::isDynamic(sizes.back())) |
310 | continue; |
311 | |
312 | auto dim = rewriter.create<tensor::DimOp>(loc, input, index); |
313 | auto offset = rewriter.create<arith::ConstantOp>( |
314 | loc, rewriter.getIndexAttr(sliceStarts[index])); |
315 | dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset)); |
316 | } |
317 | |
318 | auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>( |
319 | sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, |
320 | ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts), |
321 | rewriter.getDenseI64ArrayAttr(sizes), |
322 | rewriter.getDenseI64ArrayAttr(strides)); |
323 | |
324 | rewriter.replaceOp(sliceOp, newSliceOp.getResult()); |
325 | |
326 | // Remove const_shape ops when it no longer has use point. |
327 | Operation *startConstShape = sliceOp.getStart().getDefiningOp(); |
328 | if (startConstShape->getResult(idx: 0).hasOneUse()) |
329 | rewriter.eraseOp(op: startConstShape); |
330 | |
331 | Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); |
332 | if (sizeConstShape->getResult(idx: 0).hasOneUse()) |
333 | rewriter.eraseOp(op: sizeConstShape); |
334 | |
335 | return success(); |
336 | } |
337 | }; |
338 | |
339 | class PadConverter : public OpConversionPattern<tosa::PadOp> { |
340 | public: |
341 | using OpConversionPattern::OpConversionPattern; |
342 | |
343 | LogicalResult |
344 | matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor, |
345 | ConversionPatternRewriter &rewriter) const final { |
346 | auto loc = padOp.getLoc(); |
347 | auto input = padOp.getInput1(); |
348 | |
349 | ElementsAttr paddingElems; |
350 | if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) { |
351 | return rewriter.notifyMatchFailure( |
352 | padOp, "padding must be a static shape value" ); |
353 | } |
354 | llvm::SmallVector<int64_t> paddingVals; |
355 | for (auto idx : paddingElems.getValues<IntegerAttr>()) { |
356 | paddingVals.push_back(static_cast<int64_t>(idx.getInt())); |
357 | } |
358 | |
359 | ShapedType inputTy = cast<ShapedType>(input.getType()); |
360 | int64_t rank = inputTy.getRank(); |
361 | |
362 | // Setup the default constantAttr. |
363 | |
364 | Value padConstant = rewriter.createOrFold<tensor::ExtractOp>( |
365 | loc, padOp.getPadConst(), |
366 | ValueRange({rewriter.create<arith::ConstantIndexOp>(loc, 0)})); |
367 | |
368 | if (!padConstant) { |
369 | return rewriter.notifyMatchFailure( |
370 | padOp, "tosa.pad was unable to determine the pad constant value." ); |
371 | } |
372 | |
373 | SmallVector<OpFoldResult, 3> lowValues; |
374 | SmallVector<OpFoldResult, 3> highValues; |
375 | |
376 | lowValues.reserve(N: rank); |
377 | highValues.reserve(N: rank); |
378 | |
379 | for (int i = 0; i < rank; i++) { |
380 | Value lowVal = rewriter.create<arith::ConstantOp>( |
381 | loc, rewriter.getIndexAttr(paddingVals[2 * i])); |
382 | Value highVal = rewriter.create<arith::ConstantOp>( |
383 | loc, rewriter.getIndexAttr(paddingVals[2 * i + 1])); |
384 | lowValues.push_back(Elt: lowVal); |
385 | highValues.push_back(Elt: highVal); |
386 | } |
387 | |
388 | auto newPadOp = rewriter.create<tensor::PadOp>( |
389 | loc, padOp.getType(), input, lowValues, highValues, padConstant); |
390 | |
391 | rewriter.replaceOp(padOp, newPadOp.getResult()); |
392 | return success(); |
393 | } |
394 | }; |
395 | |
396 | struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> { |
397 | using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern; |
398 | |
399 | LogicalResult |
400 | matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, |
401 | ConversionPatternRewriter &rewriter) const override { |
402 | auto resultType = dyn_cast<RankedTensorType>(op.getType()); |
403 | |
404 | Location loc = op.getLoc(); |
405 | int axis = op.getAxis(); |
406 | Value axisValue = |
407 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis)); |
408 | int64_t rank = resultType.getRank(); |
409 | |
410 | SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); |
411 | SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); |
412 | SmallVector<OpFoldResult> sizes = |
413 | tensor::getMixedSizes(builder&: rewriter, loc: op.getLoc(), value: adaptor.getOperands()[0]); |
414 | |
415 | // Pre-compute the offsets along the axis dimension. |
416 | // The axisOffsets will be of size rank + 1, where the last value |
417 | // will hold the total size of the tensor along the 'axis' dimension. |
418 | SmallVector<OpFoldResult> axisOffsets; |
419 | axisOffsets.push_back(rewriter.getIndexAttr(0)); |
420 | axisOffsets.push_back(Elt: sizes[axis]); |
421 | |
422 | for (auto arg : adaptor.getOperands().drop_front()) { |
423 | auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue); |
424 | auto currentOffset = |
425 | getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back()); |
426 | auto total = |
427 | rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size); |
428 | axisOffsets.push_back(getAsOpFoldResult(total)); |
429 | } |
430 | sizes[axis] = axisOffsets.back(); |
431 | |
432 | // Compute the dynamic sizes of the tensor.empty operation. |
433 | // This is based off of the specified result type of the tosa.concat |
434 | // operation, since we don't want to change the result type of the operation |
435 | // during the conversion. |
436 | SmallVector<Value> dynDims; |
437 | for (int64_t i = 0; i < rank; ++i) { |
438 | if (resultType.isDynamicDim(i)) { |
439 | dynDims.push_back( |
440 | Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: sizes[i])); |
441 | } |
442 | } |
443 | |
444 | Value result = rewriter.create<tensor::EmptyOp>( |
445 | loc, resultType.getShape(), resultType.getElementType(), dynDims); |
446 | |
447 | for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) { |
448 | auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg); |
449 | offsets[axis] = offset; |
450 | result = rewriter.createOrFold<tensor::InsertSliceOp>( |
451 | loc, arg, result, offsets, sizes, strides); |
452 | } |
453 | rewriter.replaceOp(op, result); |
454 | return success(); |
455 | } |
456 | }; |
457 | |
458 | } // namespace |
459 | |
460 | void mlir::tosa::populateTosaToTensorConversionPatterns( |
461 | const TypeConverter &converter, RewritePatternSet *patterns) { |
462 | patterns |
463 | ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>( |
464 | arg: converter, args: patterns->getContext()); |
465 | } |
466 | |