1 | //===- ConversionUtils.cpp ------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Utility functions for TOSA lowering |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
14 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::tosa; |
18 | |
19 | SmallVector<utils::IteratorType> |
20 | mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) { |
21 | return SmallVector<utils::IteratorType>(nParallelLoops, |
22 | utils::IteratorType::parallel); |
23 | } |
24 | |
25 | SmallVector<Value> |
26 | mlir::tosa::condenseValues(const SmallVector<Value> &values) { |
27 | SmallVector<Value> condensedValues; |
28 | for (auto value : values) |
29 | if (value) |
30 | condensedValues.push_back(Elt: value); |
31 | return condensedValues; |
32 | } |
33 | |
34 | Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min, |
35 | Value max, OpBuilder &rewriter) { |
36 | Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max); |
37 | return rewriter.create<arith::MaximumFOp>(loc, minValue, min); |
38 | } |
39 | |
40 | Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max, |
41 | OpBuilder &rewriter) { |
42 | auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg); |
43 | return rewriter.create<arith::MinSIOp>(loc, max, minOrArg); |
44 | } |
45 | |
46 | bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) { |
47 | uint64_t bitwidth = ty.getIntOrFloatBitWidth(); |
48 | if (ty.getSignedness() == IntegerType::Unsigned) { |
49 | uint64_t uvalue = value; |
50 | APInt intMin = APInt::getMinValue(numBits: bitwidth); |
51 | APInt intMax = APInt::getMaxValue(numBits: bitwidth); |
52 | return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue(); |
53 | } |
54 | |
55 | APInt intMin = APInt::getSignedMinValue(numBits: bitwidth); |
56 | APInt intMax = APInt::getSignedMaxValue(numBits: bitwidth); |
57 | return value >= intMin.getSExtValue() && value <= intMax.getSExtValue(); |
58 | } |
59 | |
60 | namespace { |
61 | // Given two tensors of high and low ranks, derive the output shape |
62 | // to reshape the lower rank to. |
63 | // Examples: |
64 | // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. |
65 | // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. |
66 | // If lower=[a], higher=[a, a], [a] reshaped into [1, a]. |
67 | // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. |
68 | // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. |
69 | LogicalResult |
70 | computeReshapeOutput(ArrayRef<int64_t> higherRankShape, |
71 | ArrayRef<int64_t> lowerRankShape, |
72 | SmallVectorImpl<int64_t> &reshapeOutputShape) { |
73 | // Initialize new shapes with [1] * higherRank. |
74 | int64_t higherRank = higherRankShape.size(); |
75 | int64_t lowerRank = lowerRankShape.size(); |
76 | |
77 | reshapeOutputShape.assign(NumElts: higherRank, Elt: 1); |
78 | |
79 | int64_t higherRankDim; |
80 | int64_t lowerRankDim; |
81 | |
82 | for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; |
83 | i--, j--) { |
84 | higherRankDim = higherRankShape[i]; |
85 | lowerRankDim = lowerRankShape[j]; |
86 | |
87 | if (lowerRankDim == 1 && higherRankDim > 1) |
88 | reshapeOutputShape[i] = 1; |
89 | else if ((lowerRankDim > 1 && higherRankDim == 1) || |
90 | (lowerRankDim == higherRankDim)) |
91 | reshapeOutputShape[i] = lowerRankDim; |
92 | else if (higherRankDim != lowerRankDim) |
93 | return failure(); |
94 | } |
95 | return success(); |
96 | } |
97 | } // namespace |
98 | |
99 | LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc, |
100 | Value &input1, Value &input2) { |
101 | auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType()); |
102 | auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType()); |
103 | |
104 | if (!input1Ty || !input2Ty) { |
105 | return failure(); |
106 | } |
107 | |
108 | int64_t input1Rank = input1Ty.getRank(); |
109 | int64_t input2Rank = input2Ty.getRank(); |
110 | |
111 | if (input1Rank == input2Rank) |
112 | return success(); |
113 | |
114 | Value higherTensorValue, lowerTensorValue; |
115 | if (input1Rank > input2Rank) { |
116 | higherTensorValue = input1; |
117 | lowerTensorValue = input2; |
118 | } else { |
119 | higherTensorValue = input2; |
120 | lowerTensorValue = input1; |
121 | } |
122 | |
123 | ArrayRef<int64_t> higherRankShape = |
124 | llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape(); |
125 | ArrayRef<int64_t> lowerRankShape = |
126 | llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape(); |
127 | |
128 | SmallVector<int64_t, 4> reshapeOutputShape; |
129 | |
130 | if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape) |
131 | .failed()) |
132 | return failure(); |
133 | |
134 | auto reshapeInputType = |
135 | llvm::cast<RankedTensorType>(lowerTensorValue.getType()); |
136 | auto reshapeOutputType = RankedTensorType::get( |
137 | ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType()); |
138 | |
139 | auto reshapeLower = rewriter.create<tosa::ReshapeOp>( |
140 | loc, reshapeOutputType, lowerTensorValue, |
141 | rewriter.getDenseI64ArrayAttr(reshapeOutputShape)); |
142 | |
143 | if (input1Rank > input2Rank) { |
144 | input1 = higherTensorValue; |
145 | input2 = reshapeLower.getResult(); |
146 | } else { |
147 | input1 = reshapeLower.getResult(); |
148 | input2 = higherTensorValue; |
149 | } |
150 | |
151 | return success(); |
152 | } |
153 | |