1//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
10
11#include "mlir/Dialect/SCF/IR/SCF.h"
12#include "mlir/Interfaces/ValueBoundsOpInterface.h"
13
14using namespace mlir;
15
16namespace mlir {
17namespace scf {
18namespace {
19
20struct ForOpInterface
21 : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
22
23 /// Populate bounds of values/dimensions for iter_args/OpResults. If the
24 /// value/dimension size does not change in an iteration, we can deduce that
25 /// it the same as the initial value/dimension.
26 ///
27 /// Example 1:
28 /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
29 /// ...
30 /// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
31 /// scf.yield %1 : tensor<?xf32>
32 /// }
33 /// --> bound(%0)[0] == bound(%t)[0]
34 /// --> bound(%arg0)[0] == bound(%t)[0]
35 ///
36 /// Example 2:
37 /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
38 /// %sz = tensor.dim %arg0 : tensor<?xf32>
39 /// %incr = arith.addi %sz, %c1 : index
40 /// %1 = tensor.empty(%incr) : tensor<?xf32>
41 /// scf.yield %1 : tensor<?xf32>
42 /// }
43 /// --> The yielded tensor dimension size changes with each iteration. Such
44 /// loops are not supported and no constraints are added.
45 static void populateIterArgBounds(scf::ForOp forOp, Value value,
46 std::optional<int64_t> dim,
47 ValueBoundsConstraintSet &cstr) {
48 // `value` is an iter_arg or an OpResult.
49 int64_t iterArgIdx;
50 if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
51 iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
52 } else {
53 iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
54 }
55
56 Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
57 .getOperand(iterArgIdx);
58 Value iterArg = forOp.getRegionIterArg(iterArgIdx);
59 Value initArg = forOp.getInitArgs()[iterArgIdx];
60
61 // An EQ constraint can be added if the yielded value (dimension size)
62 // equals the corresponding block argument (dimension size).
63 if (cstr.populateAndCompare(
64 /*lhs=*/{yieldedValue, dim},
65 cmp: ValueBoundsConstraintSet::ComparisonOperator::EQ,
66 /*rhs=*/{iterArg, dim})) {
67 if (dim.has_value()) {
68 cstr.bound(value)[*dim] == cstr.getExpr(value: initArg, dim);
69 } else {
70 cstr.bound(value) == cstr.getExpr(value: initArg);
71 }
72 }
73 }
74
75 void populateBoundsForIndexValue(Operation *op, Value value,
76 ValueBoundsConstraintSet &cstr) const {
77 auto forOp = cast<ForOp>(op);
78
79 if (value == forOp.getInductionVar()) {
80 // TODO: Take into account step size.
81 cstr.bound(value) >= forOp.getLowerBound();
82 cstr.bound(value) < forOp.getUpperBound();
83 return;
84 }
85
86 // Handle iter_args and OpResults.
87 populateIterArgBounds(forOp, value, std::nullopt, cstr);
88 }
89
90 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
91 ValueBoundsConstraintSet &cstr) const {
92 auto forOp = cast<ForOp>(op);
93 // Handle iter_args and OpResults.
94 populateIterArgBounds(forOp, value, dim, cstr);
95 }
96};
97
98struct IfOpInterface
99 : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
100
101 static void populateBounds(scf::IfOp ifOp, Value value,
102 std::optional<int64_t> dim,
103 ValueBoundsConstraintSet &cstr) {
104 unsigned int resultNum = cast<OpResult>(value).getResultNumber();
105 Value thenValue = ifOp.thenYield().getResults()[resultNum];
106 Value elseValue = ifOp.elseYield().getResults()[resultNum];
107
108 auto boundsBuilder = cstr.bound(value);
109 if (dim)
110 boundsBuilder[*dim];
111
112 // Compare yielded values.
113 // If thenValue <= elseValue:
114 // * result <= elseValue
115 // * result >= thenValue
116 if (cstr.populateAndCompare(
117 /*lhs=*/{thenValue, dim},
118 cmp: ValueBoundsConstraintSet::ComparisonOperator::LE,
119 /*rhs=*/{elseValue, dim})) {
120 if (dim) {
121 cstr.bound(value)[*dim] >= cstr.getExpr(value: thenValue, dim);
122 cstr.bound(value)[*dim] <= cstr.getExpr(value: elseValue, dim);
123 } else {
124 cstr.bound(value) >= thenValue;
125 cstr.bound(value) <= elseValue;
126 }
127 }
128 // If elseValue <= thenValue:
129 // * result <= thenValue
130 // * result >= elseValue
131 if (cstr.populateAndCompare(
132 /*lhs=*/{elseValue, dim},
133 cmp: ValueBoundsConstraintSet::ComparisonOperator::LE,
134 /*rhs=*/{thenValue, dim})) {
135 if (dim) {
136 cstr.bound(value)[*dim] >= cstr.getExpr(value: elseValue, dim);
137 cstr.bound(value)[*dim] <= cstr.getExpr(value: thenValue, dim);
138 } else {
139 cstr.bound(value) >= elseValue;
140 cstr.bound(value) <= thenValue;
141 }
142 }
143 }
144
145 void populateBoundsForIndexValue(Operation *op, Value value,
146 ValueBoundsConstraintSet &cstr) const {
147 populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr);
148 }
149
150 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
151 ValueBoundsConstraintSet &cstr) const {
152 populateBounds(cast<IfOp>(op), value, dim, cstr);
153 }
154};
155
156} // namespace
157} // namespace scf
158} // namespace mlir
159
160void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
161 DialectRegistry &registry) {
162 registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) {
163 scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
164 scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
165 });
166}
167

source code of mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp