1//===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Tensor dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Tensor/Utils/Utils.h"
18#include "mlir/Dialect/Tosa/IR/TosaOps.h"
19#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/Transforms/DialectConversion.h"
22
23#include <numeric>
24
25using namespace mlir;
26using namespace tosa;
27
28namespace {
29
30// Infer the type to which the input of a 'tosa.reshape' op must be cast when
31// lowered.
32TensorType inferReshapeInputType(TypedValue<TensorType> input,
33 ArrayRef<int64_t> newShape) {
34 // No need to cast input for non-empty target shape
35 if (!newShape.empty())
36 return input.getType();
37
38 // The input type must be cast into a tensor with the same rank and all static
39 // dimensions set to 1. This prevents the generation of a
40 // tensor.collapse_shape op that converts a dynamically shaped tensor into a
41 // 0D tensor. While such construct is not incorrect on its own, bufferization
42 // cannot properly handle it at the moment, so we avoid it.
43 SmallVector<int64_t> shape(input.getType().getRank(), 1);
44 return input.getType().clone(shape);
45}
46
47// Infer the result type of 'tensor.expand_shape' in the collapse-expand
48// pair emitted for a 'tosa.reshape' op.
49TensorType inferReshapeExpandedType(TensorType inputType,
50 ArrayRef<int64_t> newShape) {
51 // Special case for 0D output tensor. Note: Watch out when using Type::clone()
52 // with just '{}', as it will invoke the incorrect overload.
53 if (newShape.empty())
54 return inputType.clone(ArrayRef<int64_t>{});
55
56 // Check if the input is static, and if so, get its total size
57 bool inputIsStatic = inputType.hasStaticShape();
58 int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
59
60 // Compute result shape
61 auto resultShape =
62 llvm::map_to_vector(C&: newShape, F: [&](int64_t size) -> int64_t {
63 // If this is not a placeholder, do not change it.
64 if (size >= 0)
65 return size;
66
67 // If we do not know the total size of the tensor, keep this dimension
68 // dynamic in the result shape.
69 if (!inputIsStatic)
70 return ShapedType::kDynamic;
71
72 // Calculate the product of all elements in 'newShape' except for the -1
73 // placeholder, which we discard by negating the result.
74 int64_t totalSizeNoPlaceholder = -std::accumulate(
75 first: newShape.begin(), last: newShape.end(), init: 1, binary_op: std::multiplies<int64_t>());
76
77 // If there is a 0 component in 'newShape', resolve the placeholder as
78 // 0.
79 if (totalSizeNoPlaceholder == 0)
80 return 0;
81
82 // Resolve the placeholder as the quotient between the total tensor size
83 // and the product of all other sizes.
84 return totalSize / totalSizeNoPlaceholder;
85 });
86
87 bool resultIsStatic = !ShapedType::isDynamicShape(resultShape);
88
89 // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
90 // shaped input from being reshaped into a statically shaped result. We may
91 // simply turn the first result dimension dynamic to address this.
92 if (!inputIsStatic && resultIsStatic)
93 resultShape[0] = ShapedType::kDynamic;
94
95 // The 'tensor.expand_shape' op also forbids a statically shaped input from
96 // being reshaped into a dynamically shaped result, but the placeholder
97 // inference algorithm above guarantees that this will never be the case.
98 assert(!inputIsStatic || resultIsStatic);
99
100 // Create result type
101 return inputType.clone(resultShape);
102}
103
104// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
105// pair emitted for a 'tosa.reshape' op.
106TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
107 auto lhsShape = lhsType.getShape();
108 auto rhsShape = rhsType.getShape();
109
110 if (lhsShape.empty() || rhsShape.empty())
111 return lhsType.clone(ArrayRef<int64_t>{});
112
113 if (ShapedType::isDynamicShape(lhsShape) ||
114 ShapedType::isDynamicShape(rhsShape))
115 return lhsType.clone({ShapedType::kDynamic});
116
117 SmallVector<int64_t> intermediateShape;
118 unsigned currLhsDim = 0, currRhsDim = 0;
119 while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
120 int64_t rhsSize = rhsShape[currRhsDim];
121 int64_t lhsSize = lhsShape[currLhsDim];
122 while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
123 currRhsDim < rhsShape.size()) {
124 if (lhsSize < rhsSize) {
125 currLhsDim++;
126 if (currLhsDim < lhsShape.size()) {
127 lhsSize *= lhsShape[currLhsDim];
128 }
129 } else {
130 currRhsDim++;
131 if (currRhsDim < rhsShape.size()) {
132 rhsSize *= rhsShape[currRhsDim];
133 }
134 }
135 }
136 if (lhsSize == rhsSize) {
137 intermediateShape.push_back(Elt: lhsSize);
138 }
139 currRhsDim++;
140 currLhsDim++;
141 }
142
143 // Static shapes are guaranteed to be compatible by the op verifier, so all
144 // leftover dimensions should be 1.
145 for (; currLhsDim < lhsShape.size(); currLhsDim++) {
146 assert(lhsShape[currLhsDim] == 1);
147 }
148 for (; currRhsDim < rhsShape.size(); currRhsDim++) {
149 assert(rhsShape[currRhsDim] == 1);
150 }
151
152 return lhsType.clone(intermediateShape);
153}
154
155SmallVector<ReassociationExprs>
156createReassociationMapForCollapse(OpBuilder &builder, Type srcType,
157 Type dstType) {
158 auto srcShape = cast<TensorType>(Val&: srcType).getShape();
159 auto dstShape = cast<TensorType>(Val&: dstType).getShape();
160
161 if (srcShape.empty() || dstShape.empty())
162 return {};
163
164 if (ShapedType::isDynamicShape(srcShape) ||
165 ShapedType::isDynamicShape(dstShape)) {
166 assert(dstShape.size() == 1);
167 SmallVector<AffineExpr, 2> exprs;
168 for (auto i : llvm::seq<int64_t>(Size: srcShape.size()))
169 exprs.push_back(Elt: builder.getAffineDimExpr(position: i));
170 return {exprs};
171 }
172
173 SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
174 unsigned currSrcDim = 0, currDstDim = 0;
175 while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
176 int64_t dstSize = dstShape[currDstDim];
177 int64_t srcSize = srcShape[currSrcDim];
178 while (srcSize < dstSize && currSrcDim < srcShape.size()) {
179 reassociationMap[currDstDim].push_back(
180 Elt: builder.getAffineDimExpr(position: currSrcDim++));
181 srcSize *= srcShape[currSrcDim];
182 }
183 if (srcSize == dstSize) {
184 reassociationMap[currDstDim].push_back(
185 Elt: builder.getAffineDimExpr(position: currSrcDim++));
186 // If the next dim in collapsedShape is not 1, treat subsequent dims in
187 // expandedShape which are 1 to be collapsed.
188 if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
189 while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
190 reassociationMap[currDstDim].push_back(
191 Elt: builder.getAffineDimExpr(position: currSrcDim++));
192 }
193 }
194 }
195 currDstDim++;
196 }
197
198 // If the source and target shapes are compatible, both iterators must have
199 // reached the end. This condition is guaranteed by the op verifier for
200 // static shapes.
201 assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
202 return reassociationMap;
203}
204
205// Create a tensor.collapse_shape op that reshapes the input into the given
206// result type.
207Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
208 Value input) {
209 auto reassociationMap =
210 createReassociationMapForCollapse(builder, srcType: input.getType(), dstType: resultType);
211 return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
212 reassociationMap);
213}
214
215// Create a tensor.expand_shape op that reshapes the input into the given result
216// type.
217Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
218 Value input) {
219 auto reassociationMap =
220 createReassociationMapForCollapse(builder, srcType: resultType, dstType: input.getType());
221 return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
222 reassociationMap);
223}
224
225class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
226public:
227 using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
228
229 LogicalResult
230 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
231 ConversionPatternRewriter &rewriter) const final {
232 auto loc = reshape.getLoc();
233 auto resultType = cast_if_present<ShapedType>(
234 getTypeConverter()->convertType(reshape.getType()));
235 if (!resultType) {
236 return rewriter.notifyMatchFailure(reshape.getLoc(),
237 "could not convert result type");
238 }
239 auto input = dyn_cast<TypedValue<TensorType>>(adaptor.getInput1());
240 if (!input) {
241 return rewriter.notifyMatchFailure(reshape.getLoc(),
242 "expected input type to be tensor");
243 }
244
245 llvm::SmallVector<int64_t> newShape;
246 if (!tosa::getConstShapeValues(op: reshape.getShape().getDefiningOp(),
247 result_shape&: newShape)) {
248 return failure();
249 }
250
251 // Infer all intermediate types
252 auto inputType = inferReshapeInputType(input, newShape);
253 auto expandedType = inferReshapeExpandedType(inputType, newShape);
254 auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
255
256 // Cast input if needed
257 auto castInput =
258 rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
259
260 // Emit collaspe-expand pair
261 auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
262 auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
263
264 // Cast to final result type if needed
265 auto result =
266 rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
267 rewriter.replaceOp(reshape, result);
268 return success();
269 }
270};
271
272class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
273public:
274 using OpConversionPattern<tosa::SliceOp>::OpConversionPattern;
275
276 LogicalResult
277 matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter) const final {
279 Location loc = sliceOp.getLoc();
280 Value input = adaptor.getInput1();
281 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
282 if (llvm::isa<UnrankedTensorType>(resultType))
283 return failure();
284
285 ElementsAttr startElems;
286 ElementsAttr sizeElems;
287
288 if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
289 return rewriter.notifyMatchFailure(
290 sliceOp, "start of slice must be a static ranked shape");
291
292 if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
293 return rewriter.notifyMatchFailure(
294 sliceOp, "size of slice must be a static ranked shape");
295
296 llvm::SmallVector<int64_t> sliceStarts =
297 llvm::to_vector(startElems.getValues<int64_t>());
298 llvm::SmallVector<int64_t> sliceSizes =
299 llvm::to_vector(sizeElems.getValues<int64_t>());
300
301 SmallVector<int64_t> strides, sizes;
302 strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
303
304 SmallVector<Value> dynSizes;
305 for (const auto &i : llvm::enumerate(sliceSizes)) {
306 int64_t size = i.value();
307 size_t index = i.index();
308 sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
309 if (!ShapedType::isDynamic(sizes.back()))
310 continue;
311
312 auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
313 auto offset = rewriter.create<arith::ConstantOp>(
314 loc, rewriter.getIndexAttr(sliceStarts[index]));
315 dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
316 }
317
318 auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
319 sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
320 ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts),
321 rewriter.getDenseI64ArrayAttr(sizes),
322 rewriter.getDenseI64ArrayAttr(strides));
323
324 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
325
326 // Remove const_shape ops when it no longer has use point.
327 Operation *startConstShape = sliceOp.getStart().getDefiningOp();
328 if (startConstShape->getResult(idx: 0).hasOneUse())
329 rewriter.eraseOp(op: startConstShape);
330
331 Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
332 if (sizeConstShape->getResult(idx: 0).hasOneUse())
333 rewriter.eraseOp(op: sizeConstShape);
334
335 return success();
336 }
337};
338
339class PadConverter : public OpConversionPattern<tosa::PadOp> {
340public:
341 using OpConversionPattern::OpConversionPattern;
342
343 LogicalResult
344 matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
345 ConversionPatternRewriter &rewriter) const final {
346 auto loc = padOp.getLoc();
347 auto input = padOp.getInput1();
348
349 ElementsAttr paddingElems;
350 if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
351 return rewriter.notifyMatchFailure(
352 padOp, "padding must be a static shape value");
353 }
354 llvm::SmallVector<int64_t> paddingVals;
355 for (auto idx : paddingElems.getValues<IntegerAttr>()) {
356 paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
357 }
358
359 ShapedType inputTy = cast<ShapedType>(input.getType());
360 int64_t rank = inputTy.getRank();
361
362 // Setup the default constantAttr.
363
364 Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
365 loc, padOp.getPadConst(),
366 ValueRange({rewriter.create<arith::ConstantIndexOp>(loc, 0)}));
367
368 if (!padConstant) {
369 return rewriter.notifyMatchFailure(
370 padOp, "tosa.pad was unable to determine the pad constant value.");
371 }
372
373 SmallVector<OpFoldResult, 3> lowValues;
374 SmallVector<OpFoldResult, 3> highValues;
375
376 lowValues.reserve(N: rank);
377 highValues.reserve(N: rank);
378
379 for (int i = 0; i < rank; i++) {
380 Value lowVal = rewriter.create<arith::ConstantOp>(
381 loc, rewriter.getIndexAttr(paddingVals[2 * i]));
382 Value highVal = rewriter.create<arith::ConstantOp>(
383 loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
384 lowValues.push_back(Elt: lowVal);
385 highValues.push_back(Elt: highVal);
386 }
387
388 auto newPadOp = rewriter.create<tensor::PadOp>(
389 loc, padOp.getType(), input, lowValues, highValues, padConstant);
390
391 rewriter.replaceOp(padOp, newPadOp.getResult());
392 return success();
393 }
394};
395
396struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
397 using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
398
399 LogicalResult
400 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
401 ConversionPatternRewriter &rewriter) const override {
402 auto resultType = dyn_cast<RankedTensorType>(op.getType());
403
404 Location loc = op.getLoc();
405 int axis = op.getAxis();
406 Value axisValue =
407 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis));
408 int64_t rank = resultType.getRank();
409
410 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
411 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
412 SmallVector<OpFoldResult> sizes =
413 tensor::getMixedSizes(builder&: rewriter, loc: op.getLoc(), value: adaptor.getOperands()[0]);
414
415 // Pre-compute the offsets along the axis dimension.
416 // The axisOffsets will be of size rank + 1, where the last value
417 // will hold the total size of the tensor along the 'axis' dimension.
418 SmallVector<OpFoldResult> axisOffsets;
419 axisOffsets.push_back(rewriter.getIndexAttr(0));
420 axisOffsets.push_back(Elt: sizes[axis]);
421
422 for (auto arg : adaptor.getOperands().drop_front()) {
423 auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
424 auto currentOffset =
425 getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
426 auto total =
427 rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
428 axisOffsets.push_back(getAsOpFoldResult(total));
429 }
430 sizes[axis] = axisOffsets.back();
431
432 // Compute the dynamic sizes of the tensor.empty operation.
433 // This is based off of the specified result type of the tosa.concat
434 // operation, since we don't want to change the result type of the operation
435 // during the conversion.
436 SmallVector<Value> dynDims;
437 for (int64_t i = 0; i < rank; ++i) {
438 if (resultType.isDynamicDim(i)) {
439 dynDims.push_back(
440 Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: sizes[i]));
441 }
442 }
443
444 Value result = rewriter.create<tensor::EmptyOp>(
445 loc, resultType.getShape(), resultType.getElementType(), dynDims);
446
447 for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
448 auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
449 offsets[axis] = offset;
450 result = rewriter.createOrFold<tensor::InsertSliceOp>(
451 loc, arg, result, offsets, sizes, strides);
452 }
453 rewriter.replaceOp(op, result);
454 return success();
455 }
456};
457
458} // namespace
459
460void mlir::tosa::populateTosaToTensorConversionPatterns(
461 const TypeConverter &converter, RewritePatternSet *patterns) {
462 patterns
463 ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
464 arg: converter, args: patterns->getContext());
465}
466

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp