1//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
10
11#include "mlir/Dialect/SCF/IR/SCF.h"
12#include "mlir/Interfaces/ValueBoundsOpInterface.h"
13
14using namespace mlir;
15
16namespace mlir {
17namespace scf {
18namespace {
19
20struct ForOpInterface
21 : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
22
23 static AffineExpr getTripCountExpr(scf::ForOp forOp,
24 ValueBoundsConstraintSet &cstr) {
25 AffineExpr lbExpr = cstr.getExpr(forOp.getLowerBound());
26 AffineExpr ubExpr = cstr.getExpr(forOp.getUpperBound());
27 AffineExpr stepExpr = cstr.getExpr(forOp.getStep());
28 AffineExpr tripCountExpr =
29 AffineExpr(ubExpr - lbExpr).ceilDiv(other: stepExpr); // (ub - lb) / step
30 return tripCountExpr;
31 }
32
33 /// Populate bounds of values/dimensions for iter_args/OpResults. If the
34 /// value/dimension size does not change in an iteration, we can deduce that
35 /// it the same as the initial value/dimension.
36 ///
37 /// Example 1:
38 /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
39 /// ...
40 /// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
41 /// scf.yield %1 : tensor<?xf32>
42 /// }
43 /// --> bound(%0)[0] == bound(%t)[0]
44 /// --> bound(%arg0)[0] == bound(%t)[0]
45 ///
46 /// Example 2:
47 /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
48 /// %sz = tensor.dim %arg0 : tensor<?xf32>
49 /// %incr = arith.addi %sz, %c1 : index
50 /// %1 = tensor.empty(%incr) : tensor<?xf32>
51 /// scf.yield %1 : tensor<?xf32>
52 /// }
53 /// --> The yielded tensor dimension size changes with each iteration. Such
54 /// loops are not supported and no constraints are added.
55 static void populateIterArgBounds(scf::ForOp forOp, Value value,
56 std::optional<int64_t> dim,
57 ValueBoundsConstraintSet &cstr) {
58 // `value` is an iter_arg or an OpResult.
59 int64_t iterArgIdx;
60 if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
61 iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
62 } else {
63 iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
64 }
65
66 Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
67 .getOperand(iterArgIdx);
68 Value iterArg = forOp.getRegionIterArg(iterArgIdx);
69 Value initArg = forOp.getInitArgs()[iterArgIdx];
70
71 // An EQ constraint can be added if the yielded value (dimension size)
72 // equals the corresponding block argument (dimension size).
73 if (cstr.populateAndCompare(
74 /*lhs=*/{yieldedValue, dim},
75 cmp: ValueBoundsConstraintSet::ComparisonOperator::EQ,
76 /*rhs=*/{iterArg, dim})) {
77 if (dim.has_value()) {
78 cstr.bound(value)[*dim] == cstr.getExpr(value: initArg, dim);
79 } else {
80 cstr.bound(value) == cstr.getExpr(value: initArg);
81 }
82 }
83
84 if (dim.has_value() || isa<BlockArgument>(value))
85 return;
86
87 // `value` is result of `forOp`, we can prove that:
88 // %result == %init_arg + trip_count * (%yielded_value - %iter_arg).
89 // Where trip_count is (ub - lb) / step.
90 AffineExpr tripCountExpr = getTripCountExpr(forOp, cstr);
91 AffineExpr oneIterAdvanceExpr =
92 cstr.getExpr(value: yieldedValue) - cstr.getExpr(value: iterArg);
93 cstr.bound(value) ==
94 cstr.getExpr(value: initArg) + AffineExpr(tripCountExpr * oneIterAdvanceExpr);
95 }
96
97 void populateBoundsForIndexValue(Operation *op, Value value,
98 ValueBoundsConstraintSet &cstr) const {
99 auto forOp = cast<ForOp>(op);
100
101 if (value == forOp.getInductionVar()) {
102 cstr.bound(value) >= forOp.getLowerBound();
103 cstr.bound(value) < forOp.getUpperBound();
104 // iv <= lb + ((ub-lb)/step - 1) * step
105 // This bound does not replace the `iv < ub` constraint mentioned above,
106 // since constraints involving the multiplication of two constraint set
107 // dimensions are not supported.
108 AffineExpr tripCountMinusOne =
109 getTripCountExpr(forOp, cstr) - cstr.getExpr(constant: 1);
110 AffineExpr computedUpperBound =
111 cstr.getExpr(forOp.getLowerBound()) +
112 AffineExpr(tripCountMinusOne * cstr.getExpr(forOp.getStep()));
113 cstr.bound(value) <= computedUpperBound;
114 return;
115 }
116
117 // Handle iter_args and OpResults.
118 populateIterArgBounds(forOp, value, std::nullopt, cstr);
119 }
120
121 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
122 ValueBoundsConstraintSet &cstr) const {
123 auto forOp = cast<ForOp>(op);
124 // Handle iter_args and OpResults.
125 populateIterArgBounds(forOp, value, dim, cstr);
126 }
127};
128
129struct ForallOpInterface
130 : public ValueBoundsOpInterface::ExternalModel<ForallOpInterface,
131 ForallOp> {
132
133 void populateBoundsForIndexValue(Operation *op, Value value,
134 ValueBoundsConstraintSet &cstr) const {
135 auto forallOp = cast<ForallOp>(op);
136
137 // Index values should be induction variables, since the semantics of
138 // tensor::ParallelInsertSliceOp requires forall outputs to be ranked
139 // tensors.
140 auto blockArg = cast<BlockArgument>(value);
141 assert(blockArg.getArgNumber() < forallOp.getInductionVars().size() &&
142 "expected index value to be an induction var");
143 int64_t idx = blockArg.getArgNumber();
144 // TODO: Take into account step size.
145 AffineExpr lb = cstr.getExpr(forallOp.getMixedLowerBound()[idx]);
146 AffineExpr ub = cstr.getExpr(forallOp.getMixedUpperBound()[idx]);
147 cstr.bound(value) >= lb;
148 cstr.bound(value) < ub;
149 }
150
151 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
152 ValueBoundsConstraintSet &cstr) const {
153 auto forallOp = cast<ForallOp>(op);
154
155 // `value` is an iter_arg or an OpResult.
156 int64_t iterArgIdx;
157 if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
158 iterArgIdx = iterArg.getArgNumber() - forallOp.getInductionVars().size();
159 } else {
160 iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
161 }
162
163 // The forall results and output arguments have the same sizes as the output
164 // operands.
165 Value outputOperand = forallOp.getOutputs()[iterArgIdx];
166 cstr.bound(value)[dim] == cstr.getExpr(outputOperand, dim);
167 }
168};
169
170struct IfOpInterface
171 : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
172
173 static void populateBounds(scf::IfOp ifOp, Value value,
174 std::optional<int64_t> dim,
175 ValueBoundsConstraintSet &cstr) {
176 unsigned int resultNum = cast<OpResult>(value).getResultNumber();
177 Value thenValue = ifOp.thenYield().getResults()[resultNum];
178 Value elseValue = ifOp.elseYield().getResults()[resultNum];
179
180 auto boundsBuilder = cstr.bound(value);
181 if (dim)
182 boundsBuilder[*dim];
183
184 // Compare yielded values.
185 // If thenValue <= elseValue:
186 // * result <= elseValue
187 // * result >= thenValue
188 if (cstr.populateAndCompare(
189 /*lhs=*/{thenValue, dim},
190 cmp: ValueBoundsConstraintSet::ComparisonOperator::LE,
191 /*rhs=*/{elseValue, dim})) {
192 if (dim) {
193 cstr.bound(value)[*dim] >= cstr.getExpr(value: thenValue, dim);
194 cstr.bound(value)[*dim] <= cstr.getExpr(value: elseValue, dim);
195 } else {
196 cstr.bound(value) >= thenValue;
197 cstr.bound(value) <= elseValue;
198 }
199 }
200 // If elseValue <= thenValue:
201 // * result <= thenValue
202 // * result >= elseValue
203 if (cstr.populateAndCompare(
204 /*lhs=*/{elseValue, dim},
205 cmp: ValueBoundsConstraintSet::ComparisonOperator::LE,
206 /*rhs=*/{thenValue, dim})) {
207 if (dim) {
208 cstr.bound(value)[*dim] >= cstr.getExpr(value: elseValue, dim);
209 cstr.bound(value)[*dim] <= cstr.getExpr(value: thenValue, dim);
210 } else {
211 cstr.bound(value) >= elseValue;
212 cstr.bound(value) <= thenValue;
213 }
214 }
215 }
216
217 void populateBoundsForIndexValue(Operation *op, Value value,
218 ValueBoundsConstraintSet &cstr) const {
219 populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr);
220 }
221
222 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
223 ValueBoundsConstraintSet &cstr) const {
224 populateBounds(cast<IfOp>(op), value, dim, cstr);
225 }
226};
227
228} // namespace
229} // namespace scf
230} // namespace mlir
231
232void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
233 DialectRegistry &registry) {
234 registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) {
235 scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
236 scf::ForallOp::attachInterface<scf::ForallOpInterface>(*ctx);
237 scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
238 });
239}
240

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp