1//===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/Arith.h"
10#include "mlir/Dialect/Arith/Transforms/Passes.h"
11#include "mlir/Dialect/Vector/IR/VectorOps.h"
12#include "mlir/IR/BuiltinTypeInterfaces.h"
13#include "mlir/IR/ImplicitLocOpBuilder.h"
14#include "mlir/IR/TypeUtilities.h"
15#include "mlir/Transforms/DialectConversion.h"
16
17namespace mlir {
18namespace arith {
19#define GEN_PASS_DEF_ARITHEXPANDOPSPASS
20#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21} // namespace arith
22} // namespace mlir
23
24using namespace mlir;
25
26/// Create an integer or index constant.
27static Value createConst(Location loc, Type type, int value,
28 PatternRewriter &rewriter) {
29 auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31 return rewriter.create<arith::ConstantOp>(
32 loc, DenseElementsAttr::get(shapedTy, attr));
33 }
34 return rewriter.create<arith::ConstantOp>(loc, attr);
35}
36
37/// Creates shapedType using shape from cloneFrom and base type from cloneTo
38static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
39 if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
40 return shapedTy.clone(cloneTo);
41 }
42 return cloneTo;
43}
44
45namespace {
46
47/// Expands CeilDivUIOp (n, m) into
48/// n == 0 ? 0 : ((n-1) / m) + 1
49struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
50 using OpRewritePattern::OpRewritePattern;
51 LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
52 PatternRewriter &rewriter) const final {
53 Location loc = op.getLoc();
54 Value a = op.getLhs();
55 Value b = op.getRhs();
56 Value zero = createConst(loc, type: a.getType(), value: 0, rewriter);
57 Value compare =
58 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
59 Value one = createConst(loc, type: a.getType(), value: 1, rewriter);
60 Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
61 Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
62 Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
63 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
64 return success();
65 }
66};
67
68/// Expands CeilDivSIOp (a, b) into
69/// z = a / b
70/// if (z * b != a && (a < 0) == (b < 0)) {
71/// return z + 1;
72/// } else {
73/// return z;
74/// }
75struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
76 using OpRewritePattern::OpRewritePattern;
77 LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
78 PatternRewriter &rewriter) const final {
79 Location loc = op.getLoc();
80 Type type = op.getType();
81 Value a = op.getLhs();
82 Value b = op.getRhs();
83
84 Value zero = createConst(loc, type, value: 0, rewriter);
85 Value one = createConst(loc, type, value: 1, rewriter);
86
87 Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
88 Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
89 Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
90 loc, arith::CmpIPredicate::ne, a, product);
91
92 Value aNeg =
93 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
94 Value bNeg =
95 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
96
97 Value signEqual = rewriter.create<arith::CmpIOp>(
98 loc, arith::CmpIPredicate::eq, aNeg, bNeg);
99 Value cond =
100 rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual);
101
102 Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
103
104 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
105 quotient);
106 return success();
107 }
108};
109
110/// Expands FloorDivSIOp (x, y) into
111/// z = x / y
112/// if (z * y != x && (x < 0) != (y < 0)) {
113/// return z - 1;
114/// } else {
115/// return z;
116/// }
117struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
118 using OpRewritePattern::OpRewritePattern;
119 LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
120 PatternRewriter &rewriter) const final {
121 Location loc = op.getLoc();
122 Type type = op.getType();
123 Value a = op.getLhs();
124 Value b = op.getRhs();
125
126 Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
127 Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
128 Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
129 loc, arith::CmpIPredicate::ne, a, product);
130 Value zero = createConst(loc, type, value: 0, rewriter);
131
132 Value aNeg =
133 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
134 Value bNeg =
135 rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
136
137 Value signOpposite = rewriter.create<arith::CmpIOp>(
138 loc, arith::CmpIPredicate::ne, aNeg, bNeg);
139 Value cond =
140 rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
141
142 Value minusOne = createConst(loc, type, value: -1, rewriter);
143 Value quotientMinusOne =
144 rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
145
146 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
147 quotient);
148 return success();
149 }
150};
151
152template <typename OpTy, arith::CmpIPredicate pred>
153struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
154public:
155 using OpRewritePattern<OpTy>::OpRewritePattern;
156
157 LogicalResult matchAndRewrite(OpTy op,
158 PatternRewriter &rewriter) const final {
159 Value lhs = op.getLhs();
160 Value rhs = op.getRhs();
161
162 Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs);
163 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs);
164 return success();
165 }
166};
167
168template <typename OpTy, arith::CmpFPredicate pred>
169struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
170public:
171 using OpRewritePattern<OpTy>::OpRewritePattern;
172
173 LogicalResult matchAndRewrite(OpTy op,
174 PatternRewriter &rewriter) const final {
175 Value lhs = op.getLhs();
176 Value rhs = op.getRhs();
177
178 Location loc = op.getLoc();
179 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
180 static_assert(pred == arith::CmpFPredicate::UGT ||
181 pred == arith::CmpFPredicate::ULT,
182 "pred must be either UGT or ULT");
183 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
184 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
185
186 // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
187 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
188 rhs, rhs);
189 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
190 return success();
191 }
192};
193
194template <typename OpTy, arith::CmpFPredicate pred>
195struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
196public:
197 using OpRewritePattern<OpTy>::OpRewritePattern;
198
199 LogicalResult matchAndRewrite(OpTy op,
200 PatternRewriter &rewriter) const final {
201 Value lhs = op.getLhs();
202 Value rhs = op.getRhs();
203
204 Location loc = op.getLoc();
205 // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
206 static_assert(pred == arith::CmpFPredicate::UGT ||
207 pred == arith::CmpFPredicate::ULT,
208 "pred must be either UGT or ULT");
209 Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
210 Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
211
212 // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
213 Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
214 lhs, lhs);
215 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
216 return success();
217 }
218};
219
220struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
221 using OpRewritePattern::OpRewritePattern;
222 LogicalResult matchAndRewrite(arith::ExtFOp op,
223 PatternRewriter &rewriter) const final {
224 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
225 auto operand = op.getOperand();
226 Type operandTy = operand.getType();
227 Type resultTy = op.getType();
228 Type operandETy = getElementTypeOrSelf(type: operandTy);
229 Type resultETy = getElementTypeOrSelf(type: resultTy);
230
231 if (!operandETy.isBF16() || !resultETy.isF32()) {
232 return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
233 }
234
235 Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
236 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
237
238 Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
239 Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
240
241 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
242 Value shl = b.create<arith::ShLIOp>(exti, c16);
243 Value result = b.create<arith::BitcastOp>(resultTy, shl);
244
245 rewriter.replaceOp(op, result);
246 return success();
247 }
248};
249
250struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
251 using OpRewritePattern::OpRewritePattern;
252 LogicalResult matchAndRewrite(arith::TruncFOp op,
253 PatternRewriter &rewriter) const final {
254 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
255 auto operand = op.getOperand();
256 Type operandTy = operand.getType();
257 Type resultTy = op.getType();
258 Type operandETy = getElementTypeOrSelf(type: operandTy);
259 Type resultETy = getElementTypeOrSelf(type: resultTy);
260
261 if (!operandETy.isF32() || !resultETy.isBF16()) {
262 return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
263 }
264
265 if (op.getRoundingmodeAttr()) {
266 return rewriter.notifyMatchFailure(
267 op, "only applicable to default rounding mode.");
268 }
269
270 Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
271 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
272
273 // Algorithm borrowed from this excellent code:
274 // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
275 // There is a magic idea there, to let the addition of the rounding_bias to
276 // the mantissa simply overflow into the exponent bits. It's a bit of an
277 // aggressive, obfuscating optimization, but it is well-tested code, and it
278 // results in more concise and efficient IR.
279 // The case of NaN is handled separately (see isNaN and the final select).
280 // The case of infinities is NOT handled separately, which deserves an
281 // explanation. As the encoding of infinities has zero mantissa, the
282 // rounding-bias addition never carries into the exponent so that just gets
283 // truncated away, and as bfloat16 and float32 have the same number of
284 // exponent bits, that simple truncation is the desired outcome for
285 // infinities.
286 Value isNan =
287 b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
288 // Constant used to make the rounding bias.
289 Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
290 // Constant used to generate a quiet NaN.
291 Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
292 // Small constants used to address bits.
293 Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
294 Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
295 // Reinterpret the input f32 value as bits.
296 Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
297 // Read bit 16 as a value in {0,1}.
298 Value bit16 =
299 b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
300 // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
301 // on bit 16, implementing the tie-breaking "to nearest even".
302 Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
303 // Add the rounding bias. Generally we want this to be added to the
304 // mantissa, but nothing prevents this to from carrying into the exponent
305 // bits, which would feel like a bug, but this is the magic trick here:
306 // when that happens, the mantissa gets reset to zero and the exponent
307 // gets incremented by the carry... which is actually exactly what we
308 // want.
309 Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
310 // Now that the rounding-bias has been added, truncating the low bits
311 // yields the correctly rounded result.
312 Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
313 Value normalCaseResultI16 =
314 b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
315 // Select either the above-computed result, or a quiet NaN constant
316 // if the input was NaN.
317 Value select =
318 b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
319 Value result = b.create<arith::BitcastOp>(resultTy, select);
320 rewriter.replaceOp(op, result);
321 return success();
322 }
323};
324
325struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
326 using OpRewritePattern::OpRewritePattern;
327 LogicalResult matchAndRewrite(arith::ExtFOp op,
328 PatternRewriter &rewriter) const final {
329 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
330 Value operand = op.getOperand();
331 Type operandTy = operand.getType();
332 Type resultTy = op.getType();
333 Type operandETy = getElementTypeOrSelf(type: operandTy);
334 Type resultETy = getElementTypeOrSelf(type: resultTy);
335
336 if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
337 return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
338 }
339
340 Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
341 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
342 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
343
344 Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
345 // create constants for NaNs
346 Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
347 Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
348 Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
349
350 Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
351 Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
352
353 Value isNan =
354 b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
355 // select for NaNs
356 f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
357 Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
358 if (resultETy.getIntOrFloatBitWidth() < 32) {
359 result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
360 op.getFastmathAttr());
361 } else if (resultETy.getIntOrFloatBitWidth() > 32) {
362 result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
363 }
364 rewriter.replaceOp(op, result);
365 return success();
366 }
367};
368
369/*
370TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
371Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
372they all map to NaN in F8E8M0 Type.
373*/
374struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
375 using OpRewritePattern::OpRewritePattern;
376 LogicalResult matchAndRewrite(arith::TruncFOp op,
377 PatternRewriter &rewriter) const final {
378 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
379 Value operand = op.getOperand();
380 Type operandTy = operand.getType();
381 Type operandETy = getElementTypeOrSelf(type: operandTy);
382 Type resultTy = op.getType();
383 Type resultETy = getElementTypeOrSelf(type: resultTy);
384 if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
385 return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
386 }
387
388 if (op.getRoundingmodeAttr()) {
389 return rewriter.notifyMatchFailure(
390 op, "only applicable to default rounding mode.");
391 }
392
393 Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
394 Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
395 Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
396
397 if (operandETy.getIntOrFloatBitWidth() < 32) {
398 operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
399 } else if (operandETy.getIntOrFloatBitWidth() > 32) {
400 operand = b.create<arith::TruncFOp>(
401 f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
402 }
403 Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
404 Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
405 Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
406 Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
407 Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
408 rewriter.replaceOp(op, result);
409 return success();
410 }
411};
412
413struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
414 using OpRewritePattern::OpRewritePattern;
415 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
416 PatternRewriter &rewriter) const final {
417 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
418 Value inputOperand = op.getIn();
419 Value scaleOperand = op.getScale();
420 Type scaleTy = scaleOperand.getType();
421 Type scaleETy = getElementTypeOrSelf(val: scaleOperand);
422 // allow implicit exponent extraction from 16/32 bits floats
423 if (scaleETy.getIntOrFloatBitWidth() >= 16) {
424 scaleETy = b.getF8E8M0Type();
425 scaleTy = cloneToShapedType(cloneFrom: scaleTy, cloneTo: scaleETy);
426 scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
427 op.getFastmathAttr());
428 }
429 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
430 return rewriter.notifyMatchFailure(
431 op, "scaling_extf is using scales of type which can not be converted "
432 "to f8E8M0FNU");
433 }
434 Type resultTy = op.getType();
435 // extf on scale will essentially create floating point number
436 // of type resulTy that is 2^scale and will also propagate NaNs
437 Value scaleExt =
438 b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
439 Value inputExt =
440 b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
441 Value result =
442 b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
443 rewriter.replaceOp(op, result);
444 return success();
445 }
446};
447
448/*
449Expands arith.ScalingTruncFOp(in, scale) into
450 scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
451 result = arith.truncf(in / (2^scale))
452 */
453struct ScalingTruncFOpConverter
454 : public OpRewritePattern<arith::ScalingTruncFOp> {
455 using OpRewritePattern::OpRewritePattern;
456 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
457 PatternRewriter &rewriter) const final {
458 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
459 Value inputOperand = op.getIn();
460 Value scaleOperand = op.getScale();
461 Type scaleTy = scaleOperand.getType();
462 Type scaleETy = getElementTypeOrSelf(val: scaleOperand);
463 // allow implicit exponent extraction from 16/32 bits floats
464 if (scaleETy.getIntOrFloatBitWidth() >= 16) {
465 scaleETy = b.getF8E8M0Type();
466 scaleTy = cloneToShapedType(cloneFrom: scaleTy, cloneTo: scaleETy);
467 scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
468 op.getFastmathAttr());
469 }
470 if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
471 return rewriter.notifyMatchFailure(
472 op, "scaling_truncf is using scales type which can not be converted "
473 "to f8E8M0FNU");
474 }
475 Type resultTy = op.getType();
476 Type inputTy = inputOperand.getType();
477 // this will create a floating point number of type
478 // inputTy that is 2^scale and will also propagate NaNs
479 scaleOperand =
480 b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
481 Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
482 op.getFastmathAttr());
483 Value resultCast = b.create<arith::TruncFOp>(
484 resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
485 rewriter.replaceOp(op, resultCast);
486 return success();
487 }
488};
489
490struct ArithExpandOpsPass
491 : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
492 using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
493
494 void runOnOperation() override {
495 RewritePatternSet patterns(&getContext());
496 ConversionTarget target(getContext());
497
498 arith::populateArithExpandOpsPatterns(patterns);
499
500 target.addLegalDialect<arith::ArithDialect>();
501 // clang-format off
502 target.addIllegalOp<
503 arith::CeilDivSIOp,
504 arith::CeilDivUIOp,
505 arith::FloorDivSIOp,
506 arith::MaxSIOp,
507 arith::MaxUIOp,
508 arith::MinSIOp,
509 arith::MinUIOp,
510 arith::MaximumFOp,
511 arith::MinimumFOp,
512 arith::MaxNumFOp,
513 arith::MinNumFOp,
514 arith::ScalingExtFOp,
515 arith::ScalingTruncFOp
516 >();
517
518 if (includeBf16) {
519 arith::populateExpandBFloat16Patterns(patterns);
520 }
521 if (includeF8E8M0) {
522 arith::populateExpandF8E8M0Patterns(patterns);
523 }
524
525 target.addDynamicallyLegalOp<arith::ExtFOp>(
526 [=](arith::ExtFOp op) {
527 Type inETy = getElementTypeOrSelf(op.getOperand().getType());
528 Type outETy = getElementTypeOrSelf(op.getType());
529 bool legalTypes = true;
530 if (includeBf16)
531 legalTypes &= !(inETy.isBF16() && outETy.isF32());
532 if (includeF8E8M0)
533 legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
534 return legalTypes;
535 });
536
537 target.addDynamicallyLegalOp<arith::TruncFOp>(
538 [=](arith::TruncFOp op) {
539 Type inETy = getElementTypeOrSelf(op.getOperand().getType());
540 Type outETy = getElementTypeOrSelf(op.getType());
541 bool legalTypes = true;
542 if (includeBf16)
543 legalTypes &= !(inETy.isF32() && outETy.isBF16());
544 if (includeF8E8M0)
545 legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
546 return legalTypes;
547 });
548
549 // clang-format on
550 if (failed(applyPartialConversion(getOperation(), target,
551 std::move(patterns))))
552 signalPassFailure();
553 }
554};
555
556} // namespace
557
558void mlir::arith::populateCeilFloorDivExpandOpsPatterns(
559 RewritePatternSet &patterns) {
560 patterns
561 .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
562 arg: patterns.getContext());
563}
564
565void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
566 patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
567 arg: patterns.getContext());
568}
569
570void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
571 patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
572 arg: patterns.getContext());
573}
574
575void mlir::arith::populateExpandScalingExtTruncPatterns(
576 RewritePatternSet &patterns) {
577 patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
578 arg: patterns.getContext());
579}
580
581void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
582 populateCeilFloorDivExpandOpsPatterns(patterns);
583 populateExpandScalingExtTruncPatterns(patterns);
584 // clang-format off
585 patterns.add<
586 MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
587 MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>,
588 MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>,
589 MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>,
590 MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
591 MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
592 MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
593 MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
594 >(patterns.getContext());
595 // clang-format on
596}
597

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp