1 | //===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
12 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
13 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
14 | #include "mlir/IR/TypeUtilities.h" |
15 | #include "mlir/Transforms/DialectConversion.h" |
16 | |
17 | namespace mlir { |
18 | namespace arith { |
19 | #define GEN_PASS_DEF_ARITHEXPANDOPSPASS |
20 | #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
21 | } // namespace arith |
22 | } // namespace mlir |
23 | |
24 | using namespace mlir; |
25 | |
26 | /// Create an integer or index constant. |
27 | static Value createConst(Location loc, Type type, int value, |
28 | PatternRewriter &rewriter) { |
29 | auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value); |
30 | if (auto shapedTy = dyn_cast<ShapedType>(type)) { |
31 | return rewriter.create<arith::ConstantOp>( |
32 | loc, DenseElementsAttr::get(shapedTy, attr)); |
33 | } |
34 | |
35 | return rewriter.create<arith::ConstantOp>(loc, attr); |
36 | } |
37 | |
38 | namespace { |
39 | |
40 | /// Expands CeilDivUIOp (n, m) into |
41 | /// n == 0 ? 0 : ((n-1) / m) + 1 |
42 | struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> { |
43 | using OpRewritePattern::OpRewritePattern; |
44 | LogicalResult matchAndRewrite(arith::CeilDivUIOp op, |
45 | PatternRewriter &rewriter) const final { |
46 | Location loc = op.getLoc(); |
47 | Value a = op.getLhs(); |
48 | Value b = op.getRhs(); |
49 | Value zero = createConst(loc, type: a.getType(), value: 0, rewriter); |
50 | Value compare = |
51 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero); |
52 | Value one = createConst(loc, type: a.getType(), value: 1, rewriter); |
53 | Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one); |
54 | Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b); |
55 | Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one); |
56 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne); |
57 | return success(); |
58 | } |
59 | }; |
60 | |
61 | /// Expands CeilDivSIOp (n, m) into |
62 | /// 1) x = (m > 0) ? -1 : 1 |
63 | /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) |
64 | struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> { |
65 | using OpRewritePattern::OpRewritePattern; |
66 | LogicalResult matchAndRewrite(arith::CeilDivSIOp op, |
67 | PatternRewriter &rewriter) const final { |
68 | Location loc = op.getLoc(); |
69 | Type type = op.getType(); |
70 | Value a = op.getLhs(); |
71 | Value b = op.getRhs(); |
72 | Value plusOne = createConst(loc, type, value: 1, rewriter); |
73 | Value zero = createConst(loc, type, value: 0, rewriter); |
74 | Value minusOne = createConst(loc, type, value: -1, rewriter); |
75 | // Compute x = (b>0) ? -1 : 1. |
76 | Value compare = |
77 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); |
78 | Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne); |
79 | // Compute positive res: 1 + ((x+a)/b). |
80 | Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a); |
81 | Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b); |
82 | Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB); |
83 | // Compute negative res: - ((-a)/b). |
84 | Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a); |
85 | Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b); |
86 | Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB); |
87 | // Result is (a*b>0) ? pos result : neg result. |
88 | // Note, we want to avoid using a*b because of possible overflow. |
89 | // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do |
90 | // not particuliarly care if a*b<0 is true or false when b is zero |
91 | // as this will result in an illegal divide. So `a*b<0` can be reformulated |
92 | // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'. |
93 | // We pick the first expression here. |
94 | Value aNeg = |
95 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); |
96 | Value aPos = |
97 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero); |
98 | Value bNeg = |
99 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); |
100 | Value bPos = |
101 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero); |
102 | Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg); |
103 | Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos); |
104 | Value compareRes = |
105 | rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm); |
106 | // Perform substitution and return success. |
107 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes, |
108 | negRes); |
109 | return success(); |
110 | } |
111 | }; |
112 | |
113 | /// Expands FloorDivSIOp (x, y) into |
114 | /// z = x / y |
115 | /// if (z * y != x && (x < 0) != (y < 0)) { |
116 | /// return z - 1; |
117 | /// } else { |
118 | /// return z; |
119 | /// } |
120 | struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> { |
121 | using OpRewritePattern::OpRewritePattern; |
122 | LogicalResult matchAndRewrite(arith::FloorDivSIOp op, |
123 | PatternRewriter &rewriter) const final { |
124 | Location loc = op.getLoc(); |
125 | Type type = op.getType(); |
126 | Value a = op.getLhs(); |
127 | Value b = op.getRhs(); |
128 | |
129 | Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b); |
130 | Value product = rewriter.create<arith::MulIOp>(loc, quotient, b); |
131 | Value notEqualDivisor = rewriter.create<arith::CmpIOp>( |
132 | loc, arith::CmpIPredicate::ne, a, product); |
133 | Value zero = createConst(loc, type, value: 0, rewriter); |
134 | |
135 | Value aNeg = |
136 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); |
137 | Value bNeg = |
138 | rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); |
139 | |
140 | Value signOpposite = rewriter.create<arith::CmpIOp>( |
141 | loc, arith::CmpIPredicate::ne, aNeg, bNeg); |
142 | Value cond = |
143 | rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite); |
144 | |
145 | Value minusOne = createConst(loc, type, value: -1, rewriter); |
146 | Value quotientMinusOne = |
147 | rewriter.create<arith::AddIOp>(loc, quotient, minusOne); |
148 | |
149 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne, |
150 | quotient); |
151 | return success(); |
152 | } |
153 | }; |
154 | |
155 | template <typename OpTy, arith::CmpFPredicate pred> |
156 | struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> { |
157 | public: |
158 | using OpRewritePattern<OpTy>::OpRewritePattern; |
159 | |
160 | LogicalResult matchAndRewrite(OpTy op, |
161 | PatternRewriter &rewriter) const final { |
162 | Value lhs = op.getLhs(); |
163 | Value rhs = op.getRhs(); |
164 | |
165 | Location loc = op.getLoc(); |
166 | // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs'). |
167 | static_assert(pred == arith::CmpFPredicate::UGT || |
168 | pred == arith::CmpFPredicate::ULT, |
169 | "pred must be either UGT or ULT" ); |
170 | Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs); |
171 | Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs); |
172 | |
173 | // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'. |
174 | Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO, |
175 | rhs, rhs); |
176 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select); |
177 | return success(); |
178 | } |
179 | }; |
180 | |
181 | template <typename OpTy, arith::CmpFPredicate pred> |
182 | struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> { |
183 | public: |
184 | using OpRewritePattern<OpTy>::OpRewritePattern; |
185 | |
186 | LogicalResult matchAndRewrite(OpTy op, |
187 | PatternRewriter &rewriter) const final { |
188 | Value lhs = op.getLhs(); |
189 | Value rhs = op.getRhs(); |
190 | |
191 | Location loc = op.getLoc(); |
192 | // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs'). |
193 | static_assert(pred == arith::CmpFPredicate::UGT || |
194 | pred == arith::CmpFPredicate::ULT, |
195 | "pred must be either UGT or ULT" ); |
196 | Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs); |
197 | Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs); |
198 | |
199 | // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'. |
200 | Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO, |
201 | lhs, lhs); |
202 | rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select); |
203 | return success(); |
204 | } |
205 | }; |
206 | |
207 | struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> { |
208 | using OpRewritePattern::OpRewritePattern; |
209 | LogicalResult matchAndRewrite(arith::ExtFOp op, |
210 | PatternRewriter &rewriter) const final { |
211 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
212 | auto operand = op.getOperand(); |
213 | Type operandTy = operand.getType(); |
214 | Type resultTy = op.getType(); |
215 | Type operandETy = getElementTypeOrSelf(type: operandTy); |
216 | Type resultETy = getElementTypeOrSelf(type: resultTy); |
217 | |
218 | if (!operandETy.isBF16() || !resultETy.isF32()) { |
219 | return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32." ); |
220 | } |
221 | |
222 | Type i16Ty = b.getI16Type(); |
223 | Type i32Ty = b.getI32Type(); |
224 | if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) { |
225 | i16Ty = shapedTy.clone(i16Ty); |
226 | i32Ty = shapedTy.clone(i32Ty); |
227 | } |
228 | |
229 | Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand); |
230 | Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast); |
231 | |
232 | Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); |
233 | Value shl = b.create<arith::ShLIOp>(exti, c16); |
234 | Value result = b.create<arith::BitcastOp>(resultTy, shl); |
235 | |
236 | rewriter.replaceOp(op, result); |
237 | return success(); |
238 | } |
239 | }; |
240 | |
241 | struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { |
242 | using OpRewritePattern::OpRewritePattern; |
243 | LogicalResult matchAndRewrite(arith::TruncFOp op, |
244 | PatternRewriter &rewriter) const final { |
245 | ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
246 | auto operand = op.getOperand(); |
247 | Type operandTy = operand.getType(); |
248 | Type resultTy = op.getType(); |
249 | Type operandETy = getElementTypeOrSelf(type: operandTy); |
250 | Type resultETy = getElementTypeOrSelf(type: resultTy); |
251 | |
252 | if (!operandETy.isF32() || !resultETy.isBF16()) { |
253 | return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16." ); |
254 | } |
255 | |
256 | if (op.getRoundingmodeAttr()) { |
257 | return rewriter.notifyMatchFailure( |
258 | op, "only applicable to default rounding mode." ); |
259 | } |
260 | |
261 | Type i16Ty = b.getI16Type(); |
262 | Type i32Ty = b.getI32Type(); |
263 | Type f32Ty = b.getF32Type(); |
264 | if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) { |
265 | i16Ty = shapedTy.clone(i16Ty); |
266 | i32Ty = shapedTy.clone(i32Ty); |
267 | f32Ty = shapedTy.clone(f32Ty); |
268 | } |
269 | |
270 | // Algorithm borrowed from this excellent code: |
271 | // https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79 |
272 | // There is a magic idea there, to let the addition of the rounding_bias to |
273 | // the mantissa simply overflow into the exponent bits. It's a bit of an |
274 | // aggressive, obfuscating optimization, but it is well-tested code, and it |
275 | // results in more concise and efficient IR. |
276 | // The case of NaN is handled separately (see isNaN and the final select). |
277 | // The case of infinities is NOT handled separately, which deserves an |
278 | // explanation. As the encoding of infinities has zero mantissa, the |
279 | // rounding-bias addition never carries into the exponent so that just gets |
280 | // truncated away, and as bfloat16 and float32 have the same number of |
281 | // exponent bits, that simple truncation is the desired outcome for |
282 | // infinities. |
283 | Value isNan = |
284 | b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand); |
285 | // Constant used to make the rounding bias. |
286 | Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter); |
287 | // Constant used to generate a quiet NaN. |
288 | Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter); |
289 | // Small constants used to address bits. |
290 | Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); |
291 | Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter); |
292 | // Reinterpret the input f32 value as bits. |
293 | Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand); |
294 | // Read bit 16 as a value in {0,1}. |
295 | Value bit16 = |
296 | b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1); |
297 | // Determine the rounding bias to add as either 0x7fff or 0x8000 depending |
298 | // on bit 16, implementing the tie-breaking "to nearest even". |
299 | Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF); |
300 | // Add the rounding bias. Generally we want this to be added to the |
301 | // mantissa, but nothing prevents this to from carrying into the exponent |
302 | // bits, which would feel like a bug, but this is the magic trick here: |
303 | // when that happens, the mantissa gets reset to zero and the exponent |
304 | // gets incremented by the carry... which is actually exactly what we |
305 | // want. |
306 | Value biased = b.create<arith::AddIOp>(bitcast, roundingBias); |
307 | // Now that the rounding-bias has been added, truncating the low bits |
308 | // yields the correctly rounded result. |
309 | Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16); |
310 | Value normalCaseResult_i16 = |
311 | b.create<arith::TruncIOp>(i16Ty, biasedAndShifted); |
312 | // Select either the above-computed result, or a quiet NaN constant |
313 | // if the input was NaN. |
314 | Value select = |
315 | b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16); |
316 | Value result = b.create<arith::BitcastOp>(resultTy, select); |
317 | rewriter.replaceOp(op, result); |
318 | return success(); |
319 | } |
320 | }; |
321 | |
322 | struct ArithExpandOpsPass |
323 | : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> { |
324 | using ArithExpandOpsPassBase::ArithExpandOpsPassBase; |
325 | |
326 | void runOnOperation() override { |
327 | RewritePatternSet patterns(&getContext()); |
328 | ConversionTarget target(getContext()); |
329 | |
330 | arith::populateArithExpandOpsPatterns(patterns); |
331 | |
332 | target.addLegalDialect<arith::ArithDialect>(); |
333 | // clang-format off |
334 | target.addIllegalOp< |
335 | arith::CeilDivSIOp, |
336 | arith::CeilDivUIOp, |
337 | arith::FloorDivSIOp, |
338 | arith::MaximumFOp, |
339 | arith::MinimumFOp, |
340 | arith::MaxNumFOp, |
341 | arith::MinNumFOp |
342 | >(); |
343 | |
344 | if (includeBf16) { |
345 | arith::populateExpandBFloat16Patterns(patterns); |
346 | target.addDynamicallyLegalOp<arith::ExtFOp>( |
347 | [](arith::ExtFOp op) { |
348 | Type inETy = getElementTypeOrSelf(op.getOperand().getType()); |
349 | Type outETy = getElementTypeOrSelf(op.getType()); |
350 | return !(inETy.isBF16() && outETy.isF32()); |
351 | }); |
352 | |
353 | target.addDynamicallyLegalOp<arith::TruncFOp>( |
354 | [](arith::TruncFOp op) { |
355 | Type inETy = getElementTypeOrSelf(op.getOperand().getType()); |
356 | Type outETy = getElementTypeOrSelf(op.getType()); |
357 | return !(inETy.isF32() && outETy.isBF16()); |
358 | }); |
359 | } |
360 | |
361 | // clang-format on |
362 | if (failed(applyPartialConversion(getOperation(), target, |
363 | std::move(patterns)))) |
364 | signalPassFailure(); |
365 | } |
366 | }; |
367 | |
368 | } // namespace |
369 | |
370 | void mlir::arith::populateCeilFloorDivExpandOpsPatterns( |
371 | RewritePatternSet &patterns) { |
372 | patterns |
373 | .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>( |
374 | arg: patterns.getContext()); |
375 | } |
376 | |
377 | void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) { |
378 | patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>( |
379 | arg: patterns.getContext()); |
380 | } |
381 | |
382 | void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { |
383 | populateCeilFloorDivExpandOpsPatterns(patterns); |
384 | // clang-format off |
385 | patterns.add< |
386 | MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>, |
387 | MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>, |
388 | MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>, |
389 | MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> |
390 | >(patterns.getContext()); |
391 | // clang-format on |
392 | } |
393 | |