1//===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Tensor dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Tosa/IR/TosaOps.h"
18#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Transforms/DialectConversion.h"
21
22#include <numeric>
23
24using namespace mlir;
25using namespace tosa;
26
27namespace {
28
29// Infer the type to which the input of a 'tosa.reshape' op must be cast when
30// lowered.
31TensorType inferReshapeInputType(TypedValue<TensorType> input,
32 ArrayRef<int64_t> newShape) {
33 // No need to cast input for non-empty target shape
34 if (!newShape.empty())
35 return input.getType();
36
37 // The input type must be cast into a tensor with the same rank and all static
38 // dimensions set to 1. This prevents the generation of a
39 // tensor.collapse_shape op that converts a dynamically shaped tensor into a
40 // 0D tensor. While such construct is not incorrect on its own, bufferization
41 // cannot properly handle it at the moment, so we avoid it.
42 SmallVector<int64_t> shape(input.getType().getRank(), 1);
43 return input.getType().clone(shape);
44}
45
46// Infer the result type of 'tensor.expand_shape' in the collapse-expand
47// pair emitted for a 'tosa.reshape' op.
48TensorType inferReshapeExpandedType(TensorType inputType,
49 ArrayRef<int64_t> newShape) {
50 // Special case for 0D output tensor. Note: Watch out when using Type::clone()
51 // with just '{}', as it will invoke the incorrect overload.
52 if (newShape.empty())
53 return inputType.clone(shape: ArrayRef<int64_t>{});
54
55 // Check if the input is static, and if so, get its total size
56 bool inputIsStatic = inputType.hasStaticShape();
57 int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
58
59 // Compute result shape
60 auto resultShape =
61 llvm::map_to_vector(C&: newShape, F: [&](int64_t size) -> int64_t {
62 // If this is not a placeholder, do not change it.
63 if (size >= 0)
64 return size;
65
66 // If we do not know the total size of the tensor, keep this dimension
67 // dynamic in the result shape.
68 if (!inputIsStatic)
69 return ShapedType::kDynamic;
70
71 // Calculate the product of all elements in 'newShape' except for the -1
72 // placeholder, which we discard by negating the result.
73 int64_t totalSizeNoPlaceholder = -std::accumulate(
74 first: newShape.begin(), last: newShape.end(), init: 1, binary_op: std::multiplies<int64_t>());
75
76 // If there is a 0 component in 'newShape', resolve the placeholder as
77 // 0.
78 if (totalSizeNoPlaceholder == 0)
79 return 0;
80
81 // Resolve the placeholder as the quotient between the total tensor size
82 // and the product of all other sizes.
83 return totalSize / totalSizeNoPlaceholder;
84 });
85
86 bool resultIsStatic = ShapedType::isStaticShape(dSizes: resultShape);
87
88 // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
89 // shaped input from being reshaped into a statically shaped result. We may
90 // simply turn the first result dimension dynamic to address this.
91 if (!inputIsStatic && resultIsStatic)
92 resultShape[0] = ShapedType::kDynamic;
93
94 // The 'tensor.expand_shape' op also forbids a statically shaped input from
95 // being reshaped into a dynamically shaped result, but the placeholder
96 // inference algorithm above guarantees that this will never be the case.
97 assert(!inputIsStatic || resultIsStatic);
98
99 // Create result type
100 return inputType.clone(shape: resultShape);
101}
102
103// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
104// pair emitted for a 'tosa.reshape' op.
105TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
106 auto lhsShape = lhsType.getShape();
107 auto rhsShape = rhsType.getShape();
108
109 if (lhsShape.empty() || rhsShape.empty())
110 return lhsType.clone(shape: ArrayRef<int64_t>{});
111
112 if (ShapedType::isDynamicShape(dSizes: lhsShape) ||
113 ShapedType::isDynamicShape(dSizes: rhsShape))
114 return lhsType.clone(shape: {ShapedType::kDynamic});
115
116 SmallVector<int64_t> intermediateShape;
117 unsigned currLhsDim = 0, currRhsDim = 0;
118 while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
119 int64_t rhsSize = rhsShape[currRhsDim];
120 int64_t lhsSize = lhsShape[currLhsDim];
121 while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
122 currRhsDim < rhsShape.size()) {
123 if (lhsSize < rhsSize) {
124 currLhsDim++;
125 if (currLhsDim < lhsShape.size()) {
126 lhsSize *= lhsShape[currLhsDim];
127 }
128 } else {
129 currRhsDim++;
130 if (currRhsDim < rhsShape.size()) {
131 rhsSize *= rhsShape[currRhsDim];
132 }
133 }
134 }
135 if (lhsSize == rhsSize) {
136 intermediateShape.push_back(Elt: lhsSize);
137 }
138 currRhsDim++;
139 currLhsDim++;
140 }
141
142 // Static shapes are guaranteed to be compatible by the op verifier, so all
143 // leftover dimensions should be 1.
144 for (; currLhsDim < lhsShape.size(); currLhsDim++) {
145 assert(lhsShape[currLhsDim] == 1);
146 }
147 for (; currRhsDim < rhsShape.size(); currRhsDim++) {
148 assert(rhsShape[currRhsDim] == 1);
149 }
150
151 return lhsType.clone(shape: intermediateShape);
152}
153
154SmallVector<ReassociationExprs>
155createReassociationMapForCollapse(OpBuilder &builder, Type srcType,
156 Type dstType) {
157 auto srcShape = cast<TensorType>(Val&: srcType).getShape();
158 auto dstShape = cast<TensorType>(Val&: dstType).getShape();
159
160 if (srcShape.empty() || dstShape.empty())
161 return {};
162
163 if (ShapedType::isDynamicShape(dSizes: srcShape) ||
164 ShapedType::isDynamicShape(dSizes: dstShape)) {
165 assert(dstShape.size() == 1);
166 SmallVector<AffineExpr, 2> exprs;
167 for (auto i : llvm::seq<int64_t>(Size: srcShape.size()))
168 exprs.push_back(Elt: builder.getAffineDimExpr(position: i));
169 return {exprs};
170 }
171
172 SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
173 unsigned currSrcDim = 0, currDstDim = 0;
174 while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
175 int64_t dstSize = dstShape[currDstDim];
176 int64_t srcSize = srcShape[currSrcDim];
177 while (srcSize < dstSize && currSrcDim < srcShape.size()) {
178 reassociationMap[currDstDim].push_back(
179 Elt: builder.getAffineDimExpr(position: currSrcDim++));
180 srcSize *= srcShape[currSrcDim];
181 }
182 if (srcSize == dstSize) {
183 reassociationMap[currDstDim].push_back(
184 Elt: builder.getAffineDimExpr(position: currSrcDim++));
185 // If the next dim in collapsedShape is not 1, treat subsequent dims in
186 // expandedShape which are 1 to be collapsed.
187 if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
188 while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
189 reassociationMap[currDstDim].push_back(
190 Elt: builder.getAffineDimExpr(position: currSrcDim++));
191 }
192 }
193 }
194 currDstDim++;
195 }
196
197 // If the source and target shapes are compatible, both iterators must have
198 // reached the end. This condition is guaranteed by the op verifier for
199 // static shapes.
200 assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
201 return reassociationMap;
202}
203
204// Create a tensor.collapse_shape op that reshapes the input into the given
205// result type.
206Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
207 Value input) {
208 auto reassociationMap =
209 createReassociationMapForCollapse(builder, srcType: input.getType(), dstType: resultType);
210 return builder.createOrFold<tensor::CollapseShapeOp>(location: loc, args&: resultType, args&: input,
211 args&: reassociationMap);
212}
213
214// Create a tensor.expand_shape op that reshapes the input into the given result
215// type.
216Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
217 Value input) {
218 auto reassociationMap =
219 createReassociationMapForCollapse(builder, srcType: resultType, dstType: input.getType());
220 return builder.createOrFold<tensor::ExpandShapeOp>(location: loc, args&: resultType, args&: input,
221 args&: reassociationMap);
222}
223
224class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
225public:
226 using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
227
228 LogicalResult
229 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter) const final {
231 auto loc = reshape.getLoc();
232 auto resultType = cast_if_present<ShapedType>(
233 Val: getTypeConverter()->convertType(t: reshape.getType()));
234 if (!resultType) {
235 return rewriter.notifyMatchFailure(arg: reshape.getLoc(),
236 msg: "could not convert result type");
237 }
238 auto input = dyn_cast<TypedValue<TensorType>>(Val: adaptor.getInput1());
239 if (!input) {
240 return rewriter.notifyMatchFailure(arg: reshape.getLoc(),
241 msg: "expected input type to be tensor");
242 }
243
244 llvm::SmallVector<int64_t> newShape;
245 if (!tosa::getConstShapeValues(op: reshape.getShape().getDefiningOp(),
246 result_shape&: newShape)) {
247 return failure();
248 }
249
250 // Infer all intermediate types
251 auto inputType = inferReshapeInputType(input, newShape);
252 auto expandedType = inferReshapeExpandedType(inputType, newShape);
253 auto collapsedType = inferReshapeCollapsedType(lhsType: inputType, rhsType: expandedType);
254
255 // Cast input if needed
256 auto castInput =
257 rewriter.createOrFold<tensor::CastOp>(location: loc, args&: inputType, args&: input);
258
259 // Emit collaspe-expand pair
260 auto collapsed = createCollapse(builder&: rewriter, loc, resultType: collapsedType, input: castInput);
261 auto expanded = createExpand(builder&: rewriter, loc, resultType: expandedType, input: collapsed);
262
263 // Cast to final result type if needed
264 auto result =
265 rewriter.createOrFold<tensor::CastOp>(location: loc, args&: resultType, args&: expanded);
266 rewriter.replaceOp(op: reshape, newValues: result);
267 return success();
268 }
269};
270
271class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
272public:
273 using OpConversionPattern<tosa::SliceOp>::OpConversionPattern;
274
275 LogicalResult
276 matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
277 ConversionPatternRewriter &rewriter) const final {
278 Location loc = sliceOp.getLoc();
279 Value input = adaptor.getInput1();
280 ShapedType resultType = cast<ShapedType>(Val: sliceOp.getType());
281 if (llvm::isa<UnrankedTensorType>(Val: resultType))
282 return failure();
283
284 ElementsAttr startElems;
285 ElementsAttr sizeElems;
286
287 if (!matchPattern(value: sliceOp.getStart(), pattern: m_Constant(bind_value: &startElems)))
288 return rewriter.notifyMatchFailure(
289 arg&: sliceOp, msg: "start of slice must be a static ranked shape");
290
291 if (!matchPattern(value: sliceOp.getSize(), pattern: m_Constant(bind_value: &sizeElems)))
292 return rewriter.notifyMatchFailure(
293 arg&: sliceOp, msg: "size of slice must be a static ranked shape");
294
295 llvm::SmallVector<int64_t> sliceStarts =
296 llvm::to_vector(Range: startElems.getValues<int64_t>());
297 llvm::SmallVector<int64_t> sliceSizes =
298 llvm::to_vector(Range: sizeElems.getValues<int64_t>());
299
300 SmallVector<int64_t> strides, sizes;
301 strides.resize(N: cast<ShapedType>(Val: sliceOp.getType()).getRank(), NV: 1);
302
303 SmallVector<Value> dynSizes;
304 for (const auto &i : llvm::enumerate(First&: sliceSizes)) {
305 int64_t size = i.value();
306 size_t index = i.index();
307 sizes.push_back(Elt: size == -1 ? ShapedType::kDynamic : size);
308 if (ShapedType::isStatic(dValue: sizes.back()))
309 continue;
310
311 auto dim = rewriter.create<tensor::DimOp>(location: loc, args&: input, args&: index);
312 auto offset = rewriter.create<arith::ConstantOp>(
313 location: loc, args: rewriter.getIndexAttr(value: sliceStarts[index]));
314 dynSizes.push_back(Elt: rewriter.create<arith::SubIOp>(location: loc, args&: dim, args&: offset));
315 }
316
317 auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
318 location: sliceOp.getLoc(), args: sliceOp.getType(), args&: input, args: ValueRange({}), args&: dynSizes,
319 args: ValueRange({}), args: rewriter.getDenseI64ArrayAttr(values: sliceStarts),
320 args: rewriter.getDenseI64ArrayAttr(values: sizes),
321 args: rewriter.getDenseI64ArrayAttr(values: strides));
322
323 // Remove const_shape ops when it no longer has use point.
324 Operation *startConstShape = sliceOp.getStart().getDefiningOp();
325 if (startConstShape->getResult(idx: 0).hasOneUse())
326 rewriter.eraseOp(op: startConstShape);
327
328 Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
329 if (sizeConstShape->getResult(idx: 0).hasOneUse())
330 rewriter.eraseOp(op: sizeConstShape);
331
332 rewriter.replaceOp(op: sliceOp, newValues: newSliceOp.getResult());
333 return success();
334 }
335};
336
337class PadConverter : public OpConversionPattern<tosa::PadOp> {
338public:
339 using OpConversionPattern::OpConversionPattern;
340
341 LogicalResult
342 matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor,
343 ConversionPatternRewriter &rewriter) const final {
344 auto loc = padOp.getLoc();
345 auto input = padOp.getInput1();
346
347 ElementsAttr paddingElems;
348 if (!matchPattern(value: padOp.getPadding(), pattern: m_Constant(bind_value: &paddingElems))) {
349 return rewriter.notifyMatchFailure(
350 arg&: padOp, msg: "padding must be a static shape value");
351 }
352 llvm::SmallVector<int64_t> paddingVals;
353 for (auto idx : paddingElems.getValues<IntegerAttr>()) {
354 paddingVals.push_back(Elt: static_cast<int64_t>(idx.getInt()));
355 }
356
357 ShapedType inputTy = cast<ShapedType>(Val: input.getType());
358 int64_t rank = inputTy.getRank();
359
360 // Setup the default constantAttr.
361
362 Value padConstant = rewriter.createOrFold<tensor::ExtractOp>(
363 location: loc, args: padOp.getPadConst(),
364 args: ValueRange({rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0)}));
365
366 if (!padConstant) {
367 return rewriter.notifyMatchFailure(
368 arg&: padOp, msg: "tosa.pad was unable to determine the pad constant value.");
369 }
370
371 SmallVector<OpFoldResult, 3> lowValues;
372 SmallVector<OpFoldResult, 3> highValues;
373
374 lowValues.reserve(N: rank);
375 highValues.reserve(N: rank);
376
377 for (int i = 0; i < rank; i++) {
378 Value lowVal = rewriter.create<arith::ConstantOp>(
379 location: loc, args: rewriter.getIndexAttr(value: paddingVals[2 * i]));
380 Value highVal = rewriter.create<arith::ConstantOp>(
381 location: loc, args: rewriter.getIndexAttr(value: paddingVals[2 * i + 1]));
382 lowValues.push_back(Elt: lowVal);
383 highValues.push_back(Elt: highVal);
384 }
385
386 auto newPadOp = rewriter.create<tensor::PadOp>(
387 location: loc, args: padOp.getType(), args&: input, args&: lowValues, args&: highValues, args&: padConstant);
388
389 rewriter.replaceOp(op: padOp, newValues: newPadOp.getResult());
390 return success();
391 }
392};
393
394struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
395 using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
396
397 LogicalResult
398 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
399 ConversionPatternRewriter &rewriter) const override {
400 auto resultType = dyn_cast<RankedTensorType>(Val: op.getType());
401
402 Location loc = op.getLoc();
403 int axis = op.getAxis();
404 Value axisValue =
405 rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getIndexAttr(value: axis));
406 int64_t rank = resultType.getRank();
407
408 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(value: 1));
409 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(value: 0));
410 SmallVector<OpFoldResult> sizes =
411 tensor::getMixedSizes(builder&: rewriter, loc: op.getLoc(), value: adaptor.getOperands()[0]);
412
413 // Pre-compute the offsets along the axis dimension.
414 // The axisOffsets will be of size rank + 1, where the last value
415 // will hold the total size of the tensor along the 'axis' dimension.
416 SmallVector<OpFoldResult> axisOffsets;
417 axisOffsets.push_back(Elt: rewriter.getIndexAttr(value: 0));
418 axisOffsets.push_back(Elt: sizes[axis]);
419
420 for (auto arg : adaptor.getOperands().drop_front()) {
421 auto size = rewriter.createOrFold<tensor::DimOp>(location: loc, args&: arg, args&: axisValue);
422 auto currentOffset =
423 getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: axisOffsets.back());
424 auto total =
425 rewriter.createOrFold<arith::AddIOp>(location: loc, args&: currentOffset, args&: size);
426 axisOffsets.push_back(Elt: getAsOpFoldResult(val: total));
427 }
428 sizes[axis] = axisOffsets.back();
429
430 // Compute the dynamic sizes of the tensor.empty operation.
431 // This is based off of the specified result type of the tosa.concat
432 // operation, since we don't want to change the result type of the operation
433 // during the conversion.
434 SmallVector<Value> dynDims;
435 for (int64_t i = 0; i < rank; ++i) {
436 if (resultType.isDynamicDim(idx: i)) {
437 dynDims.push_back(
438 Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: sizes[i]));
439 }
440 }
441
442 Value result = rewriter.create<tensor::EmptyOp>(
443 location: loc, args: resultType.getShape(), args: resultType.getElementType(), args&: dynDims);
444
445 for (auto [arg, offset] : llvm::zip(t: adaptor.getOperands(), u&: axisOffsets)) {
446 auto sizes = tensor::getMixedSizes(builder&: rewriter, loc: op.getLoc(), value: arg);
447 offsets[axis] = offset;
448 result = rewriter.createOrFold<tensor::InsertSliceOp>(
449 location: loc, args&: arg, args&: result, args&: offsets, args&: sizes, args&: strides);
450 }
451 rewriter.replaceOp(op, newValues: result);
452 return success();
453 }
454};
455
456} // namespace
457
458void mlir::tosa::populateTosaToTensorConversionPatterns(
459 const TypeConverter &converter, RewritePatternSet *patterns) {
460 patterns
461 ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>(
462 arg: converter, args: patterns->getContext());
463}
464

source code of mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp