1 | //=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" |
10 | #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" |
11 | #include "mlir/Dialect/Affine/Analysis/Utils.h" |
12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
13 | #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
14 | #include "mlir/Dialect/Affine/LoopUtils.h" |
15 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" |
16 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
17 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
18 | |
19 | using namespace mlir; |
20 | using namespace mlir::affine; |
21 | using namespace mlir::transform; |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // SimplifyBoundedAffineOpsOp |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | LogicalResult SimplifyBoundedAffineOpsOp::verify() { |
28 | if (getLowerBounds().size() != getBoundedValues().size()) |
29 | return emitOpError() << "incorrect number of lower bounds, expected " |
30 | << getBoundedValues().size() << " but found " |
31 | << getLowerBounds().size(); |
32 | if (getUpperBounds().size() != getBoundedValues().size()) |
33 | return emitOpError() << "incorrect number of upper bounds, expected " |
34 | << getBoundedValues().size() << " but found " |
35 | << getUpperBounds().size(); |
36 | return success(); |
37 | } |
38 | |
39 | namespace { |
40 | /// Simplify affine.min / affine.max ops with the given constraints. They are |
41 | /// either rewritten to affine.apply or left unchanged. |
42 | template <typename OpTy> |
43 | struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> { |
44 | using OpRewritePattern<OpTy>::OpRewritePattern; |
45 | SimplifyAffineMinMaxOp(MLIRContext *ctx, |
46 | const FlatAffineValueConstraints &constraints, |
47 | PatternBenefit benefit = 1) |
48 | : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {} |
49 | |
50 | LogicalResult matchAndRewrite(OpTy op, |
51 | PatternRewriter &rewriter) const override { |
52 | FailureOr<AffineValueMap> simplified = |
53 | simplifyConstrainedMinMaxOp(op, constraints); |
54 | if (failed(result: simplified)) |
55 | return failure(); |
56 | rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(), |
57 | simplified->getOperands()); |
58 | return success(); |
59 | } |
60 | |
61 | const FlatAffineValueConstraints &constraints; |
62 | }; |
63 | } // namespace |
64 | |
65 | DiagnosedSilenceableFailure |
66 | SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter, |
67 | TransformResults &results, |
68 | TransformState &state) { |
69 | // Get constraints for bounded values. |
70 | SmallVector<int64_t> lbs; |
71 | SmallVector<int64_t> ubs; |
72 | SmallVector<Value> boundedValues; |
73 | DenseSet<Operation *> boundedOps; |
74 | for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(), |
75 | getUpperBounds())) { |
76 | Value handle = std::get<0>(it); |
77 | for (Operation *op : state.getPayloadOps(handle)) { |
78 | if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { |
79 | auto diag = |
80 | emitDefiniteFailure() |
81 | << "expected bounded value handle to point to one or multiple " |
82 | "single-result index-typed ops" ; |
83 | diag.attachNote(op->getLoc()) << "multiple/non-index result" ; |
84 | return diag; |
85 | } |
86 | boundedValues.push_back(op->getResult(0)); |
87 | boundedOps.insert(op); |
88 | lbs.push_back(std::get<1>(it)); |
89 | ubs.push_back(std::get<2>(it)); |
90 | } |
91 | } |
92 | |
93 | // Build constraint set. |
94 | FlatAffineValueConstraints cstr; |
95 | for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) { |
96 | unsigned pos; |
97 | if (!cstr.findVar(std::get<0>(it), &pos)) |
98 | pos = cstr.appendSymbolVar(std::get<0>(it)); |
99 | cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it)); |
100 | // Note: addBound bounds are inclusive, but specified UB is exclusive. |
101 | cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1); |
102 | } |
103 | |
104 | // Transform all targets. |
105 | SmallVector<Operation *> targets; |
106 | for (Operation *target : state.getPayloadOps(getTarget())) { |
107 | if (!isa<AffineMinOp, AffineMaxOp>(target)) { |
108 | auto diag = emitDefiniteFailure() |
109 | << "target must be affine.min or affine.max" ; |
110 | diag.attachNote(target->getLoc()) << "target op" ; |
111 | return diag; |
112 | } |
113 | if (boundedOps.contains(target)) { |
114 | auto diag = emitDefiniteFailure() |
115 | << "target op result must not be constrainted" ; |
116 | diag.attachNote(target->getLoc()) << "target/constrained op" ; |
117 | return diag; |
118 | } |
119 | targets.push_back(target); |
120 | } |
121 | SmallVector<Operation *> transformed; |
122 | RewritePatternSet patterns(getContext()); |
123 | // Canonicalization patterns are needed so that affine.apply ops are composed |
124 | // with the remaining affine.min/max ops. |
125 | AffineMaxOp::getCanonicalizationPatterns(patterns, getContext()); |
126 | AffineMinOp::getCanonicalizationPatterns(patterns, getContext()); |
127 | patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>, |
128 | SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr); |
129 | FrozenRewritePatternSet frozenPatterns(std::move(patterns)); |
130 | GreedyRewriteConfig config; |
131 | config.listener = |
132 | static_cast<RewriterBase::Listener *>(rewriter.getListener()); |
133 | config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; |
134 | // Apply the simplification pattern to a fixpoint. |
135 | if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) { |
136 | auto diag = emitDefiniteFailure() |
137 | << "affine.min/max simplification did not converge" ; |
138 | return diag; |
139 | } |
140 | return DiagnosedSilenceableFailure::success(); |
141 | } |
142 | |
143 | void SimplifyBoundedAffineOpsOp::getEffects( |
144 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
145 | consumesHandle(getTarget(), effects); |
146 | for (Value v : getBoundedValues()) |
147 | onlyReadsHandle(v, effects); |
148 | modifiesPayload(effects); |
149 | } |
150 | |
151 | //===----------------------------------------------------------------------===// |
152 | // Transform op registration |
153 | //===----------------------------------------------------------------------===// |
154 | |
155 | namespace { |
156 | class AffineTransformDialectExtension |
157 | : public transform::TransformDialectExtension< |
158 | AffineTransformDialectExtension> { |
159 | public: |
160 | using Base::Base; |
161 | |
162 | void init() { |
163 | declareGeneratedDialect<AffineDialect>(); |
164 | |
165 | registerTransformOps< |
166 | #define GET_OP_LIST |
167 | #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc" |
168 | >(); |
169 | } |
170 | }; |
171 | } // namespace |
172 | |
173 | #define GET_OP_CLASSES |
174 | #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc" |
175 | |
176 | void mlir::affine::registerTransformDialectExtension( |
177 | DialectRegistry ®istry) { |
178 | registry.addExtensions<AffineTransformDialectExtension>(); |
179 | } |
180 | |