1//===- TestReifyValueBounds.cpp - Test value bounds reification -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestOps.h"
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
13#include "mlir/Dialect/Affine/Transforms/Transforms.h"
14#include "mlir/Dialect/Arith/Transforms/Transforms.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/Tensor/IR/Tensor.h"
18#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Interfaces/ValueBoundsOpInterface.h"
21#include "mlir/Pass/Pass.h"
22
23#define PASS_NAME "test-affine-reify-value-bounds"
24
25using namespace mlir;
26using namespace mlir::affine;
27using mlir::presburger::BoundType;
28
29namespace {
30
31/// This pass applies the permutation on the first maximal perfect nest.
32struct TestReifyValueBounds
33 : public PassWrapper<TestReifyValueBounds, OperationPass<func::FuncOp>> {
34 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReifyValueBounds)
35
36 StringRef getArgument() const final { return PASS_NAME; }
37 StringRef getDescription() const final {
38 return "Tests ValueBoundsOpInterface with affine dialect reification";
39 }
40 TestReifyValueBounds() = default;
41 TestReifyValueBounds(const TestReifyValueBounds &pass) : PassWrapper(pass){};
42
43 void getDependentDialects(DialectRegistry &registry) const override {
44 registry.insert<affine::AffineDialect, tensor::TensorDialect,
45 memref::MemRefDialect>();
46 }
47
48 void runOnOperation() override;
49
50private:
51 Option<bool> reifyToFuncArgs{
52 *this, "reify-to-func-args",
53 llvm::cl::desc("Reify in terms of function args"), llvm::cl::init(false)};
54
55 Option<bool> useArithOps{*this, "use-arith-ops",
56 llvm::cl::desc("Reify with arith dialect ops"),
57 llvm::cl::init(false)};
58};
59
60} // namespace
61
62static ValueBoundsConstraintSet::ComparisonOperator
63invertComparisonOperator(ValueBoundsConstraintSet::ComparisonOperator cmp) {
64 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LT)
65 return ValueBoundsConstraintSet::ComparisonOperator::GE;
66 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::LE)
67 return ValueBoundsConstraintSet::ComparisonOperator::GT;
68 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GT)
69 return ValueBoundsConstraintSet::ComparisonOperator::LE;
70 if (cmp == ValueBoundsConstraintSet::ComparisonOperator::GE)
71 return ValueBoundsConstraintSet::ComparisonOperator::LT;
72 llvm_unreachable("unsupported comparison operator");
73}
74
75/// Look for "test.reify_bound" ops in the input and replace their results with
76/// the reified values.
77static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
78 bool reifyToFuncArgs,
79 bool useArithOps) {
80 IRRewriter rewriter(funcOp.getContext());
81 WalkResult result = funcOp.walk([&](test::ReifyBoundOp op) {
82 auto boundType = op.getBoundType();
83 Value value = op.getVar();
84 std::optional<int64_t> dim = op.getDim();
85 bool constant = op.getConstant();
86 bool scalable = op.getScalable();
87
88 // Prepare stop condition. By default, reify in terms of the op's
89 // operands. No stop condition is used when a constant was requested.
90 std::function<bool(Value, std::optional<int64_t>,
91 ValueBoundsConstraintSet & cstr)>
92 stopCondition = [&](Value v, std::optional<int64_t> d,
93 ValueBoundsConstraintSet &cstr) {
94 // Reify in terms of SSA values that are different from `value`.
95 return v != value;
96 };
97 if (reifyToFuncArgs) {
98 // Reify in terms of function block arguments.
99 stopCondition = [](Value v, std::optional<int64_t> d,
100 ValueBoundsConstraintSet &cstr) {
101 auto bbArg = dyn_cast<BlockArgument>(Val&: v);
102 if (!bbArg)
103 return false;
104 return isa<FunctionOpInterface>(Val: bbArg.getParentBlock()->getParentOp());
105 };
106 }
107
108 // Reify value bound
109 rewriter.setInsertionPointAfter(op);
110 FailureOr<OpFoldResult> reified = failure();
111 if (constant) {
112 auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
113 type: boundType, var: {value, dim}, /*stopCondition=*/nullptr);
114 if (succeeded(reifiedConst))
115 reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(value: *reifiedConst));
116 } else if (scalable) {
117 auto loc = op->getLoc();
118 auto reifiedScalable =
119 vector::ScalableValueBoundsConstraintSet::computeScalableBound(
120 value, dim, vscaleMin: *op.getVscaleMin(), vscaleMax: *op.getVscaleMax(), boundType: boundType);
121 if (succeeded(reifiedScalable)) {
122 SmallVector<std::pair<Value, std::optional<int64_t>>, 1> vscaleOperand;
123 if (reifiedScalable->map.getNumInputs() == 1) {
124 // The only possible input to the bound is vscale.
125 vscaleOperand.push_back(std::make_pair(
126 rewriter.create<vector::VectorScaleOp>(loc), std::nullopt));
127 }
128 reified = affine::materializeComputedBound(
129 b&: rewriter, loc: loc, boundMap: reifiedScalable->map, mapOperands: vscaleOperand);
130 }
131 } else {
132 if (useArithOps) {
133 reified = arith::reifyValueBound(b&: rewriter, loc: op->getLoc(), type: boundType,
134 var: op.getVariable(), stopCondition);
135 } else {
136 reified = reifyValueBound(rewriter, op->getLoc(), boundType,
137 op.getVariable(), stopCondition);
138 }
139 }
140 if (failed(result: reified)) {
141 op->emitOpError("could not reify bound");
142 return WalkResult::interrupt();
143 }
144
145 // Replace the op with the reified bound.
146 if (auto val = llvm::dyn_cast_if_present<Value>(Val&: *reified)) {
147 rewriter.replaceOp(op, val);
148 return WalkResult::skip();
149 }
150 Value constOp = rewriter.create<arith::ConstantIndexOp>(
151 op->getLoc(), cast<IntegerAttr>(reified->get<Attribute>()).getInt());
152 rewriter.replaceOp(op, constOp);
153 return WalkResult::skip();
154 });
155 return failure(isFailure: result.wasInterrupted());
156}
157
158/// Look for "test.compare" ops and emit errors/remarks.
159static LogicalResult testEquality(func::FuncOp funcOp) {
160 IRRewriter rewriter(funcOp.getContext());
161 WalkResult result = funcOp.walk([&](test::CompareOp op) {
162 auto cmpType = op.getComparisonOperator();
163 if (op.getCompose()) {
164 if (cmpType != ValueBoundsConstraintSet::EQ) {
165 op->emitOpError(
166 "comparison operator must be EQ when 'composed' is specified");
167 return WalkResult::interrupt();
168 }
169 FailureOr<int64_t> delta = affine::fullyComposeAndComputeConstantDelta(
170 value1: op->getOperand(0), value2: op->getOperand(1));
171 if (failed(result: delta)) {
172 op->emitError("could not determine equality");
173 } else if (*delta == 0) {
174 op->emitRemark("equal");
175 } else {
176 op->emitRemark("different");
177 }
178 return WalkResult::advance();
179 }
180
181 auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
182 return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs());
183 };
184 if (compare(cmpType)) {
185 op->emitRemark("true");
186 } else if (cmpType != ValueBoundsConstraintSet::EQ &&
187 compare(invertComparisonOperator(cmpType))) {
188 op->emitRemark("false");
189 } else if (cmpType == ValueBoundsConstraintSet::EQ &&
190 (compare(ValueBoundsConstraintSet::ComparisonOperator::LT) ||
191 compare(ValueBoundsConstraintSet::ComparisonOperator::GT))) {
192 op->emitRemark("false");
193 } else {
194 op->emitError("unknown");
195 }
196 return WalkResult::advance();
197 });
198 return failure(isFailure: result.wasInterrupted());
199}
200
201void TestReifyValueBounds::runOnOperation() {
202 if (failed(
203 testReifyValueBounds(getOperation(), reifyToFuncArgs, useArithOps)))
204 signalPassFailure();
205 if (failed(testEquality(getOperation())))
206 signalPassFailure();
207}
208
209namespace mlir {
210void registerTestAffineReifyValueBoundsPass() {
211 PassRegistration<TestReifyValueBounds>();
212}
213} // namespace mlir
214

source code of mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp