1//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "mlir/Analysis/DataFlowFramework.h"
12#include "mlir/Dialect/Arith/Transforms/Passes.h"
13
14#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
15#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Utils/StaticValueUtils.h"
18#include "mlir/IR/IRMapping.h"
19#include "mlir/IR/Matchers.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/IR/TypeUtilities.h"
22#include "mlir/Interfaces/SideEffectInterfaces.h"
23#include "mlir/Transforms/FoldUtils.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26namespace mlir::arith {
27#define GEN_PASS_DEF_ARITHINTRANGEOPTS
28#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29
30#define GEN_PASS_DEF_ARITHINTRANGENARROWING
31#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
32} // namespace mlir::arith
33
34using namespace mlir;
35using namespace mlir::arith;
36using namespace mlir::dataflow;
37
38static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
39 Value value) {
40 auto *maybeInferredRange =
41 solver.lookupState<IntegerValueRangeLattice>(anchor: value);
42 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
43 return std::nullopt;
44 const ConstantIntRanges &inferredRange =
45 maybeInferredRange->getValue().getValue();
46 return inferredRange.getConstantValue();
47}
48
49static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
50 Value newVal) {
51 assert(oldVal.getType() == newVal.getType() &&
52 "Can't copy integer ranges between different types");
53 auto *oldState = solver.lookupState<IntegerValueRangeLattice>(anchor: oldVal);
54 if (!oldState)
55 return;
56 (void)solver.getOrCreateState<IntegerValueRangeLattice>(anchor: newVal)->join(
57 rhs: *oldState);
58}
59
60namespace mlir::dataflow {
61/// Patterned after SCCP
62LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
63 RewriterBase &rewriter, Value value) {
64 if (value.use_empty())
65 return failure();
66 std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
67 if (!maybeConstValue.has_value())
68 return failure();
69
70 Type type = value.getType();
71 Location loc = value.getLoc();
72 Operation *maybeDefiningOp = value.getDefiningOp();
73 Dialect *valueDialect =
74 maybeDefiningOp ? maybeDefiningOp->getDialect()
75 : value.getParentRegion()->getParentOp()->getDialect();
76
77 Attribute constAttr;
78 if (auto shaped = dyn_cast<ShapedType>(type)) {
79 constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue);
80 } else {
81 constAttr = rewriter.getIntegerAttr(type, *maybeConstValue);
82 }
83 Operation *constOp =
84 valueDialect->materializeConstant(builder&: rewriter, value: constAttr, type, loc);
85 // Fall back to arith.constant if the dialect materializer doesn't know what
86 // to do with an integer constant.
87 if (!constOp)
88 constOp = rewriter.getContext()
89 ->getLoadedDialect<ArithDialect>()
90 ->materializeConstant(rewriter, constAttr, type, loc);
91 if (!constOp)
92 return failure();
93
94 OpResult res = constOp->getResult(idx: 0);
95 if (solver.lookupState<dataflow::IntegerValueRangeLattice>(anchor: res))
96 solver.eraseState(anchor: res);
97 copyIntegerRange(solver, oldVal: value, newVal: res);
98 rewriter.replaceAllUsesWith(from: value, to: res);
99 return success();
100}
101} // namespace mlir::dataflow
102
103namespace {
104class DataFlowListener : public RewriterBase::Listener {
105public:
106 DataFlowListener(DataFlowSolver &s) : s(s) {}
107
108protected:
109 void notifyOperationErased(Operation *op) override {
110 s.eraseState(anchor: s.getProgramPointAfter(op));
111 for (Value res : op->getResults())
112 s.eraseState(anchor: res);
113 }
114
115 DataFlowSolver &s;
116};
117
118/// Rewrite any results of `op` that were inferred to be constant integers to
119/// and replace their uses with that constant. Return success() if all results
120/// where thus replaced and the operation is erased. Also replace any block
121/// arguments with their constant values.
122struct MaterializeKnownConstantValues : public RewritePattern {
123 MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
124 : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(),
125 /*benefit=*/1, context),
126 solver(s) {}
127
128 LogicalResult matchAndRewrite(Operation *op,
129 PatternRewriter &rewriter) const override {
130 if (matchPattern(op, pattern: m_Constant()))
131 return failure();
132
133 auto needsReplacing = [&](Value v) {
134 return getMaybeConstantValue(solver, value: v).has_value() && !v.use_empty();
135 };
136 bool hasConstantResults = llvm::any_of(Range: op->getResults(), P: needsReplacing);
137 if (op->getNumRegions() == 0)
138 if (!hasConstantResults)
139 return failure();
140 bool hasConstantRegionArgs = false;
141 for (Region &region : op->getRegions()) {
142 for (Block &block : region.getBlocks()) {
143 hasConstantRegionArgs |=
144 llvm::any_of(Range: block.getArguments(), P: needsReplacing);
145 }
146 }
147 if (!hasConstantResults && !hasConstantRegionArgs)
148 return failure();
149
150 bool replacedAll = (op->getNumResults() != 0);
151 for (Value v : op->getResults())
152 replacedAll &=
153 (succeeded(Result: maybeReplaceWithConstant(solver, rewriter, value: v)) ||
154 v.use_empty());
155 if (replacedAll && isOpTriviallyDead(op)) {
156 rewriter.eraseOp(op);
157 return success();
158 }
159
160 PatternRewriter::InsertionGuard guard(rewriter);
161 for (Region &region : op->getRegions()) {
162 for (Block &block : region.getBlocks()) {
163 rewriter.setInsertionPointToStart(&block);
164 for (BlockArgument &arg : block.getArguments()) {
165 (void)maybeReplaceWithConstant(solver, rewriter, value: arg);
166 }
167 }
168 }
169
170 return success();
171 }
172
173private:
174 DataFlowSolver &solver;
175};
176
177template <typename RemOp>
178struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
179 DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
180 : OpRewritePattern<RemOp>(context), solver(s) {}
181
182 LogicalResult matchAndRewrite(RemOp op,
183 PatternRewriter &rewriter) const override {
184 Value lhs = op.getOperand(0);
185 Value rhs = op.getOperand(1);
186 auto maybeModulus = getConstantIntValue(ofr: rhs);
187 if (!maybeModulus.has_value())
188 return failure();
189 int64_t modulus = *maybeModulus;
190 if (modulus <= 0)
191 return failure();
192 auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(anchor: lhs);
193 if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
194 return failure();
195 const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
196 const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
197 const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
198 // The minima and maxima here are given as closed ranges, we must be
199 // strictly less than the modulus.
200 if (min.isNegative() || min.uge(RHS: modulus))
201 return failure();
202 if (max.isNegative() || max.uge(RHS: modulus))
203 return failure();
204 if (!min.ule(RHS: max))
205 return failure();
206
207 // With all those conditions out of the way, we know thas this invocation of
208 // a remainder is a noop because the input is strictly within the range
209 // [0, modulus), so get rid of it.
210 rewriter.replaceOp(op, ValueRange{lhs});
211 return success();
212 }
213
214private:
215 DataFlowSolver &solver;
216};
217
218/// Gather ranges for all the values in `values`. Appends to the existing
219/// vector.
220static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values,
221 SmallVectorImpl<ConstantIntRanges> &ranges) {
222 for (Value val : values) {
223 auto *maybeInferredRange =
224 solver.lookupState<IntegerValueRangeLattice>(anchor: val);
225 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
226 return failure();
227
228 const ConstantIntRanges &inferredRange =
229 maybeInferredRange->getValue().getValue();
230 ranges.push_back(Elt: inferredRange);
231 }
232 return success();
233}
234
235/// Return int type truncated to `targetBitwidth`. If `srcType` is shaped,
236/// return shaped type as well.
237static Type getTargetType(Type srcType, unsigned targetBitwidth) {
238 auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth);
239 if (auto shaped = dyn_cast<ShapedType>(srcType))
240 return shaped.clone(dstType);
241
242 assert(srcType.isIntOrIndex() && "Invalid src type");
243 return dstType;
244}
245
246namespace {
247// Enum for tracking which type of truncation should be performed
248// to narrow an operation, if any.
249enum class CastKind : uint8_t { None, Signed, Unsigned, Both };
250} // namespace
251
252/// If the values within `range` can be represented using only `width` bits,
253/// return the kind of truncation needed to preserve that property.
254///
255/// This check relies on the fact that the signed and unsigned ranges are both
256/// always correct, but that one might be an approximation of the other,
257/// so we want to use the correct truncation operation.
258static CastKind checkTruncatability(const ConstantIntRanges &range,
259 unsigned targetWidth) {
260 unsigned srcWidth = range.smin().getBitWidth();
261 if (srcWidth <= targetWidth)
262 return CastKind::None;
263 unsigned removedWidth = srcWidth - targetWidth;
264 // The sign bits need to extend into the sign bit of the target width. For
265 // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign
266 // bits.
267 bool canTruncateSigned =
268 range.smin().getNumSignBits() >= (removedWidth + 1) &&
269 range.smax().getNumSignBits() >= (removedWidth + 1);
270 bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth &&
271 range.umax().countLeadingZeros() >= removedWidth;
272 if (canTruncateSigned && canTruncateUnsigned)
273 return CastKind::Both;
274 if (canTruncateSigned)
275 return CastKind::Signed;
276 if (canTruncateUnsigned)
277 return CastKind::Unsigned;
278 return CastKind::None;
279}
280
281static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) {
282 if (lhs == CastKind::None || rhs == CastKind::None)
283 return CastKind::None;
284 if (lhs == CastKind::Both)
285 return rhs;
286 if (rhs == CastKind::Both)
287 return lhs;
288 if (lhs == rhs)
289 return lhs;
290 return CastKind::None;
291}
292
293static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType,
294 CastKind castKind) {
295 Type srcType = src.getType();
296 assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) &&
297 "Mixing vector and non-vector types");
298 assert(castKind != CastKind::None && "Can't cast when casting isn't allowed");
299 Type srcElemType = getElementTypeOrSelf(type: srcType);
300 Type dstElemType = getElementTypeOrSelf(type: dstType);
301 assert(srcElemType.isIntOrIndex() && "Invalid src type");
302 assert(dstElemType.isIntOrIndex() && "Invalid dst type");
303 if (srcType == dstType)
304 return src;
305
306 if (isa<IndexType>(Val: srcElemType) || isa<IndexType>(Val: dstElemType)) {
307 if (castKind == CastKind::Signed)
308 return builder.create<arith::IndexCastOp>(loc, dstType, src);
309 return builder.create<arith::IndexCastUIOp>(loc, dstType, src);
310 }
311
312 auto srcInt = cast<IntegerType>(srcElemType);
313 auto dstInt = cast<IntegerType>(dstElemType);
314 if (dstInt.getWidth() < srcInt.getWidth())
315 return builder.create<arith::TruncIOp>(loc, dstType, src);
316
317 if (castKind == CastKind::Signed)
318 return builder.create<arith::ExtSIOp>(loc, dstType, src);
319 return builder.create<arith::ExtUIOp>(loc, dstType, src);
320}
321
322struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> {
323 NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
324 ArrayRef<unsigned> target)
325 : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {}
326
327 using OpTraitRewritePattern::OpTraitRewritePattern;
328 LogicalResult matchAndRewrite(Operation *op,
329 PatternRewriter &rewriter) const override {
330 if (op->getNumResults() == 0)
331 return rewriter.notifyMatchFailure(arg&: op, msg: "can't narrow resultless op");
332
333 SmallVector<ConstantIntRanges> ranges;
334 if (failed(Result: collectRanges(solver, values: op->getOperands(), ranges)))
335 return rewriter.notifyMatchFailure(arg&: op, msg: "input without specified range");
336 if (failed(Result: collectRanges(solver, values: op->getResults(), ranges)))
337 return rewriter.notifyMatchFailure(arg&: op, msg: "output without specified range");
338
339 Type srcType = op->getResult(idx: 0).getType();
340 if (!llvm::all_equal(Range: op->getResultTypes()))
341 return rewriter.notifyMatchFailure(arg&: op, msg: "mismatched result types");
342 if (op->getNumOperands() == 0 ||
343 !llvm::all_of(Range: op->getOperandTypes(),
344 P: [=](Type t) { return t == srcType; }))
345 return rewriter.notifyMatchFailure(
346 arg&: op, msg: "no operands or operand types don't match result type");
347
348 for (unsigned targetBitwidth : targetBitwidths) {
349 CastKind castKind = CastKind::Both;
350 for (const ConstantIntRanges &range : ranges) {
351 castKind = mergeCastKinds(lhs: castKind,
352 rhs: checkTruncatability(range, targetWidth: targetBitwidth));
353 if (castKind == CastKind::None)
354 break;
355 }
356 if (castKind == CastKind::None)
357 continue;
358 Type targetType = getTargetType(srcType, targetBitwidth);
359 if (targetType == srcType)
360 continue;
361
362 Location loc = op->getLoc();
363 IRMapping mapping;
364 for (auto [arg, argRange] : llvm::zip_first(t: op->getOperands(), u&: ranges)) {
365 CastKind argCastKind = castKind;
366 // When dealing with `index` values, preserve non-negativity in the
367 // index_casts since we can't recover this in unsigned when equivalent.
368 if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative())
369 argCastKind = CastKind::Both;
370 Value newArg = doCast(builder&: rewriter, loc, src: arg, dstType: targetType, castKind: argCastKind);
371 mapping.map(from: arg, to: newArg);
372 }
373
374 Operation *newOp = rewriter.clone(op&: *op, mapper&: mapping);
375 rewriter.modifyOpInPlace(root: newOp, callable: [&]() {
376 for (OpResult res : newOp->getResults()) {
377 res.setType(targetType);
378 }
379 });
380 SmallVector<Value> newResults;
381 for (auto [newRes, oldRes] :
382 llvm::zip_equal(t: newOp->getResults(), u: op->getResults())) {
383 Value castBack = doCast(builder&: rewriter, loc, src: newRes, dstType: srcType, castKind);
384 copyIntegerRange(solver, oldVal: oldRes, newVal: castBack);
385 newResults.push_back(Elt: castBack);
386 }
387
388 rewriter.replaceOp(op, newValues: newResults);
389 return success();
390 }
391 return failure();
392 }
393
394private:
395 DataFlowSolver &solver;
396 SmallVector<unsigned, 4> targetBitwidths;
397};
398
399struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> {
400 NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target)
401 : OpRewritePattern(context), solver(s), targetBitwidths(target) {}
402
403 LogicalResult matchAndRewrite(arith::CmpIOp op,
404 PatternRewriter &rewriter) const override {
405 Value lhs = op.getLhs();
406 Value rhs = op.getRhs();
407
408 SmallVector<ConstantIntRanges> ranges;
409 if (failed(collectRanges(solver, op.getOperands(), ranges)))
410 return failure();
411 const ConstantIntRanges &lhsRange = ranges[0];
412 const ConstantIntRanges &rhsRange = ranges[1];
413
414 Type srcType = lhs.getType();
415 for (unsigned targetBitwidth : targetBitwidths) {
416 CastKind lhsCastKind = checkTruncatability(range: lhsRange, targetWidth: targetBitwidth);
417 CastKind rhsCastKind = checkTruncatability(range: rhsRange, targetWidth: targetBitwidth);
418 CastKind castKind = mergeCastKinds(lhs: lhsCastKind, rhs: rhsCastKind);
419 // Note: this includes target width > src width.
420 if (castKind == CastKind::None)
421 continue;
422
423 Type targetType = getTargetType(srcType, targetBitwidth);
424 if (targetType == srcType)
425 continue;
426
427 Location loc = op->getLoc();
428 IRMapping mapping;
429 Value lhsCast = doCast(builder&: rewriter, loc, src: lhs, dstType: targetType, castKind: lhsCastKind);
430 Value rhsCast = doCast(builder&: rewriter, loc, src: rhs, dstType: targetType, castKind: rhsCastKind);
431 mapping.map(from: lhs, to: lhsCast);
432 mapping.map(from: rhs, to: rhsCast);
433
434 Operation *newOp = rewriter.clone(*op, mapping);
435 copyIntegerRange(solver, op.getResult(), newOp->getResult(idx: 0));
436 rewriter.replaceOp(op, newOp->getResults());
437 return success();
438 }
439 return failure();
440 }
441
442private:
443 DataFlowSolver &solver;
444 SmallVector<unsigned, 4> targetBitwidths;
445};
446
447/// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg
448/// This pattern assumes all passed `targetBitwidths` are not wider than index
449/// type.
450template <typename CastOp>
451struct FoldIndexCastChain final : OpRewritePattern<CastOp> {
452 FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target)
453 : OpRewritePattern<CastOp>(context), targetBitwidths(target) {}
454
455 LogicalResult matchAndRewrite(CastOp op,
456 PatternRewriter &rewriter) const override {
457 auto srcOp = op.getIn().template getDefiningOp<CastOp>();
458 if (!srcOp)
459 return rewriter.notifyMatchFailure(op, "doesn't come from an index cast");
460
461 Value src = srcOp.getIn();
462 if (src.getType() != op.getType())
463 return rewriter.notifyMatchFailure(op, "outer types don't match");
464
465 if (!srcOp.getType().isIndex())
466 return rewriter.notifyMatchFailure(op, "intermediate type isn't index");
467
468 auto intType = dyn_cast<IntegerType>(op.getType());
469 if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth()))
470 return failure();
471
472 rewriter.replaceOp(op, src);
473 return success();
474 }
475
476private:
477 SmallVector<unsigned, 4> targetBitwidths;
478};
479
480struct IntRangeOptimizationsPass final
481 : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
482
483 void runOnOperation() override {
484 Operation *op = getOperation();
485 MLIRContext *ctx = op->getContext();
486 DataFlowSolver solver;
487 solver.load<DeadCodeAnalysis>();
488 solver.load<IntegerRangeAnalysis>();
489 if (failed(Result: solver.initializeAndRun(top: op)))
490 return signalPassFailure();
491
492 DataFlowListener listener(solver);
493
494 RewritePatternSet patterns(ctx);
495 populateIntRangeOptimizationsPatterns(patterns, solver);
496
497 if (failed(applyPatternsGreedily(
498 op, std::move(patterns),
499 GreedyRewriteConfig().setListener(&listener))))
500 signalPassFailure();
501 }
502};
503
504struct IntRangeNarrowingPass final
505 : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> {
506 using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase;
507
508 void runOnOperation() override {
509 Operation *op = getOperation();
510 MLIRContext *ctx = op->getContext();
511 DataFlowSolver solver;
512 solver.load<DeadCodeAnalysis>();
513 solver.load<IntegerRangeAnalysis>();
514 if (failed(Result: solver.initializeAndRun(top: op)))
515 return signalPassFailure();
516
517 DataFlowListener listener(solver);
518
519 RewritePatternSet patterns(ctx);
520 populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
521
522 // We specifically need bottom-up traversal as cmpi pattern needs range
523 // data, attached to its original argument values.
524 if (failed(applyPatternsGreedily(
525 op, std::move(patterns),
526 GreedyRewriteConfig().setUseTopDownTraversal(false).setListener(
527 &listener))))
528 signalPassFailure();
529 }
530};
531} // namespace
532
533void mlir::arith::populateIntRangeOptimizationsPatterns(
534 RewritePatternSet &patterns, DataFlowSolver &solver) {
535 patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
536 DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
537}
538
539void mlir::arith::populateIntRangeNarrowingPatterns(
540 RewritePatternSet &patterns, DataFlowSolver &solver,
541 ArrayRef<unsigned> bitwidthsSupported) {
542 patterns.add<NarrowElementwise, NarrowCmpI>(arg: patterns.getContext(), args&: solver,
543 args&: bitwidthsSupported);
544 patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>,
545 FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(),
546 bitwidthsSupported);
547}
548
549std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
550 return std::make_unique<IntRangeOptimizationsPass>();
551}
552

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp