1//===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/Transforms/Passes.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Vector/IR/VectorOps.h"
13#include "mlir/IR/ImplicitLocOpBuilder.h"
14#include "mlir/IR/TypeUtilities.h"
15#include "mlir/Transforms/DialectConversion.h"
16
17namespace mlir {
18namespace arith {
19#define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21} // namespace arith
22} // namespace mlir
23
24using namespace mlir;
25
26/// Create an integer or index constant.
27static Value createConst(Location loc, Type type, int value,
28 PatternRewriter &rewriter) {
29 auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return rewriter.create<arith::ConstantOp>(
32 loc, DenseElementsAttr::get(shapedTy, attr));
33 }
34
35 return rewriter.create<arith::ConstantOp>(loc, attr);
36}
37
38namespace {
39
40/// Expands CeilDivUIOp (n, m) into
41/// n == 0 ? 0 : ((n-1) / m) + 1
42struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
43 using OpRewritePattern::OpRewritePattern;
44 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
45 PatternRewriter &rewriter) const final {
46 Location loc = op.getLoc();
47 Value a = op.getLhs();
48 Value b = op.getRhs();
49 Value zero = createConst(loc, type: a.getType(), value: 0, rewriter);
50 Value compare =
51 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
52 Value one = createConst(loc, type: a.getType(), value: 1, rewriter);
53 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
54 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
55 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
56 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
57 return success();
58 }
59};
60
61/// Expands CeilDivSIOp (n, m) into
62/// 1) x = (m > 0) ? -1 : 1
63/// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
64struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
65 using OpRewritePattern::OpRewritePattern;
66 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
67 PatternRewriter &rewriter) const final {
68 Location loc = op.getLoc();
69 Type type = op.getType();
70 Value a = op.getLhs();
71 Value b = op.getRhs();
72 Value plusOne = createConst(loc, type, value: 1, rewriter);
73 Value zero = createConst(loc, type, value: 0, rewriter);
74 Value minusOne = createConst(loc, type, value: -1, rewriter);
75 // Compute x = (b>0) ? -1 : 1.
76 Value compare =
77 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
78 Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
79 // Compute positive res: 1 + ((x+a)/b).
80 Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
81 Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
82 Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
83 // Compute negative res: - ((-a)/b).
84 Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
85 Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
86 Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
87 // Result is (a*b>0) ? pos result : neg result.
88 // Note, we want to avoid using a*b because of possible overflow.
89 // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
90 // not particuliarly care if a*b<0 is true or false when b is zero
91 // as this will result in an illegal divide. So `a*b<0` can be reformulated
92 // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
93 // We pick the first expression here.
94 Value aNeg =
95 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
96 Value aPos =
97 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
98 Value bNeg =
99 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
100 Value bPos =
101 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
102 Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
103 Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
104 Value compareRes =
105 rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
106 // Perform substitution and return success.
107 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
108 negRes);
109 return success();
110 }
111};
112
113/// Expands FloorDivSIOp (x, y) into
114/// z = x / y
115/// if (z * y != x && (x < 0) != (y < 0)) {
116/// return z - 1;
117/// } else {
118/// return z;
119/// }
120struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
121 using OpRewritePattern::OpRewritePattern;
122 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
123 PatternRewriter &rewriter) const final {
124 Location loc = op.getLoc();
125 Type type = op.getType();
126 Value a = op.getLhs();
127 Value b = op.getRhs();
128
129 Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
130 Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
131 Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
132 loc, arith::CmpIPredicate::ne, a, product);
133 Value zero = createConst(loc, type, value: 0, rewriter);
134
135 Value aNeg =
136 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
137 Value bNeg =
138 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
139
140 Value signOpposite = rewriter.create<arith::CmpIOp>(
141 loc, arith::CmpIPredicate::ne, aNeg, bNeg);
142 Value cond =
143 rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
144
145 Value minusOne = createConst(loc, type, value: -1, rewriter);
146 Value quotientMinusOne =
147 rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
148
149 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
150 quotient);
151 return success();
152 }
153};
154
155template <typename OpTy, arith::CmpFPredicate pred>
156struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
157public:
158 using OpRewritePattern<OpTy>::OpRewritePattern;
159
160 LogicalResult matchAndRewrite(OpTy op,
161 PatternRewriter &rewriter) const final {
162 Value lhs = op.getLhs();
163 Value rhs = op.getRhs();
164
165 Location loc = op.getLoc();
166 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
167 static_assert(pred == arith::CmpFPredicate::UGT ||
168 pred == arith::CmpFPredicate::ULT,
169 "pred must be either UGT or ULT");
170 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
171 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
172
173 // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
174 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
175 rhs, rhs);
176 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
177 return success();
178 }
179};
180
181template <typename OpTy, arith::CmpFPredicate pred>
182struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
183public:
184 using OpRewritePattern<OpTy>::OpRewritePattern;
185
186 LogicalResult matchAndRewrite(OpTy op,
187 PatternRewriter &rewriter) const final {
188 Value lhs = op.getLhs();
189 Value rhs = op.getRhs();
190
191 Location loc = op.getLoc();
192 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
193 static_assert(pred == arith::CmpFPredicate::UGT ||
194 pred == arith::CmpFPredicate::ULT,
195 "pred must be either UGT or ULT");
196 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
197 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
198
199 // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
200 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
201 lhs, lhs);
202 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
203 return success();
204 }
205};
206
207struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
208 using OpRewritePattern::OpRewritePattern;
209 LogicalResult matchAndRewrite(arith::ExtFOp op,
210 PatternRewriter &rewriter) const final {
211 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
212 auto operand = op.getOperand();
213 Type operandTy = operand.getType();
214 Type resultTy = op.getType();
215 Type operandETy = getElementTypeOrSelf(type: operandTy);
216 Type resultETy = getElementTypeOrSelf(type: resultTy);
217
218 if (!operandETy.isBF16() || !resultETy.isF32()) {
219 return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
220 }
221
222 Type i16Ty = b.getI16Type();
223 Type i32Ty = b.getI32Type();
224 if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
225 i16Ty = shapedTy.clone(i16Ty);
226 i32Ty = shapedTy.clone(i32Ty);
227 }
228
229 Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
230 Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
231
232 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
233 Value shl = b.create<arith::ShLIOp>(exti, c16);
234 Value result = b.create<arith::BitcastOp>(resultTy, shl);
235
236 rewriter.replaceOp(op, result);
237 return success();
238 }
239};
240
241struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
242 using OpRewritePattern::OpRewritePattern;
243 LogicalResult matchAndRewrite(arith::TruncFOp op,
244 PatternRewriter &rewriter) const final {
245 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
246 auto operand = op.getOperand();
247 Type operandTy = operand.getType();
248 Type resultTy = op.getType();
249 Type operandETy = getElementTypeOrSelf(type: operandTy);
250 Type resultETy = getElementTypeOrSelf(type: resultTy);
251
252 if (!operandETy.isF32() || !resultETy.isBF16()) {
253 return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
254 }
255
256 if (op.getRoundingmodeAttr()) {
257 return rewriter.notifyMatchFailure(
258 op, "only applicable to default rounding mode.");
259 }
260
261 Type i16Ty = b.getI16Type();
262 Type i32Ty = b.getI32Type();
263 Type f32Ty = b.getF32Type();
264 if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
265 i16Ty = shapedTy.clone(i16Ty);
266 i32Ty = shapedTy.clone(i32Ty);
267 f32Ty = shapedTy.clone(f32Ty);
268 }
269
270 // Algorithm borrowed from this excellent code:
271 // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
272 // There is a magic idea there, to let the addition of the rounding_bias to
273 // the mantissa simply overflow into the exponent bits. It's a bit of an
274 // aggressive, obfuscating optimization, but it is well-tested code, and it
275 // results in more concise and efficient IR.
276 // The case of NaN is handled separately (see isNaN and the final select).
277 // The case of infinities is NOT handled separately, which deserves an
278 // explanation. As the encoding of infinities has zero mantissa, the
279 // rounding-bias addition never carries into the exponent so that just gets
280 // truncated away, and as bfloat16 and float32 have the same number of
281 // exponent bits, that simple truncation is the desired outcome for
282 // infinities.
283 Value isNan =
284 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
285 // Constant used to make the rounding bias.
286 Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
287 // Constant used to generate a quiet NaN.
288 Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
289 // Small constants used to address bits.
290 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
291 Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
292 // Reinterpret the input f32 value as bits.
293 Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
294 // Read bit 16 as a value in {0,1}.
295 Value bit16 =
296 b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
297 // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
298 // on bit 16, implementing the tie-breaking "to nearest even".
299 Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
300 // Add the rounding bias. Generally we want this to be added to the
301 // mantissa, but nothing prevents this to from carrying into the exponent
302 // bits, which would feel like a bug, but this is the magic trick here:
303 // when that happens, the mantissa gets reset to zero and the exponent
304 // gets incremented by the carry... which is actually exactly what we
305 // want.
306 Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
307 // Now that the rounding-bias has been added, truncating the low bits
308 // yields the correctly rounded result.
309 Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
310 Value normalCaseResult_i16 =
311 b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
312 // Select either the above-computed result, or a quiet NaN constant
313 // if the input was NaN.
314 Value select =
315 b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
316 Value result = b.create<arith::BitcastOp>(resultTy, select);
317 rewriter.replaceOp(op, result);
318 return success();
319 }
320};
321
322struct ArithExpandOpsPass
323 : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
324 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
325
326 void runOnOperation() override {
327 RewritePatternSet patterns(&getContext());
328 ConversionTarget target(getContext());
329
330 arith::populateArithExpandOpsPatterns(patterns);
331
332 target.addLegalDialect<arith::ArithDialect>();
333 // clang-format off
334 target.addIllegalOp<
335 arith::CeilDivSIOp,
336 arith::CeilDivUIOp,
337 arith::FloorDivSIOp,
338 arith::MaximumFOp,
339 arith::MinimumFOp,
340 arith::MaxNumFOp,
341 arith::MinNumFOp
342 >();
343
344 if (includeBf16) {
345 arith::populateExpandBFloat16Patterns(patterns);
346 target.addDynamicallyLegalOp<arith::ExtFOp>(
347 [](arith::ExtFOp op) {
348 Type inETy = getElementTypeOrSelf(op.getOperand().getType());
349 Type outETy = getElementTypeOrSelf(op.getType());
350 return !(inETy.isBF16() && outETy.isF32());
351 });
352
353 target.addDynamicallyLegalOp<arith::TruncFOp>(
354 [](arith::TruncFOp op) {
355 Type inETy = getElementTypeOrSelf(op.getOperand().getType());
356 Type outETy = getElementTypeOrSelf(op.getType());
357 return !(inETy.isF32() && outETy.isBF16());
358 });
359 }
360
361 // clang-format on
362 if (failed(applyPartialConversion(getOperation(), target,
363 std::move(patterns))))
364 signalPassFailure();
365 }
366};
367
368} // namespace
369
370void mlir::arith::populateCeilFloorDivExpandOpsPatterns(
371 RewritePatternSet &patterns) {
372 patterns
373 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
374 arg: patterns.getContext());
375}
376
377void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
378 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
379 arg: patterns.getContext());
380}
381
382void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
383 populateCeilFloorDivExpandOpsPatterns(patterns);
384 // clang-format off
385 patterns.add<
386 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
387 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
388 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
389 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
390 >(patterns.getContext());
391 // clang-format on
392}
393

source code of mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp