1 | //===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // These rewriters lower from the Tosa to the Arith dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/TosaToArith/TosaToArith.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
16 | #include "mlir/IR/PatternMatch.h" |
17 | #include "mlir/IR/TypeUtilities.h" |
18 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace tosa; |
22 | |
23 | namespace { |
24 | |
25 | class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> { |
26 | public: |
27 | using OpRewritePattern<tosa::ConstOp>::OpRewritePattern; |
28 | |
29 | LogicalResult matchAndRewrite(tosa::ConstOp op, |
30 | PatternRewriter &rewriter) const final { |
31 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValues()); |
32 | return success(); |
33 | } |
34 | }; |
35 | |
36 | Type matchContainerType(Type element, Type container) { |
37 | if (auto shapedTy = dyn_cast<ShapedType>(container)) |
38 | return shapedTy.clone(element); |
39 | |
40 | return element; |
41 | } |
42 | |
43 | TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { |
44 | if (auto shapedTy = dyn_cast<ShapedType>(type)) { |
45 | Type eTy = shapedTy.getElementType(); |
46 | APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true); |
47 | return DenseIntElementsAttr::get(shapedTy, valueInt); |
48 | } |
49 | |
50 | return rewriter.getIntegerAttr(type, value); |
51 | } |
52 | |
53 | Value getConstantValue(Location loc, Type type, int64_t value, |
54 | PatternRewriter &rewriter) { |
55 | return rewriter.create<arith::ConstantOp>( |
56 | loc, getConstantAttr(type, value, rewriter)); |
57 | } |
58 | |
59 | // This converts the TOSA ApplyScale operator to a set of arithmetic ops, |
60 | // using 64-bit operations to perform the necessary multiply, bias, and shift. |
61 | class ApplyScaleGenericOpConverter |
62 | : public OpRewritePattern<tosa::ApplyScaleOp> { |
63 | public: |
64 | using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern; |
65 | |
66 | LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, |
67 | PatternRewriter &rewriter) const final { |
68 | StringRef roundingMode = op.getRoundingMode(); |
69 | if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND" ) { |
70 | return failure(); |
71 | } |
72 | |
73 | Location loc = op.getLoc(); |
74 | Value value = op.getValue(); |
75 | Value multiplier32 = op.getMultiplier(); |
76 | |
77 | Type resultTy = op.getType(); |
78 | Type valueTy = value.getType(); |
79 | Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy); |
80 | Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy); |
81 | |
82 | Value zero = getConstantValue(loc, type: valueTy, value: 0, rewriter); |
83 | Value one64 = getConstantValue(loc, type: i64Ty, value: 1, rewriter); |
84 | Value thirtyOne32 = getConstantValue(loc, type: i32Ty, value: 31, rewriter); |
85 | |
86 | Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift()); |
87 | |
88 | // Compute the multiplication in 64-bits then select the high / low parts. |
89 | Value value64 = value; |
90 | if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type()) |
91 | value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value); |
92 | Value multiplier64 = |
93 | rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32); |
94 | Value multiply64 = |
95 | rewriter.create<arith::MulIOp>(loc, value64, multiplier64); |
96 | |
97 | // Apply normal rounding. |
98 | Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32); |
99 | Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64); |
100 | round = rewriter.create<arith::ShRUIOp>(loc, round, one64); |
101 | multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round); |
102 | |
103 | // Apply double rounding if necessary. |
104 | if (op.getRoundingMode() == "DOUBLE_ROUND" ) { |
105 | int64_t roundInt = 1 << 30; |
106 | Value roundUp = getConstantValue(loc, type: i64Ty, value: roundInt, rewriter); |
107 | Value roundDown = getConstantValue(loc, type: i64Ty, value: -roundInt, rewriter); |
108 | Value positive = rewriter.create<arith::CmpIOp>( |
109 | loc, arith::CmpIPredicate::sge, value, zero); |
110 | Value dir = |
111 | rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown); |
112 | Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64); |
113 | Value valid = rewriter.create<arith::CmpIOp>( |
114 | loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32); |
115 | multiply64 = |
116 | rewriter.create<arith::SelectOp>(loc, valid, val, multiply64); |
117 | } |
118 | |
119 | Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64); |
120 | Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64); |
121 | |
122 | rewriter.replaceOp(op, result32); |
123 | return success(); |
124 | } |
125 | }; |
126 | |
127 | class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> { |
128 | public: |
129 | using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern; |
130 | |
131 | LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, |
132 | PatternRewriter &rewriter) const final { |
133 | StringRef roundingMode = op.getRoundingMode(); |
134 | if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND" ) { |
135 | return failure(); |
136 | } |
137 | |
138 | Location loc = op.getLoc(); |
139 | |
140 | Type resultTy = op.getType(); |
141 | Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy); |
142 | |
143 | Value value = op.getValue(); |
144 | if (getElementTypeOrSelf(type: value.getType()).getIntOrFloatBitWidth() > 32) { |
145 | return failure(); |
146 | } |
147 | |
148 | Value value32 = op.getValue(); |
149 | Value multiplier32 = op.getMultiplier(); |
150 | Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift()); |
151 | |
152 | // Constants used during the scaling operation. |
153 | Value zero32 = getConstantValue(loc, type: i32Ty, value: 0, rewriter); |
154 | Value one32 = getConstantValue(loc, type: i32Ty, value: 1, rewriter); |
155 | Value two32 = getConstantValue(loc, type: i32Ty, value: 2, rewriter); |
156 | Value thirty32 = getConstantValue(loc, type: i32Ty, value: 30, rewriter); |
157 | Value thirtyTwo32 = getConstantValue(loc, type: i32Ty, value: 32, rewriter); |
158 | |
159 | // Compute the multiplication in 64-bits then select the high / low parts. |
160 | // Grab out the high/low of the computation |
161 | auto value64 = |
162 | rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32); |
163 | Value low32 = value64.getLow(); |
164 | Value high32 = value64.getHigh(); |
165 | |
166 | // Determine the direction and amount to shift the high bits. |
167 | Value shiftOver32 = rewriter.create<arith::CmpIOp>( |
168 | loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); |
169 | Value roundHighBits = rewriter.create<arith::CmpIOp>( |
170 | loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32); |
171 | |
172 | Value shiftHighL = |
173 | rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32); |
174 | Value shiftHighR = |
175 | rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32); |
176 | |
177 | shiftHighL = |
178 | rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL); |
179 | shiftHighR = |
180 | rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32); |
181 | |
182 | // Conditionally perform our double round. |
183 | if (op.getRoundingMode() == "DOUBLE_ROUND" ) { |
184 | Value negOne32 = getConstantValue(loc, type: i32Ty, value: -1, rewriter); |
185 | Value valuePositive = rewriter.create<arith::CmpIOp>( |
186 | loc, arith::CmpIPredicate::sge, value32, zero32); |
187 | |
188 | Value roundDir = |
189 | rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32); |
190 | roundDir = |
191 | rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32); |
192 | |
193 | Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32); |
194 | Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir); |
195 | Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32); |
196 | |
197 | Value shiftRound = |
198 | rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32); |
199 | |
200 | low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound); |
201 | high32 = rewriter.create<arith::AddIOp>(loc, high32, carry); |
202 | } |
203 | |
204 | // Conditionally apply rounding in the low bits. |
205 | { |
206 | Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32); |
207 | Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne); |
208 | roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32, |
209 | roundBit); |
210 | |
211 | Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit); |
212 | Value wasRounded = rewriter.create<arith::CmpIOp>( |
213 | loc, arith::CmpIPredicate::ugt, low32, newLow32); |
214 | low32 = newLow32; |
215 | |
216 | Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded); |
217 | high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32); |
218 | } |
219 | |
220 | // Conditionally apply rounding in the high bits. |
221 | { |
222 | Value shiftSubOne = |
223 | rewriter.create<arith::SubIOp>(loc, shiftHighR, one32); |
224 | Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne); |
225 | roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit, |
226 | zero32); |
227 | high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit); |
228 | } |
229 | |
230 | // Combine the correct high/low bits into the final rescale result. |
231 | high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL); |
232 | high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR); |
233 | low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32); |
234 | low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32); |
235 | |
236 | // Apply the rounding behavior and shift to the final alignment. |
237 | Value result = rewriter.create<arith::AddIOp>(loc, low32, high32); |
238 | |
239 | // Truncate if necessary. |
240 | if (!getElementTypeOrSelf(type: resultTy).isInteger(width: 32)) { |
241 | result = rewriter.create<arith::TruncIOp>(loc, resultTy, result); |
242 | } |
243 | |
244 | rewriter.replaceOp(op, result); |
245 | return success(); |
246 | } |
247 | }; |
248 | |
249 | } // namespace |
250 | |
251 | void mlir::tosa::populateTosaToArithConversionPatterns( |
252 | RewritePatternSet *patterns) { |
253 | patterns->add<ConstOpConverter>(arg: patterns->getContext()); |
254 | } |
255 | |
256 | void mlir::tosa::populateTosaRescaleToArithConversionPatterns( |
257 | RewritePatternSet *patterns, bool include32Bit) { |
258 | patterns->add<ApplyScaleGenericOpConverter>(arg: patterns->getContext(), args: 100); |
259 | if (include32Bit) { |
260 | patterns->add<ApplyScale32BitOpConverter>(arg: patterns->getContext(), args: 200); |
261 | } |
262 | } |
263 | |