1 | //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/SCF/IR/SCF.h" |
12 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | namespace mlir { |
17 | namespace scf { |
18 | namespace { |
19 | |
20 | struct ForOpInterface |
21 | : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> { |
22 | |
23 | static AffineExpr getTripCountExpr(scf::ForOp forOp, |
24 | ValueBoundsConstraintSet &cstr) { |
25 | AffineExpr lbExpr = cstr.getExpr(forOp.getLowerBound()); |
26 | AffineExpr ubExpr = cstr.getExpr(forOp.getUpperBound()); |
27 | AffineExpr stepExpr = cstr.getExpr(forOp.getStep()); |
28 | AffineExpr tripCountExpr = |
29 | AffineExpr(ubExpr - lbExpr).ceilDiv(other: stepExpr); // (ub - lb) / step |
30 | return tripCountExpr; |
31 | } |
32 | |
33 | /// Populate bounds of values/dimensions for iter_args/OpResults. If the |
34 | /// value/dimension size does not change in an iteration, we can deduce that |
35 | /// it the same as the initial value/dimension. |
36 | /// |
37 | /// Example 1: |
38 | /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> { |
39 | /// ... |
40 | /// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32> |
41 | /// scf.yield %1 : tensor<?xf32> |
42 | /// } |
43 | /// --> bound(%0)[0] == bound(%t)[0] |
44 | /// --> bound(%arg0)[0] == bound(%t)[0] |
45 | /// |
46 | /// Example 2: |
47 | /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> { |
48 | /// %sz = tensor.dim %arg0 : tensor<?xf32> |
49 | /// %incr = arith.addi %sz, %c1 : index |
50 | /// %1 = tensor.empty(%incr) : tensor<?xf32> |
51 | /// scf.yield %1 : tensor<?xf32> |
52 | /// } |
53 | /// --> The yielded tensor dimension size changes with each iteration. Such |
54 | /// loops are not supported and no constraints are added. |
55 | static void populateIterArgBounds(scf::ForOp forOp, Value value, |
56 | std::optional<int64_t> dim, |
57 | ValueBoundsConstraintSet &cstr) { |
58 | // `value` is an iter_arg or an OpResult. |
59 | int64_t iterArgIdx; |
60 | if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) { |
61 | iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars(); |
62 | } else { |
63 | iterArgIdx = llvm::cast<OpResult>(value).getResultNumber(); |
64 | } |
65 | |
66 | Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator()) |
67 | .getOperand(iterArgIdx); |
68 | Value iterArg = forOp.getRegionIterArg(iterArgIdx); |
69 | Value initArg = forOp.getInitArgs()[iterArgIdx]; |
70 | |
71 | // An EQ constraint can be added if the yielded value (dimension size) |
72 | // equals the corresponding block argument (dimension size). |
73 | if (cstr.populateAndCompare( |
74 | /*lhs=*/{yieldedValue, dim}, |
75 | cmp: ValueBoundsConstraintSet::ComparisonOperator::EQ, |
76 | /*rhs=*/{iterArg, dim})) { |
77 | if (dim.has_value()) { |
78 | cstr.bound(value)[*dim] == cstr.getExpr(value: initArg, dim); |
79 | } else { |
80 | cstr.bound(value) == cstr.getExpr(value: initArg); |
81 | } |
82 | } |
83 | |
84 | if (dim.has_value() || isa<BlockArgument>(value)) |
85 | return; |
86 | |
87 | // `value` is result of `forOp`, we can prove that: |
88 | // %result == %init_arg + trip_count * (%yielded_value - %iter_arg). |
89 | // Where trip_count is (ub - lb) / step. |
90 | AffineExpr tripCountExpr = getTripCountExpr(forOp, cstr); |
91 | AffineExpr oneIterAdvanceExpr = |
92 | cstr.getExpr(value: yieldedValue) - cstr.getExpr(value: iterArg); |
93 | cstr.bound(value) == |
94 | cstr.getExpr(value: initArg) + AffineExpr(tripCountExpr * oneIterAdvanceExpr); |
95 | } |
96 | |
97 | void populateBoundsForIndexValue(Operation *op, Value value, |
98 | ValueBoundsConstraintSet &cstr) const { |
99 | auto forOp = cast<ForOp>(op); |
100 | |
101 | if (value == forOp.getInductionVar()) { |
102 | cstr.bound(value) >= forOp.getLowerBound(); |
103 | cstr.bound(value) < forOp.getUpperBound(); |
104 | // iv <= lb + ((ub-lb)/step - 1) * step |
105 | // This bound does not replace the `iv < ub` constraint mentioned above, |
106 | // since constraints involving the multiplication of two constraint set |
107 | // dimensions are not supported. |
108 | AffineExpr tripCountMinusOne = |
109 | getTripCountExpr(forOp, cstr) - cstr.getExpr(constant: 1); |
110 | AffineExpr computedUpperBound = |
111 | cstr.getExpr(forOp.getLowerBound()) + |
112 | AffineExpr(tripCountMinusOne * cstr.getExpr(forOp.getStep())); |
113 | cstr.bound(value) <= computedUpperBound; |
114 | return; |
115 | } |
116 | |
117 | // Handle iter_args and OpResults. |
118 | populateIterArgBounds(forOp, value, std::nullopt, cstr); |
119 | } |
120 | |
121 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
122 | ValueBoundsConstraintSet &cstr) const { |
123 | auto forOp = cast<ForOp>(op); |
124 | // Handle iter_args and OpResults. |
125 | populateIterArgBounds(forOp, value, dim, cstr); |
126 | } |
127 | }; |
128 | |
129 | struct ForallOpInterface |
130 | : public ValueBoundsOpInterface::ExternalModel<ForallOpInterface, |
131 | ForallOp> { |
132 | |
133 | void populateBoundsForIndexValue(Operation *op, Value value, |
134 | ValueBoundsConstraintSet &cstr) const { |
135 | auto forallOp = cast<ForallOp>(op); |
136 | |
137 | // Index values should be induction variables, since the semantics of |
138 | // tensor::ParallelInsertSliceOp requires forall outputs to be ranked |
139 | // tensors. |
140 | auto blockArg = cast<BlockArgument>(value); |
141 | assert(blockArg.getArgNumber() < forallOp.getInductionVars().size() && |
142 | "expected index value to be an induction var" ); |
143 | int64_t idx = blockArg.getArgNumber(); |
144 | // TODO: Take into account step size. |
145 | AffineExpr lb = cstr.getExpr(forallOp.getMixedLowerBound()[idx]); |
146 | AffineExpr ub = cstr.getExpr(forallOp.getMixedUpperBound()[idx]); |
147 | cstr.bound(value) >= lb; |
148 | cstr.bound(value) < ub; |
149 | } |
150 | |
151 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
152 | ValueBoundsConstraintSet &cstr) const { |
153 | auto forallOp = cast<ForallOp>(op); |
154 | |
155 | // `value` is an iter_arg or an OpResult. |
156 | int64_t iterArgIdx; |
157 | if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) { |
158 | iterArgIdx = iterArg.getArgNumber() - forallOp.getInductionVars().size(); |
159 | } else { |
160 | iterArgIdx = llvm::cast<OpResult>(value).getResultNumber(); |
161 | } |
162 | |
163 | // The forall results and output arguments have the same sizes as the output |
164 | // operands. |
165 | Value outputOperand = forallOp.getOutputs()[iterArgIdx]; |
166 | cstr.bound(value)[dim] == cstr.getExpr(outputOperand, dim); |
167 | } |
168 | }; |
169 | |
170 | struct IfOpInterface |
171 | : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> { |
172 | |
173 | static void populateBounds(scf::IfOp ifOp, Value value, |
174 | std::optional<int64_t> dim, |
175 | ValueBoundsConstraintSet &cstr) { |
176 | unsigned int resultNum = cast<OpResult>(value).getResultNumber(); |
177 | Value thenValue = ifOp.thenYield().getResults()[resultNum]; |
178 | Value elseValue = ifOp.elseYield().getResults()[resultNum]; |
179 | |
180 | auto boundsBuilder = cstr.bound(value); |
181 | if (dim) |
182 | boundsBuilder[*dim]; |
183 | |
184 | // Compare yielded values. |
185 | // If thenValue <= elseValue: |
186 | // * result <= elseValue |
187 | // * result >= thenValue |
188 | if (cstr.populateAndCompare( |
189 | /*lhs=*/{thenValue, dim}, |
190 | cmp: ValueBoundsConstraintSet::ComparisonOperator::LE, |
191 | /*rhs=*/{elseValue, dim})) { |
192 | if (dim) { |
193 | cstr.bound(value)[*dim] >= cstr.getExpr(value: thenValue, dim); |
194 | cstr.bound(value)[*dim] <= cstr.getExpr(value: elseValue, dim); |
195 | } else { |
196 | cstr.bound(value) >= thenValue; |
197 | cstr.bound(value) <= elseValue; |
198 | } |
199 | } |
200 | // If elseValue <= thenValue: |
201 | // * result <= thenValue |
202 | // * result >= elseValue |
203 | if (cstr.populateAndCompare( |
204 | /*lhs=*/{elseValue, dim}, |
205 | cmp: ValueBoundsConstraintSet::ComparisonOperator::LE, |
206 | /*rhs=*/{thenValue, dim})) { |
207 | if (dim) { |
208 | cstr.bound(value)[*dim] >= cstr.getExpr(value: elseValue, dim); |
209 | cstr.bound(value)[*dim] <= cstr.getExpr(value: thenValue, dim); |
210 | } else { |
211 | cstr.bound(value) >= elseValue; |
212 | cstr.bound(value) <= thenValue; |
213 | } |
214 | } |
215 | } |
216 | |
217 | void populateBoundsForIndexValue(Operation *op, Value value, |
218 | ValueBoundsConstraintSet &cstr) const { |
219 | populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr); |
220 | } |
221 | |
222 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
223 | ValueBoundsConstraintSet &cstr) const { |
224 | populateBounds(cast<IfOp>(op), value, dim, cstr); |
225 | } |
226 | }; |
227 | |
228 | } // namespace |
229 | } // namespace scf |
230 | } // namespace mlir |
231 | |
232 | void mlir::scf::registerValueBoundsOpInterfaceExternalModels( |
233 | DialectRegistry ®istry) { |
234 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) { |
235 | scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx); |
236 | scf::ForallOp::attachInterface<scf::ForallOpInterface>(*ctx); |
237 | scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx); |
238 | }); |
239 | } |
240 | |