1 | //===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/IR/Arith.h" |
10 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
11 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
12 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
13 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
14 | #include "mlir/IR/TypeUtilities.h" |
15 | #include "mlir/Transforms/DialectConversion.h" |
16 | |
17 | namespace mlir { |
18 | namespace arith { |
19 | #define GEN_PASS_DEF_ARITHEXPANDOPSPASS |
20 | #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
21 | } // namespace arith |
22 | } // namespace mlir |
23 | |
24 | using namespace mlir; |
25 | |
26 | /// Create an integer or index constant. |
27 | static Value createConst(Location loc, Type type, int value, |
28 | PatternRewriter &rewriter) { |
29 | auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value); |
30 | if (auto shapedTy = dyn_cast<ShapedType>(type)) { |
31 | return rewriter.create<arith::ConstantOp>( |
32 | loc, DenseElementsAttr::get(shapedTy, attr)); |
33 | } |
34 | return rewriter.create<arith::ConstantOp>(loc, attr); |
35 | } |
36 | |
37 | /// Creates shapedType using shape from cloneFrom and base type from cloneTo |
38 | static Type cloneToShapedType(Type cloneFrom, Type cloneTo) { |
39 | if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) { |
40 | return shapedTy.clone(cloneTo); |
41 | } |
42 | return cloneTo; |
43 | } |
44 | |
45 | namespace { |
46 | |
47 | /// Expands CeilDivUIOp (n, m) into |
48 | /// n == 0 ? 0 : ((n-1) / m) + 1 |
49 | struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> { |
50 | using OpRewritePattern::OpRewritePattern; |
51 | LogicalResult matchAndRewrite(arith::CeilDivUIOp op, |
52 | PatternRewriter &rewriter) const final { |
53 | Location loc = op.getLoc(); |
54 | Value a = op.getLhs(); |
55 | Value b = op.getRhs(); |
56 | Value zero = createConst(loc, type: a.getType(), value: 0, rewriter); |
57 | Value compare = |
58 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero); |
59 | Value one = createConst(loc, type: a.getType(), value: 1, rewriter); |
60 | Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one); |
61 | Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b); |
62 | Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one); |
63 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne); |
64 | return success(); |
65 | } |
66 | }; |
67 | |
68 | /// Expands CeilDivSIOp (a, b) into |
69 | /// z = a / b |
70 | /// if (z * b != a && (a < 0) == (b < 0)) { |
71 | /// return z + 1; |
72 | /// } else { |
73 | /// return z; |
74 | /// } |
75 | struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> { |
76 | using OpRewritePattern::OpRewritePattern; |
77 | LogicalResult matchAndRewrite(arith::CeilDivSIOp op, |
78 | PatternRewriter &rewriter) const final { |
79 | Location loc = op.getLoc(); |
80 | Type type = op.getType(); |
81 | Value a = op.getLhs(); |
82 | Value b = op.getRhs(); |
83 | |
84 | Value zero = createConst(loc, type, value: 0, rewriter); |
85 | Value one = createConst(loc, type, value: 1, rewriter); |
86 | |
87 | Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b); |
88 | Value product = rewriter.create<arith::MulIOp>(loc, quotient, b); |
89 | Value notEqualDivisor = rewriter.create<arith::CmpIOp>( |
90 | loc, arith::CmpIPredicate::ne, a, product); |
91 | |
92 | Value aNeg = |
93 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); |
94 | Value bNeg = |
95 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); |
96 | |
97 | Value signEqual = rewriter.create<arith::CmpIOp>( |
98 | loc, arith::CmpIPredicate::eq, aNeg, bNeg); |
99 | Value cond = |
100 | rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual); |
101 | |
102 | Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one); |
103 | |
104 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne, |
105 | quotient); |
106 | return success(); |
107 | } |
108 | }; |
109 | |
110 | /// Expands FloorDivSIOp (x, y) into |
111 | /// z = x / y |
112 | /// if (z * y != x && (x < 0) != (y < 0)) { |
113 | /// return z - 1; |
114 | /// } else { |
115 | /// return z; |
116 | /// } |
117 | struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> { |
118 | using OpRewritePattern::OpRewritePattern; |
119 | LogicalResult matchAndRewrite(arith::FloorDivSIOp op, |
120 | PatternRewriter &rewriter) const final { |
121 | Location loc = op.getLoc(); |
122 | Type type = op.getType(); |
123 | Value a = op.getLhs(); |
124 | Value b = op.getRhs(); |
125 | |
126 | Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b); |
127 | Value product = rewriter.create<arith::MulIOp>(loc, quotient, b); |
128 | Value notEqualDivisor = rewriter.create<arith::CmpIOp>( |
129 | loc, arith::CmpIPredicate::ne, a, product); |
130 | Value zero = createConst(loc, type, value: 0, rewriter); |
131 | |
132 | Value aNeg = |
133 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); |
134 | Value bNeg = |
135 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); |
136 | |
137 | Value signOpposite = rewriter.create<arith::CmpIOp>( |
138 | loc, arith::CmpIPredicate::ne, aNeg, bNeg); |
139 | Value cond = |
140 | rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite); |
141 | |
142 | Value minusOne = createConst(loc, type, value: -1, rewriter); |
143 | Value quotientMinusOne = |
144 | rewriter.create<arith::AddIOp>(loc, quotient, minusOne); |
145 | |
146 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne, |
147 | quotient); |
148 | return success(); |
149 | } |
150 | }; |
151 | |
152 | template <typename OpTy, arith::CmpIPredicate pred> |
153 | struct MaxMinIOpConverter : public OpRewritePattern<OpTy> { |
154 | public: |
155 | using OpRewritePattern<OpTy>::OpRewritePattern; |
156 | |
157 | LogicalResult matchAndRewrite(OpTy op, |
158 | PatternRewriter &rewriter) const final { |
159 | Value lhs = op.getLhs(); |
160 | Value rhs = op.getRhs(); |
161 | |
162 | Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs); |
163 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs); |
164 | return success(); |
165 | } |
166 | }; |
167 | |
168 | template <typename OpTy, arith::CmpFPredicate pred> |
169 | struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> { |
170 | public: |
171 | using OpRewritePattern<OpTy>::OpRewritePattern; |
172 | |
173 | LogicalResult matchAndRewrite(OpTy op, |
174 | PatternRewriter &rewriter) const final { |
175 | Value lhs = op.getLhs(); |
176 | Value rhs = op.getRhs(); |
177 | |
178 | Location loc = op.getLoc(); |
179 | // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs'). |
180 | static_assert(pred == arith::CmpFPredicate::UGT || |
181 | pred == arith::CmpFPredicate::ULT, |
182 | "pred must be either UGT or ULT"); |
183 | Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs); |
184 | Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs); |
185 | |
186 | // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'. |
187 | Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO, |
188 | rhs, rhs); |
189 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select); |
190 | return success(); |
191 | } |
192 | }; |
193 | |
194 | template <typename OpTy, arith::CmpFPredicate pred> |
195 | struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> { |
196 | public: |
197 | using OpRewritePattern<OpTy>::OpRewritePattern; |
198 | |
199 | LogicalResult matchAndRewrite(OpTy op, |
200 | PatternRewriter &rewriter) const final { |
201 | Value lhs = op.getLhs(); |
202 | Value rhs = op.getRhs(); |
203 | |
204 | Location loc = op.getLoc(); |
205 | // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs'). |
206 | static_assert(pred == arith::CmpFPredicate::UGT || |
207 | pred == arith::CmpFPredicate::ULT, |
208 | "pred must be either UGT or ULT"); |
209 | Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs); |
210 | Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs); |
211 | |
212 | // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'. |
213 | Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO, |
214 | lhs, lhs); |
215 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select); |
216 | return success(); |
217 | } |
218 | }; |
219 | |
220 | struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> { |
221 | using OpRewritePattern::OpRewritePattern; |
222 | LogicalResult matchAndRewrite(arith::ExtFOp op, |
223 | PatternRewriter &rewriter) const final { |
224 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
225 | auto operand = op.getOperand(); |
226 | Type operandTy = operand.getType(); |
227 | Type resultTy = op.getType(); |
228 | Type operandETy = getElementTypeOrSelf(type: operandTy); |
229 | Type resultETy = getElementTypeOrSelf(type: resultTy); |
230 | |
231 | if (!operandETy.isBF16() || !resultETy.isF32()) { |
232 | return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32."); |
233 | } |
234 | |
235 | Type i16Ty = cloneToShapedType(operandTy, b.getI16Type()); |
236 | Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); |
237 | |
238 | Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand); |
239 | Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast); |
240 | |
241 | Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); |
242 | Value shl = b.create<arith::ShLIOp>(exti, c16); |
243 | Value result = b.create<arith::BitcastOp>(resultTy, shl); |
244 | |
245 | rewriter.replaceOp(op, result); |
246 | return success(); |
247 | } |
248 | }; |
249 | |
250 | struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { |
251 | using OpRewritePattern::OpRewritePattern; |
252 | LogicalResult matchAndRewrite(arith::TruncFOp op, |
253 | PatternRewriter &rewriter) const final { |
254 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
255 | auto operand = op.getOperand(); |
256 | Type operandTy = operand.getType(); |
257 | Type resultTy = op.getType(); |
258 | Type operandETy = getElementTypeOrSelf(type: operandTy); |
259 | Type resultETy = getElementTypeOrSelf(type: resultTy); |
260 | |
261 | if (!operandETy.isF32() || !resultETy.isBF16()) { |
262 | return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16."); |
263 | } |
264 | |
265 | if (op.getRoundingmodeAttr()) { |
266 | return rewriter.notifyMatchFailure( |
267 | op, "only applicable to default rounding mode."); |
268 | } |
269 | |
270 | Type i16Ty = cloneToShapedType(operandTy, b.getI16Type()); |
271 | Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); |
272 | |
273 | // Algorithm borrowed from this excellent code: |
274 | // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79 |
275 | // There is a magic idea there, to let the addition of the rounding_bias to |
276 | // the mantissa simply overflow into the exponent bits. It's a bit of an |
277 | // aggressive, obfuscating optimization, but it is well-tested code, and it |
278 | // results in more concise and efficient IR. |
279 | // The case of NaN is handled separately (see isNaN and the final select). |
280 | // The case of infinities is NOT handled separately, which deserves an |
281 | // explanation. As the encoding of infinities has zero mantissa, the |
282 | // rounding-bias addition never carries into the exponent so that just gets |
283 | // truncated away, and as bfloat16 and float32 have the same number of |
284 | // exponent bits, that simple truncation is the desired outcome for |
285 | // infinities. |
286 | Value isNan = |
287 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand); |
288 | // Constant used to make the rounding bias. |
289 | Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter); |
290 | // Constant used to generate a quiet NaN. |
291 | Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter); |
292 | // Small constants used to address bits. |
293 | Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); |
294 | Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter); |
295 | // Reinterpret the input f32 value as bits. |
296 | Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand); |
297 | // Read bit 16 as a value in {0,1}. |
298 | Value bit16 = |
299 | b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1); |
300 | // Determine the rounding bias to add as either 0x7fff or 0x8000 depending |
301 | // on bit 16, implementing the tie-breaking "to nearest even". |
302 | Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF); |
303 | // Add the rounding bias. Generally we want this to be added to the |
304 | // mantissa, but nothing prevents this to from carrying into the exponent |
305 | // bits, which would feel like a bug, but this is the magic trick here: |
306 | // when that happens, the mantissa gets reset to zero and the exponent |
307 | // gets incremented by the carry... which is actually exactly what we |
308 | // want. |
309 | Value biased = b.create<arith::AddIOp>(bitcast, roundingBias); |
310 | // Now that the rounding-bias has been added, truncating the low bits |
311 | // yields the correctly rounded result. |
312 | Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16); |
313 | Value normalCaseResultI16 = |
314 | b.create<arith::TruncIOp>(i16Ty, biasedAndShifted); |
315 | // Select either the above-computed result, or a quiet NaN constant |
316 | // if the input was NaN. |
317 | Value select = |
318 | b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16); |
319 | Value result = b.create<arith::BitcastOp>(resultTy, select); |
320 | rewriter.replaceOp(op, result); |
321 | return success(); |
322 | } |
323 | }; |
324 | |
325 | struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> { |
326 | using OpRewritePattern::OpRewritePattern; |
327 | LogicalResult matchAndRewrite(arith::ExtFOp op, |
328 | PatternRewriter &rewriter) const final { |
329 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
330 | Value operand = op.getOperand(); |
331 | Type operandTy = operand.getType(); |
332 | Type resultTy = op.getType(); |
333 | Type operandETy = getElementTypeOrSelf(type: operandTy); |
334 | Type resultETy = getElementTypeOrSelf(type: resultTy); |
335 | |
336 | if (!llvm::isa<Float8E8M0FNUType>(operandETy)) { |
337 | return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU"); |
338 | } |
339 | |
340 | Type i8Ty = cloneToShapedType(operandTy, b.getI8Type()); |
341 | Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); |
342 | Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); |
343 | |
344 | Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand); |
345 | // create constants for NaNs |
346 | Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter); |
347 | Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter); |
348 | Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); |
349 | |
350 | Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast); |
351 | Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth); |
352 | |
353 | Value isNan = |
354 | b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN); |
355 | // select for NaNs |
356 | f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits); |
357 | Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits); |
358 | if (resultETy.getIntOrFloatBitWidth() < 32) { |
359 | result = b.create<arith::TruncFOp>(resultTy, result, nullptr, |
360 | op.getFastmathAttr()); |
361 | } else if (resultETy.getIntOrFloatBitWidth() > 32) { |
362 | result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr()); |
363 | } |
364 | rewriter.replaceOp(op, result); |
365 | return success(); |
366 | } |
367 | }; |
368 | |
369 | /* |
370 | TruncF to F8E8M0 is expected to extract exponent bits out of F32 type |
371 | Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type, |
372 | they all map to NaN in F8E8M0 Type. |
373 | */ |
374 | struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { |
375 | using OpRewritePattern::OpRewritePattern; |
376 | LogicalResult matchAndRewrite(arith::TruncFOp op, |
377 | PatternRewriter &rewriter) const final { |
378 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
379 | Value operand = op.getOperand(); |
380 | Type operandTy = operand.getType(); |
381 | Type operandETy = getElementTypeOrSelf(type: operandTy); |
382 | Type resultTy = op.getType(); |
383 | Type resultETy = getElementTypeOrSelf(type: resultTy); |
384 | if (!llvm::isa<Float8E8M0FNUType>(resultETy)) { |
385 | return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU"); |
386 | } |
387 | |
388 | if (op.getRoundingmodeAttr()) { |
389 | return rewriter.notifyMatchFailure( |
390 | op, "only applicable to default rounding mode."); |
391 | } |
392 | |
393 | Type i8Ty = cloneToShapedType(operandTy, b.getI8Type()); |
394 | Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); |
395 | Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); |
396 | |
397 | if (operandETy.getIntOrFloatBitWidth() < 32) { |
398 | operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr()); |
399 | } else if (operandETy.getIntOrFloatBitWidth() > 32) { |
400 | operand = b.create<arith::TruncFOp>( |
401 | f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr()); |
402 | } |
403 | Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand); |
404 | Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); |
405 | Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth); |
406 | Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp); |
407 | Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits); |
408 | rewriter.replaceOp(op, result); |
409 | return success(); |
410 | } |
411 | }; |
412 | |
413 | struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> { |
414 | using OpRewritePattern::OpRewritePattern; |
415 | LogicalResult matchAndRewrite(arith::ScalingExtFOp op, |
416 | PatternRewriter &rewriter) const final { |
417 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
418 | Value inputOperand = op.getIn(); |
419 | Value scaleOperand = op.getScale(); |
420 | Type scaleTy = scaleOperand.getType(); |
421 | Type scaleETy = getElementTypeOrSelf(val: scaleOperand); |
422 | // allow implicit exponent extraction from 16/32 bits floats |
423 | if (scaleETy.getIntOrFloatBitWidth() >= 16) { |
424 | scaleETy = b.getF8E8M0Type(); |
425 | scaleTy = cloneToShapedType(cloneFrom: scaleTy, cloneTo: scaleETy); |
426 | scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr, |
427 | op.getFastmathAttr()); |
428 | } |
429 | if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) { |
430 | return rewriter.notifyMatchFailure( |
431 | op, "scaling_extf is using scales of type which can not be converted " |
432 | "to f8E8M0FNU"); |
433 | } |
434 | Type resultTy = op.getType(); |
435 | // extf on scale will essentially create floating point number |
436 | // of type resulTy that is 2^scale and will also propagate NaNs |
437 | Value scaleExt = |
438 | b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr()); |
439 | Value inputExt = |
440 | b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr()); |
441 | Value result = |
442 | b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr()); |
443 | rewriter.replaceOp(op, result); |
444 | return success(); |
445 | } |
446 | }; |
447 | |
448 | /* |
449 | Expands arith.ScalingTruncFOp(in, scale) into |
450 | scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU |
451 | result = arith.truncf(in / (2^scale)) |
452 | */ |
453 | struct ScalingTruncFOpConverter |
454 | : public OpRewritePattern<arith::ScalingTruncFOp> { |
455 | using OpRewritePattern::OpRewritePattern; |
456 | LogicalResult matchAndRewrite(arith::ScalingTruncFOp op, |
457 | PatternRewriter &rewriter) const final { |
458 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
459 | Value inputOperand = op.getIn(); |
460 | Value scaleOperand = op.getScale(); |
461 | Type scaleTy = scaleOperand.getType(); |
462 | Type scaleETy = getElementTypeOrSelf(val: scaleOperand); |
463 | // allow implicit exponent extraction from 16/32 bits floats |
464 | if (scaleETy.getIntOrFloatBitWidth() >= 16) { |
465 | scaleETy = b.getF8E8M0Type(); |
466 | scaleTy = cloneToShapedType(cloneFrom: scaleTy, cloneTo: scaleETy); |
467 | scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr, |
468 | op.getFastmathAttr()); |
469 | } |
470 | if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) { |
471 | return rewriter.notifyMatchFailure( |
472 | op, "scaling_truncf is using scales type which can not be converted " |
473 | "to f8E8M0FNU"); |
474 | } |
475 | Type resultTy = op.getType(); |
476 | Type inputTy = inputOperand.getType(); |
477 | // this will create a floating point number of type |
478 | // inputTy that is 2^scale and will also propagate NaNs |
479 | scaleOperand = |
480 | b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr()); |
481 | Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand, |
482 | op.getFastmathAttr()); |
483 | Value resultCast = b.create<arith::TruncFOp>( |
484 | resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); |
485 | rewriter.replaceOp(op, resultCast); |
486 | return success(); |
487 | } |
488 | }; |
489 | |
490 | struct ArithExpandOpsPass |
491 | : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> { |
492 | using ArithExpandOpsPassBase::ArithExpandOpsPassBase; |
493 | |
494 | void runOnOperation() override { |
495 | RewritePatternSet patterns(&getContext()); |
496 | ConversionTarget target(getContext()); |
497 | |
498 | arith::populateArithExpandOpsPatterns(patterns); |
499 | |
500 | target.addLegalDialect<arith::ArithDialect>(); |
501 | // clang-format off |
502 | target.addIllegalOp< |
503 | arith::CeilDivSIOp, |
504 | arith::CeilDivUIOp, |
505 | arith::FloorDivSIOp, |
506 | arith::MaxSIOp, |
507 | arith::MaxUIOp, |
508 | arith::MinSIOp, |
509 | arith::MinUIOp, |
510 | arith::MaximumFOp, |
511 | arith::MinimumFOp, |
512 | arith::MaxNumFOp, |
513 | arith::MinNumFOp, |
514 | arith::ScalingExtFOp, |
515 | arith::ScalingTruncFOp |
516 | >(); |
517 | |
518 | if (includeBf16) { |
519 | arith::populateExpandBFloat16Patterns(patterns); |
520 | } |
521 | if (includeF8E8M0) { |
522 | arith::populateExpandF8E8M0Patterns(patterns); |
523 | } |
524 | |
525 | target.addDynamicallyLegalOp<arith::ExtFOp>( |
526 | [=](arith::ExtFOp op) { |
527 | Type inETy = getElementTypeOrSelf(op.getOperand().getType()); |
528 | Type outETy = getElementTypeOrSelf(op.getType()); |
529 | bool legalTypes = true; |
530 | if (includeBf16) |
531 | legalTypes &= !(inETy.isBF16() && outETy.isF32()); |
532 | if (includeF8E8M0) |
533 | legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy); |
534 | return legalTypes; |
535 | }); |
536 | |
537 | target.addDynamicallyLegalOp<arith::TruncFOp>( |
538 | [=](arith::TruncFOp op) { |
539 | Type inETy = getElementTypeOrSelf(op.getOperand().getType()); |
540 | Type outETy = getElementTypeOrSelf(op.getType()); |
541 | bool legalTypes = true; |
542 | if (includeBf16) |
543 | legalTypes &= !(inETy.isF32() && outETy.isBF16()); |
544 | if (includeF8E8M0) |
545 | legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy)); |
546 | return legalTypes; |
547 | }); |
548 | |
549 | // clang-format on |
550 | if (failed(applyPartialConversion(getOperation(), target, |
551 | std::move(patterns)))) |
552 | signalPassFailure(); |
553 | } |
554 | }; |
555 | |
556 | } // namespace |
557 | |
558 | void mlir::arith::populateCeilFloorDivExpandOpsPatterns( |
559 | RewritePatternSet &patterns) { |
560 | patterns |
561 | .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>( |
562 | arg: patterns.getContext()); |
563 | } |
564 | |
565 | void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) { |
566 | patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>( |
567 | arg: patterns.getContext()); |
568 | } |
569 | |
570 | void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) { |
571 | patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>( |
572 | arg: patterns.getContext()); |
573 | } |
574 | |
575 | void mlir::arith::populateExpandScalingExtTruncPatterns( |
576 | RewritePatternSet &patterns) { |
577 | patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>( |
578 | arg: patterns.getContext()); |
579 | } |
580 | |
581 | void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { |
582 | populateCeilFloorDivExpandOpsPatterns(patterns); |
583 | populateExpandScalingExtTruncPatterns(patterns); |
584 | // clang-format off |
585 | patterns.add< |
586 | MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>, |
587 | MaxMinIOpConverter<MaxUIOp, arith::CmpIPredicate::ugt>, |
588 | MaxMinIOpConverter<MinSIOp, arith::CmpIPredicate::slt>, |
589 | MaxMinIOpConverter<MinUIOp, arith::CmpIPredicate::ult>, |
590 | MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>, |
591 | MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>, |
592 | MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>, |
593 | MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> |
594 | >(patterns.getContext()); |
595 | // clang-format on |
596 | } |
597 |
Definitions
- createConst
- cloneToShapedType
- CeilDivUIOpConverter
- matchAndRewrite
- CeilDivSIOpConverter
- matchAndRewrite
- FloorDivSIOpConverter
- matchAndRewrite
- MaxMinIOpConverter
- matchAndRewrite
- MaximumMinimumFOpConverter
- matchAndRewrite
- MaxNumMinNumFOpConverter
- matchAndRewrite
- BFloat16ExtFOpConverter
- matchAndRewrite
- BFloat16TruncFOpConverter
- matchAndRewrite
- F8E8M0ExtFOpConverter
- matchAndRewrite
- F8E8M0TruncFOpConverter
- matchAndRewrite
- ScalingExtFOpConverter
- matchAndRewrite
- ScalingTruncFOpConverter
- matchAndRewrite
- ArithExpandOpsPass
- runOnOperation
- populateCeilFloorDivExpandOpsPatterns
- populateExpandBFloat16Patterns
- populateExpandF8E8M0Patterns
- populateExpandScalingExtTruncPatterns
Learn to use CMake with our Intro Training
Find out more