1//=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
10#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
11#include "mlir/Dialect/Affine/Analysis/Utils.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
14#include "mlir/Dialect/Affine/LoopUtils.h"
15#include "mlir/Dialect/Affine/Transforms/Transforms.h"
16#include "mlir/Dialect/Transform/IR/TransformDialect.h"
17#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20using namespace mlir;
21using namespace mlir::affine;
22using namespace mlir::transform;
23
24//===----------------------------------------------------------------------===//
25// SimplifyBoundedAffineOpsOp
26//===----------------------------------------------------------------------===//
27
28LogicalResult SimplifyBoundedAffineOpsOp::verify() {
29 if (getLowerBounds().size() != getBoundedValues().size())
30 return emitOpError() << "incorrect number of lower bounds, expected "
31 << getBoundedValues().size() << " but found "
32 << getLowerBounds().size();
33 if (getUpperBounds().size() != getBoundedValues().size())
34 return emitOpError() << "incorrect number of upper bounds, expected "
35 << getBoundedValues().size() << " but found "
36 << getUpperBounds().size();
37 return success();
38}
39
40namespace {
41/// Simplify affine.min / affine.max ops with the given constraints. They are
42/// either rewritten to affine.apply or left unchanged.
43template <typename OpTy>
44struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
45 using OpRewritePattern<OpTy>::OpRewritePattern;
46 SimplifyAffineMinMaxOp(MLIRContext *ctx,
47 const FlatAffineValueConstraints &constraints,
48 PatternBenefit benefit = 1)
49 : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
50
51 LogicalResult matchAndRewrite(OpTy op,
52 PatternRewriter &rewriter) const override {
53 FailureOr<AffineValueMap> simplified =
54 simplifyConstrainedMinMaxOp(op, constraints);
55 if (failed(Result: simplified))
56 return failure();
57 rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
58 simplified->getOperands());
59 return success();
60 }
61
62 const FlatAffineValueConstraints &constraints;
63};
64} // namespace
65
66DiagnosedSilenceableFailure
67SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
68 TransformResults &results,
69 TransformState &state) {
70 // Get constraints for bounded values.
71 SmallVector<int64_t> lbs;
72 SmallVector<int64_t> ubs;
73 SmallVector<Value> boundedValues;
74 DenseSet<Operation *> boundedOps;
75 for (const auto &it : llvm::zip_equal(t: getBoundedValues(), u: getLowerBounds(),
76 args: getUpperBounds())) {
77 Value handle = std::get<0>(t: it);
78 for (Operation *op : state.getPayloadOps(value: handle)) {
79 if (op->getNumResults() != 1 || !op->getResult(idx: 0).getType().isIndex()) {
80 auto diag =
81 emitDefiniteFailure()
82 << "expected bounded value handle to point to one or multiple "
83 "single-result index-typed ops";
84 diag.attachNote(loc: op->getLoc()) << "multiple/non-index result";
85 return diag;
86 }
87 boundedValues.push_back(Elt: op->getResult(idx: 0));
88 boundedOps.insert(V: op);
89 lbs.push_back(Elt: std::get<1>(t: it));
90 ubs.push_back(Elt: std::get<2>(t: it));
91 }
92 }
93
94 // Build constraint set.
95 FlatAffineValueConstraints cstr;
96 for (const auto &it : llvm::zip(t&: boundedValues, u&: lbs, args&: ubs)) {
97 unsigned pos;
98 if (!cstr.findVar(val: std::get<0>(t: it), pos: &pos))
99 pos = cstr.appendSymbolVar(vals: std::get<0>(t: it));
100 cstr.addBound(type: presburger::BoundType::LB, pos, value: std::get<1>(t: it));
101 // Note: addBound bounds are inclusive, but specified UB is exclusive.
102 cstr.addBound(type: presburger::BoundType::UB, pos, value: std::get<2>(t: it) - 1);
103 }
104
105 // Transform all targets.
106 SmallVector<Operation *> targets;
107 for (Operation *target : state.getPayloadOps(value: getTarget())) {
108 if (!isa<AffineMinOp, AffineMaxOp>(Val: target)) {
109 auto diag = emitDefiniteFailure()
110 << "target must be affine.min or affine.max";
111 diag.attachNote(loc: target->getLoc()) << "target op";
112 return diag;
113 }
114 if (boundedOps.contains(V: target)) {
115 auto diag = emitDefiniteFailure()
116 << "target op result must not be constrained";
117 diag.attachNote(loc: target->getLoc()) << "target/constrained op";
118 return diag;
119 }
120 targets.push_back(Elt: target);
121 }
122 RewritePatternSet patterns(getContext());
123 // Canonicalization patterns are needed so that affine.apply ops are composed
124 // with the remaining affine.min/max ops.
125 AffineMaxOp::getCanonicalizationPatterns(results&: patterns, context: getContext());
126 AffineMinOp::getCanonicalizationPatterns(results&: patterns, context: getContext());
127 patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
128 SimplifyAffineMinMaxOp<AffineMaxOp>>(arg: getContext(), args&: cstr);
129 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
130 // Apply the simplification pattern to a fixpoint.
131 if (failed(Result: applyOpPatternsGreedily(
132 ops: targets, patterns: frozenPatterns,
133 config: GreedyRewriteConfig()
134 .setListener(
135 static_cast<RewriterBase::Listener *>(rewriter.getListener()))
136 .setStrictness(GreedyRewriteStrictness::ExistingAndNewOps)))) {
137 auto diag = emitDefiniteFailure()
138 << "affine.min/max simplification did not converge";
139 return diag;
140 }
141 return DiagnosedSilenceableFailure::success();
142}
143
144void SimplifyBoundedAffineOpsOp::getEffects(
145 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
146 consumesHandle(handles: getTargetMutable(), effects);
147 for (OpOperand &operand : getBoundedValuesMutable())
148 onlyReadsHandle(handles: operand, effects);
149 modifiesPayload(effects);
150}
151
152//===----------------------------------------------------------------------===//
153// SimplifyMinMaxAffineOpsOp
154//===----------------------------------------------------------------------===//
155DiagnosedSilenceableFailure
156SimplifyMinMaxAffineOpsOp::apply(transform::TransformRewriter &rewriter,
157 TransformResults &results,
158 TransformState &state) {
159 SmallVector<Operation *> targets;
160 for (Operation *target : state.getPayloadOps(value: getTarget())) {
161 if (!isa<AffineMinOp, AffineMaxOp>(Val: target)) {
162 auto diag = emitDefiniteFailure()
163 << "target must be affine.min or affine.max";
164 diag.attachNote(loc: target->getLoc()) << "target op";
165 return diag;
166 }
167 targets.push_back(Elt: target);
168 }
169 bool modified = false;
170 if (failed(Result: mlir::affine::simplifyAffineMinMaxOps(rewriter, ops: targets,
171 modified: &modified))) {
172 return emitDefiniteFailure()
173 << "affine.min/max simplification did not converge";
174 }
175 if (!modified) {
176 return emitSilenceableError()
177 << "the transform failed to simplify any of the target operations";
178 }
179 return DiagnosedSilenceableFailure::success();
180}
181
182void SimplifyMinMaxAffineOpsOp::getEffects(
183 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
184 consumesHandle(handles: getTargetMutable(), effects);
185 modifiesPayload(effects);
186}
187
188//===----------------------------------------------------------------------===//
189// Transform op registration
190//===----------------------------------------------------------------------===//
191
192namespace {
193class AffineTransformDialectExtension
194 : public transform::TransformDialectExtension<
195 AffineTransformDialectExtension> {
196public:
197 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AffineTransformDialectExtension)
198
199 using Base::Base;
200
201 void init() {
202 declareGeneratedDialect<AffineDialect>();
203
204 registerTransformOps<
205#define GET_OP_LIST
206#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
207 >();
208 }
209};
210} // namespace
211
212#define GET_OP_CLASSES
213#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
214
215void mlir::affine::registerTransformDialectExtension(
216 DialectRegistry &registry) {
217 registry.addExtensions<AffineTransformDialectExtension>();
218}
219

source code of mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp