1//===- TosaConvertIntegerTypeToSignless.cpp
2//-------------------------------------------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===-------------------------------------------------------------------------------===//
9
10// -----------
11// Motivation:
12// -----------
13
14// The TOSA specification uses a signless type system, which means that
15// information about signedness must be encapsulated by the operations
16// themselves. For example, tosa.rescale provides the attributes
17// `input_unsigned` and `output_unsigned` to indicate whether the input/output
18// should be interpreted as unsigned or signed.
19
20// The TOSA dialect, on the other hand, allows the use of signed or unsigned
21// types in addition to signless. As such, when converting from TOSA dialect to
22// other formats, we need to ensure that we conform to the TOSA specification.
23
24// ---------
25// Overview:
26// ---------
27
28// This pass converts signed or unsigned integer types to signless. It currently
29// does this greedily for all operators and can also change the signature of the
30// function. Should the signature of the entrypoint function change, it will be
31// the responsibility of the user to carry signedness information of the inputs
32// and outputs independently.
33
34#include "mlir/Dialect/Func/IR/FuncOps.h"
35#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
36#include "mlir/Dialect/Tosa/IR/TosaOps.h"
37#include "mlir/Dialect/Tosa/Transforms/Passes.h"
38#include "mlir/Transforms/DialectConversion.h"
39
40namespace mlir {
41namespace tosa {
42
43#define GEN_PASS_DEF_TOSACONVERTINTEGERTYPETOSIGNLESS
44#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
45
46namespace {
47class ToSignlessTensorTypeConverter : public TypeConverter {
48 static Type convertType(Type type) {
49 const auto tensorType = dyn_cast<TensorType>(Val&: type);
50 if (!tensorType)
51 return type;
52
53 const auto intType = dyn_cast<IntegerType>(Val: tensorType.getElementType());
54 if (!intType ||
55 intType.getSignedness() == IntegerType::SignednessSemantics::Signless)
56 return type;
57
58 const auto signlessType = IntegerType::get(
59 context: intType.getContext(), width: intType.getWidth(), signedness: IntegerType::Signless);
60 return tensorType.cloneWith(shape: std::nullopt, elementType: signlessType);
61 }
62
63public:
64 explicit ToSignlessTensorTypeConverter() { addConversion(callback&: convertType); }
65};
66
67class ConvertGenericOpWithIntegerTensorType : public ConversionPattern {
68public:
69 ConvertGenericOpWithIntegerTensorType(TypeConverter &typeConverter,
70 MLIRContext *context)
71 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
72
73 LogicalResult
74 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
75 ConversionPatternRewriter &rewriter) const final {
76 // Typically TOSA operators have a single result, but some have an
77 // arbitrary number. 4 seems like a good balance as an optimization
78 // hint for storing result types.
79 constexpr unsigned int numResults = 4;
80
81 // Convert integer types to signless
82 SmallVector<Type, numResults> resultTypes;
83 if (failed(Result: typeConverter->convertTypes(types: op->getResultTypes(), results&: resultTypes)))
84 return failure();
85
86 // Create new op with replaced operands and results
87 auto *newOp = Operation::create(
88 location: op->getLoc(), name: op->getName(), resultTypes, operands, attributes: op->getAttrs(),
89 properties: op->getPropertiesStorage(), successors: op->getSuccessors(), numRegions: op->getNumRegions());
90
91 // Handle regions in e.g. tosa.cond_if and tosa.while_loop
92 for (auto regions : llvm::zip(t: op->getRegions(), u: newOp->getRegions())) {
93 Region &before = std::get<0>(t&: regions);
94 Region &parent = std::get<1>(t&: regions);
95 rewriter.inlineRegionBefore(region&: before, parent, before: parent.end());
96 if (failed(Result: rewriter.convertRegionTypes(region: &parent, converter: *typeConverter)))
97 return failure();
98 }
99
100 // Replace with rewritten op
101 rewriter.insert(op: newOp);
102 rewriter.replaceOp(op, newValues: newOp->getResults());
103 return success();
104 }
105};
106
107class TosaConvertIntegerTypeToSignless
108 : public impl::TosaConvertIntegerTypeToSignlessBase<
109 TosaConvertIntegerTypeToSignless> {
110public:
111 void runOnOperation() override {
112 MLIRContext *context = &getContext();
113 ConversionTarget target(*context);
114 ToSignlessTensorTypeConverter typeConverter;
115
116 target.addDynamicallyLegalOp<func::FuncOp>(callback: [&](func::FuncOp op) {
117 return typeConverter.isSignatureLegal(ty: op.getFunctionType()) &&
118 typeConverter.isLegal(region: &op.getBody());
119 });
120 target.markUnknownOpDynamicallyLegal(fn: [&](Operation *op) {
121 return typeConverter.isLegal(range: op->getOperandTypes()) &&
122 typeConverter.isLegal(range: op->getResultTypes());
123 });
124
125 RewritePatternSet patterns(context);
126 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
127 patterns, converter: typeConverter);
128 patterns.add<ConvertGenericOpWithIntegerTensorType>(arg&: typeConverter, args&: context);
129
130 if (failed(
131 Result: applyFullConversion(op: getOperation(), target, patterns: std::move(patterns))))
132 signalPassFailure();
133 }
134};
135
136} // namespace
137
138} // namespace tosa
139} // namespace mlir
140

source code of mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp