| 1 | //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" |
| 10 | |
| 11 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 12 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| 13 | |
| 14 | using namespace mlir; |
| 15 | |
| 16 | namespace mlir { |
| 17 | namespace scf { |
| 18 | namespace { |
| 19 | |
| 20 | struct ForOpInterface |
| 21 | : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> { |
| 22 | |
| 23 | static AffineExpr getTripCountExpr(scf::ForOp forOp, |
| 24 | ValueBoundsConstraintSet &cstr) { |
| 25 | AffineExpr lbExpr = cstr.getExpr(forOp.getLowerBound()); |
| 26 | AffineExpr ubExpr = cstr.getExpr(forOp.getUpperBound()); |
| 27 | AffineExpr stepExpr = cstr.getExpr(forOp.getStep()); |
| 28 | AffineExpr tripCountExpr = |
| 29 | AffineExpr(ubExpr - lbExpr).ceilDiv(other: stepExpr); // (ub - lb) / step |
| 30 | return tripCountExpr; |
| 31 | } |
| 32 | |
| 33 | /// Populate bounds of values/dimensions for iter_args/OpResults. If the |
| 34 | /// value/dimension size does not change in an iteration, we can deduce that |
| 35 | /// it the same as the initial value/dimension. |
| 36 | /// |
| 37 | /// Example 1: |
| 38 | /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> { |
| 39 | /// ... |
| 40 | /// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32> |
| 41 | /// scf.yield %1 : tensor<?xf32> |
| 42 | /// } |
| 43 | /// --> bound(%0)[0] == bound(%t)[0] |
| 44 | /// --> bound(%arg0)[0] == bound(%t)[0] |
| 45 | /// |
| 46 | /// Example 2: |
| 47 | /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> { |
| 48 | /// %sz = tensor.dim %arg0 : tensor<?xf32> |
| 49 | /// %incr = arith.addi %sz, %c1 : index |
| 50 | /// %1 = tensor.empty(%incr) : tensor<?xf32> |
| 51 | /// scf.yield %1 : tensor<?xf32> |
| 52 | /// } |
| 53 | /// --> The yielded tensor dimension size changes with each iteration. Such |
| 54 | /// loops are not supported and no constraints are added. |
| 55 | static void populateIterArgBounds(scf::ForOp forOp, Value value, |
| 56 | std::optional<int64_t> dim, |
| 57 | ValueBoundsConstraintSet &cstr) { |
| 58 | // `value` is an iter_arg or an OpResult. |
| 59 | int64_t iterArgIdx; |
| 60 | if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) { |
| 61 | iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars(); |
| 62 | } else { |
| 63 | iterArgIdx = llvm::cast<OpResult>(value).getResultNumber(); |
| 64 | } |
| 65 | |
| 66 | Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator()) |
| 67 | .getOperand(iterArgIdx); |
| 68 | Value iterArg = forOp.getRegionIterArg(iterArgIdx); |
| 69 | Value initArg = forOp.getInitArgs()[iterArgIdx]; |
| 70 | |
| 71 | // An EQ constraint can be added if the yielded value (dimension size) |
| 72 | // equals the corresponding block argument (dimension size). |
| 73 | if (cstr.populateAndCompare( |
| 74 | /*lhs=*/{yieldedValue, dim}, |
| 75 | cmp: ValueBoundsConstraintSet::ComparisonOperator::EQ, |
| 76 | /*rhs=*/{iterArg, dim})) { |
| 77 | if (dim.has_value()) { |
| 78 | cstr.bound(value)[*dim] == cstr.getExpr(value: initArg, dim); |
| 79 | } else { |
| 80 | cstr.bound(value) == cstr.getExpr(value: initArg); |
| 81 | } |
| 82 | } |
| 83 | |
| 84 | if (dim.has_value() || isa<BlockArgument>(value)) |
| 85 | return; |
| 86 | |
| 87 | // `value` is result of `forOp`, we can prove that: |
| 88 | // %result == %init_arg + trip_count * (%yielded_value - %iter_arg). |
| 89 | // Where trip_count is (ub - lb) / step. |
| 90 | AffineExpr tripCountExpr = getTripCountExpr(forOp, cstr); |
| 91 | AffineExpr oneIterAdvanceExpr = |
| 92 | cstr.getExpr(value: yieldedValue) - cstr.getExpr(value: iterArg); |
| 93 | cstr.bound(value) == |
| 94 | cstr.getExpr(value: initArg) + AffineExpr(tripCountExpr * oneIterAdvanceExpr); |
| 95 | } |
| 96 | |
| 97 | void populateBoundsForIndexValue(Operation *op, Value value, |
| 98 | ValueBoundsConstraintSet &cstr) const { |
| 99 | auto forOp = cast<ForOp>(op); |
| 100 | |
| 101 | if (value == forOp.getInductionVar()) { |
| 102 | cstr.bound(value) >= forOp.getLowerBound(); |
| 103 | cstr.bound(value) < forOp.getUpperBound(); |
| 104 | // iv <= lb + ((ub-lb)/step - 1) * step |
| 105 | // This bound does not replace the `iv < ub` constraint mentioned above, |
| 106 | // since constraints involving the multiplication of two constraint set |
| 107 | // dimensions are not supported. |
| 108 | AffineExpr tripCountMinusOne = |
| 109 | getTripCountExpr(forOp, cstr) - cstr.getExpr(constant: 1); |
| 110 | AffineExpr computedUpperBound = |
| 111 | cstr.getExpr(forOp.getLowerBound()) + |
| 112 | AffineExpr(tripCountMinusOne * cstr.getExpr(forOp.getStep())); |
| 113 | cstr.bound(value) <= computedUpperBound; |
| 114 | return; |
| 115 | } |
| 116 | |
| 117 | // Handle iter_args and OpResults. |
| 118 | populateIterArgBounds(forOp, value, std::nullopt, cstr); |
| 119 | } |
| 120 | |
| 121 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
| 122 | ValueBoundsConstraintSet &cstr) const { |
| 123 | auto forOp = cast<ForOp>(op); |
| 124 | // Handle iter_args and OpResults. |
| 125 | populateIterArgBounds(forOp, value, dim, cstr); |
| 126 | } |
| 127 | }; |
| 128 | |
| 129 | struct ForallOpInterface |
| 130 | : public ValueBoundsOpInterface::ExternalModel<ForallOpInterface, |
| 131 | ForallOp> { |
| 132 | |
| 133 | void populateBoundsForIndexValue(Operation *op, Value value, |
| 134 | ValueBoundsConstraintSet &cstr) const { |
| 135 | auto forallOp = cast<ForallOp>(op); |
| 136 | |
| 137 | // Index values should be induction variables, since the semantics of |
| 138 | // tensor::ParallelInsertSliceOp requires forall outputs to be ranked |
| 139 | // tensors. |
| 140 | auto blockArg = cast<BlockArgument>(value); |
| 141 | assert(blockArg.getArgNumber() < forallOp.getInductionVars().size() && |
| 142 | "expected index value to be an induction var" ); |
| 143 | int64_t idx = blockArg.getArgNumber(); |
| 144 | // TODO: Take into account step size. |
| 145 | AffineExpr lb = cstr.getExpr(forallOp.getMixedLowerBound()[idx]); |
| 146 | AffineExpr ub = cstr.getExpr(forallOp.getMixedUpperBound()[idx]); |
| 147 | cstr.bound(value) >= lb; |
| 148 | cstr.bound(value) < ub; |
| 149 | } |
| 150 | |
| 151 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
| 152 | ValueBoundsConstraintSet &cstr) const { |
| 153 | auto forallOp = cast<ForallOp>(op); |
| 154 | |
| 155 | // `value` is an iter_arg or an OpResult. |
| 156 | int64_t iterArgIdx; |
| 157 | if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) { |
| 158 | iterArgIdx = iterArg.getArgNumber() - forallOp.getInductionVars().size(); |
| 159 | } else { |
| 160 | iterArgIdx = llvm::cast<OpResult>(value).getResultNumber(); |
| 161 | } |
| 162 | |
| 163 | // The forall results and output arguments have the same sizes as the output |
| 164 | // operands. |
| 165 | Value outputOperand = forallOp.getOutputs()[iterArgIdx]; |
| 166 | cstr.bound(value)[dim] == cstr.getExpr(outputOperand, dim); |
| 167 | } |
| 168 | }; |
| 169 | |
| 170 | struct IfOpInterface |
| 171 | : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> { |
| 172 | |
| 173 | static void populateBounds(scf::IfOp ifOp, Value value, |
| 174 | std::optional<int64_t> dim, |
| 175 | ValueBoundsConstraintSet &cstr) { |
| 176 | unsigned int resultNum = cast<OpResult>(value).getResultNumber(); |
| 177 | Value thenValue = ifOp.thenYield().getResults()[resultNum]; |
| 178 | Value elseValue = ifOp.elseYield().getResults()[resultNum]; |
| 179 | |
| 180 | auto boundsBuilder = cstr.bound(value); |
| 181 | if (dim) |
| 182 | boundsBuilder[*dim]; |
| 183 | |
| 184 | // Compare yielded values. |
| 185 | // If thenValue <= elseValue: |
| 186 | // * result <= elseValue |
| 187 | // * result >= thenValue |
| 188 | if (cstr.populateAndCompare( |
| 189 | /*lhs=*/{thenValue, dim}, |
| 190 | cmp: ValueBoundsConstraintSet::ComparisonOperator::LE, |
| 191 | /*rhs=*/{elseValue, dim})) { |
| 192 | if (dim) { |
| 193 | cstr.bound(value)[*dim] >= cstr.getExpr(value: thenValue, dim); |
| 194 | cstr.bound(value)[*dim] <= cstr.getExpr(value: elseValue, dim); |
| 195 | } else { |
| 196 | cstr.bound(value) >= thenValue; |
| 197 | cstr.bound(value) <= elseValue; |
| 198 | } |
| 199 | } |
| 200 | // If elseValue <= thenValue: |
| 201 | // * result <= thenValue |
| 202 | // * result >= elseValue |
| 203 | if (cstr.populateAndCompare( |
| 204 | /*lhs=*/{elseValue, dim}, |
| 205 | cmp: ValueBoundsConstraintSet::ComparisonOperator::LE, |
| 206 | /*rhs=*/{thenValue, dim})) { |
| 207 | if (dim) { |
| 208 | cstr.bound(value)[*dim] >= cstr.getExpr(value: elseValue, dim); |
| 209 | cstr.bound(value)[*dim] <= cstr.getExpr(value: thenValue, dim); |
| 210 | } else { |
| 211 | cstr.bound(value) >= elseValue; |
| 212 | cstr.bound(value) <= thenValue; |
| 213 | } |
| 214 | } |
| 215 | } |
| 216 | |
| 217 | void populateBoundsForIndexValue(Operation *op, Value value, |
| 218 | ValueBoundsConstraintSet &cstr) const { |
| 219 | populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr); |
| 220 | } |
| 221 | |
| 222 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
| 223 | ValueBoundsConstraintSet &cstr) const { |
| 224 | populateBounds(cast<IfOp>(op), value, dim, cstr); |
| 225 | } |
| 226 | }; |
| 227 | |
| 228 | } // namespace |
| 229 | } // namespace scf |
| 230 | } // namespace mlir |
| 231 | |
| 232 | void mlir::scf::registerValueBoundsOpInterfaceExternalModels( |
| 233 | DialectRegistry ®istry) { |
| 234 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, scf::SCFDialect *dialect) { |
| 235 | scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx); |
| 236 | scf::ForallOp::attachInterface<scf::ForallOpInterface>(*ctx); |
| 237 | scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx); |
| 238 | }); |
| 239 | } |
| 240 | |