1 | //===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Analysis/Presburger/IntegerRelation.h" |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/Arith/Transforms/Transforms.h" |
14 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
15 | #include "mlir/IR/BuiltinAttributes.h" |
16 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "mlir/IR/MLIRContext.h" |
19 | #include "mlir/IR/Matchers.h" |
20 | #include "mlir/IR/Operation.h" |
21 | #include "mlir/IR/PatternMatch.h" |
22 | #include "mlir/IR/TypeUtilities.h" |
23 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
24 | #include "mlir/Support/LogicalResult.h" |
25 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
26 | #include "llvm/ADT/STLExtras.h" |
27 | #include "llvm/ADT/SmallVector.h" |
28 | #include <cassert> |
29 | #include <cstdint> |
30 | |
31 | namespace mlir::arith { |
32 | #define GEN_PASS_DEF_ARITHINTNARROWING |
33 | #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
34 | } // namespace mlir::arith |
35 | |
36 | namespace mlir::arith { |
37 | namespace { |
38 | //===----------------------------------------------------------------------===// |
39 | // Common Helpers |
40 | //===----------------------------------------------------------------------===// |
41 | |
42 | /// The base for integer bitwidth narrowing patterns. |
43 | template <typename SourceOp> |
44 | struct NarrowingPattern : OpRewritePattern<SourceOp> { |
45 | NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options, |
46 | PatternBenefit benefit = 1) |
47 | : OpRewritePattern<SourceOp>(ctx, benefit), |
48 | supportedBitwidths(options.bitwidthsSupported.begin(), |
49 | options.bitwidthsSupported.end()) { |
50 | assert(!supportedBitwidths.empty() && "Invalid options" ); |
51 | assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth" ); |
52 | llvm::sort(C&: supportedBitwidths); |
53 | } |
54 | |
55 | FailureOr<unsigned> |
56 | getNarrowestCompatibleBitwidth(unsigned bitsRequired) const { |
57 | for (unsigned candidate : supportedBitwidths) |
58 | if (candidate >= bitsRequired) |
59 | return candidate; |
60 | |
61 | return failure(); |
62 | } |
63 | |
64 | /// Returns the narrowest supported type that fits `bitsRequired`. |
65 | FailureOr<Type> getNarrowType(unsigned bitsRequired, Type origTy) const { |
66 | assert(origTy); |
67 | FailureOr<unsigned> bestBitwidth = |
68 | getNarrowestCompatibleBitwidth(bitsRequired); |
69 | if (failed(result: bestBitwidth)) |
70 | return failure(); |
71 | |
72 | Type elemTy = getElementTypeOrSelf(type: origTy); |
73 | if (!isa<IntegerType>(Val: elemTy)) |
74 | return failure(); |
75 | |
76 | auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth); |
77 | if (newElemTy == elemTy) |
78 | return failure(); |
79 | |
80 | if (origTy == elemTy) |
81 | return newElemTy; |
82 | |
83 | if (auto shapedTy = dyn_cast<ShapedType>(origTy)) |
84 | if (dyn_cast<IntegerType>(shapedTy.getElementType())) |
85 | return shapedTy.clone(shapedTy.getShape(), newElemTy); |
86 | |
87 | return failure(); |
88 | } |
89 | |
90 | private: |
91 | // Supported integer bitwidths in the ascending order. |
92 | llvm::SmallVector<unsigned, 6> supportedBitwidths; |
93 | }; |
94 | |
95 | /// Returns the integer bitwidth required to represent `type`. |
96 | FailureOr<unsigned> calculateBitsRequired(Type type) { |
97 | assert(type); |
98 | if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(type))) |
99 | return intTy.getWidth(); |
100 | |
101 | return failure(); |
102 | } |
103 | |
104 | enum class ExtensionKind { Sign, Zero }; |
105 | |
106 | /// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away |
107 | /// the exact op type. Exposes helper functions to query the types, operands, |
108 | /// and the result. This is so that we can handle both extension kinds without |
109 | /// needing to use templates or branching. |
110 | class ExtensionOp { |
111 | public: |
112 | /// Attemps to create a new extension op from `op`. Returns an extension op |
113 | /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure |
114 | /// otherwise. |
115 | static FailureOr<ExtensionOp> from(Operation *op) { |
116 | if (dyn_cast_or_null<arith::ExtSIOp>(op)) |
117 | return ExtensionOp{op, ExtensionKind::Sign}; |
118 | if (dyn_cast_or_null<arith::ExtUIOp>(op)) |
119 | return ExtensionOp{op, ExtensionKind::Zero}; |
120 | |
121 | return failure(); |
122 | } |
123 | |
124 | ExtensionOp(const ExtensionOp &) = default; |
125 | ExtensionOp &operator=(const ExtensionOp &) = default; |
126 | |
127 | /// Creates a new extension op of the same kind. |
128 | Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType, |
129 | Value in) { |
130 | if (kind == ExtensionKind::Sign) |
131 | return rewriter.create<arith::ExtSIOp>(loc, newType, in); |
132 | |
133 | return rewriter.create<arith::ExtUIOp>(loc, newType, in); |
134 | } |
135 | |
136 | /// Replaces `toReplace` with a new extension op of the same kind. |
137 | void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace, |
138 | Value in) { |
139 | assert(toReplace->getNumResults() == 1); |
140 | Type newType = toReplace->getResult(idx: 0).getType(); |
141 | Operation *newOp = recreate(rewriter, loc: toReplace->getLoc(), newType, in); |
142 | rewriter.replaceOp(op: toReplace, newValues: newOp->getResult(idx: 0)); |
143 | } |
144 | |
145 | ExtensionKind getKind() { return kind; } |
146 | |
147 | Value getResult() { return op->getResult(idx: 0); } |
148 | Value getIn() { return op->getOperand(idx: 0); } |
149 | |
150 | Type getType() { return getResult().getType(); } |
151 | Type getElementType() { return getElementTypeOrSelf(type: getType()); } |
152 | Type getInType() { return getIn().getType(); } |
153 | Type getInElementType() { return getElementTypeOrSelf(type: getInType()); } |
154 | |
155 | private: |
156 | ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) { |
157 | assert(op); |
158 | assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op" ); |
159 | } |
160 | Operation *op = nullptr; |
161 | ExtensionKind kind = {}; |
162 | }; |
163 | |
164 | /// Returns the integer bitwidth required to represent `value`. |
165 | unsigned calculateBitsRequired(const APInt &value, |
166 | ExtensionKind lookThroughExtension) { |
167 | // For unsigned values, we only need the active bits. As a special case, zero |
168 | // requires one bit. |
169 | if (lookThroughExtension == ExtensionKind::Zero) |
170 | return std::max(a: value.getActiveBits(), b: 1u); |
171 | |
172 | // If a signed value is nonnegative, we need one extra bit for the sign. |
173 | if (value.isNonNegative()) |
174 | return value.getActiveBits() + 1; |
175 | |
176 | // For the signed min, we need all the bits. |
177 | if (value.isMinSignedValue()) |
178 | return value.getBitWidth(); |
179 | |
180 | // For negative values, we need all the non-sign bits and one extra bit for |
181 | // the sign. |
182 | return value.getBitWidth() - value.getNumSignBits() + 1; |
183 | } |
184 | |
185 | /// Returns the integer bitwidth required to represent `value`. |
186 | /// Looks through either sign- or zero-extension as specified by |
187 | /// `lookThroughExtension`. |
188 | FailureOr<unsigned> calculateBitsRequired(Value value, |
189 | ExtensionKind lookThroughExtension) { |
190 | // Handle constants. |
191 | if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) { |
192 | if (auto intAttr = dyn_cast<IntegerAttr>(attr)) |
193 | return calculateBitsRequired(intAttr.getValue(), lookThroughExtension); |
194 | |
195 | if (auto elemsAttr = dyn_cast<DenseElementsAttr>(attr)) { |
196 | if (elemsAttr.getElementType().isIntOrIndex()) { |
197 | if (elemsAttr.isSplat()) |
198 | return calculateBitsRequired(elemsAttr.getSplatValue<APInt>(), |
199 | lookThroughExtension); |
200 | |
201 | unsigned maxBits = 1; |
202 | for (const APInt &elemValue : elemsAttr.getValues<APInt>()) |
203 | maxBits = std::max( |
204 | maxBits, calculateBitsRequired(elemValue, lookThroughExtension)); |
205 | return maxBits; |
206 | } |
207 | } |
208 | } |
209 | |
210 | if (lookThroughExtension == ExtensionKind::Sign) { |
211 | if (auto sext = value.getDefiningOp<arith::ExtSIOp>()) |
212 | return calculateBitsRequired(sext.getIn().getType()); |
213 | } else if (lookThroughExtension == ExtensionKind::Zero) { |
214 | if (auto zext = value.getDefiningOp<arith::ExtUIOp>()) |
215 | return calculateBitsRequired(zext.getIn().getType()); |
216 | } |
217 | |
218 | // If nothing else worked, return the type requirements for this element type. |
219 | return calculateBitsRequired(type: value.getType()); |
220 | } |
221 | |
222 | /// Base pattern for arith binary ops. |
223 | /// Example: |
224 | /// ``` |
225 | /// %lhs = arith.extsi %a : i8 to i32 |
226 | /// %rhs = arith.extsi %b : i8 to i32 |
227 | /// %r = arith.addi %lhs, %rhs : i32 |
228 | /// ==> |
229 | /// %lhs = arith.extsi %a : i8 to i16 |
230 | /// %rhs = arith.extsi %b : i8 to i16 |
231 | /// %add = arith.addi %lhs, %rhs : i16 |
232 | /// %r = arith.extsi %add : i16 to i32 |
233 | /// ``` |
234 | template <typename BinaryOp> |
235 | struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> { |
236 | using NarrowingPattern<BinaryOp>::NarrowingPattern; |
237 | |
238 | /// Returns the number of bits required to represent the full result, assuming |
239 | /// that both operands are `operandBits`-wide. Derived classes must implement |
240 | /// this, taking into account `BinaryOp` semantics. |
241 | virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0; |
242 | |
243 | /// Customization point for patterns that should only apply with |
244 | /// zero/sign-extension ops as arguments. |
245 | virtual bool isSupported(ExtensionOp) const { return true; } |
246 | |
247 | LogicalResult matchAndRewrite(BinaryOp op, |
248 | PatternRewriter &rewriter) const final { |
249 | Type origTy = op.getType(); |
250 | FailureOr<unsigned> resultBits = calculateBitsRequired(type: origTy); |
251 | if (failed(result: resultBits)) |
252 | return failure(); |
253 | |
254 | // For the optimization to apply, we expect the lhs to be an extension op, |
255 | // and for the rhs to either be the same extension op or a constant. |
256 | FailureOr<ExtensionOp> ext = ExtensionOp::from(op: op.getLhs().getDefiningOp()); |
257 | if (failed(result: ext) || !isSupported(*ext)) |
258 | return failure(); |
259 | |
260 | FailureOr<unsigned> lhsBitsRequired = |
261 | calculateBitsRequired(value: ext->getIn(), lookThroughExtension: ext->getKind()); |
262 | if (failed(result: lhsBitsRequired) || *lhsBitsRequired >= *resultBits) |
263 | return failure(); |
264 | |
265 | FailureOr<unsigned> rhsBitsRequired = |
266 | calculateBitsRequired(op.getRhs(), ext->getKind()); |
267 | if (failed(result: rhsBitsRequired) || *rhsBitsRequired >= *resultBits) |
268 | return failure(); |
269 | |
270 | // Negotiate a common bit requirements for both lhs and rhs, accounting for |
271 | // the result requiring more bits than the operands. |
272 | unsigned commonBitsRequired = |
273 | getResultBitsProduced(operandBits: std::max(a: *lhsBitsRequired, b: *rhsBitsRequired)); |
274 | FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy); |
275 | if (failed(result: narrowTy) || calculateBitsRequired(type: *narrowTy) >= *resultBits) |
276 | return failure(); |
277 | |
278 | Location loc = op.getLoc(); |
279 | Value newLhs = |
280 | rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs()); |
281 | Value newRhs = |
282 | rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs()); |
283 | Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs); |
284 | ext->recreateAndReplace(rewriter, toReplace: op, in: newAdd); |
285 | return success(); |
286 | } |
287 | }; |
288 | |
289 | //===----------------------------------------------------------------------===// |
290 | // AddIOp Pattern |
291 | //===----------------------------------------------------------------------===// |
292 | |
293 | struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> { |
294 | using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; |
295 | |
296 | // Addition may require one extra bit for the result. |
297 | // Example: `UINT8_MAX + 1 == 255 + 1 == 256`. |
298 | unsigned getResultBitsProduced(unsigned operandBits) const override { |
299 | return operandBits + 1; |
300 | } |
301 | }; |
302 | |
303 | //===----------------------------------------------------------------------===// |
304 | // SubIOp Pattern |
305 | //===----------------------------------------------------------------------===// |
306 | |
307 | struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> { |
308 | using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; |
309 | |
310 | // This optimization only applies to signed arguments. |
311 | bool isSupported(ExtensionOp ext) const override { |
312 | return ext.getKind() == ExtensionKind::Sign; |
313 | } |
314 | |
315 | // Subtraction may require one extra bit for the result. |
316 | // Example: `INT8_MAX - (-1) == 127 - (-1) == 128`. |
317 | unsigned getResultBitsProduced(unsigned operandBits) const override { |
318 | return operandBits + 1; |
319 | } |
320 | }; |
321 | |
322 | //===----------------------------------------------------------------------===// |
323 | // MulIOp Pattern |
324 | //===----------------------------------------------------------------------===// |
325 | |
326 | struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> { |
327 | using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; |
328 | |
329 | // Multiplication may require up double the operand bits. |
330 | // Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`. |
331 | unsigned getResultBitsProduced(unsigned operandBits) const override { |
332 | return 2 * operandBits; |
333 | } |
334 | }; |
335 | |
336 | //===----------------------------------------------------------------------===// |
337 | // DivSIOp Pattern |
338 | //===----------------------------------------------------------------------===// |
339 | |
340 | struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> { |
341 | using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; |
342 | |
343 | // This optimization only applies to signed arguments. |
344 | bool isSupported(ExtensionOp ext) const override { |
345 | return ext.getKind() == ExtensionKind::Sign; |
346 | } |
347 | |
348 | // Unlike multiplication, signed division requires only one more result bit. |
349 | // Example: `INT8_MIN / (-1) == -128 / (-1) == 128`. |
350 | unsigned getResultBitsProduced(unsigned operandBits) const override { |
351 | return operandBits + 1; |
352 | } |
353 | }; |
354 | |
355 | //===----------------------------------------------------------------------===// |
356 | // DivUIOp Pattern |
357 | //===----------------------------------------------------------------------===// |
358 | |
359 | struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> { |
360 | using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; |
361 | |
362 | // This optimization only applies to unsigned arguments. |
363 | bool isSupported(ExtensionOp ext) const override { |
364 | return ext.getKind() == ExtensionKind::Zero; |
365 | } |
366 | |
367 | // Unsigned division does not require any extra result bits. |
368 | unsigned getResultBitsProduced(unsigned operandBits) const override { |
369 | return operandBits; |
370 | } |
371 | }; |
372 | |
373 | //===----------------------------------------------------------------------===// |
374 | // Min/Max Patterns |
375 | //===----------------------------------------------------------------------===// |
376 | |
377 | template <typename MinMaxOp, ExtensionKind Kind> |
378 | struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> { |
379 | using BinaryOpNarrowingPattern<MinMaxOp>::BinaryOpNarrowingPattern; |
380 | |
381 | bool isSupported(ExtensionOp ext) const override { |
382 | return ext.getKind() == Kind; |
383 | } |
384 | |
385 | // Min/max returns one of the arguments and does not require any extra result |
386 | // bits. |
387 | unsigned getResultBitsProduced(unsigned operandBits) const override { |
388 | return operandBits; |
389 | } |
390 | }; |
391 | using MaxSIPattern = MinMaxPattern<arith::MaxSIOp, ExtensionKind::Sign>; |
392 | using MaxUIPattern = MinMaxPattern<arith::MaxUIOp, ExtensionKind::Zero>; |
393 | using MinSIPattern = MinMaxPattern<arith::MinSIOp, ExtensionKind::Sign>; |
394 | using MinUIPattern = MinMaxPattern<arith::MinUIOp, ExtensionKind::Zero>; |
395 | |
396 | //===----------------------------------------------------------------------===// |
397 | // *IToFPOp Patterns |
398 | //===----------------------------------------------------------------------===// |
399 | |
400 | template <typename IToFPOp, ExtensionKind Extension> |
401 | struct IToFPPattern final : NarrowingPattern<IToFPOp> { |
402 | using NarrowingPattern<IToFPOp>::NarrowingPattern; |
403 | |
404 | LogicalResult matchAndRewrite(IToFPOp op, |
405 | PatternRewriter &rewriter) const override { |
406 | FailureOr<unsigned> narrowestWidth = |
407 | calculateBitsRequired(op.getIn(), Extension); |
408 | if (failed(result: narrowestWidth)) |
409 | return failure(); |
410 | |
411 | FailureOr<Type> narrowTy = |
412 | this->getNarrowType(*narrowestWidth, op.getIn().getType()); |
413 | if (failed(result: narrowTy)) |
414 | return failure(); |
415 | |
416 | Value newIn = rewriter.createOrFold<arith::TruncIOp>(op.getLoc(), *narrowTy, |
417 | op.getIn()); |
418 | rewriter.replaceOpWithNewOp<IToFPOp>(op, op.getType(), newIn); |
419 | return success(); |
420 | } |
421 | }; |
422 | using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>; |
423 | using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>; |
424 | |
425 | //===----------------------------------------------------------------------===// |
426 | // Index Cast Patterns |
427 | //===----------------------------------------------------------------------===// |
428 | |
429 | // These rely on the `ValueBounds` interface for index values. For example, we |
430 | // can often statically tell index value bounds of loop induction variables. |
431 | |
432 | template <typename CastOp, ExtensionKind Kind> |
433 | struct IndexCastPattern final : NarrowingPattern<CastOp> { |
434 | using NarrowingPattern<CastOp>::NarrowingPattern; |
435 | |
436 | LogicalResult matchAndRewrite(CastOp op, |
437 | PatternRewriter &rewriter) const override { |
438 | Value in = op.getIn(); |
439 | // We only support scalar index -> integer casts. |
440 | if (!isa<IndexType>(Val: in.getType())) |
441 | return failure(); |
442 | |
443 | // Check the lower bound in both the signed and unsigned cast case. We |
444 | // conservatively assume that even unsigned casts may be performed on |
445 | // negative indices. |
446 | FailureOr<int64_t> lb = ValueBoundsConstraintSet::computeConstantBound( |
447 | type: presburger::BoundType::LB, var: in); |
448 | if (failed(result: lb)) |
449 | return failure(); |
450 | |
451 | FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound( |
452 | type: presburger::BoundType::UB, var: in, |
453 | /*stopCondition=*/nullptr, /*closedUB=*/true); |
454 | if (failed(result: ub)) |
455 | return failure(); |
456 | |
457 | assert(*lb <= *ub && "Invalid bounds" ); |
458 | unsigned lbBitsRequired = calculateBitsRequired(value: APInt(64, *lb), lookThroughExtension: Kind); |
459 | unsigned ubBitsRequired = calculateBitsRequired(value: APInt(64, *ub), lookThroughExtension: Kind); |
460 | unsigned bitsRequired = std::max(a: lbBitsRequired, b: ubBitsRequired); |
461 | |
462 | IntegerType resultTy = cast<IntegerType>(op.getType()); |
463 | if (resultTy.getWidth() <= bitsRequired) |
464 | return failure(); |
465 | |
466 | FailureOr<Type> narrowTy = this->getNarrowType(bitsRequired, resultTy); |
467 | if (failed(result: narrowTy)) |
468 | return failure(); |
469 | |
470 | Value newCast = rewriter.create<CastOp>(op.getLoc(), *narrowTy, op.getIn()); |
471 | |
472 | if (Kind == ExtensionKind::Sign) |
473 | rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, resultTy, newCast); |
474 | else |
475 | rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, resultTy, newCast); |
476 | return success(); |
477 | } |
478 | }; |
479 | using IndexCastSIPattern = |
480 | IndexCastPattern<arith::IndexCastOp, ExtensionKind::Sign>; |
481 | using IndexCastUIPattern = |
482 | IndexCastPattern<arith::IndexCastUIOp, ExtensionKind::Zero>; |
483 | |
484 | //===----------------------------------------------------------------------===// |
485 | // Patterns to Commute Extension Ops |
486 | //===----------------------------------------------------------------------===// |
487 | |
488 | struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> { |
489 | using NarrowingPattern::NarrowingPattern; |
490 | |
491 | LogicalResult matchAndRewrite(vector::BroadcastOp op, |
492 | PatternRewriter &rewriter) const override { |
493 | FailureOr<ExtensionOp> ext = |
494 | ExtensionOp::from(op: op.getSource().getDefiningOp()); |
495 | if (failed(result: ext)) |
496 | return failure(); |
497 | |
498 | VectorType origTy = op.getResultVectorType(); |
499 | VectorType newTy = |
500 | origTy.cloneWith(origTy.getShape(), ext->getInElementType()); |
501 | Value newBroadcast = |
502 | rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn()); |
503 | ext->recreateAndReplace(rewriter, toReplace: op, in: newBroadcast); |
504 | return success(); |
505 | } |
506 | }; |
507 | |
508 | struct final : NarrowingPattern<vector::ExtractOp> { |
509 | using NarrowingPattern::NarrowingPattern; |
510 | |
511 | LogicalResult matchAndRewrite(vector::ExtractOp op, |
512 | PatternRewriter &rewriter) const override { |
513 | FailureOr<ExtensionOp> ext = |
514 | ExtensionOp::from(op: op.getVector().getDefiningOp()); |
515 | if (failed(result: ext)) |
516 | return failure(); |
517 | |
518 | Value = rewriter.create<vector::ExtractOp>( |
519 | op.getLoc(), ext->getIn(), op.getMixedPosition()); |
520 | ext->recreateAndReplace(rewriter, toReplace: op, in: newExtract); |
521 | return success(); |
522 | } |
523 | }; |
524 | |
525 | struct final |
526 | : NarrowingPattern<vector::ExtractElementOp> { |
527 | using NarrowingPattern::NarrowingPattern; |
528 | |
529 | LogicalResult matchAndRewrite(vector::ExtractElementOp op, |
530 | PatternRewriter &rewriter) const override { |
531 | FailureOr<ExtensionOp> ext = |
532 | ExtensionOp::from(op: op.getVector().getDefiningOp()); |
533 | if (failed(result: ext)) |
534 | return failure(); |
535 | |
536 | Value = rewriter.create<vector::ExtractElementOp>( |
537 | op.getLoc(), ext->getIn(), op.getPosition()); |
538 | ext->recreateAndReplace(rewriter, toReplace: op, in: newExtract); |
539 | return success(); |
540 | } |
541 | }; |
542 | |
543 | struct final |
544 | : NarrowingPattern<vector::ExtractStridedSliceOp> { |
545 | using NarrowingPattern::NarrowingPattern; |
546 | |
547 | LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op, |
548 | PatternRewriter &rewriter) const override { |
549 | FailureOr<ExtensionOp> ext = |
550 | ExtensionOp::from(op: op.getVector().getDefiningOp()); |
551 | if (failed(result: ext)) |
552 | return failure(); |
553 | |
554 | VectorType origTy = op.getType(); |
555 | VectorType = |
556 | origTy.cloneWith(origTy.getShape(), ext->getInElementType()); |
557 | Value = rewriter.create<vector::ExtractStridedSliceOp>( |
558 | op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(), |
559 | op.getStrides()); |
560 | ext->recreateAndReplace(rewriter, toReplace: op, in: newExtract); |
561 | return success(); |
562 | } |
563 | }; |
564 | |
565 | /// Base pattern for `vector.insert` narrowing patterns. |
566 | template <typename InsertionOp> |
567 | struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> { |
568 | using NarrowingPattern<InsertionOp>::NarrowingPattern; |
569 | |
570 | /// Derived classes must provide a function to create the matching insertion |
571 | /// op based on the original op and new arguments. |
572 | virtual InsertionOp createInsertionOp(PatternRewriter &rewriter, |
573 | InsertionOp origInsert, |
574 | Value narrowValue, |
575 | Value narrowDest) const = 0; |
576 | |
577 | LogicalResult matchAndRewrite(InsertionOp op, |
578 | PatternRewriter &rewriter) const final { |
579 | FailureOr<ExtensionOp> ext = |
580 | ExtensionOp::from(op: op.getSource().getDefiningOp()); |
581 | if (failed(result: ext)) |
582 | return failure(); |
583 | |
584 | FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, insValue: *ext); |
585 | if (failed(newInsert)) |
586 | return failure(); |
587 | ext->recreateAndReplace(rewriter, toReplace: op, in: *newInsert); |
588 | return success(); |
589 | } |
590 | |
591 | FailureOr<InsertionOp> createNarrowInsert(InsertionOp op, |
592 | PatternRewriter &rewriter, |
593 | ExtensionOp insValue) const { |
594 | // Calculate the operand and result bitwidths. We can only apply narrowing |
595 | // when the inserted source value and destination vector require fewer bits |
596 | // than the result. Because the source and destination may have different |
597 | // bitwidths requirements, we have to find the common narrow bitwidth that |
598 | // is greater equal to the operand bitwidth requirements and still narrower |
599 | // than the result. |
600 | FailureOr<unsigned> origBitsRequired = calculateBitsRequired(op.getType()); |
601 | if (failed(result: origBitsRequired)) |
602 | return failure(); |
603 | |
604 | // TODO: We could relax this check by disregarding bitwidth requirements of |
605 | // elements that we know will be replaced by the insertion. |
606 | FailureOr<unsigned> destBitsRequired = |
607 | calculateBitsRequired(op.getDest(), insValue.getKind()); |
608 | if (failed(result: destBitsRequired) || *destBitsRequired >= *origBitsRequired) |
609 | return failure(); |
610 | |
611 | FailureOr<unsigned> insertedBitsRequired = |
612 | calculateBitsRequired(value: insValue.getIn(), lookThroughExtension: insValue.getKind()); |
613 | if (failed(result: insertedBitsRequired) || |
614 | *insertedBitsRequired >= *origBitsRequired) |
615 | return failure(); |
616 | |
617 | // Find a narrower element type that satisfies the bitwidth requirements of |
618 | // both the source and the destination values. |
619 | unsigned newInsertionBits = |
620 | std::max(a: *destBitsRequired, b: *insertedBitsRequired); |
621 | FailureOr<Type> newVecTy = |
622 | this->getNarrowType(newInsertionBits, op.getType()); |
623 | if (failed(result: newVecTy) || *newVecTy == op.getType()) |
624 | return failure(); |
625 | |
626 | FailureOr<Type> newInsertedValueTy = |
627 | this->getNarrowType(newInsertionBits, insValue.getType()); |
628 | if (failed(result: newInsertedValueTy)) |
629 | return failure(); |
630 | |
631 | Location loc = op.getLoc(); |
632 | Value narrowValue = rewriter.createOrFold<arith::TruncIOp>( |
633 | loc, *newInsertedValueTy, insValue.getResult()); |
634 | Value narrowDest = |
635 | rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest()); |
636 | return createInsertionOp(rewriter, origInsert: op, narrowValue, narrowDest); |
637 | } |
638 | }; |
639 | |
640 | struct ExtensionOverInsert final |
641 | : ExtensionOverInsertionPattern<vector::InsertOp> { |
642 | using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; |
643 | |
644 | vector::InsertOp createInsertionOp(PatternRewriter &rewriter, |
645 | vector::InsertOp origInsert, |
646 | Value narrowValue, |
647 | Value narrowDest) const override { |
648 | return rewriter.create<vector::InsertOp>(origInsert.getLoc(), narrowValue, |
649 | narrowDest, |
650 | origInsert.getMixedPosition()); |
651 | } |
652 | }; |
653 | |
654 | struct ExtensionOverInsertElement final |
655 | : ExtensionOverInsertionPattern<vector::InsertElementOp> { |
656 | using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; |
657 | |
658 | vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter, |
659 | vector::InsertElementOp origInsert, |
660 | Value narrowValue, |
661 | Value narrowDest) const override { |
662 | return rewriter.create<vector::InsertElementOp>( |
663 | origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition()); |
664 | } |
665 | }; |
666 | |
667 | struct ExtensionOverInsertStridedSlice final |
668 | : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> { |
669 | using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; |
670 | |
671 | vector::InsertStridedSliceOp |
672 | createInsertionOp(PatternRewriter &rewriter, |
673 | vector::InsertStridedSliceOp origInsert, Value narrowValue, |
674 | Value narrowDest) const override { |
675 | return rewriter.create<vector::InsertStridedSliceOp>( |
676 | origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(), |
677 | origInsert.getStrides()); |
678 | } |
679 | }; |
680 | |
681 | struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> { |
682 | using NarrowingPattern::NarrowingPattern; |
683 | |
684 | LogicalResult matchAndRewrite(vector::ShapeCastOp op, |
685 | PatternRewriter &rewriter) const override { |
686 | FailureOr<ExtensionOp> ext = |
687 | ExtensionOp::from(op: op.getSource().getDefiningOp()); |
688 | if (failed(result: ext)) |
689 | return failure(); |
690 | |
691 | VectorType origTy = op.getResultVectorType(); |
692 | VectorType newTy = |
693 | origTy.cloneWith(origTy.getShape(), ext->getInElementType()); |
694 | Value newCast = |
695 | rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn()); |
696 | ext->recreateAndReplace(rewriter, toReplace: op, in: newCast); |
697 | return success(); |
698 | } |
699 | }; |
700 | |
701 | struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> { |
702 | using NarrowingPattern::NarrowingPattern; |
703 | |
704 | LogicalResult matchAndRewrite(vector::TransposeOp op, |
705 | PatternRewriter &rewriter) const override { |
706 | FailureOr<ExtensionOp> ext = |
707 | ExtensionOp::from(op: op.getVector().getDefiningOp()); |
708 | if (failed(result: ext)) |
709 | return failure(); |
710 | |
711 | VectorType origTy = op.getResultVectorType(); |
712 | VectorType newTy = |
713 | origTy.cloneWith(origTy.getShape(), ext->getInElementType()); |
714 | Value newTranspose = rewriter.create<vector::TransposeOp>( |
715 | op.getLoc(), newTy, ext->getIn(), op.getPermutation()); |
716 | ext->recreateAndReplace(rewriter, toReplace: op, in: newTranspose); |
717 | return success(); |
718 | } |
719 | }; |
720 | |
721 | struct ExtensionOverFlatTranspose final |
722 | : NarrowingPattern<vector::FlatTransposeOp> { |
723 | using NarrowingPattern::NarrowingPattern; |
724 | |
725 | LogicalResult matchAndRewrite(vector::FlatTransposeOp op, |
726 | PatternRewriter &rewriter) const override { |
727 | FailureOr<ExtensionOp> ext = |
728 | ExtensionOp::from(op: op.getMatrix().getDefiningOp()); |
729 | if (failed(result: ext)) |
730 | return failure(); |
731 | |
732 | VectorType origTy = op.getType(); |
733 | VectorType newTy = |
734 | origTy.cloneWith(origTy.getShape(), ext->getInElementType()); |
735 | Value newTranspose = rewriter.create<vector::FlatTransposeOp>( |
736 | op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(), |
737 | op.getColumnsAttr()); |
738 | ext->recreateAndReplace(rewriter, toReplace: op, in: newTranspose); |
739 | return success(); |
740 | } |
741 | }; |
742 | |
743 | //===----------------------------------------------------------------------===// |
744 | // Pass Definitions |
745 | //===----------------------------------------------------------------------===// |
746 | |
747 | struct ArithIntNarrowingPass final |
748 | : impl::ArithIntNarrowingBase<ArithIntNarrowingPass> { |
749 | using ArithIntNarrowingBase::ArithIntNarrowingBase; |
750 | |
751 | void runOnOperation() override { |
752 | if (bitwidthsSupported.empty() || |
753 | llvm::is_contained(bitwidthsSupported, 0)) { |
754 | // Invalid pass options. |
755 | return signalPassFailure(); |
756 | } |
757 | |
758 | Operation *op = getOperation(); |
759 | MLIRContext *ctx = op->getContext(); |
760 | RewritePatternSet patterns(ctx); |
761 | populateArithIntNarrowingPatterns( |
762 | patterns, ArithIntNarrowingOptions{bitwidthsSupported}); |
763 | if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) |
764 | signalPassFailure(); |
765 | } |
766 | }; |
767 | } // namespace |
768 | |
769 | //===----------------------------------------------------------------------===// |
770 | // Public API |
771 | //===----------------------------------------------------------------------===// |
772 | |
773 | void populateArithIntNarrowingPatterns( |
774 | RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) { |
775 | // Add commute patterns with a higher benefit. This is to expose more |
776 | // optimization opportunities to narrowing patterns. |
777 | patterns.add<ExtensionOverBroadcast, ExtensionOverExtract, |
778 | ExtensionOverExtractElement, ExtensionOverExtractStridedSlice, |
779 | ExtensionOverInsert, ExtensionOverInsertElement, |
780 | ExtensionOverInsertStridedSlice, ExtensionOverShapeCast, |
781 | ExtensionOverTranspose, ExtensionOverFlatTranspose>( |
782 | patterns.getContext(), options, PatternBenefit(2)); |
783 | |
784 | patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern, |
785 | DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern, |
786 | MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern, |
787 | IndexCastUIPattern>(patterns.getContext(), options); |
788 | } |
789 | |
790 | } // namespace mlir::arith |
791 | |