1 | //===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <utility> |
10 | |
11 | #include "mlir/Analysis/DataFlowFramework.h" |
12 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
13 | |
14 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
15 | #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
18 | #include "mlir/IR/IRMapping.h" |
19 | #include "mlir/IR/Matchers.h" |
20 | #include "mlir/IR/PatternMatch.h" |
21 | #include "mlir/IR/TypeUtilities.h" |
22 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
23 | #include "mlir/Transforms/FoldUtils.h" |
24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
25 | |
26 | namespace mlir::arith { |
27 | #define GEN_PASS_DEF_ARITHINTRANGEOPTS |
28 | #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
29 | |
30 | #define GEN_PASS_DEF_ARITHINTRANGENARROWING |
31 | #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
32 | } // namespace mlir::arith |
33 | |
34 | using namespace mlir; |
35 | using namespace mlir::arith; |
36 | using namespace mlir::dataflow; |
37 | |
38 | static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver, |
39 | Value value) { |
40 | auto *maybeInferredRange = |
41 | solver.lookupState<IntegerValueRangeLattice>(anchor: value); |
42 | if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) |
43 | return std::nullopt; |
44 | const ConstantIntRanges &inferredRange = |
45 | maybeInferredRange->getValue().getValue(); |
46 | return inferredRange.getConstantValue(); |
47 | } |
48 | |
49 | static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, |
50 | Value newVal) { |
51 | assert(oldVal.getType() == newVal.getType() && |
52 | "Can't copy integer ranges between different types"); |
53 | auto *oldState = solver.lookupState<IntegerValueRangeLattice>(anchor: oldVal); |
54 | if (!oldState) |
55 | return; |
56 | (void)solver.getOrCreateState<IntegerValueRangeLattice>(anchor: newVal)->join( |
57 | rhs: *oldState); |
58 | } |
59 | |
60 | namespace mlir::dataflow { |
61 | /// Patterned after SCCP |
62 | LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, |
63 | RewriterBase &rewriter, Value value) { |
64 | if (value.use_empty()) |
65 | return failure(); |
66 | std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value); |
67 | if (!maybeConstValue.has_value()) |
68 | return failure(); |
69 | |
70 | Type type = value.getType(); |
71 | Location loc = value.getLoc(); |
72 | Operation *maybeDefiningOp = value.getDefiningOp(); |
73 | Dialect *valueDialect = |
74 | maybeDefiningOp ? maybeDefiningOp->getDialect() |
75 | : value.getParentRegion()->getParentOp()->getDialect(); |
76 | |
77 | Attribute constAttr; |
78 | if (auto shaped = dyn_cast<ShapedType>(type)) { |
79 | constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue); |
80 | } else { |
81 | constAttr = rewriter.getIntegerAttr(type, *maybeConstValue); |
82 | } |
83 | Operation *constOp = |
84 | valueDialect->materializeConstant(builder&: rewriter, value: constAttr, type, loc); |
85 | // Fall back to arith.constant if the dialect materializer doesn't know what |
86 | // to do with an integer constant. |
87 | if (!constOp) |
88 | constOp = rewriter.getContext() |
89 | ->getLoadedDialect<ArithDialect>() |
90 | ->materializeConstant(rewriter, constAttr, type, loc); |
91 | if (!constOp) |
92 | return failure(); |
93 | |
94 | OpResult res = constOp->getResult(idx: 0); |
95 | if (solver.lookupState<dataflow::IntegerValueRangeLattice>(anchor: res)) |
96 | solver.eraseState(anchor: res); |
97 | copyIntegerRange(solver, oldVal: value, newVal: res); |
98 | rewriter.replaceAllUsesWith(from: value, to: res); |
99 | return success(); |
100 | } |
101 | } // namespace mlir::dataflow |
102 | |
103 | namespace { |
104 | class DataFlowListener : public RewriterBase::Listener { |
105 | public: |
106 | DataFlowListener(DataFlowSolver &s) : s(s) {} |
107 | |
108 | protected: |
109 | void notifyOperationErased(Operation *op) override { |
110 | s.eraseState(anchor: s.getProgramPointAfter(op)); |
111 | for (Value res : op->getResults()) |
112 | s.eraseState(anchor: res); |
113 | } |
114 | |
115 | DataFlowSolver &s; |
116 | }; |
117 | |
118 | /// Rewrite any results of `op` that were inferred to be constant integers to |
119 | /// and replace their uses with that constant. Return success() if all results |
120 | /// where thus replaced and the operation is erased. Also replace any block |
121 | /// arguments with their constant values. |
122 | struct MaterializeKnownConstantValues : public RewritePattern { |
123 | MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s) |
124 | : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(), |
125 | /*benefit=*/1, context), |
126 | solver(s) {} |
127 | |
128 | LogicalResult matchAndRewrite(Operation *op, |
129 | PatternRewriter &rewriter) const override { |
130 | if (matchPattern(op, pattern: m_Constant())) |
131 | return failure(); |
132 | |
133 | auto needsReplacing = [&](Value v) { |
134 | return getMaybeConstantValue(solver, value: v).has_value() && !v.use_empty(); |
135 | }; |
136 | bool hasConstantResults = llvm::any_of(Range: op->getResults(), P: needsReplacing); |
137 | if (op->getNumRegions() == 0) |
138 | if (!hasConstantResults) |
139 | return failure(); |
140 | bool hasConstantRegionArgs = false; |
141 | for (Region ®ion : op->getRegions()) { |
142 | for (Block &block : region.getBlocks()) { |
143 | hasConstantRegionArgs |= |
144 | llvm::any_of(Range: block.getArguments(), P: needsReplacing); |
145 | } |
146 | } |
147 | if (!hasConstantResults && !hasConstantRegionArgs) |
148 | return failure(); |
149 | |
150 | bool replacedAll = (op->getNumResults() != 0); |
151 | for (Value v : op->getResults()) |
152 | replacedAll &= |
153 | (succeeded(Result: maybeReplaceWithConstant(solver, rewriter, value: v)) || |
154 | v.use_empty()); |
155 | if (replacedAll && isOpTriviallyDead(op)) { |
156 | rewriter.eraseOp(op); |
157 | return success(); |
158 | } |
159 | |
160 | PatternRewriter::InsertionGuard guard(rewriter); |
161 | for (Region ®ion : op->getRegions()) { |
162 | for (Block &block : region.getBlocks()) { |
163 | rewriter.setInsertionPointToStart(&block); |
164 | for (BlockArgument &arg : block.getArguments()) { |
165 | (void)maybeReplaceWithConstant(solver, rewriter, value: arg); |
166 | } |
167 | } |
168 | } |
169 | |
170 | return success(); |
171 | } |
172 | |
173 | private: |
174 | DataFlowSolver &solver; |
175 | }; |
176 | |
177 | template <typename RemOp> |
178 | struct DeleteTrivialRem : public OpRewritePattern<RemOp> { |
179 | DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s) |
180 | : OpRewritePattern<RemOp>(context), solver(s) {} |
181 | |
182 | LogicalResult matchAndRewrite(RemOp op, |
183 | PatternRewriter &rewriter) const override { |
184 | Value lhs = op.getOperand(0); |
185 | Value rhs = op.getOperand(1); |
186 | auto maybeModulus = getConstantIntValue(ofr: rhs); |
187 | if (!maybeModulus.has_value()) |
188 | return failure(); |
189 | int64_t modulus = *maybeModulus; |
190 | if (modulus <= 0) |
191 | return failure(); |
192 | auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(anchor: lhs); |
193 | if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized()) |
194 | return failure(); |
195 | const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue(); |
196 | const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin(); |
197 | const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax(); |
198 | // The minima and maxima here are given as closed ranges, we must be |
199 | // strictly less than the modulus. |
200 | if (min.isNegative() || min.uge(RHS: modulus)) |
201 | return failure(); |
202 | if (max.isNegative() || max.uge(RHS: modulus)) |
203 | return failure(); |
204 | if (!min.ule(RHS: max)) |
205 | return failure(); |
206 | |
207 | // With all those conditions out of the way, we know thas this invocation of |
208 | // a remainder is a noop because the input is strictly within the range |
209 | // [0, modulus), so get rid of it. |
210 | rewriter.replaceOp(op, ValueRange{lhs}); |
211 | return success(); |
212 | } |
213 | |
214 | private: |
215 | DataFlowSolver &solver; |
216 | }; |
217 | |
218 | /// Gather ranges for all the values in `values`. Appends to the existing |
219 | /// vector. |
220 | static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values, |
221 | SmallVectorImpl<ConstantIntRanges> &ranges) { |
222 | for (Value val : values) { |
223 | auto *maybeInferredRange = |
224 | solver.lookupState<IntegerValueRangeLattice>(anchor: val); |
225 | if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) |
226 | return failure(); |
227 | |
228 | const ConstantIntRanges &inferredRange = |
229 | maybeInferredRange->getValue().getValue(); |
230 | ranges.push_back(Elt: inferredRange); |
231 | } |
232 | return success(); |
233 | } |
234 | |
235 | /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped, |
236 | /// return shaped type as well. |
237 | static Type getTargetType(Type srcType, unsigned targetBitwidth) { |
238 | auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth); |
239 | if (auto shaped = dyn_cast<ShapedType>(srcType)) |
240 | return shaped.clone(dstType); |
241 | |
242 | assert(srcType.isIntOrIndex() && "Invalid src type"); |
243 | return dstType; |
244 | } |
245 | |
246 | namespace { |
247 | // Enum for tracking which type of truncation should be performed |
248 | // to narrow an operation, if any. |
249 | enum class CastKind : uint8_t { None, Signed, Unsigned, Both }; |
250 | } // namespace |
251 | |
252 | /// If the values within `range` can be represented using only `width` bits, |
253 | /// return the kind of truncation needed to preserve that property. |
254 | /// |
255 | /// This check relies on the fact that the signed and unsigned ranges are both |
256 | /// always correct, but that one might be an approximation of the other, |
257 | /// so we want to use the correct truncation operation. |
258 | static CastKind checkTruncatability(const ConstantIntRanges &range, |
259 | unsigned targetWidth) { |
260 | unsigned srcWidth = range.smin().getBitWidth(); |
261 | if (srcWidth <= targetWidth) |
262 | return CastKind::None; |
263 | unsigned removedWidth = srcWidth - targetWidth; |
264 | // The sign bits need to extend into the sign bit of the target width. For |
265 | // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign |
266 | // bits. |
267 | bool canTruncateSigned = |
268 | range.smin().getNumSignBits() >= (removedWidth + 1) && |
269 | range.smax().getNumSignBits() >= (removedWidth + 1); |
270 | bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth && |
271 | range.umax().countLeadingZeros() >= removedWidth; |
272 | if (canTruncateSigned && canTruncateUnsigned) |
273 | return CastKind::Both; |
274 | if (canTruncateSigned) |
275 | return CastKind::Signed; |
276 | if (canTruncateUnsigned) |
277 | return CastKind::Unsigned; |
278 | return CastKind::None; |
279 | } |
280 | |
281 | static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) { |
282 | if (lhs == CastKind::None || rhs == CastKind::None) |
283 | return CastKind::None; |
284 | if (lhs == CastKind::Both) |
285 | return rhs; |
286 | if (rhs == CastKind::Both) |
287 | return lhs; |
288 | if (lhs == rhs) |
289 | return lhs; |
290 | return CastKind::None; |
291 | } |
292 | |
293 | static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType, |
294 | CastKind castKind) { |
295 | Type srcType = src.getType(); |
296 | assert(isa<VectorType>(srcType) == isa<VectorType>(dstType) && |
297 | "Mixing vector and non-vector types"); |
298 | assert(castKind != CastKind::None && "Can't cast when casting isn't allowed"); |
299 | Type srcElemType = getElementTypeOrSelf(type: srcType); |
300 | Type dstElemType = getElementTypeOrSelf(type: dstType); |
301 | assert(srcElemType.isIntOrIndex() && "Invalid src type"); |
302 | assert(dstElemType.isIntOrIndex() && "Invalid dst type"); |
303 | if (srcType == dstType) |
304 | return src; |
305 | |
306 | if (isa<IndexType>(Val: srcElemType) || isa<IndexType>(Val: dstElemType)) { |
307 | if (castKind == CastKind::Signed) |
308 | return builder.create<arith::IndexCastOp>(loc, dstType, src); |
309 | return builder.create<arith::IndexCastUIOp>(loc, dstType, src); |
310 | } |
311 | |
312 | auto srcInt = cast<IntegerType>(srcElemType); |
313 | auto dstInt = cast<IntegerType>(dstElemType); |
314 | if (dstInt.getWidth() < srcInt.getWidth()) |
315 | return builder.create<arith::TruncIOp>(loc, dstType, src); |
316 | |
317 | if (castKind == CastKind::Signed) |
318 | return builder.create<arith::ExtSIOp>(loc, dstType, src); |
319 | return builder.create<arith::ExtUIOp>(loc, dstType, src); |
320 | } |
321 | |
322 | struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> { |
323 | NarrowElementwise(MLIRContext *context, DataFlowSolver &s, |
324 | ArrayRef<unsigned> target) |
325 | : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {} |
326 | |
327 | using OpTraitRewritePattern::OpTraitRewritePattern; |
328 | LogicalResult matchAndRewrite(Operation *op, |
329 | PatternRewriter &rewriter) const override { |
330 | if (op->getNumResults() == 0) |
331 | return rewriter.notifyMatchFailure(arg&: op, msg: "can't narrow resultless op"); |
332 | |
333 | SmallVector<ConstantIntRanges> ranges; |
334 | if (failed(Result: collectRanges(solver, values: op->getOperands(), ranges))) |
335 | return rewriter.notifyMatchFailure(arg&: op, msg: "input without specified range"); |
336 | if (failed(Result: collectRanges(solver, values: op->getResults(), ranges))) |
337 | return rewriter.notifyMatchFailure(arg&: op, msg: "output without specified range"); |
338 | |
339 | Type srcType = op->getResult(idx: 0).getType(); |
340 | if (!llvm::all_equal(Range: op->getResultTypes())) |
341 | return rewriter.notifyMatchFailure(arg&: op, msg: "mismatched result types"); |
342 | if (op->getNumOperands() == 0 || |
343 | !llvm::all_of(Range: op->getOperandTypes(), |
344 | P: [=](Type t) { return t == srcType; })) |
345 | return rewriter.notifyMatchFailure( |
346 | arg&: op, msg: "no operands or operand types don't match result type"); |
347 | |
348 | for (unsigned targetBitwidth : targetBitwidths) { |
349 | CastKind castKind = CastKind::Both; |
350 | for (const ConstantIntRanges &range : ranges) { |
351 | castKind = mergeCastKinds(lhs: castKind, |
352 | rhs: checkTruncatability(range, targetWidth: targetBitwidth)); |
353 | if (castKind == CastKind::None) |
354 | break; |
355 | } |
356 | if (castKind == CastKind::None) |
357 | continue; |
358 | Type targetType = getTargetType(srcType, targetBitwidth); |
359 | if (targetType == srcType) |
360 | continue; |
361 | |
362 | Location loc = op->getLoc(); |
363 | IRMapping mapping; |
364 | for (auto [arg, argRange] : llvm::zip_first(t: op->getOperands(), u&: ranges)) { |
365 | CastKind argCastKind = castKind; |
366 | // When dealing with `index` values, preserve non-negativity in the |
367 | // index_casts since we can't recover this in unsigned when equivalent. |
368 | if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative()) |
369 | argCastKind = CastKind::Both; |
370 | Value newArg = doCast(builder&: rewriter, loc, src: arg, dstType: targetType, castKind: argCastKind); |
371 | mapping.map(from: arg, to: newArg); |
372 | } |
373 | |
374 | Operation *newOp = rewriter.clone(op&: *op, mapper&: mapping); |
375 | rewriter.modifyOpInPlace(root: newOp, callable: [&]() { |
376 | for (OpResult res : newOp->getResults()) { |
377 | res.setType(targetType); |
378 | } |
379 | }); |
380 | SmallVector<Value> newResults; |
381 | for (auto [newRes, oldRes] : |
382 | llvm::zip_equal(t: newOp->getResults(), u: op->getResults())) { |
383 | Value castBack = doCast(builder&: rewriter, loc, src: newRes, dstType: srcType, castKind); |
384 | copyIntegerRange(solver, oldVal: oldRes, newVal: castBack); |
385 | newResults.push_back(Elt: castBack); |
386 | } |
387 | |
388 | rewriter.replaceOp(op, newValues: newResults); |
389 | return success(); |
390 | } |
391 | return failure(); |
392 | } |
393 | |
394 | private: |
395 | DataFlowSolver &solver; |
396 | SmallVector<unsigned, 4> targetBitwidths; |
397 | }; |
398 | |
399 | struct NarrowCmpI final : OpRewritePattern<arith::CmpIOp> { |
400 | NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef<unsigned> target) |
401 | : OpRewritePattern(context), solver(s), targetBitwidths(target) {} |
402 | |
403 | LogicalResult matchAndRewrite(arith::CmpIOp op, |
404 | PatternRewriter &rewriter) const override { |
405 | Value lhs = op.getLhs(); |
406 | Value rhs = op.getRhs(); |
407 | |
408 | SmallVector<ConstantIntRanges> ranges; |
409 | if (failed(collectRanges(solver, op.getOperands(), ranges))) |
410 | return failure(); |
411 | const ConstantIntRanges &lhsRange = ranges[0]; |
412 | const ConstantIntRanges &rhsRange = ranges[1]; |
413 | |
414 | Type srcType = lhs.getType(); |
415 | for (unsigned targetBitwidth : targetBitwidths) { |
416 | CastKind lhsCastKind = checkTruncatability(range: lhsRange, targetWidth: targetBitwidth); |
417 | CastKind rhsCastKind = checkTruncatability(range: rhsRange, targetWidth: targetBitwidth); |
418 | CastKind castKind = mergeCastKinds(lhs: lhsCastKind, rhs: rhsCastKind); |
419 | // Note: this includes target width > src width. |
420 | if (castKind == CastKind::None) |
421 | continue; |
422 | |
423 | Type targetType = getTargetType(srcType, targetBitwidth); |
424 | if (targetType == srcType) |
425 | continue; |
426 | |
427 | Location loc = op->getLoc(); |
428 | IRMapping mapping; |
429 | Value lhsCast = doCast(builder&: rewriter, loc, src: lhs, dstType: targetType, castKind: lhsCastKind); |
430 | Value rhsCast = doCast(builder&: rewriter, loc, src: rhs, dstType: targetType, castKind: rhsCastKind); |
431 | mapping.map(from: lhs, to: lhsCast); |
432 | mapping.map(from: rhs, to: rhsCast); |
433 | |
434 | Operation *newOp = rewriter.clone(*op, mapping); |
435 | copyIntegerRange(solver, op.getResult(), newOp->getResult(idx: 0)); |
436 | rewriter.replaceOp(op, newOp->getResults()); |
437 | return success(); |
438 | } |
439 | return failure(); |
440 | } |
441 | |
442 | private: |
443 | DataFlowSolver &solver; |
444 | SmallVector<unsigned, 4> targetBitwidths; |
445 | }; |
446 | |
447 | /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg |
448 | /// This pattern assumes all passed `targetBitwidths` are not wider than index |
449 | /// type. |
450 | template <typename CastOp> |
451 | struct FoldIndexCastChain final : OpRewritePattern<CastOp> { |
452 | FoldIndexCastChain(MLIRContext *context, ArrayRef<unsigned> target) |
453 | : OpRewritePattern<CastOp>(context), targetBitwidths(target) {} |
454 | |
455 | LogicalResult matchAndRewrite(CastOp op, |
456 | PatternRewriter &rewriter) const override { |
457 | auto srcOp = op.getIn().template getDefiningOp<CastOp>(); |
458 | if (!srcOp) |
459 | return rewriter.notifyMatchFailure(op, "doesn't come from an index cast"); |
460 | |
461 | Value src = srcOp.getIn(); |
462 | if (src.getType() != op.getType()) |
463 | return rewriter.notifyMatchFailure(op, "outer types don't match"); |
464 | |
465 | if (!srcOp.getType().isIndex()) |
466 | return rewriter.notifyMatchFailure(op, "intermediate type isn't index"); |
467 | |
468 | auto intType = dyn_cast<IntegerType>(op.getType()); |
469 | if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth())) |
470 | return failure(); |
471 | |
472 | rewriter.replaceOp(op, src); |
473 | return success(); |
474 | } |
475 | |
476 | private: |
477 | SmallVector<unsigned, 4> targetBitwidths; |
478 | }; |
479 | |
480 | struct IntRangeOptimizationsPass final |
481 | : arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> { |
482 | |
483 | void runOnOperation() override { |
484 | Operation *op = getOperation(); |
485 | MLIRContext *ctx = op->getContext(); |
486 | DataFlowSolver solver; |
487 | solver.load<DeadCodeAnalysis>(); |
488 | solver.load<IntegerRangeAnalysis>(); |
489 | if (failed(Result: solver.initializeAndRun(top: op))) |
490 | return signalPassFailure(); |
491 | |
492 | DataFlowListener listener(solver); |
493 | |
494 | RewritePatternSet patterns(ctx); |
495 | populateIntRangeOptimizationsPatterns(patterns, solver); |
496 | |
497 | if (failed(applyPatternsGreedily( |
498 | op, std::move(patterns), |
499 | GreedyRewriteConfig().setListener(&listener)))) |
500 | signalPassFailure(); |
501 | } |
502 | }; |
503 | |
504 | struct IntRangeNarrowingPass final |
505 | : arith::impl::ArithIntRangeNarrowingBase<IntRangeNarrowingPass> { |
506 | using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase; |
507 | |
508 | void runOnOperation() override { |
509 | Operation *op = getOperation(); |
510 | MLIRContext *ctx = op->getContext(); |
511 | DataFlowSolver solver; |
512 | solver.load<DeadCodeAnalysis>(); |
513 | solver.load<IntegerRangeAnalysis>(); |
514 | if (failed(Result: solver.initializeAndRun(top: op))) |
515 | return signalPassFailure(); |
516 | |
517 | DataFlowListener listener(solver); |
518 | |
519 | RewritePatternSet patterns(ctx); |
520 | populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported); |
521 | |
522 | // We specifically need bottom-up traversal as cmpi pattern needs range |
523 | // data, attached to its original argument values. |
524 | if (failed(applyPatternsGreedily( |
525 | op, std::move(patterns), |
526 | GreedyRewriteConfig().setUseTopDownTraversal(false).setListener( |
527 | &listener)))) |
528 | signalPassFailure(); |
529 | } |
530 | }; |
531 | } // namespace |
532 | |
533 | void mlir::arith::populateIntRangeOptimizationsPatterns( |
534 | RewritePatternSet &patterns, DataFlowSolver &solver) { |
535 | patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>, |
536 | DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver); |
537 | } |
538 | |
539 | void mlir::arith::populateIntRangeNarrowingPatterns( |
540 | RewritePatternSet &patterns, DataFlowSolver &solver, |
541 | ArrayRef<unsigned> bitwidthsSupported) { |
542 | patterns.add<NarrowElementwise, NarrowCmpI>(arg: patterns.getContext(), args&: solver, |
543 | args&: bitwidthsSupported); |
544 | patterns.add<FoldIndexCastChain<arith::IndexCastUIOp>, |
545 | FoldIndexCastChain<arith::IndexCastOp>>(patterns.getContext(), |
546 | bitwidthsSupported); |
547 | } |
548 | |
549 | std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() { |
550 | return std::make_unique<IntRangeOptimizationsPass>(); |
551 | } |
552 |
Definitions
- getMaybeConstantValue
- copyIntegerRange
- maybeReplaceWithConstant
- DataFlowListener
- DataFlowListener
- notifyOperationErased
- MaterializeKnownConstantValues
- MaterializeKnownConstantValues
- matchAndRewrite
- DeleteTrivialRem
- DeleteTrivialRem
- matchAndRewrite
- collectRanges
- getTargetType
- CastKind
- checkTruncatability
- mergeCastKinds
- doCast
- NarrowElementwise
- NarrowElementwise
- matchAndRewrite
- NarrowCmpI
- NarrowCmpI
- matchAndRewrite
- FoldIndexCastChain
- FoldIndexCastChain
- matchAndRewrite
- IntRangeOptimizationsPass
- runOnOperation
- IntRangeNarrowingPass
- runOnOperation
- populateIntRangeOptimizationsPatterns
- populateIntRangeNarrowingPatterns
Improve your Profiling and Debugging skills
Find out more