1//===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Arith dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/TosaToArith/TosaToArith.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/IR/TypeUtilities.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20using namespace mlir;
21using namespace tosa;
22
23namespace {
24
25class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
26public:
27 using OpRewritePattern<tosa::ConstOp>::OpRewritePattern;
28
29 LogicalResult matchAndRewrite(tosa::ConstOp op,
30 PatternRewriter &rewriter) const final {
31 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValues());
32 return success();
33 }
34};
35
36Type matchContainerType(Type element, Type container) {
37 if (auto shapedTy = dyn_cast<ShapedType>(container))
38 return shapedTy.clone(element);
39
40 return element;
41}
42
43TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
44 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
45 Type eTy = shapedTy.getElementType();
46 APInt valueInt(eTy.getIntOrFloatBitWidth(), value, /*isSigned=*/true);
47 return DenseIntElementsAttr::get(shapedTy, valueInt);
48 }
49
50 return rewriter.getIntegerAttr(type, value);
51}
52
53Value getConstantValue(Location loc, Type type, int64_t value,
54 PatternRewriter &rewriter) {
55 return rewriter.create<arith::ConstantOp>(
56 loc, getConstantAttr(type, value, rewriter));
57}
58
59// This converts the TOSA ApplyScale operator to a set of arithmetic ops,
60// using 64-bit operations to perform the necessary multiply, bias, and shift.
61class ApplyScaleGenericOpConverter
62 : public OpRewritePattern<tosa::ApplyScaleOp> {
63public:
64 using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
65
66 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
67 PatternRewriter &rewriter) const final {
68 StringRef roundingMode = op.getRoundingMode();
69 if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
70 return failure();
71 }
72
73 Location loc = op.getLoc();
74 Value value = op.getValue();
75 Value multiplier32 = op.getMultiplier();
76
77 Type resultTy = op.getType();
78 Type valueTy = value.getType();
79 Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
80 Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
81
82 Value zero = getConstantValue(loc, type: valueTy, value: 0, rewriter);
83 Value one64 = getConstantValue(loc, type: i64Ty, value: 1, rewriter);
84 Value thirtyOne32 = getConstantValue(loc, type: i32Ty, value: 31, rewriter);
85
86 Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
87
88 // Compute the multiplication in 64-bits then select the high / low parts.
89 Value value64 = value;
90 if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type())
91 value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
92 Value multiplier64 =
93 rewriter.create<arith::ExtSIOp>(loc, i64Ty, multiplier32);
94 Value multiply64 =
95 rewriter.create<arith::MulIOp>(loc, value64, multiplier64);
96
97 // Apply normal rounding.
98 Value shift64 = rewriter.create<arith::ExtUIOp>(loc, i64Ty, shift32);
99 Value round = rewriter.create<arith::ShLIOp>(loc, one64, shift64);
100 round = rewriter.create<arith::ShRUIOp>(loc, round, one64);
101 multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
102
103 // Apply double rounding if necessary.
104 if (op.getRoundingMode() == "DOUBLE_ROUND") {
105 int64_t roundInt = 1 << 30;
106 Value roundUp = getConstantValue(loc, type: i64Ty, value: roundInt, rewriter);
107 Value roundDown = getConstantValue(loc, type: i64Ty, value: -roundInt, rewriter);
108 Value positive = rewriter.create<arith::CmpIOp>(
109 loc, arith::CmpIPredicate::sge, value, zero);
110 Value dir =
111 rewriter.create<arith::SelectOp>(loc, positive, roundUp, roundDown);
112 Value val = rewriter.create<arith::AddIOp>(loc, dir, multiply64);
113 Value valid = rewriter.create<arith::CmpIOp>(
114 loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32);
115 multiply64 =
116 rewriter.create<arith::SelectOp>(loc, valid, val, multiply64);
117 }
118
119 Value result64 = rewriter.create<arith::ShRSIOp>(loc, multiply64, shift64);
120 Value result32 = rewriter.create<arith::TruncIOp>(loc, i32Ty, result64);
121
122 rewriter.replaceOp(op, result32);
123 return success();
124 }
125};
126
127class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
128public:
129 using OpRewritePattern<tosa::ApplyScaleOp>::OpRewritePattern;
130
131 LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
132 PatternRewriter &rewriter) const final {
133 StringRef roundingMode = op.getRoundingMode();
134 if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
135 return failure();
136 }
137
138 Location loc = op.getLoc();
139
140 Type resultTy = op.getType();
141 Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
142
143 Value value = op.getValue();
144 if (getElementTypeOrSelf(type: value.getType()).getIntOrFloatBitWidth() > 32) {
145 return failure();
146 }
147
148 Value value32 = op.getValue();
149 Value multiplier32 = op.getMultiplier();
150 Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
151
152 // Constants used during the scaling operation.
153 Value zero32 = getConstantValue(loc, type: i32Ty, value: 0, rewriter);
154 Value one32 = getConstantValue(loc, type: i32Ty, value: 1, rewriter);
155 Value two32 = getConstantValue(loc, type: i32Ty, value: 2, rewriter);
156 Value thirty32 = getConstantValue(loc, type: i32Ty, value: 30, rewriter);
157 Value thirtyTwo32 = getConstantValue(loc, type: i32Ty, value: 32, rewriter);
158
159 // Compute the multiplication in 64-bits then select the high / low parts.
160 // Grab out the high/low of the computation
161 auto value64 =
162 rewriter.create<arith::MulSIExtendedOp>(loc, value32, multiplier32);
163 Value low32 = value64.getLow();
164 Value high32 = value64.getHigh();
165
166 // Determine the direction and amount to shift the high bits.
167 Value shiftOver32 = rewriter.create<arith::CmpIOp>(
168 loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32);
169 Value roundHighBits = rewriter.create<arith::CmpIOp>(
170 loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32);
171
172 Value shiftHighL =
173 rewriter.create<arith::SubIOp>(loc, thirtyTwo32, shift32);
174 Value shiftHighR =
175 rewriter.create<arith::SubIOp>(loc, shift32, thirtyTwo32);
176
177 shiftHighL =
178 rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, shiftHighL);
179 shiftHighR =
180 rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
181
182 // Conditionally perform our double round.
183 if (op.getRoundingMode() == "DOUBLE_ROUND") {
184 Value negOne32 = getConstantValue(loc, type: i32Ty, value: -1, rewriter);
185 Value valuePositive = rewriter.create<arith::CmpIOp>(
186 loc, arith::CmpIPredicate::sge, value32, zero32);
187
188 Value roundDir =
189 rewriter.create<arith::SelectOp>(loc, valuePositive, one32, negOne32);
190 roundDir =
191 rewriter.create<arith::SelectOp>(loc, shiftOver32, roundDir, zero32);
192
193 Value shiftLow = rewriter.create<arith::ShRUIOp>(loc, low32, thirty32);
194 Value rounded = rewriter.create<arith::AddIOp>(loc, shiftLow, roundDir);
195 Value carry = rewriter.create<arith::ShRSIOp>(loc, rounded, two32);
196
197 Value shiftRound =
198 rewriter.create<arith::ShLIOp>(loc, roundDir, thirty32);
199
200 low32 = rewriter.create<arith::AddIOp>(loc, low32, shiftRound);
201 high32 = rewriter.create<arith::AddIOp>(loc, high32, carry);
202 }
203
204 // Conditionally apply rounding in the low bits.
205 {
206 Value shiftSubOne = rewriter.create<arith::SubIOp>(loc, shift32, one32);
207 Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
208 roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, zero32,
209 roundBit);
210
211 Value newLow32 = rewriter.create<arith::AddIOp>(loc, low32, roundBit);
212 Value wasRounded = rewriter.create<arith::CmpIOp>(
213 loc, arith::CmpIPredicate::ugt, low32, newLow32);
214 low32 = newLow32;
215
216 Value rounded32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, wasRounded);
217 high32 = rewriter.create<arith::AddIOp>(loc, high32, rounded32);
218 }
219
220 // Conditionally apply rounding in the high bits.
221 {
222 Value shiftSubOne =
223 rewriter.create<arith::SubIOp>(loc, shiftHighR, one32);
224 Value roundBit = rewriter.create<arith::ShLIOp>(loc, one32, shiftSubOne);
225 roundBit = rewriter.create<arith::SelectOp>(loc, roundHighBits, roundBit,
226 zero32);
227 high32 = rewriter.create<arith::AddIOp>(loc, high32, roundBit);
228 }
229
230 // Combine the correct high/low bits into the final rescale result.
231 high32 = rewriter.create<arith::ShLIOp>(loc, high32, shiftHighL);
232 high32 = rewriter.create<arith::ShRSIOp>(loc, high32, shiftHighR);
233 low32 = rewriter.create<arith::ShRUIOp>(loc, low32, shift32);
234 low32 = rewriter.create<arith::SelectOp>(loc, shiftOver32, zero32, low32);
235
236 // Apply the rounding behavior and shift to the final alignment.
237 Value result = rewriter.create<arith::AddIOp>(loc, low32, high32);
238
239 // Truncate if necessary.
240 if (!getElementTypeOrSelf(type: resultTy).isInteger(width: 32)) {
241 result = rewriter.create<arith::TruncIOp>(loc, resultTy, result);
242 }
243
244 rewriter.replaceOp(op, result);
245 return success();
246 }
247};
248
249} // namespace
250
251void mlir::tosa::populateTosaToArithConversionPatterns(
252 RewritePatternSet *patterns) {
253 patterns->add<ConstOpConverter>(arg: patterns->getContext());
254}
255
256void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
257 RewritePatternSet *patterns, bool include32Bit) {
258 patterns->add<ApplyScaleGenericOpConverter>(arg: patterns->getContext(), args: 100);
259 if (include32Bit) {
260 patterns->add<ApplyScale32BitOpConverter>(arg: patterns->getContext(), args: 200);
261 }
262}
263

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/TosaToArith/TosaToArith.cpp