| 1 | //===- TosaConvertIntegerTypeToSignless.cpp |
| 2 | //-------------------------------------------===// |
| 3 | // |
| 4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 5 | // See https://llvm.org/LICENSE.txt for license information. |
| 6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 7 | // |
| 8 | //===-------------------------------------------------------------------------------===// |
| 9 | |
| 10 | // ----------- |
| 11 | // Motivation: |
| 12 | // ----------- |
| 13 | |
| 14 | // The TOSA specification uses a signless type system, which means that |
| 15 | // information about signedness must be encapsulated by the operations |
| 16 | // themselves. For example, tosa.rescale provides the attributes |
| 17 | // `input_unsigned` and `output_unsigned` to indicate whether the input/output |
| 18 | // should be interpreted as unsigned or signed. |
| 19 | |
| 20 | // The TOSA dialect, on the other hand, allows the use of signed or unsigned |
| 21 | // types in addition to signless. As such, when converting from TOSA dialect to |
| 22 | // other formats, we need to ensure that we conform to the TOSA specification. |
| 23 | |
| 24 | // --------- |
| 25 | // Overview: |
| 26 | // --------- |
| 27 | |
| 28 | // This pass converts signed or unsigned integer types to signless. It currently |
| 29 | // does this greedily for all operators and can also change the signature of the |
| 30 | // function. Should the signature of the entrypoint function change, it will be |
| 31 | // the responsibility of the user to carry signedness information of the inputs |
| 32 | // and outputs independently. |
| 33 | |
| 34 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 35 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
| 36 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
| 37 | #include "mlir/Dialect/Tosa/Transforms/Passes.h" |
| 38 | #include "mlir/Transforms/DialectConversion.h" |
| 39 | |
| 40 | namespace mlir { |
| 41 | namespace tosa { |
| 42 | |
| 43 | #define GEN_PASS_DEF_TOSACONVERTINTEGERTYPETOSIGNLESS |
| 44 | #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" |
| 45 | |
| 46 | namespace { |
| 47 | class ToSignlessTensorTypeConverter : public TypeConverter { |
| 48 | static Type convertType(Type type) { |
| 49 | const auto tensorType = dyn_cast<TensorType>(Val&: type); |
| 50 | if (!tensorType) |
| 51 | return type; |
| 52 | |
| 53 | const auto intType = dyn_cast<IntegerType>(Val: tensorType.getElementType()); |
| 54 | if (!intType || |
| 55 | intType.getSignedness() == IntegerType::SignednessSemantics::Signless) |
| 56 | return type; |
| 57 | |
| 58 | const auto signlessType = IntegerType::get( |
| 59 | context: intType.getContext(), width: intType.getWidth(), signedness: IntegerType::Signless); |
| 60 | return tensorType.cloneWith(shape: std::nullopt, elementType: signlessType); |
| 61 | } |
| 62 | |
| 63 | public: |
| 64 | explicit ToSignlessTensorTypeConverter() { addConversion(callback&: convertType); } |
| 65 | }; |
| 66 | |
| 67 | class ConvertGenericOpWithIntegerTensorType : public ConversionPattern { |
| 68 | public: |
| 69 | ConvertGenericOpWithIntegerTensorType(TypeConverter &typeConverter, |
| 70 | MLIRContext *context) |
| 71 | : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {} |
| 72 | |
| 73 | LogicalResult |
| 74 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
| 75 | ConversionPatternRewriter &rewriter) const final { |
| 76 | // Typically TOSA operators have a single result, but some have an |
| 77 | // arbitrary number. 4 seems like a good balance as an optimization |
| 78 | // hint for storing result types. |
| 79 | constexpr unsigned int numResults = 4; |
| 80 | |
| 81 | // Convert integer types to signless |
| 82 | SmallVector<Type, numResults> resultTypes; |
| 83 | if (failed(Result: typeConverter->convertTypes(types: op->getResultTypes(), results&: resultTypes))) |
| 84 | return failure(); |
| 85 | |
| 86 | // Create new op with replaced operands and results |
| 87 | auto *newOp = Operation::create( |
| 88 | location: op->getLoc(), name: op->getName(), resultTypes, operands, attributes: op->getAttrs(), |
| 89 | properties: op->getPropertiesStorage(), successors: op->getSuccessors(), numRegions: op->getNumRegions()); |
| 90 | |
| 91 | // Handle regions in e.g. tosa.cond_if and tosa.while_loop |
| 92 | for (auto regions : llvm::zip(t: op->getRegions(), u: newOp->getRegions())) { |
| 93 | Region &before = std::get<0>(t&: regions); |
| 94 | Region &parent = std::get<1>(t&: regions); |
| 95 | rewriter.inlineRegionBefore(region&: before, parent, before: parent.end()); |
| 96 | if (failed(Result: rewriter.convertRegionTypes(region: &parent, converter: *typeConverter))) |
| 97 | return failure(); |
| 98 | } |
| 99 | |
| 100 | // Replace with rewritten op |
| 101 | rewriter.insert(op: newOp); |
| 102 | rewriter.replaceOp(op, newValues: newOp->getResults()); |
| 103 | return success(); |
| 104 | } |
| 105 | }; |
| 106 | |
| 107 | class TosaConvertIntegerTypeToSignless |
| 108 | : public impl::TosaConvertIntegerTypeToSignlessBase< |
| 109 | TosaConvertIntegerTypeToSignless> { |
| 110 | public: |
| 111 | void runOnOperation() override { |
| 112 | MLIRContext *context = &getContext(); |
| 113 | ConversionTarget target(*context); |
| 114 | ToSignlessTensorTypeConverter typeConverter; |
| 115 | |
| 116 | target.addDynamicallyLegalOp<func::FuncOp>(callback: [&](func::FuncOp op) { |
| 117 | return typeConverter.isSignatureLegal(ty: op.getFunctionType()) && |
| 118 | typeConverter.isLegal(region: &op.getBody()); |
| 119 | }); |
| 120 | target.markUnknownOpDynamicallyLegal(fn: [&](Operation *op) { |
| 121 | return typeConverter.isLegal(range: op->getOperandTypes()) && |
| 122 | typeConverter.isLegal(range: op->getResultTypes()); |
| 123 | }); |
| 124 | |
| 125 | RewritePatternSet patterns(context); |
| 126 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( |
| 127 | patterns, converter: typeConverter); |
| 128 | patterns.add<ConvertGenericOpWithIntegerTensorType>(arg&: typeConverter, args&: context); |
| 129 | |
| 130 | if (failed( |
| 131 | Result: applyFullConversion(op: getOperation(), target, patterns: std::move(patterns)))) |
| 132 | signalPassFailure(); |
| 133 | } |
| 134 | }; |
| 135 | |
| 136 | } // namespace |
| 137 | |
| 138 | } // namespace tosa |
| 139 | } // namespace mlir |
| 140 | |