1//===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Tensor dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Tensor/Utils/Utils.h"
18#include "mlir/Dialect/Tosa/IR/TosaOps.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Transforms/DialectConversion.h"
21
22#include <numeric>
23
24using namespace mlir;
25using namespace tosa;
26
27namespace {
28
29// Infer the type to which the input of a 'tosa.reshape' op must be cast when
30// lowered.
31TensorType inferReshapeInputType(TypedValue<TensorType> input,
32 ArrayRef<int64_t> newShape) {
33 // No need to cast input for non-empty target shape
34 if (!newShape.empty())
35 return input.getType();
36
37 // The input type must be cast into a tensor with the same rank and all static
38 // dimensions set to 1. This prevents the generation of a tensor.collapse_shape
39 // op that converts a dynamically shaped tensor into a 0D tensor. While such
40 // construct is not incorrect on its own, bufferization cannot properly handle
41 // it at the moment, so we avoid it.
42 SmallVector<int64_t> shape(input.getType().getRank(), 1);
43 return input.getType().clone(shape);
44}
45
46// Infer the result type of 'tensor.expand_shape' in the collapse-expand
47// pair emitted for a 'tosa.reshape' op.
48TensorType inferReshapeExpandedType(TensorType inputType,
49 ArrayRef<int64_t> newShape) {
50 // Special case for 0D output tensor. Note: Watch out when using Type::clone()
51 // with just '{}', as it will invoke the incorrect overload.
52 if (newShape.empty())
53 return inputType.clone(ArrayRef<int64_t>{});
54
55 // Check if the input is static, and if so, get its total size
56 bool inputIsStatic = inputType.hasStaticShape();
57 int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
58
59 // Compute result shape
60 bool resultIsStatic = true;
61 auto resultShape = llvm::map_to_vector(C&: newShape, F: [&](int64_t size) -> int64_t {
62 // If this is not a placeholder, do not change it
63 if (size >= 0)
64 return size;
65
66 // If we do not know the total size of the tensor, keep this dimension
67 // dynamic in the result shape.
68 if (!inputIsStatic) {
69 resultIsStatic = false;
70 return ShapedType::kDynamic;
71 }
72
73 // Calculate the product of all elements in 'newShape' except for the -1
74 // placeholder, which we discard by negating the result.
75 int64_t totalSizeNoPlaceholder = -std::accumulate(
76 first: newShape.begin(), last: newShape.end(), init: 1, binary_op: std::multiplies<int64_t>());
77
78 // If there is a 0 component in 'newShape', resolve the placeholder as 0.
79 if (totalSizeNoPlaceholder == 0)
80 return 0;
81
82 // Resolve the placeholder as the quotient between the total tensor size and
83 // the product of all other sizes.
84 return totalSize / totalSizeNoPlaceholder;
85 });
86
87 // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
88 // shaped input from being reshaped into a statically shaped result. We may
89 // simply turn the first result dimension dynamic to address this.
90 if (!inputIsStatic && resultIsStatic)
91 resultShape[0] = ShapedType::kDynamic;
92
93 // The 'tensor.expand_shape' op also forbids a statically shaped input from
94 // being reshaped into a dynamically shaped result, but the placeholder
95 // inference algorithm above guarantees that this will never be the case.
96 assert(!inputIsStatic || resultIsStatic);
97
98 // Create result type
99 return inputType.clone(resultShape);
100}
101
102// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
103// pair emitted for a 'tosa.reshape' op.
104TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
105 auto lhsShape = lhsType.getShape();
106 auto rhsShape = rhsType.getShape();
107
108 if (lhsShape.empty() || rhsShape.empty())
109 return lhsType.clone(ArrayRef<int64_t>{});
110
111 if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
112 return lhsType.clone({ShapedType::kDynamic});
113
114 SmallVector<int64_t> intermediateShape;
115 unsigned currLhsDim = 0, currRhsDim = 0;
116 while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
117 int64_t rhsSize = rhsShape[currRhsDim];
118 int64_t lhsSize = lhsShape[currLhsDim];
119 while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
120 currRhsDim < rhsShape.size()) {
121 if (lhsSize < rhsSize) {
122 currLhsDim++;
123 if (currLhsDim < lhsShape.size()) {
124 lhsSize *= lhsShape[currLhsDim];
125 }
126 } else {
127 currRhsDim++;
128 if (currRhsDim < rhsShape.size()) {
129 rhsSize *= rhsShape[currRhsDim];
130 }
131 }
132 }
133 if (lhsSize == rhsSize) {
134 intermediateShape.push_back(Elt: lhsSize);
135 }
136 currRhsDim++;
137 currLhsDim++;
138 }
139
140 // Static shapes are guaranteed to be compatible by the op verifier, so all
141 // leftover dimensions should be 1.
142 for (; currLhsDim < lhsShape.size(); currLhsDim++) {
143 assert(lhsShape[currLhsDim] == 1);
144 }
145 for (; currRhsDim < rhsShape.size(); currRhsDim++) {
146 assert(rhsShape[currRhsDim] == 1);
147 }
148
149 return lhsType.clone(intermediateShape);
150}
151
152SmallVector<ReassociationExprs>
153createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
154 auto srcShape = cast<TensorType>(Val&: srcType).getShape();
155 auto dstShape = cast<TensorType>(Val&: dstType).getShape();
156
157 if (srcShape.empty() || dstShape.empty())
158 return {};
159
160 if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
161 assert(dstShape.size() == 1);
162 SmallVector<AffineExpr, 2> exprs;
163 for (auto i : llvm::seq<int64_t>(Size: srcShape.size()))
164 exprs.push_back(Elt: builder.getAffineDimExpr(position: i));
165 return {exprs};
166 }
167
168 SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
169 unsigned currSrcDim = 0, currDstDim = 0;
170 while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
171 int64_t dstSize = dstShape[currDstDim];
172 int64_t srcSize = srcShape[currSrcDim];
173 while (srcSize < dstSize && currSrcDim < srcShape.size()) {
174 reassociationMap[currDstDim].push_back(
175 Elt: builder.getAffineDimExpr(position: currSrcDim++));
176 srcSize *= srcShape[currSrcDim];
177 }
178 if (srcSize == dstSize) {
179 reassociationMap[currDstDim].push_back(
180 Elt: builder.getAffineDimExpr(position: currSrcDim++));
181 // If the next dim in collapsedShape is not 1, treat subsequent dims in
182 // expandedShape which are 1 to be collapsed.
183 if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
184 while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
185 reassociationMap[currDstDim].push_back(
186 Elt: builder.getAffineDimExpr(position: currSrcDim++));
187 }
188 }
189 }
190 currDstDim++;
191 }
192
193 // If the source and target shapes are compatible, both iterators must have
194 // reached the end. This condition is guaranteed by the op verifier for
195 // static shapes.
196 assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
197 return reassociationMap;
198}
199
200// Create a tensor.collapse_shape op that reshapes the input into the given
201// result type.
202Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
203 Value input) {
204 auto reassociationMap =
205 createReassociationMapForCollapse(builder, srcType: input.getType(), dstType: resultType);
206 return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
207 reassociationMap);
208}
209
210// Create a tensor.expand_shape op that reshapes the input into the given result
211// type.
212Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
213 Value input) {
214 auto reassociationMap =
215 createReassociationMapForCollapse(builder, srcType: resultType, dstType: input.getType());
216 return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
217 reassociationMap);
218}
219
220class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
221public:
222 using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
223
224 LogicalResult
225 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
226 ConversionPatternRewriter &rewriter) const final {
227 auto loc = reshape.getLoc();
228 auto resultType = reshape.getResult().getType();
229 auto input = reshape.getInput1();
230 auto newShape = reshape.getNewShape();
231
232 // Infer all intermediate types
233 auto inputType = inferReshapeInputType(input, newShape);
234 auto expandedType = inferReshapeExpandedType(inputType, newShape);
235 auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
236
237 // Cast input if needed
238 auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
239
240 // Emit collaspe-expand pair
241 auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
242 auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
243
244 // Cast to final result type if needed
245 auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
246 rewriter.replaceOp(reshape, result);
247 return success();
248 }
249};
250
251class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
252public:
253 using OpConversionPattern<tosa::SliceOp>::OpConversionPattern;
254
255 LogicalResult
256 matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
257 ConversionPatternRewriter &rewriter) const final {
258 Location loc = sliceOp.getLoc();
259 Value input = adaptor.getInput();
260 ShapedType resultType = cast<ShapedType>(sliceOp.getType());
261 if (llvm::isa<UnrankedTensorType>(resultType))
262 return failure();
263 SmallVector<int64_t> strides, sizes;
264 ArrayRef<int64_t> starts = sliceOp.getStart();
265 strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
266
267 SmallVector<Value> dynSizes;
268 for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
269 int64_t size = i.value();
270 size_t index = i.index();
271 sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
272 if (!ShapedType::isDynamic(sizes.back()))
273 continue;
274
275 auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
276 auto offset = rewriter.create<arith::ConstantOp>(
277 loc, rewriter.getIndexAttr(starts[index]));
278 dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
279 }
280
281 auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
282 sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
283 ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
284 rewriter.getDenseI64ArrayAttr(sizes),
285 rewriter.getDenseI64ArrayAttr(strides));
286
287 rewriter.replaceOp(sliceOp, newSliceOp.getResult());
288 return success();
289 }
290};
291
292class PadConverter : public OpRewritePattern<tosa::PadOp> {
293public:
294 using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
295
296 LogicalResult matchAndRewrite(tosa::PadOp padOp,
297 PatternRewriter &rewriter) const final {
298 auto loc = padOp.getLoc();
299 auto input = padOp.getInput1();
300 auto padding = padOp.getPadding();
301
302 ShapedType inputTy = cast<ShapedType>(input.getType());
303 Type elementTy = inputTy.getElementType();
304 int64_t rank = inputTy.getRank();
305
306 // Setup the default constantAttr.
307
308 Value padConstant;
309
310 if (padOp.getPadConst()) {
311 padConstant = rewriter.createOrFold<tensor::ExtractOp>(
312 loc, padOp.getPadConst(), ValueRange({}));
313 } else {
314 TypedAttr constantAttr;
315 if (isa<FloatType>(Val: elementTy)) {
316 constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
317 } else if (isa<IntegerType>(Val: elementTy) && !padOp.getQuantizationInfo()) {
318 constantAttr = rewriter.getIntegerAttr(elementTy, 0);
319 } else if (isa<IntegerType>(Val: elementTy) && padOp.getQuantizationInfo()) {
320 int64_t value = padOp.getQuantizationInfo()->getInputZp();
321 constantAttr = rewriter.getIntegerAttr(elementTy, value);
322 }
323 if (constantAttr)
324 padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
325 }
326
327 if (!padConstant) {
328 return rewriter.notifyMatchFailure(
329 padOp, "tosa.pad was unable to determine the pad constant value.");
330 }
331
332 Value lowIndex =
333 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
334 Value highIndex =
335 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
336
337 SmallVector<OpFoldResult, 3> lowValues;
338 SmallVector<OpFoldResult, 3> highValues;
339
340 lowValues.reserve(N: rank);
341 highValues.reserve(N: rank);
342
343 for (int i = 0; i < rank; i++) {
344 Value inputIndex = rewriter.create<arith::ConstantIndexOp>(loc, i);
345 Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
346 loc, padding, ValueRange({inputIndex, lowIndex}));
347 Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
348 loc, padding, ValueRange({inputIndex, highIndex}));
349
350 lowVal = rewriter.createOrFold<arith::IndexCastOp>(
351 loc, rewriter.getIndexType(), lowVal);
352 highVal = rewriter.createOrFold<arith::IndexCastOp>(
353 loc, rewriter.getIndexType(), highVal);
354
355 lowValues.push_back(Elt: lowVal);
356 highValues.push_back(Elt: highVal);
357 }
358
359 auto newPadOp = rewriter.create<tensor::PadOp>(
360 loc, padOp.getType(), input, lowValues, highValues, padConstant);
361
362 rewriter.replaceOp(padOp, newPadOp.getResult());
363 return success();
364 }
365};
366
367struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
368 using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
369
370 LogicalResult
371 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
372 ConversionPatternRewriter &rewriter) const override {
373 auto resultType = dyn_cast<RankedTensorType>(op.getType());
374
375 Location loc = op.getLoc();
376 int axis = op.getAxis();
377 Value axisValue =
378 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(axis));
379 int64_t rank = resultType.getRank();
380
381 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
382 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
383 SmallVector<OpFoldResult> sizes =
384 tensor::getMixedSizes(builder&: rewriter, loc: op.getLoc(), value: adaptor.getOperands()[0]);
385
386 // Pre-compute the offsets along the axis dimension.
387 // The axisOffsets will be of size rank + 1, where the last value
388 // will hold the total size of the tensor along the 'axis' dimension.
389 SmallVector<OpFoldResult> axisOffsets;
390 axisOffsets.push_back(rewriter.getIndexAttr(0));
391 axisOffsets.push_back(Elt: sizes[axis]);
392
393 for (auto arg : adaptor.getOperands().drop_front()) {
394 auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
395 auto currentOffset =
396 getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
397 auto total =
398 rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
399 axisOffsets.push_back(getAsOpFoldResult(total));
400 }
401 sizes[axis] = axisOffsets.back();
402
403 // Compute the dynamic sizes of the tensor.empty operation.
404 // This is based off of the specified result type of the tosa.concat
405 // operation, since we don't want to change the result type of the operation
406 // during the conversion.
407 SmallVector<Value> dynDims;
408 for (int64_t i = 0; i < rank; ++i) {
409 if (resultType.isDynamicDim(i)) {
410 dynDims.push_back(
411 Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: sizes[i]));
412 }
413 }
414
415 Value result = rewriter.create<tensor::EmptyOp>(
416 loc, resultType.getShape(), resultType.getElementType(), dynDims);
417
418 for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
419 auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
420 offsets[axis] = offset;
421 result = rewriter.createOrFold<tensor::InsertSliceOp>(
422 loc, arg, result, offsets, sizes, strides);
423 }
424 rewriter.replaceOp(op, result);
425 return success();
426 }
427};
428
429} // namespace
430
431void mlir::tosa::populateTosaToTensorConversionPatterns(
432 RewritePatternSet *patterns) {
433 patterns->add<
434 ConcatConverter,
435 PadConverter,
436 ReshapeConverter,
437 SliceConverter
438 >(arg: patterns->getContext());
439}
440

source code of mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp