| 1 | //===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a transform to simplify mix/max affine operations. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Affine/Passes.h" |
| 14 | |
| 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 16 | #include "mlir/Dialect/Affine/Transforms/Transforms.h" |
| 17 | #include "mlir/IR/PatternMatch.h" |
| 18 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 19 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| 20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 21 | #include "llvm/ADT/IntEqClasses.h" |
| 22 | #include "llvm/Support/Debug.h" |
| 23 | #include "llvm/Support/InterleavedRange.h" |
| 24 | |
| 25 | #define DEBUG_TYPE "affine-min-max" |
| 26 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") |
| 27 | |
| 28 | using namespace mlir; |
| 29 | using namespace mlir::affine; |
| 30 | |
| 31 | /// Simplifies an affine min/max operation by proving there's a lower or upper |
| 32 | /// bound. |
| 33 | template <typename AffineOp> |
| 34 | static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { |
| 35 | using Variable = ValueBoundsConstraintSet::Variable; |
| 36 | using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator; |
| 37 | |
| 38 | AffineMap affineMap = affineOp.getMap(); |
| 39 | ValueRange operands = affineOp.getOperands(); |
| 40 | static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>; |
| 41 | |
| 42 | LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n" ; }); |
| 43 | |
| 44 | // Create a `Variable` list with values corresponding to each of the results |
| 45 | // in the affine affineMap. |
| 46 | SmallVector<Variable> variables = llvm::map_to_vector( |
| 47 | llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false), |
| 48 | [&](unsigned i) { |
| 49 | return Variable(affineMap.getSliceMap(start: i, length: 1), operands); |
| 50 | }); |
| 51 | LLVM_DEBUG({ |
| 52 | DBGS() << "- constructed variables are: " |
| 53 | << llvm::interleaved_array(llvm::map_range( |
| 54 | variables, [](const Variable &v) { return v.getMap(); })) |
| 55 | << "`\n" ; |
| 56 | }); |
| 57 | |
| 58 | // Get the comparison operation. |
| 59 | ComparisonOperator cmpOp = |
| 60 | isMin ? ComparisonOperator::LT : ComparisonOperator::GT; |
| 61 | |
| 62 | // Find disjoint sets bounded by a common value. |
| 63 | llvm::IntEqClasses boundedClasses(variables.size()); |
| 64 | DenseMap<unsigned, Variable *> bounds; |
| 65 | for (auto &&[i, v] : llvm::enumerate(First&: variables)) { |
| 66 | unsigned eqClass = boundedClasses.findLeader(a: i); |
| 67 | |
| 68 | // If the class already has a bound continue. |
| 69 | if (bounds.contains(Val: eqClass)) |
| 70 | continue; |
| 71 | |
| 72 | // Initialize the bound. |
| 73 | Variable *bound = &v; |
| 74 | |
| 75 | LLVM_DEBUG({ |
| 76 | DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap() |
| 77 | << "`\n" ; |
| 78 | }); |
| 79 | |
| 80 | // Check against the other variables. |
| 81 | for (size_t j = i + 1; j < variables.size(); ++j) { |
| 82 | unsigned jEqClass = boundedClasses.findLeader(a: j); |
| 83 | // Skip if the class is the same. |
| 84 | if (jEqClass == eqClass) |
| 85 | continue; |
| 86 | |
| 87 | // Get the bound of the equivalence class or itself. |
| 88 | Variable *nv = bounds.lookup_or(Val: jEqClass, Default: &variables[j]); |
| 89 | |
| 90 | LLVM_DEBUG({ |
| 91 | DBGS() << "- comparing with variable: #" << jEqClass |
| 92 | << ", with map: " << nv->getMap() << "\n" ; |
| 93 | }); |
| 94 | |
| 95 | // Compare the variables. |
| 96 | FailureOr<bool> cmpResult = |
| 97 | ValueBoundsConstraintSet::strongCompare(lhs: *bound, cmp: cmpOp, rhs: *nv); |
| 98 | |
| 99 | // The variables cannot be compared. |
| 100 | if (failed(Result: cmpResult)) { |
| 101 | LLVM_DEBUG({ |
| 102 | DBGS() << "-- classes: #" << i << ", #" << jEqClass |
| 103 | << " cannot be merged\n" ; |
| 104 | }); |
| 105 | continue; |
| 106 | } |
| 107 | |
| 108 | // Join the equivalent classes and update the bound if necessary. |
| 109 | LLVM_DEBUG({ |
| 110 | DBGS() << "-- merging classes: #" << i << ", #" << jEqClass |
| 111 | << ", is cmp(lhs, rhs): " << *cmpResult << "`\n" ; |
| 112 | }); |
| 113 | if (*cmpResult) { |
| 114 | boundedClasses.join(a: eqClass, b: jEqClass); |
| 115 | } else { |
| 116 | // In this case we have lhs > rhs if isMin == true, or lhs < rhs if |
| 117 | // isMin == false. |
| 118 | bound = nv; |
| 119 | boundedClasses.join(a: eqClass, b: jEqClass); |
| 120 | } |
| 121 | } |
| 122 | bounds[boundedClasses.findLeader(a: i)] = bound; |
| 123 | } |
| 124 | |
| 125 | // Return if there's no simplification. |
| 126 | if (bounds.size() >= affineMap.getNumResults()) { |
| 127 | LLVM_DEBUG( |
| 128 | { DBGS() << "- the affine operation couldn't get simplified\n" ; }); |
| 129 | return false; |
| 130 | } |
| 131 | |
| 132 | // Construct the new affine affineMap. |
| 133 | SmallVector<AffineExpr> results; |
| 134 | results.reserve(N: bounds.size()); |
| 135 | for (auto [k, bound] : bounds) |
| 136 | results.push_back(Elt: bound->getMap().getResult(idx: 0)); |
| 137 | |
| 138 | LLVM_DEBUG({ |
| 139 | DBGS() << "- starting from map: " << affineMap << "\n" ; |
| 140 | DBGS() << "- creating new map with: \n" ; |
| 141 | DBGS() << "--- dims: " << affineMap.getNumDims() << "\n" ; |
| 142 | DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n" ; |
| 143 | DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n" ; |
| 144 | }); |
| 145 | |
| 146 | affineMap = |
| 147 | AffineMap::get(dimCount: 0, symbolCount: affineMap.getNumSymbols() + affineMap.getNumDims(), |
| 148 | results, context: rewriter.getContext()); |
| 149 | |
| 150 | // Update the affine op. |
| 151 | rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); }); |
| 152 | LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n" ; }); |
| 153 | return true; |
| 154 | } |
| 155 | |
| 156 | bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) { |
| 157 | return simplifyAffineMinMaxOp(rewriter, affineOp: op); |
| 158 | } |
| 159 | |
| 160 | bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) { |
| 161 | return simplifyAffineMinMaxOp(rewriter, affineOp: op); |
| 162 | } |
| 163 | |
| 164 | LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter, |
| 165 | ArrayRef<Operation *> ops, |
| 166 | bool *modified) { |
| 167 | bool changed = false; |
| 168 | for (Operation *op : ops) { |
| 169 | if (auto minOp = dyn_cast<AffineMinOp>(Val: op)) |
| 170 | changed = simplifyAffineMinOp(rewriter, op: minOp) || changed; |
| 171 | else if (auto maxOp = cast<AffineMaxOp>(Val: op)) |
| 172 | changed = simplifyAffineMaxOp(rewriter, op: maxOp) || changed; |
| 173 | } |
| 174 | RewritePatternSet patterns(rewriter.getContext()); |
| 175 | AffineMaxOp::getCanonicalizationPatterns(results&: patterns, context: rewriter.getContext()); |
| 176 | AffineMinOp::getCanonicalizationPatterns(results&: patterns, context: rewriter.getContext()); |
| 177 | FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
| 178 | if (modified) |
| 179 | *modified = changed; |
| 180 | // Canonicalize to a fixpoint. |
| 181 | if (failed(Result: applyOpPatternsGreedily( |
| 182 | ops, patterns: frozenPatterns, |
| 183 | config: GreedyRewriteConfig() |
| 184 | .setListener( |
| 185 | static_cast<RewriterBase::Listener *>(rewriter.getListener())) |
| 186 | .setStrictness(GreedyRewriteStrictness::ExistingAndNewOps), |
| 187 | changed: &changed))) { |
| 188 | return failure(); |
| 189 | } |
| 190 | if (modified) |
| 191 | *modified = changed; |
| 192 | return success(); |
| 193 | } |
| 194 | |
| 195 | namespace { |
| 196 | |
| 197 | struct SimplifyAffineMaxOp : public OpRewritePattern<AffineMaxOp> { |
| 198 | using OpRewritePattern<AffineMaxOp>::OpRewritePattern; |
| 199 | |
| 200 | LogicalResult matchAndRewrite(AffineMaxOp affineOp, |
| 201 | PatternRewriter &rewriter) const override { |
| 202 | return success(IsSuccess: simplifyAffineMaxOp(rewriter, op: affineOp)); |
| 203 | } |
| 204 | }; |
| 205 | |
| 206 | struct SimplifyAffineMinOp : public OpRewritePattern<AffineMinOp> { |
| 207 | using OpRewritePattern<AffineMinOp>::OpRewritePattern; |
| 208 | |
| 209 | LogicalResult matchAndRewrite(AffineMinOp affineOp, |
| 210 | PatternRewriter &rewriter) const override { |
| 211 | return success(IsSuccess: simplifyAffineMinOp(rewriter, op: affineOp)); |
| 212 | } |
| 213 | }; |
| 214 | |
| 215 | struct SimplifyAffineApplyOp : public OpRewritePattern<AffineApplyOp> { |
| 216 | using OpRewritePattern<AffineApplyOp>::OpRewritePattern; |
| 217 | |
| 218 | LogicalResult matchAndRewrite(AffineApplyOp affineOp, |
| 219 | PatternRewriter &rewriter) const override { |
| 220 | AffineMap map = affineOp.getAffineMap(); |
| 221 | SmallVector<Value> operands{affineOp->getOperands().begin(), |
| 222 | affineOp->getOperands().end()}; |
| 223 | fullyComposeAffineMapAndOperands(map: &map, operands: &operands, |
| 224 | /*composeAffineMin=*/true); |
| 225 | |
| 226 | // No change => failure to apply. |
| 227 | if (map == affineOp.getAffineMap()) |
| 228 | return failure(); |
| 229 | |
| 230 | rewriter.modifyOpInPlace(root: affineOp, callable: [&]() { |
| 231 | affineOp.setMap(map); |
| 232 | affineOp->setOperands(operands); |
| 233 | }); |
| 234 | return success(); |
| 235 | } |
| 236 | }; |
| 237 | |
| 238 | } // namespace |
| 239 | |
| 240 | namespace mlir { |
| 241 | namespace affine { |
| 242 | #define GEN_PASS_DEF_SIMPLIFYAFFINEMINMAXPASS |
| 243 | #include "mlir/Dialect/Affine/Passes.h.inc" |
| 244 | } // namespace affine |
| 245 | } // namespace mlir |
| 246 | |
| 247 | /// Creates a simplification pass for affine min/max/apply. |
| 248 | struct SimplifyAffineMinMaxPass |
| 249 | : public affine::impl::SimplifyAffineMinMaxPassBase< |
| 250 | SimplifyAffineMinMaxPass> { |
| 251 | void runOnOperation() override; |
| 252 | }; |
| 253 | |
| 254 | void SimplifyAffineMinMaxPass::runOnOperation() { |
| 255 | FunctionOpInterface func = getOperation(); |
| 256 | RewritePatternSet patterns(func.getContext()); |
| 257 | AffineMaxOp::getCanonicalizationPatterns(results&: patterns, context: func.getContext()); |
| 258 | AffineMinOp::getCanonicalizationPatterns(results&: patterns, context: func.getContext()); |
| 259 | patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>( |
| 260 | arg: func.getContext()); |
| 261 | FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
| 262 | if (failed(Result: applyPatternsGreedily(op: func, patterns: std::move(frozenPatterns)))) |
| 263 | return signalPassFailure(); |
| 264 | } |
| 265 | |