1 | //===- ConversionUtils.cpp ------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Utility functions for TOSA lowering |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" |
14 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
15 | |
16 | using namespace mlir; |
17 | using namespace mlir::tosa; |
18 | |
19 | SmallVector<utils::IteratorType> |
20 | mlir::tosa::getNParallelLoopsAttrs(unsigned nParallelLoops) { |
21 | return SmallVector<utils::IteratorType>(nParallelLoops, |
22 | utils::IteratorType::parallel); |
23 | } |
24 | |
25 | SmallVector<Value> |
26 | mlir::tosa::condenseValues(const SmallVector<Value> &values) { |
27 | SmallVector<Value> condensedValues; |
28 | for (auto value : values) |
29 | if (value) |
30 | condensedValues.push_back(Elt: value); |
31 | return condensedValues; |
32 | } |
33 | |
34 | Value mlir::tosa::clampFloatHelper(Location loc, Value arg, Value min, |
35 | Value max, OpBuilder &rewriter) { |
36 | Value minValue = rewriter.create<arith::MinimumFOp>(loc, arg, max); |
37 | return rewriter.create<arith::MaximumFOp>(loc, minValue, min); |
38 | } |
39 | |
40 | Value mlir::tosa::clampIntHelper(Location loc, Value arg, Value min, Value max, |
41 | OpBuilder &rewriter, bool isUnsigned) { |
42 | if (isUnsigned) { |
43 | auto minOrArg = rewriter.create<arith::MaxUIOp>(loc, min, arg); |
44 | return rewriter.create<arith::MinUIOp>(loc, max, minOrArg); |
45 | } |
46 | auto minOrArg = rewriter.create<arith::MaxSIOp>(loc, min, arg); |
47 | return rewriter.create<arith::MinSIOp>(loc, max, minOrArg); |
48 | } |
49 | |
50 | bool mlir::tosa::validIntegerRange(IntegerType ty, int64_t value) { |
51 | uint64_t bitwidth = ty.getIntOrFloatBitWidth(); |
52 | if (ty.getSignedness() == IntegerType::Unsigned) { |
53 | uint64_t uvalue = value; |
54 | APInt intMin = APInt::getMinValue(numBits: bitwidth); |
55 | APInt intMax = APInt::getMaxValue(numBits: bitwidth); |
56 | return uvalue >= intMin.getZExtValue() && uvalue <= intMax.getZExtValue(); |
57 | } |
58 | |
59 | APInt intMin = APInt::getSignedMinValue(numBits: bitwidth); |
60 | APInt intMax = APInt::getSignedMaxValue(numBits: bitwidth); |
61 | return value >= intMin.getSExtValue() && value <= intMax.getSExtValue(); |
62 | } |
63 | |
64 | namespace { |
65 | // Given two tensors of high and low ranks, derive the output shape |
66 | // to reshape the lower rank to. |
67 | // Examples: |
68 | // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. |
69 | // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. |
70 | // If lower=[a], higher=[a, a], [a] reshaped into [1, a]. |
71 | // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. |
72 | // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. |
73 | LogicalResult |
74 | computeReshapeOutput(ArrayRef<int64_t> higherRankShape, |
75 | ArrayRef<int64_t> lowerRankShape, |
76 | SmallVectorImpl<int64_t> &reshapeOutputShape) { |
77 | // Initialize new shapes with [1] * higherRank. |
78 | int64_t higherRank = higherRankShape.size(); |
79 | int64_t lowerRank = lowerRankShape.size(); |
80 | reshapeOutputShape.assign(NumElts: higherRank, Elt: 1); |
81 | |
82 | int64_t higherRankDim; |
83 | int64_t lowerRankDim; |
84 | const int64_t rankDiff = higherRank - lowerRank; |
85 | |
86 | for (int64_t i = lowerRank - 1; i >= 0; i--) { |
87 | higherRankDim = higherRankShape[i + rankDiff]; |
88 | lowerRankDim = lowerRankShape[i]; |
89 | |
90 | if (lowerRankDim != 1 && higherRankDim != 1 && |
91 | lowerRankDim != higherRankDim) |
92 | return failure(); |
93 | |
94 | reshapeOutputShape[i + rankDiff] = lowerRankDim == 1 ? 1 : lowerRankDim; |
95 | } |
96 | return success(); |
97 | } |
98 | } // namespace |
99 | |
100 | LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc, |
101 | Value &input1, Value &input2) { |
102 | ImplicitLocOpBuilder builder(loc, rewriter); |
103 | return EqualizeRanks(builder, input1, input2); |
104 | } |
105 | |
106 | LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder, |
107 | Value &input1, Value &input2) { |
108 | auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType()); |
109 | auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType()); |
110 | |
111 | if (!input1Ty || !input2Ty) { |
112 | return failure(); |
113 | } |
114 | |
115 | int64_t input1Rank = input1Ty.getRank(); |
116 | int64_t input2Rank = input2Ty.getRank(); |
117 | |
118 | if (input1Rank == input2Rank) |
119 | return success(); |
120 | |
121 | Value higherTensorValue, lowerTensorValue; |
122 | if (input1Rank > input2Rank) { |
123 | higherTensorValue = input1; |
124 | lowerTensorValue = input2; |
125 | } else { |
126 | higherTensorValue = input2; |
127 | lowerTensorValue = input1; |
128 | } |
129 | |
130 | ArrayRef<int64_t> higherRankShape = |
131 | llvm::cast<RankedTensorType>(higherTensorValue.getType()).getShape(); |
132 | ArrayRef<int64_t> lowerRankShape = |
133 | llvm::cast<RankedTensorType>(lowerTensorValue.getType()).getShape(); |
134 | |
135 | SmallVector<int64_t, 4> reshapeOutputShape; |
136 | |
137 | if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape) |
138 | .failed()) |
139 | return failure(); |
140 | |
141 | auto reshapeInputType = |
142 | llvm::cast<RankedTensorType>(lowerTensorValue.getType()); |
143 | auto reshapeOutputType = RankedTensorType::get( |
144 | ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType()); |
145 | auto reshapeOutputShapeValue = getTosaConstShape(builder, shape: reshapeOutputShape); |
146 | |
147 | auto reshapeLower = builder.create<tosa::ReshapeOp>( |
148 | reshapeOutputType, lowerTensorValue, reshapeOutputShapeValue); |
149 | |
150 | if (input1Rank > input2Rank) { |
151 | input1 = higherTensorValue; |
152 | input2 = reshapeLower.getResult(); |
153 | } else { |
154 | input1 = reshapeLower.getResult(); |
155 | input2 = higherTensorValue; |
156 | } |
157 | |
158 | return success(); |
159 | } |
160 | |
161 | Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder, |
162 | llvm::ArrayRef<int64_t> shape) { |
163 | auto attr = builder.getIndexTensorAttr(values: convertFromMlirShape(shape)); |
164 | auto type = mlir::tosa::shapeType::get(builder.getContext(), shape.size()); |
165 | mlir::Operation *mlir_op = builder.create<tosa::ConstShapeOp>(type, attr); |
166 | return mlir_op->getResult(idx: 0); |
167 | } |
168 | |
169 | Value mlir::tosa::getTosaConstShape(PatternRewriter &rewriter, Location loc, |
170 | llvm::ArrayRef<int64_t> shape) { |
171 | ImplicitLocOpBuilder builder(loc, rewriter); |
172 | return getTosaConstShape(builder, shape); |
173 | } |
174 | |
175 | SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) { |
176 | return to_vector(Range: llvm::map_range(C&: shape, F: [](int64_t dim) { |
177 | return ShapedType::isDynamic(dim) ? -1 : dim; |
178 | })); |
179 | } |
180 | |
181 | bool mlir::tosa::getConstShapeValues(Operation *op, |
182 | llvm::SmallVector<int64_t> &result_shape) { |
183 | if (!op) { |
184 | return false; |
185 | } |
186 | if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) { |
187 | Attribute constOpAttr = constOp->getAttr("values" ); |
188 | DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(Val&: constOpAttr); |
189 | for (int i = 0; i < elementsAttr.size(); i++) { |
190 | int64_t val = elementsAttr.getValues<int64_t>()[i]; |
191 | result_shape.push_back(Elt: val); |
192 | } |
193 | return true; |
194 | } |
195 | // for undefined op, return false. |
196 | return false; |
197 | } |
198 | |
199 | // returns a small vector of int64_t values that attr contains |
200 | SmallVector<int64_t> |
201 | mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) { |
202 | if (attr.isSplat()) { |
203 | int64_t v = attr.getSplatValue<APInt>().getSExtValue(); |
204 | return SmallVector<int64_t>(rank, v); |
205 | } |
206 | |
207 | if (auto int_array_attr = llvm::dyn_cast<DenseIntElementsAttr>(attr)) { |
208 | SmallVector<int64_t> vec; |
209 | for (APInt val : int_array_attr.getValues<APInt>()) { |
210 | vec.push_back(val.getSExtValue()); |
211 | } |
212 | return vec; |
213 | } |
214 | return {}; |
215 | } |
216 | |