1//===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/Transforms/Passes.h"
10
11#include "mlir/Analysis/Presburger/IntegerRelation.h"
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/Arith/Transforms/Transforms.h"
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15#include "mlir/IR/BuiltinAttributes.h"
16#include "mlir/IR/BuiltinTypeInterfaces.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/MLIRContext.h"
19#include "mlir/IR/Matchers.h"
20#include "mlir/IR/Operation.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/IR/TypeUtilities.h"
23#include "mlir/Interfaces/ValueBoundsOpInterface.h"
24#include "mlir/Support/LogicalResult.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallVector.h"
28#include <cassert>
29#include <cstdint>
30
31namespace mlir::arith {
32#define GEN_PASS_DEF_ARITHINTNARROWING
33#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
34} // namespace mlir::arith
35
36namespace mlir::arith {
37namespace {
38//===----------------------------------------------------------------------===//
39// Common Helpers
40//===----------------------------------------------------------------------===//
41
42/// The base for integer bitwidth narrowing patterns.
43template <typename SourceOp>
44struct NarrowingPattern : OpRewritePattern<SourceOp> {
45 NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options,
46 PatternBenefit benefit = 1)
47 : OpRewritePattern<SourceOp>(ctx, benefit),
48 supportedBitwidths(options.bitwidthsSupported.begin(),
49 options.bitwidthsSupported.end()) {
50 assert(!supportedBitwidths.empty() && "Invalid options");
51 assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth");
52 llvm::sort(C&: supportedBitwidths);
53 }
54
55 FailureOr<unsigned>
56 getNarrowestCompatibleBitwidth(unsigned bitsRequired) const {
57 for (unsigned candidate : supportedBitwidths)
58 if (candidate >= bitsRequired)
59 return candidate;
60
61 return failure();
62 }
63
64 /// Returns the narrowest supported type that fits `bitsRequired`.
65 FailureOr<Type> getNarrowType(unsigned bitsRequired, Type origTy) const {
66 assert(origTy);
67 FailureOr<unsigned> bestBitwidth =
68 getNarrowestCompatibleBitwidth(bitsRequired);
69 if (failed(result: bestBitwidth))
70 return failure();
71
72 Type elemTy = getElementTypeOrSelf(type: origTy);
73 if (!isa<IntegerType>(Val: elemTy))
74 return failure();
75
76 auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth);
77 if (newElemTy == elemTy)
78 return failure();
79
80 if (origTy == elemTy)
81 return newElemTy;
82
83 if (auto shapedTy = dyn_cast<ShapedType>(origTy))
84 if (dyn_cast<IntegerType>(shapedTy.getElementType()))
85 return shapedTy.clone(shapedTy.getShape(), newElemTy);
86
87 return failure();
88 }
89
90private:
91 // Supported integer bitwidths in the ascending order.
92 llvm::SmallVector<unsigned, 6> supportedBitwidths;
93};
94
95/// Returns the integer bitwidth required to represent `type`.
96FailureOr<unsigned> calculateBitsRequired(Type type) {
97 assert(type);
98 if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(type)))
99 return intTy.getWidth();
100
101 return failure();
102}
103
104enum class ExtensionKind { Sign, Zero };
105
106/// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
107/// the exact op type. Exposes helper functions to query the types, operands,
108/// and the result. This is so that we can handle both extension kinds without
109/// needing to use templates or branching.
110class ExtensionOp {
111public:
112 /// Attemps to create a new extension op from `op`. Returns an extension op
113 /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure
114 /// otherwise.
115 static FailureOr<ExtensionOp> from(Operation *op) {
116 if (dyn_cast_or_null<arith::ExtSIOp>(op))
117 return ExtensionOp{op, ExtensionKind::Sign};
118 if (dyn_cast_or_null<arith::ExtUIOp>(op))
119 return ExtensionOp{op, ExtensionKind::Zero};
120
121 return failure();
122 }
123
124 ExtensionOp(const ExtensionOp &) = default;
125 ExtensionOp &operator=(const ExtensionOp &) = default;
126
127 /// Creates a new extension op of the same kind.
128 Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType,
129 Value in) {
130 if (kind == ExtensionKind::Sign)
131 return rewriter.create<arith::ExtSIOp>(loc, newType, in);
132
133 return rewriter.create<arith::ExtUIOp>(loc, newType, in);
134 }
135
136 /// Replaces `toReplace` with a new extension op of the same kind.
137 void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
138 Value in) {
139 assert(toReplace->getNumResults() == 1);
140 Type newType = toReplace->getResult(idx: 0).getType();
141 Operation *newOp = recreate(rewriter, loc: toReplace->getLoc(), newType, in);
142 rewriter.replaceOp(op: toReplace, newValues: newOp->getResult(idx: 0));
143 }
144
145 ExtensionKind getKind() { return kind; }
146
147 Value getResult() { return op->getResult(idx: 0); }
148 Value getIn() { return op->getOperand(idx: 0); }
149
150 Type getType() { return getResult().getType(); }
151 Type getElementType() { return getElementTypeOrSelf(type: getType()); }
152 Type getInType() { return getIn().getType(); }
153 Type getInElementType() { return getElementTypeOrSelf(type: getInType()); }
154
155private:
156 ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
157 assert(op);
158 assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
159 }
160 Operation *op = nullptr;
161 ExtensionKind kind = {};
162};
163
164/// Returns the integer bitwidth required to represent `value`.
165unsigned calculateBitsRequired(const APInt &value,
166 ExtensionKind lookThroughExtension) {
167 // For unsigned values, we only need the active bits. As a special case, zero
168 // requires one bit.
169 if (lookThroughExtension == ExtensionKind::Zero)
170 return std::max(a: value.getActiveBits(), b: 1u);
171
172 // If a signed value is nonnegative, we need one extra bit for the sign.
173 if (value.isNonNegative())
174 return value.getActiveBits() + 1;
175
176 // For the signed min, we need all the bits.
177 if (value.isMinSignedValue())
178 return value.getBitWidth();
179
180 // For negative values, we need all the non-sign bits and one extra bit for
181 // the sign.
182 return value.getBitWidth() - value.getNumSignBits() + 1;
183}
184
185/// Returns the integer bitwidth required to represent `value`.
186/// Looks through either sign- or zero-extension as specified by
187/// `lookThroughExtension`.
188FailureOr<unsigned> calculateBitsRequired(Value value,
189 ExtensionKind lookThroughExtension) {
190 // Handle constants.
191 if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) {
192 if (auto intAttr = dyn_cast<IntegerAttr>(attr))
193 return calculateBitsRequired(intAttr.getValue(), lookThroughExtension);
194
195 if (auto elemsAttr = dyn_cast<DenseElementsAttr>(attr)) {
196 if (elemsAttr.getElementType().isIntOrIndex()) {
197 if (elemsAttr.isSplat())
198 return calculateBitsRequired(elemsAttr.getSplatValue<APInt>(),
199 lookThroughExtension);
200
201 unsigned maxBits = 1;
202 for (const APInt &elemValue : elemsAttr.getValues<APInt>())
203 maxBits = std::max(
204 maxBits, calculateBitsRequired(elemValue, lookThroughExtension));
205 return maxBits;
206 }
207 }
208 }
209
210 if (lookThroughExtension == ExtensionKind::Sign) {
211 if (auto sext = value.getDefiningOp<arith::ExtSIOp>())
212 return calculateBitsRequired(sext.getIn().getType());
213 } else if (lookThroughExtension == ExtensionKind::Zero) {
214 if (auto zext = value.getDefiningOp<arith::ExtUIOp>())
215 return calculateBitsRequired(zext.getIn().getType());
216 }
217
218 // If nothing else worked, return the type requirements for this element type.
219 return calculateBitsRequired(type: value.getType());
220}
221
222/// Base pattern for arith binary ops.
223/// Example:
224/// ```
225/// %lhs = arith.extsi %a : i8 to i32
226/// %rhs = arith.extsi %b : i8 to i32
227/// %r = arith.addi %lhs, %rhs : i32
228/// ==>
229/// %lhs = arith.extsi %a : i8 to i16
230/// %rhs = arith.extsi %b : i8 to i16
231/// %add = arith.addi %lhs, %rhs : i16
232/// %r = arith.extsi %add : i16 to i32
233/// ```
234template <typename BinaryOp>
235struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
236 using NarrowingPattern<BinaryOp>::NarrowingPattern;
237
238 /// Returns the number of bits required to represent the full result, assuming
239 /// that both operands are `operandBits`-wide. Derived classes must implement
240 /// this, taking into account `BinaryOp` semantics.
241 virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0;
242
243 /// Customization point for patterns that should only apply with
244 /// zero/sign-extension ops as arguments.
245 virtual bool isSupported(ExtensionOp) const { return true; }
246
247 LogicalResult matchAndRewrite(BinaryOp op,
248 PatternRewriter &rewriter) const final {
249 Type origTy = op.getType();
250 FailureOr<unsigned> resultBits = calculateBitsRequired(type: origTy);
251 if (failed(result: resultBits))
252 return failure();
253
254 // For the optimization to apply, we expect the lhs to be an extension op,
255 // and for the rhs to either be the same extension op or a constant.
256 FailureOr<ExtensionOp> ext = ExtensionOp::from(op: op.getLhs().getDefiningOp());
257 if (failed(result: ext) || !isSupported(*ext))
258 return failure();
259
260 FailureOr<unsigned> lhsBitsRequired =
261 calculateBitsRequired(value: ext->getIn(), lookThroughExtension: ext->getKind());
262 if (failed(result: lhsBitsRequired) || *lhsBitsRequired >= *resultBits)
263 return failure();
264
265 FailureOr<unsigned> rhsBitsRequired =
266 calculateBitsRequired(op.getRhs(), ext->getKind());
267 if (failed(result: rhsBitsRequired) || *rhsBitsRequired >= *resultBits)
268 return failure();
269
270 // Negotiate a common bit requirements for both lhs and rhs, accounting for
271 // the result requiring more bits than the operands.
272 unsigned commonBitsRequired =
273 getResultBitsProduced(operandBits: std::max(a: *lhsBitsRequired, b: *rhsBitsRequired));
274 FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy);
275 if (failed(result: narrowTy) || calculateBitsRequired(type: *narrowTy) >= *resultBits)
276 return failure();
277
278 Location loc = op.getLoc();
279 Value newLhs =
280 rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs());
281 Value newRhs =
282 rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs());
283 Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs);
284 ext->recreateAndReplace(rewriter, toReplace: op, in: newAdd);
285 return success();
286 }
287};
288
289//===----------------------------------------------------------------------===//
290// AddIOp Pattern
291//===----------------------------------------------------------------------===//
292
293struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
294 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
295
296 // Addition may require one extra bit for the result.
297 // Example: `UINT8_MAX + 1 == 255 + 1 == 256`.
298 unsigned getResultBitsProduced(unsigned operandBits) const override {
299 return operandBits + 1;
300 }
301};
302
303//===----------------------------------------------------------------------===//
304// SubIOp Pattern
305//===----------------------------------------------------------------------===//
306
307struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> {
308 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
309
310 // This optimization only applies to signed arguments.
311 bool isSupported(ExtensionOp ext) const override {
312 return ext.getKind() == ExtensionKind::Sign;
313 }
314
315 // Subtraction may require one extra bit for the result.
316 // Example: `INT8_MAX - (-1) == 127 - (-1) == 128`.
317 unsigned getResultBitsProduced(unsigned operandBits) const override {
318 return operandBits + 1;
319 }
320};
321
322//===----------------------------------------------------------------------===//
323// MulIOp Pattern
324//===----------------------------------------------------------------------===//
325
326struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
327 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
328
329 // Multiplication may require up double the operand bits.
330 // Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`.
331 unsigned getResultBitsProduced(unsigned operandBits) const override {
332 return 2 * operandBits;
333 }
334};
335
336//===----------------------------------------------------------------------===//
337// DivSIOp Pattern
338//===----------------------------------------------------------------------===//
339
340struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> {
341 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
342
343 // This optimization only applies to signed arguments.
344 bool isSupported(ExtensionOp ext) const override {
345 return ext.getKind() == ExtensionKind::Sign;
346 }
347
348 // Unlike multiplication, signed division requires only one more result bit.
349 // Example: `INT8_MIN / (-1) == -128 / (-1) == 128`.
350 unsigned getResultBitsProduced(unsigned operandBits) const override {
351 return operandBits + 1;
352 }
353};
354
355//===----------------------------------------------------------------------===//
356// DivUIOp Pattern
357//===----------------------------------------------------------------------===//
358
359struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {
360 using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
361
362 // This optimization only applies to unsigned arguments.
363 bool isSupported(ExtensionOp ext) const override {
364 return ext.getKind() == ExtensionKind::Zero;
365 }
366
367 // Unsigned division does not require any extra result bits.
368 unsigned getResultBitsProduced(unsigned operandBits) const override {
369 return operandBits;
370 }
371};
372
373//===----------------------------------------------------------------------===//
374// Min/Max Patterns
375//===----------------------------------------------------------------------===//
376
377template <typename MinMaxOp, ExtensionKind Kind>
378struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> {
379 using BinaryOpNarrowingPattern<MinMaxOp>::BinaryOpNarrowingPattern;
380
381 bool isSupported(ExtensionOp ext) const override {
382 return ext.getKind() == Kind;
383 }
384
385 // Min/max returns one of the arguments and does not require any extra result
386 // bits.
387 unsigned getResultBitsProduced(unsigned operandBits) const override {
388 return operandBits;
389 }
390};
391using MaxSIPattern = MinMaxPattern<arith::MaxSIOp, ExtensionKind::Sign>;
392using MaxUIPattern = MinMaxPattern<arith::MaxUIOp, ExtensionKind::Zero>;
393using MinSIPattern = MinMaxPattern<arith::MinSIOp, ExtensionKind::Sign>;
394using MinUIPattern = MinMaxPattern<arith::MinUIOp, ExtensionKind::Zero>;
395
396//===----------------------------------------------------------------------===//
397// *IToFPOp Patterns
398//===----------------------------------------------------------------------===//
399
400template <typename IToFPOp, ExtensionKind Extension>
401struct IToFPPattern final : NarrowingPattern<IToFPOp> {
402 using NarrowingPattern<IToFPOp>::NarrowingPattern;
403
404 LogicalResult matchAndRewrite(IToFPOp op,
405 PatternRewriter &rewriter) const override {
406 FailureOr<unsigned> narrowestWidth =
407 calculateBitsRequired(op.getIn(), Extension);
408 if (failed(result: narrowestWidth))
409 return failure();
410
411 FailureOr<Type> narrowTy =
412 this->getNarrowType(*narrowestWidth, op.getIn().getType());
413 if (failed(result: narrowTy))
414 return failure();
415
416 Value newIn = rewriter.createOrFold<arith::TruncIOp>(op.getLoc(), *narrowTy,
417 op.getIn());
418 rewriter.replaceOpWithNewOp<IToFPOp>(op, op.getType(), newIn);
419 return success();
420 }
421};
422using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
423using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
424
425//===----------------------------------------------------------------------===//
426// Index Cast Patterns
427//===----------------------------------------------------------------------===//
428
429// These rely on the `ValueBounds` interface for index values. For example, we
430// can often statically tell index value bounds of loop induction variables.
431
432template <typename CastOp, ExtensionKind Kind>
433struct IndexCastPattern final : NarrowingPattern<CastOp> {
434 using NarrowingPattern<CastOp>::NarrowingPattern;
435
436 LogicalResult matchAndRewrite(CastOp op,
437 PatternRewriter &rewriter) const override {
438 Value in = op.getIn();
439 // We only support scalar index -> integer casts.
440 if (!isa<IndexType>(Val: in.getType()))
441 return failure();
442
443 // Check the lower bound in both the signed and unsigned cast case. We
444 // conservatively assume that even unsigned casts may be performed on
445 // negative indices.
446 FailureOr<int64_t> lb = ValueBoundsConstraintSet::computeConstantBound(
447 type: presburger::BoundType::LB, var: in);
448 if (failed(result: lb))
449 return failure();
450
451 FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
452 type: presburger::BoundType::UB, var: in,
453 /*stopCondition=*/nullptr, /*closedUB=*/true);
454 if (failed(result: ub))
455 return failure();
456
457 assert(*lb <= *ub && "Invalid bounds");
458 unsigned lbBitsRequired = calculateBitsRequired(value: APInt(64, *lb), lookThroughExtension: Kind);
459 unsigned ubBitsRequired = calculateBitsRequired(value: APInt(64, *ub), lookThroughExtension: Kind);
460 unsigned bitsRequired = std::max(a: lbBitsRequired, b: ubBitsRequired);
461
462 IntegerType resultTy = cast<IntegerType>(op.getType());
463 if (resultTy.getWidth() <= bitsRequired)
464 return failure();
465
466 FailureOr<Type> narrowTy = this->getNarrowType(bitsRequired, resultTy);
467 if (failed(result: narrowTy))
468 return failure();
469
470 Value newCast = rewriter.create<CastOp>(op.getLoc(), *narrowTy, op.getIn());
471
472 if (Kind == ExtensionKind::Sign)
473 rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, resultTy, newCast);
474 else
475 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, resultTy, newCast);
476 return success();
477 }
478};
479using IndexCastSIPattern =
480 IndexCastPattern<arith::IndexCastOp, ExtensionKind::Sign>;
481using IndexCastUIPattern =
482 IndexCastPattern<arith::IndexCastUIOp, ExtensionKind::Zero>;
483
484//===----------------------------------------------------------------------===//
485// Patterns to Commute Extension Ops
486//===----------------------------------------------------------------------===//
487
488struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
489 using NarrowingPattern::NarrowingPattern;
490
491 LogicalResult matchAndRewrite(vector::BroadcastOp op,
492 PatternRewriter &rewriter) const override {
493 FailureOr<ExtensionOp> ext =
494 ExtensionOp::from(op: op.getSource().getDefiningOp());
495 if (failed(result: ext))
496 return failure();
497
498 VectorType origTy = op.getResultVectorType();
499 VectorType newTy =
500 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
501 Value newBroadcast =
502 rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn());
503 ext->recreateAndReplace(rewriter, toReplace: op, in: newBroadcast);
504 return success();
505 }
506};
507
508struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
509 using NarrowingPattern::NarrowingPattern;
510
511 LogicalResult matchAndRewrite(vector::ExtractOp op,
512 PatternRewriter &rewriter) const override {
513 FailureOr<ExtensionOp> ext =
514 ExtensionOp::from(op: op.getVector().getDefiningOp());
515 if (failed(result: ext))
516 return failure();
517
518 Value newExtract = rewriter.create<vector::ExtractOp>(
519 op.getLoc(), ext->getIn(), op.getMixedPosition());
520 ext->recreateAndReplace(rewriter, toReplace: op, in: newExtract);
521 return success();
522 }
523};
524
525struct ExtensionOverExtractElement final
526 : NarrowingPattern<vector::ExtractElementOp> {
527 using NarrowingPattern::NarrowingPattern;
528
529 LogicalResult matchAndRewrite(vector::ExtractElementOp op,
530 PatternRewriter &rewriter) const override {
531 FailureOr<ExtensionOp> ext =
532 ExtensionOp::from(op: op.getVector().getDefiningOp());
533 if (failed(result: ext))
534 return failure();
535
536 Value newExtract = rewriter.create<vector::ExtractElementOp>(
537 op.getLoc(), ext->getIn(), op.getPosition());
538 ext->recreateAndReplace(rewriter, toReplace: op, in: newExtract);
539 return success();
540 }
541};
542
543struct ExtensionOverExtractStridedSlice final
544 : NarrowingPattern<vector::ExtractStridedSliceOp> {
545 using NarrowingPattern::NarrowingPattern;
546
547 LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
548 PatternRewriter &rewriter) const override {
549 FailureOr<ExtensionOp> ext =
550 ExtensionOp::from(op: op.getVector().getDefiningOp());
551 if (failed(result: ext))
552 return failure();
553
554 VectorType origTy = op.getType();
555 VectorType extractTy =
556 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
557 Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
558 op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
559 op.getStrides());
560 ext->recreateAndReplace(rewriter, toReplace: op, in: newExtract);
561 return success();
562 }
563};
564
565/// Base pattern for `vector.insert` narrowing patterns.
566template <typename InsertionOp>
567struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {
568 using NarrowingPattern<InsertionOp>::NarrowingPattern;
569
570 /// Derived classes must provide a function to create the matching insertion
571 /// op based on the original op and new arguments.
572 virtual InsertionOp createInsertionOp(PatternRewriter &rewriter,
573 InsertionOp origInsert,
574 Value narrowValue,
575 Value narrowDest) const = 0;
576
577 LogicalResult matchAndRewrite(InsertionOp op,
578 PatternRewriter &rewriter) const final {
579 FailureOr<ExtensionOp> ext =
580 ExtensionOp::from(op: op.getSource().getDefiningOp());
581 if (failed(result: ext))
582 return failure();
583
584 FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, insValue: *ext);
585 if (failed(newInsert))
586 return failure();
587 ext->recreateAndReplace(rewriter, toReplace: op, in: *newInsert);
588 return success();
589 }
590
591 FailureOr<InsertionOp> createNarrowInsert(InsertionOp op,
592 PatternRewriter &rewriter,
593 ExtensionOp insValue) const {
594 // Calculate the operand and result bitwidths. We can only apply narrowing
595 // when the inserted source value and destination vector require fewer bits
596 // than the result. Because the source and destination may have different
597 // bitwidths requirements, we have to find the common narrow bitwidth that
598 // is greater equal to the operand bitwidth requirements and still narrower
599 // than the result.
600 FailureOr<unsigned> origBitsRequired = calculateBitsRequired(op.getType());
601 if (failed(result: origBitsRequired))
602 return failure();
603
604 // TODO: We could relax this check by disregarding bitwidth requirements of
605 // elements that we know will be replaced by the insertion.
606 FailureOr<unsigned> destBitsRequired =
607 calculateBitsRequired(op.getDest(), insValue.getKind());
608 if (failed(result: destBitsRequired) || *destBitsRequired >= *origBitsRequired)
609 return failure();
610
611 FailureOr<unsigned> insertedBitsRequired =
612 calculateBitsRequired(value: insValue.getIn(), lookThroughExtension: insValue.getKind());
613 if (failed(result: insertedBitsRequired) ||
614 *insertedBitsRequired >= *origBitsRequired)
615 return failure();
616
617 // Find a narrower element type that satisfies the bitwidth requirements of
618 // both the source and the destination values.
619 unsigned newInsertionBits =
620 std::max(a: *destBitsRequired, b: *insertedBitsRequired);
621 FailureOr<Type> newVecTy =
622 this->getNarrowType(newInsertionBits, op.getType());
623 if (failed(result: newVecTy) || *newVecTy == op.getType())
624 return failure();
625
626 FailureOr<Type> newInsertedValueTy =
627 this->getNarrowType(newInsertionBits, insValue.getType());
628 if (failed(result: newInsertedValueTy))
629 return failure();
630
631 Location loc = op.getLoc();
632 Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
633 loc, *newInsertedValueTy, insValue.getResult());
634 Value narrowDest =
635 rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
636 return createInsertionOp(rewriter, origInsert: op, narrowValue, narrowDest);
637 }
638};
639
640struct ExtensionOverInsert final
641 : ExtensionOverInsertionPattern<vector::InsertOp> {
642 using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
643
644 vector::InsertOp createInsertionOp(PatternRewriter &rewriter,
645 vector::InsertOp origInsert,
646 Value narrowValue,
647 Value narrowDest) const override {
648 return rewriter.create<vector::InsertOp>(origInsert.getLoc(), narrowValue,
649 narrowDest,
650 origInsert.getMixedPosition());
651 }
652};
653
654struct ExtensionOverInsertElement final
655 : ExtensionOverInsertionPattern<vector::InsertElementOp> {
656 using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
657
658 vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter,
659 vector::InsertElementOp origInsert,
660 Value narrowValue,
661 Value narrowDest) const override {
662 return rewriter.create<vector::InsertElementOp>(
663 origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
664 }
665};
666
667struct ExtensionOverInsertStridedSlice final
668 : ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {
669 using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
670
671 vector::InsertStridedSliceOp
672 createInsertionOp(PatternRewriter &rewriter,
673 vector::InsertStridedSliceOp origInsert, Value narrowValue,
674 Value narrowDest) const override {
675 return rewriter.create<vector::InsertStridedSliceOp>(
676 origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(),
677 origInsert.getStrides());
678 }
679};
680
681struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
682 using NarrowingPattern::NarrowingPattern;
683
684 LogicalResult matchAndRewrite(vector::ShapeCastOp op,
685 PatternRewriter &rewriter) const override {
686 FailureOr<ExtensionOp> ext =
687 ExtensionOp::from(op: op.getSource().getDefiningOp());
688 if (failed(result: ext))
689 return failure();
690
691 VectorType origTy = op.getResultVectorType();
692 VectorType newTy =
693 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
694 Value newCast =
695 rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn());
696 ext->recreateAndReplace(rewriter, toReplace: op, in: newCast);
697 return success();
698 }
699};
700
701struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
702 using NarrowingPattern::NarrowingPattern;
703
704 LogicalResult matchAndRewrite(vector::TransposeOp op,
705 PatternRewriter &rewriter) const override {
706 FailureOr<ExtensionOp> ext =
707 ExtensionOp::from(op: op.getVector().getDefiningOp());
708 if (failed(result: ext))
709 return failure();
710
711 VectorType origTy = op.getResultVectorType();
712 VectorType newTy =
713 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
714 Value newTranspose = rewriter.create<vector::TransposeOp>(
715 op.getLoc(), newTy, ext->getIn(), op.getPermutation());
716 ext->recreateAndReplace(rewriter, toReplace: op, in: newTranspose);
717 return success();
718 }
719};
720
721struct ExtensionOverFlatTranspose final
722 : NarrowingPattern<vector::FlatTransposeOp> {
723 using NarrowingPattern::NarrowingPattern;
724
725 LogicalResult matchAndRewrite(vector::FlatTransposeOp op,
726 PatternRewriter &rewriter) const override {
727 FailureOr<ExtensionOp> ext =
728 ExtensionOp::from(op: op.getMatrix().getDefiningOp());
729 if (failed(result: ext))
730 return failure();
731
732 VectorType origTy = op.getType();
733 VectorType newTy =
734 origTy.cloneWith(origTy.getShape(), ext->getInElementType());
735 Value newTranspose = rewriter.create<vector::FlatTransposeOp>(
736 op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(),
737 op.getColumnsAttr());
738 ext->recreateAndReplace(rewriter, toReplace: op, in: newTranspose);
739 return success();
740 }
741};
742
743//===----------------------------------------------------------------------===//
744// Pass Definitions
745//===----------------------------------------------------------------------===//
746
747struct ArithIntNarrowingPass final
748 : impl::ArithIntNarrowingBase<ArithIntNarrowingPass> {
749 using ArithIntNarrowingBase::ArithIntNarrowingBase;
750
751 void runOnOperation() override {
752 if (bitwidthsSupported.empty() ||
753 llvm::is_contained(bitwidthsSupported, 0)) {
754 // Invalid pass options.
755 return signalPassFailure();
756 }
757
758 Operation *op = getOperation();
759 MLIRContext *ctx = op->getContext();
760 RewritePatternSet patterns(ctx);
761 populateArithIntNarrowingPatterns(
762 patterns, ArithIntNarrowingOptions{bitwidthsSupported});
763 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
764 signalPassFailure();
765 }
766};
767} // namespace
768
769//===----------------------------------------------------------------------===//
770// Public API
771//===----------------------------------------------------------------------===//
772
773void populateArithIntNarrowingPatterns(
774 RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
775 // Add commute patterns with a higher benefit. This is to expose more
776 // optimization opportunities to narrowing patterns.
777 patterns.add<ExtensionOverBroadcast, ExtensionOverExtract,
778 ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
779 ExtensionOverInsert, ExtensionOverInsertElement,
780 ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
781 ExtensionOverTranspose, ExtensionOverFlatTranspose>(
782 patterns.getContext(), options, PatternBenefit(2));
783
784 patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
785 DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern,
786 MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern,
787 IndexCastUIPattern>(patterns.getContext(), options);
788}
789
790} // namespace mlir::arith
791

source code of mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp