1 | //===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // These rewriters lower from the Tosa to the Arith dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/TosaToArith/TosaToArith.h" |
14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | #include "mlir/Dialect/Tosa/IR/TosaOps.h" |
16 | #include "mlir/IR/PatternMatch.h" |
17 | #include "mlir/IR/TypeUtilities.h" |
18 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace tosa; |
22 | |
23 | namespace { |
24 | |
25 | class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> { |
26 | public: |
27 | using OpRewritePattern<tosa::ConstOp>::OpRewritePattern; |
28 | |
29 | LogicalResult matchAndRewrite(tosa::ConstOp op, |
30 | PatternRewriter &rewriter) const final { |
31 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue()); |
32 | return success(); |
33 | } |
34 | }; |
35 | |
36 | Type matchContainerType(Type element, Type container) { |
37 | if (auto shapedTy = dyn_cast<ShapedType>(container)) |
38 | return shapedTy.clone(element); |
39 | |
40 | return element; |
41 | } |
42 | |
43 | TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { |
44 | if (auto shapedTy = dyn_cast<ShapedType>(type)) { |
45 | Type eTy = shapedTy.getElementType(); |
46 | APInt valueInt(eTy.getIntOrFloatBitWidth(), value); |
47 | return DenseIntElementsAttr::get(shapedTy, valueInt); |
48 | } |
49 | |
50 | return rewriter.getIntegerAttr(type, value); |
51 | } |
52 | |
53 | Value getConstantValue(Location loc, Type type, int64_t value, |
54 | PatternRewriter &rewriter) { |
55 | return rewriter.create<arith::ConstantOp>( |
56 | loc, getConstantAttr(type, value, rewriter)); |
57 | } |
58 | |
59 | // This converts the TOSA ApplyScale operator to a set of arithmetic ops, |
60 | // using 64-bit operations to perform the necessary multiply, bias, and shift. |
61 | class ApplyScaleGenericOpConverter |
62 | : public OpRewritePattern<tosa::ApplyScaleOp> { |
63 | public: |
64 | using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern; |
65 | |
66 | LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, |
67 | PatternRewriter &rewriter) const final { |
68 | Location loc = op.getLoc(); |
69 | Value value = op.getValue(); |
70 | Value multiplier32 = op.getMultiplier(); |
71 | |
72 | Type resultTy = op.getType(); |
73 | Type valueTy = value.getType(); |
74 | Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy); |
75 | Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy); |
76 | |
77 | Value zero = getConstantValue(loc, type: valueTy, value: 0, rewriter); |
78 | Value one64 = getConstantValue(loc, type: i64Ty, value: 1, rewriter); |
79 | Value thirtyOne32 = getConstantValue(loc, type: i32Ty, value: 31, rewriter); |
80 | |
81 | Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift()); |
82 | |
83 | // Compute the multiplication in 64-bits then select the high / low parts. |
84 | Value value64 = value; |
85 | if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type()) |
86 | value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value); |
87 | Value multiplier64 = |
88 | rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32); |
89 | Value multiply64 = |
90 | rewriter.create<arith::MulIOp>(loc, value64, multiplier64); |
91 | |
92 | // Apply normal rounding. |
93 | Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32); |
94 | Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64); |
95 | round = rewriter.create<arith::ShRUIOp>(loc, round, one64); |
96 | multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round); |
97 | |
98 | // Apply double rounding if necessary. |
99 | if (op.getDoubleRound()) { |
100 | int64_t roundInt = 1 << 30; |
101 | Value roundUp = getConstantValue(loc, type: i64Ty, value: roundInt, rewriter); |
102 | Value roundDown = getConstantValue(loc, type: i64Ty, value: -roundInt, rewriter); |
103 | Value positive = rewriter.create<arith::CmpIOp>( |
104 | loc, arith::CmpIPredicate::sge, value, zero); |
105 | Value dir = |
106 | rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown); |
107 | Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64); |
108 | Value valid = rewriter.create<arith::CmpIOp>( |
109 | loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32); |
110 | multiply64 = |
111 | rewriter.create<arith::SelectOp>(loc, valid, val, multiply64); |
112 | } |
113 | |
114 | Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64); |
115 | Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64); |
116 | |
117 | rewriter.replaceOp(op, result32); |
118 | return success(); |
119 | } |
120 | }; |
121 | |
122 | class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> { |
123 | public: |
124 | using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern; |
125 | |
126 | LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, |
127 | PatternRewriter &rewriter) const final { |
128 | Location loc = op.getLoc(); |
129 | |
130 | Type resultTy = op.getType(); |
131 | Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy); |
132 | |
133 | Value value = op.getValue(); |
134 | if (getElementTypeOrSelf(type: value.getType()).getIntOrFloatBitWidth() > 32) { |
135 | return failure(); |
136 | } |
137 | |
138 | Value value32 = op.getValue(); |
139 | Value multiplier32 = op.getMultiplier(); |
140 | Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift()); |
141 | |
142 | // Constants used during the scaling operation. |
143 | Value zero32 = getConstantValue(loc, type: i32Ty, value: 0, rewriter); |
144 | Value one32 = getConstantValue(loc, type: i32Ty, value: 1, rewriter); |
145 | Value two32 = getConstantValue(loc, type: i32Ty, value: 2, rewriter); |
146 | Value thirty32 = getConstantValue(loc, type: i32Ty, value: 30, rewriter); |
147 | Value thirtyTwo32 = getConstantValue(loc, type: i32Ty, value: 32, rewriter); |
148 | |
149 | // Compute the multiplication in 64-bits then select the high / low parts. |
150 | // Grab out the high/low of the computation |
151 | auto value64 = |
152 | rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32); |
153 | Value low32 = value64.getLow(); |
154 | Value high32 = value64.getHigh(); |
155 | |
156 | // Determine the direction and amount to shift the high bits. |
157 | Value shiftOver32 = rewriter.create<arith::CmpIOp>( |
158 | loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); |
159 | Value roundHighBits = rewriter.create<arith::CmpIOp>( |
160 | loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32); |
161 | |
162 | Value shiftHighL = |
163 | rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32); |
164 | Value shiftHighR = |
165 | rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32); |
166 | |
167 | shiftHighL = |
168 | rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL); |
169 | shiftHighR = |
170 | rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32); |
171 | |
172 | // Conditionally perform our double round. |
173 | if (op.getDoubleRound()) { |
174 | Value negOne32 = getConstantValue(loc, type: i32Ty, value: -1, rewriter); |
175 | Value valuePositive = rewriter.create<arith::CmpIOp>( |
176 | loc, arith::CmpIPredicate::sge, value32, zero32); |
177 | |
178 | Value roundDir = |
179 | rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32); |
180 | roundDir = |
181 | rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32); |
182 | |
183 | Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32); |
184 | Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir); |
185 | Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32); |
186 | |
187 | Value shiftRound = |
188 | rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32); |
189 | |
190 | low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound); |
191 | high32 = rewriter.create<arith::AddIOp>(loc, high32, carry); |
192 | } |
193 | |
194 | // Conditionally apply rounding in the low bits. |
195 | { |
196 | Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32); |
197 | Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne); |
198 | roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32, |
199 | roundBit); |
200 | |
201 | Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit); |
202 | Value wasRounded = rewriter.create<arith::CmpIOp>( |
203 | loc, arith::CmpIPredicate::ugt, low32, newLow32); |
204 | low32 = newLow32; |
205 | |
206 | Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded); |
207 | high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32); |
208 | } |
209 | |
210 | // Conditionally apply rounding in the high bits. |
211 | { |
212 | Value shiftSubOne = |
213 | rewriter.create<arith::SubIOp>(loc, shiftHighR, one32); |
214 | Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne); |
215 | roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit, |
216 | zero32); |
217 | high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit); |
218 | } |
219 | |
220 | // Combine the correct high/low bits into the final rescale result. |
221 | high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL); |
222 | high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR); |
223 | low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32); |
224 | low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32); |
225 | |
226 | // Apply the rounding behavior and shift to the final alignment. |
227 | Value result = rewriter.create<arith::AddIOp>(loc, low32, high32); |
228 | |
229 | // Truncate if necessary. |
230 | if (!getElementTypeOrSelf(type: resultTy).isInteger(width: 32)) { |
231 | result = rewriter.create<arith::TruncIOp>(loc, resultTy, result); |
232 | } |
233 | |
234 | rewriter.replaceOp(op, result); |
235 | return success(); |
236 | } |
237 | }; |
238 | |
239 | } // namespace |
240 | |
241 | void mlir::tosa::populateTosaToArithConversionPatterns( |
242 | RewritePatternSet *patterns) { |
243 | patterns->add<ConstOpConverter>(arg: patterns->getContext()); |
244 | } |
245 | |
246 | void mlir::tosa::populateTosaRescaleToArithConversionPatterns( |
247 | RewritePatternSet *patterns, bool include32Bit) { |
248 | patterns->add<ApplyScaleGenericOpConverter>(arg: patterns->getContext(), args: 100); |
249 | if (include32Bit) { |
250 | patterns->add<ApplyScale32BitOpConverter>(arg: patterns->getContext(), args: 200); |
251 | } |
252 | } |
253 | |