| 1 | //===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // These rewriters lower from the Tosa to the Tensor dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" |
| 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 17 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
| 18 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
| 19 | #include "mlir/IR/PatternMatch.h" |
| 20 | #include "mlir/Transforms/DialectConversion.h" |
| 21 | |
| 22 | #include <numeric> |
| 23 | |
| 24 | using namespace mlir; |
| 25 | using namespace tosa; |
| 26 | |
| 27 | namespace { |
| 28 | |
| 29 | // Infer the type to which the input of a 'tosa.reshape' op must be cast when |
| 30 | // lowered. |
| 31 | TensorType inferReshapeInputType(TypedValue<TensorType> input, |
| 32 | ArrayRef<int64_t> newShape) { |
| 33 | // No need to cast input for non-empty target shape |
| 34 | if (!newShape.empty()) |
| 35 | return input.getType(); |
| 36 | |
| 37 | // The input type must be cast into a tensor with the same rank and all static |
| 38 | // dimensions set to 1. This prevents the generation of a |
| 39 | // tensor.collapse_shape op that converts a dynamically shaped tensor into a |
| 40 | // 0D tensor. While such construct is not incorrect on its own, bufferization |
| 41 | // cannot properly handle it at the moment, so we avoid it. |
| 42 | SmallVector<int64_t> shape(input.getType().getRank(), 1); |
| 43 | return input.getType().clone(shape); |
| 44 | } |
| 45 | |
| 46 | // Infer the result type of 'tensor.expand_shape' in the collapse-expand |
| 47 | // pair emitted for a 'tosa.reshape' op. |
| 48 | TensorType inferReshapeExpandedType(TensorType inputType, |
| 49 | ArrayRef<int64_t> newShape) { |
| 50 | // Special case for 0D output tensor. Note: Watch out when using Type::clone() |
| 51 | // with just '{}', as it will invoke the incorrect overload. |
| 52 | if (newShape.empty()) |
| 53 | return inputType.clone(shape: ArrayRef<int64_t>{}); |
| 54 | |
| 55 | // Check if the input is static, and if so, get its total size |
| 56 | bool inputIsStatic = inputType.hasStaticShape(); |
| 57 | int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1; |
| 58 | |
| 59 | // Compute result shape |
| 60 | auto resultShape = |
| 61 | llvm::map_to_vector(C&: newShape, F: [&](int64_t size) -> int64_t { |
| 62 | // If this is not a placeholder, do not change it. |
| 63 | if (size >= 0) |
| 64 | return size; |
| 65 | |
| 66 | // If we do not know the total size of the tensor, keep this dimension |
| 67 | // dynamic in the result shape. |
| 68 | if (!inputIsStatic) |
| 69 | return ShapedType::kDynamic; |
| 70 | |
| 71 | // Calculate the product of all elements in 'newShape' except for the -1 |
| 72 | // placeholder, which we discard by negating the result. |
| 73 | int64_t totalSizeNoPlaceholder = -std::accumulate( |
| 74 | first: newShape.begin(), last: newShape.end(), init: 1, binary_op: std::multiplies<int64_t>()); |
| 75 | |
| 76 | // If there is a 0 component in 'newShape', resolve the placeholder as |
| 77 | // 0. |
| 78 | if (totalSizeNoPlaceholder == 0) |
| 79 | return 0; |
| 80 | |
| 81 | // Resolve the placeholder as the quotient between the total tensor size |
| 82 | // and the product of all other sizes. |
| 83 | return totalSize / totalSizeNoPlaceholder; |
| 84 | }); |
| 85 | |
| 86 | bool resultIsStatic = ShapedType::isStaticShape(dSizes: resultShape); |
| 87 | |
| 88 | // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically |
| 89 | // shaped input from being reshaped into a statically shaped result. We may |
| 90 | // simply turn the first result dimension dynamic to address this. |
| 91 | if (!inputIsStatic && resultIsStatic) |
| 92 | resultShape[0] = ShapedType::kDynamic; |
| 93 | |
| 94 | // The 'tensor.expand_shape' op also forbids a statically shaped input from |
| 95 | // being reshaped into a dynamically shaped result, but the placeholder |
| 96 | // inference algorithm above guarantees that this will never be the case. |
| 97 | assert(!inputIsStatic || resultIsStatic); |
| 98 | |
| 99 | // Create result type |
| 100 | return inputType.clone(shape: resultShape); |
| 101 | } |
| 102 | |
| 103 | // Infer the result type of 'tensor.collapse_shape' in the collapse-expand |
| 104 | // pair emitted for a 'tosa.reshape' op. |
| 105 | TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) { |
| 106 | auto lhsShape = lhsType.getShape(); |
| 107 | auto rhsShape = rhsType.getShape(); |
| 108 | |
| 109 | if (lhsShape.empty() || rhsShape.empty()) |
| 110 | return lhsType.clone(shape: ArrayRef<int64_t>{}); |
| 111 | |
| 112 | if (ShapedType::isDynamicShape(dSizes: lhsShape) || |
| 113 | ShapedType::isDynamicShape(dSizes: rhsShape)) |
| 114 | return lhsType.clone(shape: {ShapedType::kDynamic}); |
| 115 | |
| 116 | SmallVector<int64_t> intermediateShape; |
| 117 | unsigned currLhsDim = 0, currRhsDim = 0; |
| 118 | while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { |
| 119 | int64_t rhsSize = rhsShape[currRhsDim]; |
| 120 | int64_t lhsSize = lhsShape[currLhsDim]; |
| 121 | while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && |
| 122 | currRhsDim < rhsShape.size()) { |
| 123 | if (lhsSize < rhsSize) { |
| 124 | currLhsDim++; |
| 125 | if (currLhsDim < lhsShape.size()) { |
| 126 | lhsSize *= lhsShape[currLhsDim]; |
| 127 | } |
| 128 | } else { |
| 129 | currRhsDim++; |
| 130 | if (currRhsDim < rhsShape.size()) { |
| 131 | rhsSize *= rhsShape[currRhsDim]; |
| 132 | } |
| 133 | } |
| 134 | } |
| 135 | if (lhsSize == rhsSize) { |
| 136 | intermediateShape.push_back(Elt: lhsSize); |
| 137 | } |
| 138 | currRhsDim++; |
| 139 | currLhsDim++; |
| 140 | } |
| 141 | |
| 142 | // Static shapes are guaranteed to be compatible by the op verifier, so all |
| 143 | // leftover dimensions should be 1. |
| 144 | for (; currLhsDim < lhsShape.size(); currLhsDim++) { |
| 145 | assert(lhsShape[currLhsDim] == 1); |
| 146 | } |
| 147 | for (; currRhsDim < rhsShape.size(); currRhsDim++) { |
| 148 | assert(rhsShape[currRhsDim] == 1); |
| 149 | } |
| 150 | |
| 151 | return lhsType.clone(shape: intermediateShape); |
| 152 | } |
| 153 | |
| 154 | SmallVector<ReassociationExprs> |
| 155 | createReassociationMapForCollapse(OpBuilder &builder, Type srcType, |
| 156 | Type dstType) { |
| 157 | auto srcShape = cast<TensorType>(Val&: srcType).getShape(); |
| 158 | auto dstShape = cast<TensorType>(Val&: dstType).getShape(); |
| 159 | |
| 160 | if (srcShape.empty() || dstShape.empty()) |
| 161 | return {}; |
| 162 | |
| 163 | if (ShapedType::isDynamicShape(dSizes: srcShape) || |
| 164 | ShapedType::isDynamicShape(dSizes: dstShape)) { |
| 165 | assert(dstShape.size() == 1); |
| 166 | SmallVector<AffineExpr, 2> exprs; |
| 167 | for (auto i : llvm::seq<int64_t>(Size: srcShape.size())) |
| 168 | exprs.push_back(Elt: builder.getAffineDimExpr(position: i)); |
| 169 | return {exprs}; |
| 170 | } |
| 171 | |
| 172 | SmallVector<ReassociationExprs> reassociationMap(dstShape.size()); |
| 173 | unsigned currSrcDim = 0, currDstDim = 0; |
| 174 | while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { |
| 175 | int64_t dstSize = dstShape[currDstDim]; |
| 176 | int64_t srcSize = srcShape[currSrcDim]; |
| 177 | while (srcSize < dstSize && currSrcDim < srcShape.size()) { |
| 178 | reassociationMap[currDstDim].push_back( |
| 179 | Elt: builder.getAffineDimExpr(position: currSrcDim++)); |
| 180 | srcSize *= srcShape[currSrcDim]; |
| 181 | } |
| 182 | if (srcSize == dstSize) { |
| 183 | reassociationMap[currDstDim].push_back( |
| 184 | Elt: builder.getAffineDimExpr(position: currSrcDim++)); |
| 185 | // If the next dim in collapsedShape is not 1, treat subsequent dims in |
| 186 | // expandedShape which are 1 to be collapsed. |
| 187 | if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { |
| 188 | while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { |
| 189 | reassociationMap[currDstDim].push_back( |
| 190 | Elt: builder.getAffineDimExpr(position: currSrcDim++)); |
| 191 | } |
| 192 | } |
| 193 | } |
| 194 | currDstDim++; |
| 195 | } |
| 196 | |
| 197 | // If the source and target shapes are compatible, both iterators must have |
| 198 | // reached the end. This condition is guaranteed by the op verifier for |
| 199 | // static shapes. |
| 200 | assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size()); |
| 201 | return reassociationMap; |
| 202 | } |
| 203 | |
| 204 | // Create a tensor.collapse_shape op that reshapes the input into the given |
| 205 | // result type. |
| 206 | Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType, |
| 207 | Value input) { |
| 208 | auto reassociationMap = |
| 209 | createReassociationMapForCollapse(builder, srcType: input.getType(), dstType: resultType); |
| 210 | return builder.createOrFold<tensor::CollapseShapeOp>(location: loc, args&: resultType, args&: input, |
| 211 | args&: reassociationMap); |
| 212 | } |
| 213 | |
| 214 | // Create a tensor.expand_shape op that reshapes the input into the given result |
| 215 | // type. |
| 216 | Value createExpand(OpBuilder &builder, Location loc, TensorType resultType, |
| 217 | Value input) { |
| 218 | auto reassociationMap = |
| 219 | createReassociationMapForCollapse(builder, srcType: resultType, dstType: input.getType()); |
| 220 | return builder.createOrFold<tensor::ExpandShapeOp>(location: loc, args&: resultType, args&: input, |
| 221 | args&: reassociationMap); |
| 222 | } |
| 223 | |
| 224 | class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> { |
| 225 | public: |
| 226 | using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern; |
| 227 | |
| 228 | LogicalResult |
| 229 | matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, |
| 230 | ConversionPatternRewriter &rewriter) const final { |
| 231 | auto loc = reshape.getLoc(); |
| 232 | auto resultType = cast_if_present<ShapedType>( |
| 233 | Val: getTypeConverter()->convertType(t: reshape.getType())); |
| 234 | if (!resultType) { |
| 235 | return rewriter.notifyMatchFailure(arg: reshape.getLoc(), |
| 236 | msg: "could not convert result type" ); |
| 237 | } |
| 238 | auto input = dyn_cast<TypedValue<TensorType>>(Val: adaptor.getInput1()); |
| 239 | if (!input) { |
| 240 | return rewriter.notifyMatchFailure(arg: reshape.getLoc(), |
| 241 | msg: "expected input type to be tensor" ); |
| 242 | } |
| 243 | |
| 244 | llvm::SmallVector<int64_t> newShape; |
| 245 | if (!tosa::getConstShapeValues(op: reshape.getShape().getDefiningOp(), |
| 246 | result_shape&: newShape)) { |
| 247 | return failure(); |
| 248 | } |
| 249 | |
| 250 | // Infer all intermediate types |
| 251 | auto inputType = inferReshapeInputType(input, newShape); |
| 252 | auto expandedType = inferReshapeExpandedType(inputType, newShape); |
| 253 | auto collapsedType = inferReshapeCollapsedType(lhsType: inputType, rhsType: expandedType); |
| 254 | |
| 255 | // Cast input if needed |
| 256 | auto castInput = |
| 257 | rewriter.createOrFold<tensor::CastOp>(location: loc, args&: inputType, args&: input); |
| 258 | |
| 259 | // Emit collaspe-expand pair |
| 260 | auto collapsed = createCollapse(builder&: rewriter, loc, resultType: collapsedType, input: castInput); |
| 261 | auto expanded = createExpand(builder&: rewriter, loc, resultType: expandedType, input: collapsed); |
| 262 | |
| 263 | // Cast to final result type if needed |
| 264 | auto result = |
| 265 | rewriter.createOrFold<tensor::CastOp>(location: loc, args&: resultType, args&: expanded); |
| 266 | rewriter.replaceOp(op: reshape, newValues: result); |
| 267 | return success(); |
| 268 | } |
| 269 | }; |
| 270 | |
| 271 | class SliceConverter : public OpConversionPattern<tosa::SliceOp> { |
| 272 | public: |
| 273 | using OpConversionPattern<tosa::SliceOp>::OpConversionPattern; |
| 274 | |
| 275 | LogicalResult |
| 276 | matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor, |
| 277 | ConversionPatternRewriter &rewriter) const final { |
| 278 | Location loc = sliceOp.getLoc(); |
| 279 | Value input = adaptor.getInput1(); |
| 280 | ShapedType resultType = cast<ShapedType>(Val: sliceOp.getType()); |
| 281 | if (llvm::isa<UnrankedTensorType>(Val: resultType)) |
| 282 | return failure(); |
| 283 | |
| 284 | ElementsAttr startElems; |
| 285 | ElementsAttr sizeElems; |
| 286 | |
| 287 | if (!matchPattern(value: sliceOp.getStart(), pattern: m_Constant(bind_value: &startElems))) |
| 288 | return rewriter.notifyMatchFailure( |
| 289 | arg&: sliceOp, msg: "start of slice must be a static ranked shape" ); |
| 290 | |
| 291 | if (!matchPattern(value: sliceOp.getSize(), pattern: m_Constant(bind_value: &sizeElems))) |
| 292 | return rewriter.notifyMatchFailure( |
| 293 | arg&: sliceOp, msg: "size of slice must be a static ranked shape" ); |
| 294 | |
| 295 | llvm::SmallVector<int64_t> sliceStarts = |
| 296 | llvm::to_vector(Range: startElems.getValues<int64_t>()); |
| 297 | llvm::SmallVector<int64_t> sliceSizes = |
| 298 | llvm::to_vector(Range: sizeElems.getValues<int64_t>()); |
| 299 | |
| 300 | SmallVector<int64_t> strides, sizes; |
| 301 | strides.resize(N: cast<ShapedType>(Val: sliceOp.getType()).getRank(), NV: 1); |
| 302 | |
| 303 | SmallVector<Value> dynSizes; |
| 304 | for (const auto &i : llvm::enumerate(First&: sliceSizes)) { |
| 305 | int64_t size = i.value(); |
| 306 | size_t index = i.index(); |
| 307 | sizes.push_back(Elt: size == -1 ? ShapedType::kDynamic : size); |
| 308 | if (ShapedType::isStatic(dValue: sizes.back())) |
| 309 | continue; |
| 310 | |
| 311 | auto dim = rewriter.create<tensor::DimOp>(location: loc, args&: input, args&: index); |
| 312 | auto offset = rewriter.create<arith::ConstantOp>( |
| 313 | location: loc, args: rewriter.getIndexAttr(value: sliceStarts[index])); |
| 314 | dynSizes.push_back(Elt: rewriter.create<arith::SubIOp>(location: loc, args&: dim, args&: offset)); |
| 315 | } |
| 316 | |
| 317 | auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>( |
| 318 | location: sliceOp.getLoc(), args: sliceOp.getType(), args&: input, args: ValueRange({}), args&: dynSizes, |
| 319 | args: ValueRange({}), args: rewriter.getDenseI64ArrayAttr(values: sliceStarts), |
| 320 | args: rewriter.getDenseI64ArrayAttr(values: sizes), |
| 321 | args: rewriter.getDenseI64ArrayAttr(values: strides)); |
| 322 | |
| 323 | // Remove const_shape ops when it no longer has use point. |
| 324 | Operation *startConstShape = sliceOp.getStart().getDefiningOp(); |
| 325 | if (startConstShape->getResult(idx: 0).hasOneUse()) |
| 326 | rewriter.eraseOp(op: startConstShape); |
| 327 | |
| 328 | Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); |
| 329 | if (sizeConstShape->getResult(idx: 0).hasOneUse()) |
| 330 | rewriter.eraseOp(op: sizeConstShape); |
| 331 | |
| 332 | rewriter.replaceOp(op: sliceOp, newValues: newSliceOp.getResult()); |
| 333 | return success(); |
| 334 | } |
| 335 | }; |
| 336 | |
| 337 | class PadConverter : public OpConversionPattern<tosa::PadOp> { |
| 338 | public: |
| 339 | using OpConversionPattern::OpConversionPattern; |
| 340 | |
| 341 | LogicalResult |
| 342 | matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor, |
| 343 | ConversionPatternRewriter &rewriter) const final { |
| 344 | auto loc = padOp.getLoc(); |
| 345 | auto input = padOp.getInput1(); |
| 346 | |
| 347 | ElementsAttr paddingElems; |
| 348 | if (!matchPattern(value: padOp.getPadding(), pattern: m_Constant(bind_value: &paddingElems))) { |
| 349 | return rewriter.notifyMatchFailure( |
| 350 | arg&: padOp, msg: "padding must be a static shape value" ); |
| 351 | } |
| 352 | llvm::SmallVector<int64_t> paddingVals; |
| 353 | for (auto idx : paddingElems.getValues<IntegerAttr>()) { |
| 354 | paddingVals.push_back(Elt: static_cast<int64_t>(idx.getInt())); |
| 355 | } |
| 356 | |
| 357 | ShapedType inputTy = cast<ShapedType>(Val: input.getType()); |
| 358 | int64_t rank = inputTy.getRank(); |
| 359 | |
| 360 | // Setup the default constantAttr. |
| 361 | |
| 362 | Value padConstant = rewriter.createOrFold<tensor::ExtractOp>( |
| 363 | location: loc, args: padOp.getPadConst(), |
| 364 | args: ValueRange({rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0)})); |
| 365 | |
| 366 | if (!padConstant) { |
| 367 | return rewriter.notifyMatchFailure( |
| 368 | arg&: padOp, msg: "tosa.pad was unable to determine the pad constant value." ); |
| 369 | } |
| 370 | |
| 371 | SmallVector<OpFoldResult, 3> lowValues; |
| 372 | SmallVector<OpFoldResult, 3> highValues; |
| 373 | |
| 374 | lowValues.reserve(N: rank); |
| 375 | highValues.reserve(N: rank); |
| 376 | |
| 377 | for (int i = 0; i < rank; i++) { |
| 378 | Value lowVal = rewriter.create<arith::ConstantOp>( |
| 379 | location: loc, args: rewriter.getIndexAttr(value: paddingVals[2 * i])); |
| 380 | Value highVal = rewriter.create<arith::ConstantOp>( |
| 381 | location: loc, args: rewriter.getIndexAttr(value: paddingVals[2 * i + 1])); |
| 382 | lowValues.push_back(Elt: lowVal); |
| 383 | highValues.push_back(Elt: highVal); |
| 384 | } |
| 385 | |
| 386 | auto newPadOp = rewriter.create<tensor::PadOp>( |
| 387 | location: loc, args: padOp.getType(), args&: input, args&: lowValues, args&: highValues, args&: padConstant); |
| 388 | |
| 389 | rewriter.replaceOp(op: padOp, newValues: newPadOp.getResult()); |
| 390 | return success(); |
| 391 | } |
| 392 | }; |
| 393 | |
| 394 | struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> { |
| 395 | using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern; |
| 396 | |
| 397 | LogicalResult |
| 398 | matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, |
| 399 | ConversionPatternRewriter &rewriter) const override { |
| 400 | auto resultType = dyn_cast<RankedTensorType>(Val: op.getType()); |
| 401 | |
| 402 | Location loc = op.getLoc(); |
| 403 | int axis = op.getAxis(); |
| 404 | Value axisValue = |
| 405 | rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getIndexAttr(value: axis)); |
| 406 | int64_t rank = resultType.getRank(); |
| 407 | |
| 408 | SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(value: 1)); |
| 409 | SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(value: 0)); |
| 410 | SmallVector<OpFoldResult> sizes = |
| 411 | tensor::getMixedSizes(builder&: rewriter, loc: op.getLoc(), value: adaptor.getOperands()[0]); |
| 412 | |
| 413 | // Pre-compute the offsets along the axis dimension. |
| 414 | // The axisOffsets will be of size rank + 1, where the last value |
| 415 | // will hold the total size of the tensor along the 'axis' dimension. |
| 416 | SmallVector<OpFoldResult> axisOffsets; |
| 417 | axisOffsets.push_back(Elt: rewriter.getIndexAttr(value: 0)); |
| 418 | axisOffsets.push_back(Elt: sizes[axis]); |
| 419 | |
| 420 | for (auto arg : adaptor.getOperands().drop_front()) { |
| 421 | auto size = rewriter.createOrFold<tensor::DimOp>(location: loc, args&: arg, args&: axisValue); |
| 422 | auto currentOffset = |
| 423 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: axisOffsets.back()); |
| 424 | auto total = |
| 425 | rewriter.createOrFold<arith::AddIOp>(location: loc, args&: currentOffset, args&: size); |
| 426 | axisOffsets.push_back(Elt: getAsOpFoldResult(val: total)); |
| 427 | } |
| 428 | sizes[axis] = axisOffsets.back(); |
| 429 | |
| 430 | // Compute the dynamic sizes of the tensor.empty operation. |
| 431 | // This is based off of the specified result type of the tosa.concat |
| 432 | // operation, since we don't want to change the result type of the operation |
| 433 | // during the conversion. |
| 434 | SmallVector<Value> dynDims; |
| 435 | for (int64_t i = 0; i < rank; ++i) { |
| 436 | if (resultType.isDynamicDim(idx: i)) { |
| 437 | dynDims.push_back( |
| 438 | Elt: getValueOrCreateConstantIndexOp(b&: rewriter, loc, ofr: sizes[i])); |
| 439 | } |
| 440 | } |
| 441 | |
| 442 | Value result = rewriter.create<tensor::EmptyOp>( |
| 443 | location: loc, args: resultType.getShape(), args: resultType.getElementType(), args&: dynDims); |
| 444 | |
| 445 | for (auto [arg, offset] : llvm::zip(t: adaptor.getOperands(), u&: axisOffsets)) { |
| 446 | auto sizes = tensor::getMixedSizes(builder&: rewriter, loc: op.getLoc(), value: arg); |
| 447 | offsets[axis] = offset; |
| 448 | result = rewriter.createOrFold<tensor::InsertSliceOp>( |
| 449 | location: loc, args&: arg, args&: result, args&: offsets, args&: sizes, args&: strides); |
| 450 | } |
| 451 | rewriter.replaceOp(op, newValues: result); |
| 452 | return success(); |
| 453 | } |
| 454 | }; |
| 455 | |
| 456 | } // namespace |
| 457 | |
| 458 | void mlir::tosa::populateTosaToTensorConversionPatterns( |
| 459 | const TypeConverter &converter, RewritePatternSet *patterns) { |
| 460 | patterns |
| 461 | ->add<ConcatConverter, PadConverter, ReshapeConverter, SliceConverter>( |
| 462 | arg: converter, args: patterns->getContext()); |
| 463 | } |
| 464 | |