1//===- SimplifyAffineMinMax.cpp - Simplify affine min/max ops -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a transform to simplify mix/max affine operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/Passes.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Affine/Transforms/Transforms.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/Interfaces/FunctionInterfaces.h"
19#include "mlir/Interfaces/ValueBoundsOpInterface.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21#include "llvm/ADT/IntEqClasses.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/InterleavedRange.h"
24
25#define DEBUG_TYPE "affine-min-max"
26#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
27
28using namespace mlir;
29using namespace mlir::affine;
30
31/// Simplifies an affine min/max operation by proving there's a lower or upper
32/// bound.
33template <typename AffineOp>
34static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
35 using Variable = ValueBoundsConstraintSet::Variable;
36 using ComparisonOperator = ValueBoundsConstraintSet::ComparisonOperator;
37
38 AffineMap affineMap = affineOp.getMap();
39 ValueRange operands = affineOp.getOperands();
40 static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
41
42 LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
43
44 // Create a `Variable` list with values corresponding to each of the results
45 // in the affine affineMap.
46 SmallVector<Variable> variables = llvm::map_to_vector(
47 llvm::iota_range<unsigned>(0u, affineMap.getNumResults(), false),
48 [&](unsigned i) {
49 return Variable(affineMap.getSliceMap(start: i, length: 1), operands);
50 });
51 LLVM_DEBUG({
52 DBGS() << "- constructed variables are: "
53 << llvm::interleaved_array(llvm::map_range(
54 variables, [](const Variable &v) { return v.getMap(); }))
55 << "`\n";
56 });
57
58 // Get the comparison operation.
59 ComparisonOperator cmpOp =
60 isMin ? ComparisonOperator::LT : ComparisonOperator::GT;
61
62 // Find disjoint sets bounded by a common value.
63 llvm::IntEqClasses boundedClasses(variables.size());
64 DenseMap<unsigned, Variable *> bounds;
65 for (auto &&[i, v] : llvm::enumerate(First&: variables)) {
66 unsigned eqClass = boundedClasses.findLeader(a: i);
67
68 // If the class already has a bound continue.
69 if (bounds.contains(Val: eqClass))
70 continue;
71
72 // Initialize the bound.
73 Variable *bound = &v;
74
75 LLVM_DEBUG({
76 DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
77 << "`\n";
78 });
79
80 // Check against the other variables.
81 for (size_t j = i + 1; j < variables.size(); ++j) {
82 unsigned jEqClass = boundedClasses.findLeader(a: j);
83 // Skip if the class is the same.
84 if (jEqClass == eqClass)
85 continue;
86
87 // Get the bound of the equivalence class or itself.
88 Variable *nv = bounds.lookup_or(Val: jEqClass, Default: &variables[j]);
89
90 LLVM_DEBUG({
91 DBGS() << "- comparing with variable: #" << jEqClass
92 << ", with map: " << nv->getMap() << "\n";
93 });
94
95 // Compare the variables.
96 FailureOr<bool> cmpResult =
97 ValueBoundsConstraintSet::strongCompare(lhs: *bound, cmp: cmpOp, rhs: *nv);
98
99 // The variables cannot be compared.
100 if (failed(Result: cmpResult)) {
101 LLVM_DEBUG({
102 DBGS() << "-- classes: #" << i << ", #" << jEqClass
103 << " cannot be merged\n";
104 });
105 continue;
106 }
107
108 // Join the equivalent classes and update the bound if necessary.
109 LLVM_DEBUG({
110 DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
111 << ", is cmp(lhs, rhs): " << *cmpResult << "`\n";
112 });
113 if (*cmpResult) {
114 boundedClasses.join(a: eqClass, b: jEqClass);
115 } else {
116 // In this case we have lhs > rhs if isMin == true, or lhs < rhs if
117 // isMin == false.
118 bound = nv;
119 boundedClasses.join(a: eqClass, b: jEqClass);
120 }
121 }
122 bounds[boundedClasses.findLeader(a: i)] = bound;
123 }
124
125 // Return if there's no simplification.
126 if (bounds.size() >= affineMap.getNumResults()) {
127 LLVM_DEBUG(
128 { DBGS() << "- the affine operation couldn't get simplified\n"; });
129 return false;
130 }
131
132 // Construct the new affine affineMap.
133 SmallVector<AffineExpr> results;
134 results.reserve(N: bounds.size());
135 for (auto [k, bound] : bounds)
136 results.push_back(Elt: bound->getMap().getResult(idx: 0));
137
138 LLVM_DEBUG({
139 DBGS() << "- starting from map: " << affineMap << "\n";
140 DBGS() << "- creating new map with: \n";
141 DBGS() << "--- dims: " << affineMap.getNumDims() << "\n";
142 DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n";
143 DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n";
144 });
145
146 affineMap =
147 AffineMap::get(dimCount: 0, symbolCount: affineMap.getNumSymbols() + affineMap.getNumDims(),
148 results, context: rewriter.getContext());
149
150 // Update the affine op.
151 rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
152 LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
153 return true;
154}
155
156bool mlir::affine::simplifyAffineMinOp(RewriterBase &rewriter, AffineMinOp op) {
157 return simplifyAffineMinMaxOp(rewriter, affineOp: op);
158}
159
160bool mlir::affine::simplifyAffineMaxOp(RewriterBase &rewriter, AffineMaxOp op) {
161 return simplifyAffineMinMaxOp(rewriter, affineOp: op);
162}
163
164LogicalResult mlir::affine::simplifyAffineMinMaxOps(RewriterBase &rewriter,
165 ArrayRef<Operation *> ops,
166 bool *modified) {
167 bool changed = false;
168 for (Operation *op : ops) {
169 if (auto minOp = dyn_cast<AffineMinOp>(Val: op))
170 changed = simplifyAffineMinOp(rewriter, op: minOp) || changed;
171 else if (auto maxOp = cast<AffineMaxOp>(Val: op))
172 changed = simplifyAffineMaxOp(rewriter, op: maxOp) || changed;
173 }
174 RewritePatternSet patterns(rewriter.getContext());
175 AffineMaxOp::getCanonicalizationPatterns(results&: patterns, context: rewriter.getContext());
176 AffineMinOp::getCanonicalizationPatterns(results&: patterns, context: rewriter.getContext());
177 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
178 if (modified)
179 *modified = changed;
180 // Canonicalize to a fixpoint.
181 if (failed(Result: applyOpPatternsGreedily(
182 ops, patterns: frozenPatterns,
183 config: GreedyRewriteConfig()
184 .setListener(
185 static_cast<RewriterBase::Listener *>(rewriter.getListener()))
186 .setStrictness(GreedyRewriteStrictness::ExistingAndNewOps),
187 changed: &changed))) {
188 return failure();
189 }
190 if (modified)
191 *modified = changed;
192 return success();
193}
194
195namespace {
196
197struct SimplifyAffineMaxOp : public OpRewritePattern<AffineMaxOp> {
198 using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
199
200 LogicalResult matchAndRewrite(AffineMaxOp affineOp,
201 PatternRewriter &rewriter) const override {
202 return success(IsSuccess: simplifyAffineMaxOp(rewriter, op: affineOp));
203 }
204};
205
206struct SimplifyAffineMinOp : public OpRewritePattern<AffineMinOp> {
207 using OpRewritePattern<AffineMinOp>::OpRewritePattern;
208
209 LogicalResult matchAndRewrite(AffineMinOp affineOp,
210 PatternRewriter &rewriter) const override {
211 return success(IsSuccess: simplifyAffineMinOp(rewriter, op: affineOp));
212 }
213};
214
215struct SimplifyAffineApplyOp : public OpRewritePattern<AffineApplyOp> {
216 using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
217
218 LogicalResult matchAndRewrite(AffineApplyOp affineOp,
219 PatternRewriter &rewriter) const override {
220 AffineMap map = affineOp.getAffineMap();
221 SmallVector<Value> operands{affineOp->getOperands().begin(),
222 affineOp->getOperands().end()};
223 fullyComposeAffineMapAndOperands(map: &map, operands: &operands,
224 /*composeAffineMin=*/true);
225
226 // No change => failure to apply.
227 if (map == affineOp.getAffineMap())
228 return failure();
229
230 rewriter.modifyOpInPlace(root: affineOp, callable: [&]() {
231 affineOp.setMap(map);
232 affineOp->setOperands(operands);
233 });
234 return success();
235 }
236};
237
238} // namespace
239
240namespace mlir {
241namespace affine {
242#define GEN_PASS_DEF_SIMPLIFYAFFINEMINMAXPASS
243#include "mlir/Dialect/Affine/Passes.h.inc"
244} // namespace affine
245} // namespace mlir
246
247/// Creates a simplification pass for affine min/max/apply.
248struct SimplifyAffineMinMaxPass
249 : public affine::impl::SimplifyAffineMinMaxPassBase<
250 SimplifyAffineMinMaxPass> {
251 void runOnOperation() override;
252};
253
254void SimplifyAffineMinMaxPass::runOnOperation() {
255 FunctionOpInterface func = getOperation();
256 RewritePatternSet patterns(func.getContext());
257 AffineMaxOp::getCanonicalizationPatterns(results&: patterns, context: func.getContext());
258 AffineMinOp::getCanonicalizationPatterns(results&: patterns, context: func.getContext());
259 patterns.add<SimplifyAffineMaxOp, SimplifyAffineMinOp, SimplifyAffineApplyOp>(
260 arg: func.getContext());
261 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
262 if (failed(Result: applyPatternsGreedily(op: func, patterns: std::move(frozenPatterns))))
263 return signalPassFailure();
264}
265

source code of mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp