1//===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/Arith/Transforms/Passes.h"
11#include "mlir/Dialect/Vector/IR/VectorOps.h"
12#include "mlir/IR/BuiltinTypeInterfaces.h"
13#include "mlir/IR/ImplicitLocOpBuilder.h"
14#include "mlir/IR/Location.h"
15#include "mlir/IR/TypeUtilities.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18namespace mlir {
19namespace arith {
20#define GEN_PASS_DEF_ARITHEXPANDOPSPASS
21#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
22} // namespace arith
23} // namespace mlir
24
25using namespace mlir;
26
27/// Create an integer or index constant.
28static Value createConst(Location loc, Type type, int value,
29 PatternRewriter &rewriter) {
30 auto attr = rewriter.getIntegerAttr(type: getElementTypeOrSelf(type), value);
31 if (auto shapedTy = dyn_cast<ShapedType>(Val&: type)) {
32 return rewriter.create<arith::ConstantOp>(
33 location: loc, args: DenseElementsAttr::get(type: shapedTy, values: attr));
34 }
35 return rewriter.create<arith::ConstantOp>(location: loc, args&: attr);
36}
37
38/// Create a float constant.
39static Value createFloatConst(Location loc, Type type, APFloat value,
40 PatternRewriter &rewriter) {
41 auto attr = rewriter.getFloatAttr(type: getElementTypeOrSelf(type), value);
42 if (auto shapedTy = dyn_cast<ShapedType>(Val&: type)) {
43 return rewriter.create<arith::ConstantOp>(
44 location: loc, args: DenseElementsAttr::get(type: shapedTy, values: attr));
45 }
46
47 return rewriter.create<arith::ConstantOp>(location: loc, args&: attr);
48}
49
50/// Creates shapedType using shape from cloneFrom and base type from cloneTo
51static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
52 if (auto shapedTy = dyn_cast<ShapedType>(Val&: cloneFrom)) {
53 return shapedTy.clone(elementType: cloneTo);
54 }
55 return cloneTo;
56}
57
58namespace {
59
60/// Expands CeilDivUIOp (n, m) into
61/// n == 0 ? 0 : ((n-1) / m) + 1
62struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
63 using OpRewritePattern::OpRewritePattern;
64 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
65 PatternRewriter &rewriter) const final {
66 Location loc = op.getLoc();
67 Value a = op.getLhs();
68 Value b = op.getRhs();
69 Value zero = createConst(loc, type: a.getType(), value: 0, rewriter);
70 Value compare =
71 rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::eq, args&: a, args&: zero);
72 Value one = createConst(loc, type: a.getType(), value: 1, rewriter);
73 Value minusOne = rewriter.create<arith::SubIOp>(location: loc, args&: a, args&: one);
74 Value quotient = rewriter.create<arith::DivUIOp>(location: loc, args&: minusOne, args&: b);
75 Value plusOne = rewriter.create<arith::AddIOp>(location: loc, args&: quotient, args&: one);
76 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, args&: compare, args&: zero, args&: plusOne);
77 return success();
78 }
79};
80
81/// Expands CeilDivSIOp (a, b) into
82/// z = a / b
83/// if (z * b != a && (a < 0) == (b < 0)) {
84/// return z + 1;
85/// } else {
86/// return z;
87/// }
88struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
89 using OpRewritePattern::OpRewritePattern;
90 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
91 PatternRewriter &rewriter) const final {
92 Location loc = op.getLoc();
93 Type type = op.getType();
94 Value a = op.getLhs();
95 Value b = op.getRhs();
96
97 Value zero = createConst(loc, type, value: 0, rewriter);
98 Value one = createConst(loc, type, value: 1, rewriter);
99
100 Value quotient = rewriter.create<arith::DivSIOp>(location: loc, args&: a, args&: b);
101 Value product = rewriter.create<arith::MulIOp>(location: loc, args&: quotient, args&: b);
102 Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
103 location: loc, args: arith::CmpIPredicate::ne, args&: a, args&: product);
104
105 Value aNeg =
106 rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::slt, args&: a, args&: zero);
107 Value bNeg =
108 rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::slt, args&: b, args&: zero);
109
110 Value signEqual = rewriter.create<arith::CmpIOp>(
111 location: loc, args: arith::CmpIPredicate::eq, args&: aNeg, args&: bNeg);
112 Value cond =
113 rewriter.create<arith::AndIOp>(location: loc, args&: notEqualDivisor, args&: signEqual);
114
115 Value quotientPlusOne = rewriter.create<arith::AddIOp>(location: loc, args&: quotient, args&: one);
116
117 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, args&: cond, args&: quotientPlusOne,
118 args&: quotient);
119 return success();
120 }
121};
122
123/// Expands FloorDivSIOp (x, y) into
124/// z = x / y
125/// if (z * y != x && (x < 0) != (y < 0)) {
126/// return z - 1;
127/// } else {
128/// return z;
129/// }
130struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
131 using OpRewritePattern::OpRewritePattern;
132 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
133 PatternRewriter &rewriter) const final {
134 Location loc = op.getLoc();
135 Type type = op.getType();
136 Value a = op.getLhs();
137 Value b = op.getRhs();
138
139 Value quotient = rewriter.create<arith::DivSIOp>(location: loc, args&: a, args&: b);
140 Value product = rewriter.create<arith::MulIOp>(location: loc, args&: quotient, args&: b);
141 Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
142 location: loc, args: arith::CmpIPredicate::ne, args&: a, args&: product);
143 Value zero = createConst(loc, type, value: 0, rewriter);
144
145 Value aNeg =
146 rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::slt, args&: a, args&: zero);
147 Value bNeg =
148 rewriter.create<arith::CmpIOp>(location: loc, args: arith::CmpIPredicate::slt, args&: b, args&: zero);
149
150 Value signOpposite = rewriter.create<arith::CmpIOp>(
151 location: loc, args: arith::CmpIPredicate::ne, args&: aNeg, args&: bNeg);
152 Value cond =
153 rewriter.create<arith::AndIOp>(location: loc, args&: notEqualDivisor, args&: signOpposite);
154
155 Value minusOne = createConst(loc, type, value: -1, rewriter);
156 Value quotientMinusOne =
157 rewriter.create<arith::AddIOp>(location: loc, args&: quotient, args&: minusOne);
158
159 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, args&: cond, args&: quotientMinusOne,
160 args&: quotient);
161 return success();
162 }
163};
164
165template <typename OpTy, arith::CmpIPredicate pred>
166struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
167public:
168 using OpRewritePattern<OpTy>::OpRewritePattern;
169
170 LogicalResult matchAndRewrite(OpTy op,
171 PatternRewriter &rewriter) const final {
172 Value lhs = op.getLhs();
173 Value rhs = op.getRhs();
174
175 Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
176 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
177 return success();
178 }
179};
180
181template <typename OpTy, arith::CmpFPredicate pred>
182struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
183public:
184 using OpRewritePattern<OpTy>::OpRewritePattern;
185
186 LogicalResult matchAndRewrite(OpTy op,
187 PatternRewriter &rewriter) const final {
188 Value lhs = op.getLhs();
189 Value rhs = op.getRhs();
190
191 Location loc = op.getLoc();
192 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
193 static_assert(pred == arith::CmpFPredicate::UGT ||
194 pred == arith::CmpFPredicate::ULT,
195 "pred must be either UGT or ULT");
196 Value cmp = rewriter.create<arith::CmpFOp>(location: loc, args: pred, args&: lhs, args&: rhs);
197 Value select = rewriter.create<arith::SelectOp>(location: loc, args&: cmp, args&: lhs, args&: rhs);
198
199 // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
200 Value isNaN = rewriter.create<arith::CmpFOp>(location: loc, args: arith::CmpFPredicate::UNO,
201 args&: rhs, args&: rhs);
202 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
203 return success();
204 }
205};
206
207template <typename OpTy, arith::CmpFPredicate pred>
208struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
209public:
210 using OpRewritePattern<OpTy>::OpRewritePattern;
211
212 LogicalResult matchAndRewrite(OpTy op,
213 PatternRewriter &rewriter) const final {
214 Value lhs = op.getLhs();
215 Value rhs = op.getRhs();
216
217 Location loc = op.getLoc();
218 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
219 static_assert(pred == arith::CmpFPredicate::UGT ||
220 pred == arith::CmpFPredicate::ULT,
221 "pred must be either UGT or ULT");
222 Value cmp = rewriter.create<arith::CmpFOp>(location: loc, args: pred, args&: lhs, args&: rhs);
223 Value select = rewriter.create<arith::SelectOp>(location: loc, args&: cmp, args&: lhs, args&: rhs);
224
225 // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
226 Value isNaN = rewriter.create<arith::CmpFOp>(location: loc, args: arith::CmpFPredicate::UNO,
227 args&: lhs, args&: lhs);
228 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
229 return success();
230 }
231};
232
233struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
234 using OpRewritePattern::OpRewritePattern;
235 LogicalResult matchAndRewrite(arith::ExtFOp op,
236 PatternRewriter &rewriter) const final {
237 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
238 auto operand = op.getOperand();
239 Type operandTy = operand.getType();
240 Type resultTy = op.getType();
241 Type operandETy = getElementTypeOrSelf(type: operandTy);
242 Type resultETy = getElementTypeOrSelf(type: resultTy);
243
244 if (!operandETy.isBF16() || !resultETy.isF32()) {
245 return rewriter.notifyMatchFailure(arg&: op, msg: "not a ext of bf16 to f32.");
246 }
247
248 Type i16Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI16Type());
249 Type i32Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI32Type());
250
251 Value bitcast = b.create<arith::BitcastOp>(args&: i16Ty, args&: operand);
252 Value exti = b.create<arith::ExtUIOp>(args&: i32Ty, args&: bitcast);
253
254 Value c16 = createConst(loc: op.getLoc(), type: i32Ty, value: 16, rewriter);
255 Value shl = b.create<arith::ShLIOp>(args&: exti, args&: c16);
256 Value result = b.create<arith::BitcastOp>(args&: resultTy, args&: shl);
257
258 rewriter.replaceOp(op, newValues: result);
259 return success();
260 }
261};
262
263struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
264 using OpRewritePattern::OpRewritePattern;
265 LogicalResult matchAndRewrite(arith::TruncFOp op,
266 PatternRewriter &rewriter) const final {
267 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
268 auto operand = op.getOperand();
269 Type operandTy = operand.getType();
270 Type resultTy = op.getType();
271 Type operandETy = getElementTypeOrSelf(type: operandTy);
272 Type resultETy = getElementTypeOrSelf(type: resultTy);
273
274 if (!operandETy.isF32() || !resultETy.isBF16()) {
275 return rewriter.notifyMatchFailure(arg&: op, msg: "not a trunc of f32 to bf16.");
276 }
277
278 if (op.getRoundingmodeAttr()) {
279 return rewriter.notifyMatchFailure(
280 arg&: op, msg: "only applicable to default rounding mode.");
281 }
282
283 Type i16Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI16Type());
284 Type i32Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI32Type());
285
286 // Algorithm borrowed from this excellent code:
287 // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
288 // There is a magic idea there, to let the addition of the rounding_bias to
289 // the mantissa simply overflow into the exponent bits. It's a bit of an
290 // aggressive, obfuscating optimization, but it is well-tested code, and it
291 // results in more concise and efficient IR.
292 // The case of NaN is handled separately (see isNaN and the final select).
293 // The case of infinities is NOT handled separately, which deserves an
294 // explanation. As the encoding of infinities has zero mantissa, the
295 // rounding-bias addition never carries into the exponent so that just gets
296 // truncated away, and as bfloat16 and float32 have the same number of
297 // exponent bits, that simple truncation is the desired outcome for
298 // infinities.
299 Value isNan =
300 b.create<arith::CmpFOp>(args: arith::CmpFPredicate::UNE, args&: operand, args&: operand);
301 // Constant used to make the rounding bias.
302 Value c7FFF = createConst(loc: op.getLoc(), type: i32Ty, value: 0x7fff, rewriter);
303 // Constant used to generate a quiet NaN.
304 Value c7FC0I16 = createConst(loc: op.getLoc(), type: i16Ty, value: 0x7fc0, rewriter);
305 // Small constants used to address bits.
306 Value c16 = createConst(loc: op.getLoc(), type: i32Ty, value: 16, rewriter);
307 Value c1 = createConst(loc: op.getLoc(), type: i32Ty, value: 1, rewriter);
308 // Reinterpret the input f32 value as bits.
309 Value bitcast = b.create<arith::BitcastOp>(args&: i32Ty, args&: operand);
310 // Read bit 16 as a value in {0,1}.
311 Value bit16 =
312 b.create<arith::AndIOp>(args: b.create<arith::ShRUIOp>(args&: bitcast, args&: c16), args&: c1);
313 // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
314 // on bit 16, implementing the tie-breaking "to nearest even".
315 Value roundingBias = b.create<arith::AddIOp>(args&: bit16, args&: c7FFF);
316 // Add the rounding bias. Generally we want this to be added to the
317 // mantissa, but nothing prevents this to from carrying into the exponent
318 // bits, which would feel like a bug, but this is the magic trick here:
319 // when that happens, the mantissa gets reset to zero and the exponent
320 // gets incremented by the carry... which is actually exactly what we
321 // want.
322 Value biased = b.create<arith::AddIOp>(args&: bitcast, args&: roundingBias);
323 // Now that the rounding-bias has been added, truncating the low bits
324 // yields the correctly rounded result.
325 Value biasedAndShifted = b.create<arith::ShRUIOp>(args&: biased, args&: c16);
326 Value normalCaseResultI16 =
327 b.create<arith::TruncIOp>(args&: i16Ty, args&: biasedAndShifted);
328 // Select either the above-computed result, or a quiet NaN constant
329 // if the input was NaN.
330 Value select =
331 b.create<arith::SelectOp>(args&: isNan, args&: c7FC0I16, args&: normalCaseResultI16);
332 Value result = b.create<arith::BitcastOp>(args&: resultTy, args&: select);
333 rewriter.replaceOp(op, newValues: result);
334 return success();
335 }
336};
337
338/// In this implementation of extf we take advantage of some key patterns we
339/// notice between the binary representation of an F4E2M1 value and its
340/// corresponding value in F32.
341///
342/// Note: x is sign bit
343/// | Binary | F4E2M1 | f32[23:32]
344/// | x000 | 0.0 | x000 0000 00
345/// | x001 | 0.5 | x011 1111 00
346/// | x010 | 1.0 | x011 1111 10
347/// | x011 | 1.5 | x011 1111 11
348/// | x100 | 2.0 | x010 0000 00
349/// | x101 | 3.0 | x010 0000 01
350/// | x110 | 4.0 | x010 0000 10
351/// | x111 | 6.0 | x010 0000 11
352///
353/// 1) There are only two versions of bits [25:31] in the f32 result
354/// F4E2M1 bits[2:3] decide whether:
355/// - F32 bits[25:31] = 0011 1111
356/// - F32 bits[25:31] = 0010 0000
357/// Exception is zero where
358/// - F32 bits[25:31] = 0000 0000
359///
360/// 2) F4E2M1 bits[1:2] = F32 bits[23:24]
361/// Exception is 0.5 where
362/// - F4E2M1 bits[1:2] = 01, F32 bits[23:24] = 00
363///
364/// 3) F4E2M1 bits[4] = F32 bits[32] (sign bits are equal)
365///
366/// 4) F32 bits[1:22] = 0
367struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
368 using OpRewritePattern::OpRewritePattern;
369 LogicalResult matchAndRewrite(arith::ExtFOp op,
370 PatternRewriter &rewriter) const final {
371 Location loc = op.getLoc();
372 ImplicitLocOpBuilder b(loc, rewriter);
373 Value operand = op.getOperand();
374 Type operandTy = operand.getType();
375 Type resultTy = op.getType();
376 Type operandETy = getElementTypeOrSelf(type: operandTy);
377 Type resultETy = getElementTypeOrSelf(type: resultTy);
378
379 if (!isa<Float4E2M1FNType>(Val: operandETy))
380 return rewriter.notifyMatchFailure(arg&: op, msg: "not a ext of F4E2M1FN");
381
382 Type f32Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getF32Type());
383 Type i4Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI4Type());
384 Type i32Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI32Type());
385 Value i4Bits = b.create<arith::BitcastOp>(args&: i4Ty, args&: operand);
386
387 Value c0x0 = createConst(loc, type: i4Ty, value: 0x0, rewriter);
388 Value c0x1 = createConst(loc, type: i4Ty, value: 0x1, rewriter);
389 Value c0x2 = createConst(loc, type: i4Ty, value: 0x2, rewriter);
390 Value c0x4 = createConst(loc, type: i4Ty, value: 0x4, rewriter);
391
392 // Set last Exponent bit and Mantissa.
393 Value c0x00000014 = createConst(loc, type: i32Ty, value: 0x14, rewriter);
394 Value bits1To24 = b.create<arith::ShLIOp>(args&: i4Bits, args&: c0x2);
395 Value isHalf =
396 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: i4Bits, args&: c0x1);
397 bits1To24 = b.create<arith::SelectOp>(args&: isHalf, args&: c0x0, args&: bits1To24);
398 bits1To24 = b.create<arith::ExtUIOp>(args&: i32Ty, args&: bits1To24);
399 bits1To24 = b.create<arith::ShLIOp>(args&: bits1To24, args&: c0x00000014);
400
401 // Set first 7 bits of Exponent.
402 Value zeroExpBits = createConst(loc, type: i32Ty, value: 0x00000000, rewriter);
403 Value highExpBits = createConst(loc, type: i32Ty, value: 0x40000000, rewriter);
404 Value lowExpBits = createConst(loc, type: i32Ty, value: 0x3f000000, rewriter);
405 Value useLargerExp =
406 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::uge, args&: i4Bits, args&: c0x4);
407 Value bits25To31 =
408 b.create<arith::SelectOp>(args&: useLargerExp, args&: highExpBits, args&: lowExpBits);
409 Value zeroExp =
410 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: i4Bits, args&: c0x0);
411 bits25To31 = b.create<arith::SelectOp>(args&: zeroExp, args&: zeroExpBits, args&: bits25To31);
412
413 // Set sign.
414 Value c0x80000000 = createConst(loc, type: i32Ty, value: 0x80000000, rewriter);
415 Value c0x8 = createConst(loc, type: i4Ty, value: 0x8, rewriter);
416 Value negative =
417 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::uge, args&: i4Bits, args&: c0x8);
418 Value bit32 = b.create<arith::SelectOp>(args&: negative, args&: c0x80000000, args&: zeroExpBits);
419
420 // Add segments together.
421 Value bits1To31 = b.create<arith::AddIOp>(args&: bits1To24, args&: bits25To31);
422 Value bits1To32 = b.create<arith::AddIOp>(args&: bits1To31, args&: bit32);
423 Value result = b.create<arith::BitcastOp>(args&: f32Ty, args&: bits1To32);
424 if (!isa<Float32Type>(Val: resultETy))
425 result = b.create<arith::TruncFOp>(args&: resultTy, args&: result);
426
427 rewriter.replaceOp(op, newValues: result);
428 return success();
429 }
430};
431
432struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
433 using OpRewritePattern::OpRewritePattern;
434 LogicalResult matchAndRewrite(arith::ExtFOp op,
435 PatternRewriter &rewriter) const final {
436 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
437 Value operand = op.getOperand();
438 Type operandTy = operand.getType();
439 Type resultTy = op.getType();
440 Type operandETy = getElementTypeOrSelf(type: operandTy);
441 Type resultETy = getElementTypeOrSelf(type: resultTy);
442
443 if (!llvm::isa<Float8E8M0FNUType>(Val: operandETy)) {
444 return rewriter.notifyMatchFailure(arg&: op, msg: "not a ext of F8E8M0FNU");
445 }
446
447 Type i8Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI8Type());
448 Type i32Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI32Type());
449 Type f32Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getF32Type());
450
451 Value bitcast = b.create<arith::BitcastOp>(args&: i8Ty, args&: operand);
452 // create constants for NaNs
453 Value cF8NaN = createConst(loc: op.getLoc(), type: i8Ty, value: 0xff, rewriter);
454 Value cF32NaN = createConst(loc: op.getLoc(), type: i32Ty, value: 0xffffffff, rewriter);
455 Value cF32MantissaWidth = createConst(loc: op->getLoc(), type: i32Ty, value: 23, rewriter);
456
457 Value exti = b.create<arith::ExtUIOp>(args&: i32Ty, args&: bitcast);
458 Value f32Bits = b.create<arith::ShLIOp>(args&: exti, args&: cF32MantissaWidth);
459
460 Value isNan =
461 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: bitcast, args&: cF8NaN);
462 // select for NaNs
463 f32Bits = b.create<arith::SelectOp>(args&: isNan, args&: cF32NaN, args&: f32Bits);
464 Value result = b.create<arith::BitcastOp>(args&: f32Ty, args&: f32Bits);
465 if (resultETy.getIntOrFloatBitWidth() < 32) {
466 result = b.create<arith::TruncFOp>(args&: resultTy, args&: result, args: nullptr,
467 args: op.getFastmathAttr());
468 } else if (resultETy.getIntOrFloatBitWidth() > 32) {
469 result = b.create<arith::ExtFOp>(args&: resultTy, args&: result, args: op.getFastmathAttr());
470 }
471 rewriter.replaceOp(op, newValues: result);
472 return success();
473 }
474};
475
476/// Conversion from F32 to F4E2M1 according to the OCP Spec:
477/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
478///
479/// The spec requiers us to perform Round to Nearest, Ties to Even.
480///
481/// This means that after rounding, we should break ties by choosing the option
482/// which results in a mantissa of 0 in the least significant digit.
483///
484/// Table of representable values in F4E2M1:
485///
486/// Note: x is sign bit
487/// | Binary | F4E2M1 | F32[23:32]
488/// | x000 | 0.0 | x000 0000 00
489/// | x001 | 0.5 | x011 1111 00
490/// | x010 | 1.0 | x011 1111 10
491/// | x011 | 1.5 | x011 1111 11
492/// | x100 | 2.0 | x010 0000 00
493/// | x101 | 3.0 | x010 0000 01
494/// | x110 | 4.0 | x010 0000 10
495/// | x111 | 6.0 | x010 0000 11
496///
497/// Conversion procedure:
498/// Step 1: Clamp to representable bounds.
499/// Step 2: Convert exponent by adjusting bias.
500/// Step 3: Set mantissa to first bit.
501/// Step 4: Special consideration for subnormal and zero exponent.
502/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
503/// subnormal.
504struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
505 using OpRewritePattern::OpRewritePattern;
506 LogicalResult matchAndRewrite(arith::TruncFOp op,
507 PatternRewriter &rewriter) const final {
508 Location loc = op.getLoc();
509 ImplicitLocOpBuilder b(loc, rewriter);
510 Value operand = op.getOperand();
511 Type operandTy = operand.getType();
512 Type resultTy = op.getType();
513 Type operandETy = getElementTypeOrSelf(type: operandTy);
514 Type resultETy = getElementTypeOrSelf(type: resultTy);
515
516 Type i4Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI4Type());
517 Type i8Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI8Type());
518 Type i32Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI32Type());
519 Type f32Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getF32Type());
520
521 if (!isa<Float32Type>(Val: operandETy))
522 operand = b.create<arith::ExtFOp>(args&: f32Ty, args&: operand);
523 if (!isa<Float4E2M1FNType>(Val: resultETy))
524 return rewriter.notifyMatchFailure(arg&: op, msg: "not a trunc of F4E2M1FN");
525
526 Value c0x1 = createConst(loc, type: i4Ty, value: 1, rewriter);
527 Value c0x3 = createConst(loc, type: i4Ty, value: 3, rewriter);
528 Value c0x00000016 = createConst(loc, type: i32Ty, value: 22, rewriter);
529 Value c0x00 = createConst(loc, type: i8Ty, value: 0x00, rewriter);
530 Value c0xff = createConst(loc, type: i8Ty, value: 0xff, rewriter);
531 Value zeroExpBits = createConst(loc, type: i32Ty, value: 0, rewriter);
532
533 // Step 0: Clamp to bounds.
534 Value cHigherBound = createFloatConst(loc, type: f32Ty, value: APFloat(6.0f), rewriter);
535 Value cLowerBound = createFloatConst(loc, type: f32Ty, value: APFloat(-6.0f), rewriter);
536 Value operandClamped = b.create<arith::MinNumFOp>(args&: cHigherBound, args&: operand);
537 operandClamped = b.create<arith::MaxNumFOp>(args&: cLowerBound, args&: operandClamped);
538 Value f32Bits = b.create<arith::BitcastOp>(args&: i32Ty, args&: operandClamped);
539
540 // Step 1: Set sign bit.
541 Value cF32ExpManWidth = createConst(loc, type: i32Ty, value: 31, rewriter); // 23
542 Value f32Sign = b.create<arith::ShRUIOp>(args&: f32Bits, args&: cF32ExpManWidth);
543 Value f4Sign = b.create<arith::TruncIOp>(args&: i4Ty, args&: f32Sign);
544 Value f4Bits = b.create<arith::ShLIOp>(args&: f4Sign, args&: c0x3);
545
546 // Step 2: Convert exponent by adjusting bias.
547 Value biasAdjustment = createConst(loc, type: i32Ty, value: 0x7e, rewriter);
548 Value cF4MantissaWidth = c0x1; // 1
549 Value cF32MantissaWidth = createConst(loc, type: i32Ty, value: 23, rewriter); // 23
550 Value f32SignExp = b.create<arith::ShRUIOp>(args&: f32Bits, args&: cF32MantissaWidth);
551 Value biasAdjustedSignExp =
552 b.create<arith::SubIOp>(args&: f32SignExp, args&: biasAdjustment);
553 Value f4Exp = b.create<arith::TruncIOp>(args&: i4Ty, args&: biasAdjustedSignExp);
554 f4Exp = b.create<arith::ShLIOp>(args&: f4Exp, args&: cF4MantissaWidth);
555 f4Bits = b.create<arith::AddIOp>(args&: f4Bits, args&: f4Exp);
556
557 // Step 3: Set mantissa to first bit.
558 Value cF32FirstBitMask = createConst(loc, type: i32Ty, value: 0x400000, rewriter);
559 Value man1Bit = b.create<arith::AndIOp>(args&: f32Bits, args&: cF32FirstBitMask);
560 man1Bit = b.create<arith::ShRUIOp>(args&: man1Bit, args&: c0x00000016);
561 Value f4Man = b.create<arith::TruncIOp>(args&: i4Ty, args&: man1Bit);
562 f4Bits = b.create<arith::AddIOp>(args&: f4Bits, args&: f4Man);
563
564 // Step 4: Special consideration for conversion to 0.5.
565 Value cF32MantissaMask = createConst(loc, type: i32Ty, value: 0x7fffff, rewriter);
566 Value f8Exp = b.create<arith::TruncIOp>(args&: i8Ty, args&: biasAdjustedSignExp);
567 Value isSubnormal =
568 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::sle, args&: f8Exp, args&: c0x00);
569 Value isNegOneExp =
570 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: f8Exp, args&: c0xff);
571 Value man23Bits = b.create<arith::AndIOp>(args&: f32Bits, args&: cF32MantissaMask);
572 Value isNonZeroMan = b.create<arith::CmpIOp>(args: arith::CmpIPredicate::ugt,
573 args&: man23Bits, args&: zeroExpBits);
574 Value roundToHalf = b.create<arith::AndIOp>(args&: isNegOneExp, args&: isNonZeroMan);
575 Value isZeroExp =
576 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::eq, args&: f8Exp, args&: c0x00);
577 Value subnormalF4Bits = createConst(loc, type: i4Ty, value: 0xf, rewriter);
578 Value halfF4Bits = createConst(loc, type: i4Ty, value: 0x0, rewriter);
579 Value subResult =
580 b.create<arith::SelectOp>(args&: isSubnormal, args&: subnormalF4Bits, args&: f4Bits);
581 subResult = b.create<arith::SelectOp>(args&: roundToHalf, args&: halfF4Bits, args&: subResult);
582 f4Bits = b.create<arith::SelectOp>(args&: isZeroExp, args&: f4Bits, args&: subResult);
583
584 // Step 5: Round up if necessary.
585 Value cF32Last22BitMask = createConst(loc, type: i32Ty, value: 0x3fffff, rewriter);
586 Value cRound = createConst(loc, type: i32Ty, value: 0x200000, rewriter); // 010 0000...
587 Value man22Bits = b.create<arith::AndIOp>(args&: f32Bits, args&: cF32Last22BitMask);
588 Value shouldRound =
589 b.create<arith::CmpIOp>(args: arith::CmpIPredicate::uge, args&: man22Bits, args&: cRound);
590 shouldRound = b.create<arith::OrIOp>(args&: shouldRound, args&: isSubnormal);
591 Value roundedF4Bits = b.create<arith::AddIOp>(args&: f4Bits, args&: c0x1);
592 f4Bits = b.create<arith::SelectOp>(args&: shouldRound, args&: roundedF4Bits, args&: f4Bits);
593
594 Value result = b.create<arith::BitcastOp>(args&: resultTy, args&: f4Bits);
595 rewriter.replaceOp(op, newValues: result);
596 return success();
597 }
598};
599
600/*
601TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
602Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
603they all map to NaN in F8E8M0 Type.
604*/
605struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
606 using OpRewritePattern::OpRewritePattern;
607 LogicalResult matchAndRewrite(arith::TruncFOp op,
608 PatternRewriter &rewriter) const final {
609 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
610 Value operand = op.getOperand();
611 Type operandTy = operand.getType();
612 Type operandETy = getElementTypeOrSelf(type: operandTy);
613 Type resultTy = op.getType();
614 Type resultETy = getElementTypeOrSelf(type: resultTy);
615 if (!llvm::isa<Float8E8M0FNUType>(Val: resultETy)) {
616 return rewriter.notifyMatchFailure(arg&: op, msg: "not a truncf to f8E8M0FNU");
617 }
618
619 if (op.getRoundingmodeAttr()) {
620 return rewriter.notifyMatchFailure(
621 arg&: op, msg: "only applicable to default rounding mode.");
622 }
623
624 Type i8Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI8Type());
625 Type i32Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getI32Type());
626 Type f32Ty = cloneToShapedType(cloneFrom: operandTy, cloneTo: b.getF32Type());
627
628 if (operandETy.getIntOrFloatBitWidth() < 32) {
629 operand = b.create<arith::ExtFOp>(args&: f32Ty, args&: operand, args: op.getFastmathAttr());
630 } else if (operandETy.getIntOrFloatBitWidth() > 32) {
631 operand = b.create<arith::TruncFOp>(
632 args&: f32Ty, args&: operand, args: op.getRoundingmodeAttr(), args: op.getFastmathAttr());
633 }
634 Value f32Bits = b.create<arith::BitcastOp>(args&: i32Ty, args&: operand);
635 Value cF32MantissaWidth = createConst(loc: op->getLoc(), type: i32Ty, value: 23, rewriter);
636 Value f32SignExp = b.create<arith::ShRUIOp>(args&: f32Bits, args&: cF32MantissaWidth);
637 Value exp8Bits = b.create<arith::TruncIOp>(args&: i8Ty, args&: f32SignExp);
638 Value result = b.create<arith::BitcastOp>(args&: resultTy, args&: exp8Bits);
639 rewriter.replaceOp(op, newValues: result);
640 return success();
641 }
642};
643
644struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
645 using OpRewritePattern::OpRewritePattern;
646 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
647 PatternRewriter &rewriter) const final {
648 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
649 Value inputOperand = op.getIn();
650 Value scaleOperand = op.getScale();
651 Type scaleTy = scaleOperand.getType();
652 Type scaleETy = getElementTypeOrSelf(val: scaleOperand);
653 // allow implicit exponent extraction from 16/32 bits floats
654 if (scaleETy.getIntOrFloatBitWidth() >= 16) {
655 scaleETy = b.getF8E8M0Type();
656 scaleTy = cloneToShapedType(cloneFrom: scaleTy, cloneTo: scaleETy);
657 scaleOperand = b.create<arith::TruncFOp>(args&: scaleTy, args&: scaleOperand, args: nullptr,
658 args: op.getFastmathAttr());
659 }
660 if (!llvm::isa<Float8E8M0FNUType>(Val: scaleETy)) {
661 return rewriter.notifyMatchFailure(
662 arg&: op, msg: "scaling_extf is using scales of type which can not be converted "
663 "to f8E8M0FNU");
664 }
665 Type resultTy = op.getType();
666 // extf on scale will essentially create floating point number
667 // of type resulTy that is 2^scale and will also propagate NaNs
668 Value scaleExt =
669 b.create<arith::ExtFOp>(args&: resultTy, args&: scaleOperand, args: op.getFastmathAttr());
670 Value inputExt =
671 b.create<arith::ExtFOp>(args&: resultTy, args&: inputOperand, args: op.getFastmathAttr());
672 Value result =
673 b.create<arith::MulFOp>(args&: inputExt, args&: scaleExt, args: op.getFastmathAttr());
674 rewriter.replaceOp(op, newValues: result);
675 return success();
676 }
677};
678
679/*
680Expands arith.ScalingTruncFOp(in, scale) into
681 scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
682 result = arith.truncf(in / (2^scale))
683 */
684struct ScalingTruncFOpConverter
685 : public OpRewritePattern<arith::ScalingTruncFOp> {
686 using OpRewritePattern::OpRewritePattern;
687 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
688 PatternRewriter &rewriter) const final {
689 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
690 Value inputOperand = op.getIn();
691 Value scaleOperand = op.getScale();
692 Type scaleTy = scaleOperand.getType();
693 Type scaleETy = getElementTypeOrSelf(val: scaleOperand);
694 // allow implicit exponent extraction from 16/32 bits floats
695 if (scaleETy.getIntOrFloatBitWidth() >= 16) {
696 scaleETy = b.getF8E8M0Type();
697 scaleTy = cloneToShapedType(cloneFrom: scaleTy, cloneTo: scaleETy);
698 scaleOperand = b.create<arith::TruncFOp>(args&: scaleTy, args&: scaleOperand, args: nullptr,
699 args: op.getFastmathAttr());
700 }
701 if (!llvm::isa<Float8E8M0FNUType>(Val: scaleETy)) {
702 return rewriter.notifyMatchFailure(
703 arg&: op, msg: "scaling_truncf is using scales type which can not be converted "
704 "to f8E8M0FNU");
705 }
706 Type resultTy = op.getType();
707 Type inputTy = inputOperand.getType();
708 // this will create a floating point number of type
709 // inputTy that is 2^scale and will also propagate NaNs
710 scaleOperand =
711 b.create<arith::ExtFOp>(args&: inputTy, args&: scaleOperand, args: op.getFastmathAttr());
712 Value result = b.create<arith::DivFOp>(args&: inputOperand, args&: scaleOperand,
713 args: op.getFastmathAttr());
714 Value resultCast = b.create<arith::TruncFOp>(
715 args&: resultTy, args&: result, args: op.getRoundingmodeAttr(), args: op.getFastmathAttr());
716 rewriter.replaceOp(op, newValues: resultCast);
717 return success();
718 }
719};
720
721struct ArithExpandOpsPass
722 : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
723 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
724
725 void runOnOperation() override {
726 RewritePatternSet patterns(&getContext());
727 ConversionTarget target(getContext());
728
729 arith::populateArithExpandOpsPatterns(patterns);
730
731 target.addLegalDialect<arith::ArithDialect>();
732 target.addLegalDialect<vector::VectorDialect>();
733
734 // clang-format off
735 target.addIllegalOp<
736 arith::CeilDivSIOp,
737 arith::CeilDivUIOp,
738 arith::FloorDivSIOp,
739 arith::MaxSIOp,
740 arith::MaxUIOp,
741 arith::MinSIOp,
742 arith::MinUIOp,
743 arith::MaximumFOp,
744 arith::MinimumFOp,
745 arith::MaxNumFOp,
746 arith::MinNumFOp,
747 arith::ScalingExtFOp,
748 arith::ScalingTruncFOp
749 >();
750
751 if (includeBf16)
752 arith::populateExpandBFloat16Patterns(patterns);
753 if (includeF8E8M0)
754 arith::populateExpandF8E8M0Patterns(patterns);
755 if (includeF4E2M1)
756 arith::populateExpandF4E2M1Patterns(patterns);
757
758 target.addDynamicallyLegalOp<arith::ExtFOp>(
759 callback: [=](arith::ExtFOp op) {
760 Type inETy = getElementTypeOrSelf(type: op.getOperand().getType());
761 Type outETy = getElementTypeOrSelf(type: op.getType());
762 bool legalTypes = true;
763 if (includeBf16)
764 legalTypes &= !(inETy.isBF16() && outETy.isF32());
765 if (includeF8E8M0)
766 legalTypes &= !llvm::isa<Float8E8M0FNUType>(Val: inETy);
767 if (includeF4E2M1)
768 legalTypes &= !llvm::isa<Float4E2M1FNType>(Val: inETy);
769 return legalTypes;
770 });
771
772 target.addDynamicallyLegalOp<arith::TruncFOp>(
773 callback: [=](arith::TruncFOp op) {
774 Type inETy = getElementTypeOrSelf(type: op.getOperand().getType());
775 Type outETy = getElementTypeOrSelf(type: op.getType());
776 bool legalTypes = true;
777 if (includeBf16)
778 legalTypes &= !(inETy.isF32() && outETy.isBF16());
779 if (includeF8E8M0)
780 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(Val: outETy));
781 if (includeF4E2M1)
782 legalTypes &= !llvm::isa<Float4E2M1FNType>(Val: outETy);
783 return legalTypes;
784 });
785
786 // clang-format on
787 if (failed(Result: applyPartialConversion(op: getOperation(), target,
788 patterns: std::move(patterns))))
789 signalPassFailure();
790 }
791};
792
793} // namespace
794
795void mlir::arith::populateCeilFloorDivExpandOpsPatterns(
796 RewritePatternSet &patterns) {
797 patterns
798 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
799 arg: patterns.getContext());
800}
801
802void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
803 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
804 arg: patterns.getContext());
805}
806
807void mlir::arith::populateExpandF4E2M1Patterns(RewritePatternSet &patterns) {
808 patterns.add<F4E2M1ExtFOpConverter, F4E2M1TruncFOpConverter>(
809 arg: patterns.getContext());
810}
811
812void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
813 patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
814 arg: patterns.getContext());
815}
816
817void mlir::arith::populateExpandScalingExtTruncPatterns(
818 RewritePatternSet &patterns) {
819 patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
820 arg: patterns.getContext());
821}
822
823void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
824 populateCeilFloorDivExpandOpsPatterns(patterns);
825 populateExpandScalingExtTruncPatterns(patterns);
826 // clang-format off
827 patterns.add<
828 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
829 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
830 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
831 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
832 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
833 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
834 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
835 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
836 >(arg: patterns.getContext());
837 // clang-format on
838}
839

source code of mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp