1//=== AffineTransformOps.cpp - Implementation of Affine transformation ops ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
10#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
11#include "mlir/Dialect/Affine/Analysis/Utils.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
14#include "mlir/Dialect/Affine/LoopUtils.h"
15#include "mlir/Dialect/Transform/IR/TransformDialect.h"
16#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
17#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18
19using namespace mlir;
20using namespace mlir::affine;
21using namespace mlir::transform;
22
23//===----------------------------------------------------------------------===//
24// SimplifyBoundedAffineOpsOp
25//===----------------------------------------------------------------------===//
26
27LogicalResult SimplifyBoundedAffineOpsOp::verify() {
28 if (getLowerBounds().size() != getBoundedValues().size())
29 return emitOpError() << "incorrect number of lower bounds, expected "
30 << getBoundedValues().size() << " but found "
31 << getLowerBounds().size();
32 if (getUpperBounds().size() != getBoundedValues().size())
33 return emitOpError() << "incorrect number of upper bounds, expected "
34 << getBoundedValues().size() << " but found "
35 << getUpperBounds().size();
36 return success();
37}
38
39namespace {
40/// Simplify affine.min / affine.max ops with the given constraints. They are
41/// either rewritten to affine.apply or left unchanged.
42template <typename OpTy>
43struct SimplifyAffineMinMaxOp : public OpRewritePattern<OpTy> {
44 using OpRewritePattern<OpTy>::OpRewritePattern;
45 SimplifyAffineMinMaxOp(MLIRContext *ctx,
46 const FlatAffineValueConstraints &constraints,
47 PatternBenefit benefit = 1)
48 : OpRewritePattern<OpTy>(ctx, benefit), constraints(constraints) {}
49
50 LogicalResult matchAndRewrite(OpTy op,
51 PatternRewriter &rewriter) const override {
52 FailureOr<AffineValueMap> simplified =
53 simplifyConstrainedMinMaxOp(op, constraints);
54 if (failed(result: simplified))
55 return failure();
56 rewriter.replaceOpWithNewOp<AffineApplyOp>(op, simplified->getAffineMap(),
57 simplified->getOperands());
58 return success();
59 }
60
61 const FlatAffineValueConstraints &constraints;
62};
63} // namespace
64
65DiagnosedSilenceableFailure
66SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
67 TransformResults &results,
68 TransformState &state) {
69 // Get constraints for bounded values.
70 SmallVector<int64_t> lbs;
71 SmallVector<int64_t> ubs;
72 SmallVector<Value> boundedValues;
73 DenseSet<Operation *> boundedOps;
74 for (const auto &it : llvm::zip_equal(getBoundedValues(), getLowerBounds(),
75 getUpperBounds())) {
76 Value handle = std::get<0>(it);
77 for (Operation *op : state.getPayloadOps(handle)) {
78 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
79 auto diag =
80 emitDefiniteFailure()
81 << "expected bounded value handle to point to one or multiple "
82 "single-result index-typed ops";
83 diag.attachNote(op->getLoc()) << "multiple/non-index result";
84 return diag;
85 }
86 boundedValues.push_back(op->getResult(0));
87 boundedOps.insert(op);
88 lbs.push_back(std::get<1>(it));
89 ubs.push_back(std::get<2>(it));
90 }
91 }
92
93 // Build constraint set.
94 FlatAffineValueConstraints cstr;
95 for (const auto &it : llvm::zip(boundedValues, lbs, ubs)) {
96 unsigned pos;
97 if (!cstr.findVar(std::get<0>(it), &pos))
98 pos = cstr.appendSymbolVar(std::get<0>(it));
99 cstr.addBound(presburger::BoundType::LB, pos, std::get<1>(it));
100 // Note: addBound bounds are inclusive, but specified UB is exclusive.
101 cstr.addBound(presburger::BoundType::UB, pos, std::get<2>(it) - 1);
102 }
103
104 // Transform all targets.
105 SmallVector<Operation *> targets;
106 for (Operation *target : state.getPayloadOps(getTarget())) {
107 if (!isa<AffineMinOp, AffineMaxOp>(target)) {
108 auto diag = emitDefiniteFailure()
109 << "target must be affine.min or affine.max";
110 diag.attachNote(target->getLoc()) << "target op";
111 return diag;
112 }
113 if (boundedOps.contains(target)) {
114 auto diag = emitDefiniteFailure()
115 << "target op result must not be constrainted";
116 diag.attachNote(target->getLoc()) << "target/constrained op";
117 return diag;
118 }
119 targets.push_back(target);
120 }
121 SmallVector<Operation *> transformed;
122 RewritePatternSet patterns(getContext());
123 // Canonicalization patterns are needed so that affine.apply ops are composed
124 // with the remaining affine.min/max ops.
125 AffineMaxOp::getCanonicalizationPatterns(patterns, getContext());
126 AffineMinOp::getCanonicalizationPatterns(patterns, getContext());
127 patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
128 SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
129 FrozenRewritePatternSet frozenPatterns(std::move(patterns));
130 GreedyRewriteConfig config;
131 config.listener =
132 static_cast<RewriterBase::Listener *>(rewriter.getListener());
133 config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
134 // Apply the simplification pattern to a fixpoint.
135 if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) {
136 auto diag = emitDefiniteFailure()
137 << "affine.min/max simplification did not converge";
138 return diag;
139 }
140 return DiagnosedSilenceableFailure::success();
141}
142
143void SimplifyBoundedAffineOpsOp::getEffects(
144 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
145 consumesHandle(getTarget(), effects);
146 for (Value v : getBoundedValues())
147 onlyReadsHandle(v, effects);
148 modifiesPayload(effects);
149}
150
151//===----------------------------------------------------------------------===//
152// Transform op registration
153//===----------------------------------------------------------------------===//
154
155namespace {
156class AffineTransformDialectExtension
157 : public transform::TransformDialectExtension<
158 AffineTransformDialectExtension> {
159public:
160 using Base::Base;
161
162 void init() {
163 declareGeneratedDialect<AffineDialect>();
164
165 registerTransformOps<
166#define GET_OP_LIST
167#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
168 >();
169 }
170};
171} // namespace
172
173#define GET_OP_CLASSES
174#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.cpp.inc"
175
176void mlir::affine::registerTransformDialectExtension(
177 DialectRegistry &registry) {
178 registry.addExtensions<AffineTransformDialectExtension>();
179}
180

source code of mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp