1//===- TestReifyValueBounds.cpp - Test value bounds reification -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
13#include "mlir/Dialect/Affine/Transforms/Transforms.h"
14#include "mlir/Dialect/Arith/Transforms/Transforms.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/Tensor/IR/Tensor.h"
18#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Interfaces/FunctionInterfaces.h"
21#include "mlir/Interfaces/ValueBoundsOpInterface.h"
22#include "mlir/Pass/Pass.h"
23
24#define PASS_NAME "test-affine-reify-value-bounds"
25
26using namespace mlir;
27using namespace mlir::affine;
28
29namespace {
30
31/// This pass applies the permutation on the first maximal perfect nest.
32struct TestReifyValueBounds
33 : public PassWrapper<TestReifyValueBounds,
34 InterfacePass<FunctionOpInterface>> {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds)
36
37 StringRef getArgument() const final { return PASS_NAME; }
38 StringRef getDescription() const final {
39 return "Tests ValueBoundsOpInterface with affine dialect reification";
40 }
41 TestReifyValueBounds() = default;
42 TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){};
43
44 void getDependentDialects(DialectRegistry &registry) const override {
45 registry.insert<affine::AffineDialect, tensor::TensorDialect,
46 memref::MemRefDialect>();
47 }
48
49 void runOnOperation() override;
50
51private:
52 Option<bool> reifyToFuncArgs{
53 *this, "reify-to-func-args",
54 llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(Val: false)};
55
56 Option<bool> useArithOps{*this, "use-arith-ops",
57 llvm::cl::desc("Reify with arith dialect ops"),
58 llvm::cl::init(Val: false)};
59};
60
61} // namespace
62
63static ValueBoundsConstraintSet::ComparisonOperator
64invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) {
65 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT)
66 return ValueBoundsConstraintSet::ComparisonOperator::GE;
67 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LE)
68 return ValueBoundsConstraintSet::ComparisonOperator::GT;
69 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GT)
70 return ValueBoundsConstraintSet::ComparisonOperator::LE;
71 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GE)
72 return ValueBoundsConstraintSet::ComparisonOperator::LT;
73 llvm_unreachable("unsupported comparison operator");
74}
75
76/// Look for "test.reify_bound" ops in the input and replace their results with
77/// the reified values.
78static LogicalResult testReifyValueBounds(FunctionOpInterface funcOp,
79 bool reifyToFuncArgs,
80 bool useArithOps) {
81 IRRewriter rewriter(funcOp.getContext());
82 WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) {
83 auto boundType = op.getBoundType();
84 Value value = op.getVar();
85 std::optional<int64_t> dim = op.getDim();
86 auto shapedType = dyn_cast<ShapedType>(value.getType());
87 if (!shapedType && dim.has_value()) {
88 op->emitOpError("dim specified for non-shaped type");
89 return WalkResult::interrupt();
90 }
91 if (shapedType && !dim.has_value()) {
92 op->emitOpError("dim not specified for shaped type");
93 return WalkResult::interrupt();
94 }
95 if (shapedType && shapedType.hasRank() && dim.has_value()) {
96 if (dim.value() < 0) {
97 op->emitOpError("dim must be non-negative");
98 return WalkResult::interrupt();
99 }
100
101 if (dim.value() >= shapedType.getRank()) {
102 op->emitOpError("invalid dim for shaped type rank");
103 return WalkResult::interrupt();
104 }
105 }
106
107 bool constant = op.getConstant();
108 bool scalable = op.getScalable();
109
110 // Prepare stop condition. By default, reify in terms of the op's
111 // operands. No stop condition is used when a constant was requested.
112 std::function<bool(Value, std::optional<int64_t>,
113 ValueBoundsConstraintSet & cstr)>
114 stopCondition = [&](Value v, std::optional<int64_t> d,
115 ValueBoundsConstraintSet &cstr) {
116 // Reify in terms of SSA values that are different from `value`.
117 return v != value;
118 };
119 if (reifyToFuncArgs) {
120 // Reify in terms of function block arguments.
121 stopCondition = [](Value v, std::optional<int64_t> d,
122 ValueBoundsConstraintSet &cstr) {
123 auto bbArg = dyn_cast<BlockArgument>(Val&: v);
124 if (!bbArg)
125 return false;
126 return isa<FunctionOpInterface>(Val: bbArg.getParentBlock()->getParentOp());
127 };
128 }
129
130 // Reify value bound
131 rewriter.setInsertionPointAfter(op);
132 FailureOr<OpFoldResult> reified = failure();
133 if (constant) {
134 auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
135 type: boundType, var: {value, dim}, /*stopCondition=*/nullptr);
136 if (succeeded(reifiedConst))
137 reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(value: *reifiedConst));
138 } else if (scalable) {
139 auto loc = op->getLoc();
140 auto reifiedScalable =
141 vector::ScalableValueBoundsConstraintSet::computeScalableBound(
142 value, dim, vscaleMin: *op.getVscaleMin(), vscaleMax: *op.getVscaleMax(), boundType: boundType);
143 if (succeeded(reifiedScalable)) {
144 SmallVector<std::pair<Value, std::optional<int64_t>>, 1> vscaleOperand;
145 if (reifiedScalable->map.getNumInputs() == 1) {
146 // The only possible input to the bound is vscale.
147 vscaleOperand.push_back(std::make_pair(
148 rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
149 }
150 reified = affine::materializeComputedBound(
151 b&: rewriter, loc: loc, boundMap: reifiedScalable->map, mapOperands: vscaleOperand);
152 }
153 } else {
154 if (useArithOps) {
155 reified = arith::reifyValueBound(b&: rewriter, loc: op->getLoc(), type: boundType,
156 var: op.getVariable(), stopCondition);
157 } else {
158 reified = reifyValueBound(rewriter, op->getLoc(), boundType,
159 op.getVariable(), stopCondition);
160 }
161 }
162 if (failed(Result: reified)) {
163 op->emitOpError("could not reify bound");
164 return WalkResult::interrupt();
165 }
166
167 // Replace the op with the reified bound.
168 if (auto val = llvm::dyn_cast_if_present<Value>(Val&: *reified)) {
169 rewriter.replaceOp(op, val);
170 return WalkResult::skip();
171 }
172 Value constOp = rewriter.create<arith::ConstantIndexOp>(
173 op->getLoc(), cast<IntegerAttr>(cast<Attribute>(Val&: *reified)).getInt());
174 rewriter.replaceOp(op, constOp);
175 return WalkResult::skip();
176 });
177 return failure(IsFailure: result.wasInterrupted());
178}
179
180/// Look for "test.compare" ops and emit errors/remarks.
181static LogicalResult testEquality(FunctionOpInterface funcOp) {
182 IRRewriter rewriter(funcOp.getContext());
183 WalkResult result = funcOp.walk([&](test::CompareOp op) {
184 auto cmpType = op.getComparisonOperator();
185 if (op.getCompose()) {
186 if (cmpType != ValueBoundsConstraintSet::EQ) {
187 op->emitOpError(
188 "comparison operator must be EQ when 'composed' is specified");
189 return WalkResult::interrupt();
190 }
191 FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
192 value1: op->getOperand(0), value2: op->getOperand(1));
193 if (failed(Result: delta)) {
194 op->emitError("could not determine equality");
195 } else if (*delta == 0) {
196 op->emitRemark("equal");
197 } else {
198 op->emitRemark("different");
199 }
200 return WalkResult::advance();
201 }
202
203 auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
204 return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs());
205 };
206 if (compare(cmpType)) {
207 op->emitRemark("true");
208 } else if (cmpType != ValueBoundsConstraintSet::EQ &&
209 compare(invertComparisonOperator(cmpType))) {
210 op->emitRemark("false");
211 } else if (cmpType == ValueBoundsConstraintSet::EQ &&
212 (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
213 compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
214 op->emitRemark("false");
215 } else {
216 op->emitError("unknown");
217 }
218 return WalkResult::advance();
219 });
220 return failure(IsFailure: result.wasInterrupted());
221}
222
223void TestReifyValueBounds::runOnOperation() {
224 if (failed(
225 testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps)))
226 signalPassFailure();
227 if (failed(testEquality(getOperation())))
228 signalPassFailure();
229}
230
231namespace mlir {
232void registerTestAffineReifyValueBoundsPass() {
233 PassRegistration<TestReifyValueBounds>();
234}
235} // namespace mlir
236

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp