1 | //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/SCF/IR/SCF.h" |
12 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | namespace mlir { |
17 | namespace scf { |
18 | namespace { |
19 | |
20 | struct ForOpInterface |
21 | : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> { |
22 | |
23 | /// Populate bounds of values/dimensions for iter_args/OpResults. If the |
24 | /// value/dimension size does not change in an iteration, we can deduce that |
25 | /// it the same as the initial value/dimension. |
26 | /// |
27 | /// Example 1: |
28 | /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> { |
29 | /// ... |
30 | /// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32> |
31 | /// scf.yield %1 : tensor<?xf32> |
32 | /// } |
33 | /// --> bound(%0)[0] == bound(%t)[0] |
34 | /// --> bound(%arg0)[0] == bound(%t)[0] |
35 | /// |
36 | /// Example 2: |
37 | /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> { |
38 | /// %sz = tensor.dim %arg0 : tensor<?xf32> |
39 | /// %incr = arith.addi %sz, %c1 : index |
40 | /// %1 = tensor.empty(%incr) : tensor<?xf32> |
41 | /// scf.yield %1 : tensor<?xf32> |
42 | /// } |
43 | /// --> The yielded tensor dimension size changes with each iteration. Such |
44 | /// loops are not supported and no constraints are added. |
45 | static void populateIterArgBounds(scf::ForOp forOp, Value value, |
46 | std::optional<int64_t> dim, |
47 | ValueBoundsConstraintSet &cstr) { |
48 | // `value` is an iter_arg or an OpResult. |
49 | int64_t iterArgIdx; |
50 | if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) { |
51 | iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars(); |
52 | } else { |
53 | iterArgIdx = llvm::cast<OpResult>(value).getResultNumber(); |
54 | } |
55 | |
56 | Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator()) |
57 | .getOperand(iterArgIdx); |
58 | Value iterArg = forOp.getRegionIterArg(iterArgIdx); |
59 | Value initArg = forOp.getInitArgs()[iterArgIdx]; |
60 | |
61 | // An EQ constraint can be added if the yielded value (dimension size) |
62 | // equals the corresponding block argument (dimension size). |
63 | if (cstr.populateAndCompare( |
64 | /*lhs=*/{yieldedValue, dim}, |
65 | cmp: ValueBoundsConstraintSet::ComparisonOperator::EQ, |
66 | /*rhs=*/{iterArg, dim})) { |
67 | if (dim.has_value()) { |
68 | cstr.bound(value)[*dim] == cstr.getExpr(value: initArg, dim); |
69 | } else { |
70 | cstr.bound(value) == cstr.getExpr(value: initArg); |
71 | } |
72 | } |
73 | } |
74 | |
75 | void populateBoundsForIndexValue(Operation *op, Value value, |
76 | ValueBoundsConstraintSet &cstr) const { |
77 | auto forOp = cast<ForOp>(op); |
78 | |
79 | if (value == forOp.getInductionVar()) { |
80 | // TODO: Take into account step size. |
81 | cstr.bound(value) >= forOp.getLowerBound(); |
82 | cstr.bound(value) < forOp.getUpperBound(); |
83 | return; |
84 | } |
85 | |
86 | // Handle iter_args and OpResults. |
87 | populateIterArgBounds(forOp, value, std::nullopt, cstr); |
88 | } |
89 | |
90 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
91 | ValueBoundsConstraintSet &cstr) const { |
92 | auto forOp = cast<ForOp>(op); |
93 | // Handle iter_args and OpResults. |
94 | populateIterArgBounds(forOp, value, dim, cstr); |
95 | } |
96 | }; |
97 | |
98 | struct IfOpInterface |
99 | : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> { |
100 | |
101 | static void populateBounds(scf::IfOp ifOp, Value value, |
102 | std::optional<int64_t> dim, |
103 | ValueBoundsConstraintSet &cstr) { |
104 | unsigned int resultNum = cast<OpResult>(value).getResultNumber(); |
105 | Value thenValue = ifOp.thenYield().getResults()[resultNum]; |
106 | Value elseValue = ifOp.elseYield().getResults()[resultNum]; |
107 | |
108 | auto boundsBuilder = cstr.bound(value); |
109 | if (dim) |
110 | boundsBuilder[*dim]; |
111 | |
112 | // Compare yielded values. |
113 | // If thenValue <= elseValue: |
114 | // * result <= elseValue |
115 | // * result >= thenValue |
116 | if (cstr.populateAndCompare( |
117 | /*lhs=*/{thenValue, dim}, |
118 | cmp: ValueBoundsConstraintSet::ComparisonOperator::LE, |
119 | /*rhs=*/{elseValue, dim})) { |
120 | if (dim) { |
121 | cstr.bound(value)[*dim] >= cstr.getExpr(value: thenValue, dim); |
122 | cstr.bound(value)[*dim] <= cstr.getExpr(value: elseValue, dim); |
123 | } else { |
124 | cstr.bound(value) >= thenValue; |
125 | cstr.bound(value) <= elseValue; |
126 | } |
127 | } |
128 | // If elseValue <= thenValue: |
129 | // * result <= thenValue |
130 | // * result >= elseValue |
131 | if (cstr.populateAndCompare( |
132 | /*lhs=*/{elseValue, dim}, |
133 | cmp: ValueBoundsConstraintSet::ComparisonOperator::LE, |
134 | /*rhs=*/{thenValue, dim})) { |
135 | if (dim) { |
136 | cstr.bound(value)[*dim] >= cstr.getExpr(value: elseValue, dim); |
137 | cstr.bound(value)[*dim] <= cstr.getExpr(value: thenValue, dim); |
138 | } else { |
139 | cstr.bound(value) >= elseValue; |
140 | cstr.bound(value) <= thenValue; |
141 | } |
142 | } |
143 | } |
144 | |
145 | void populateBoundsForIndexValue(Operation *op, Value value, |
146 | ValueBoundsConstraintSet &cstr) const { |
147 | populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr); |
148 | } |
149 | |
150 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
151 | ValueBoundsConstraintSet &cstr) const { |
152 | populateBounds(cast<IfOp>(op), value, dim, cstr); |
153 | } |
154 | }; |
155 | |
156 | } // namespace |
157 | } // namespace scf |
158 | } // namespace mlir |
159 | |
160 | void mlir::scf::registerValueBoundsOpInterfaceExternalModels( |
161 | DialectRegistry ®istry) { |
162 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) { |
163 | scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx); |
164 | scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx); |
165 | }); |
166 | } |
167 | |