1//===- ReifyValueBounds.cpp --- Reify value bounds with arith ops -------*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/Transforms/Transforms.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/Dialect/Tensor/IR/Tensor.h"
14#include "mlir/Interfaces/ValueBoundsOpInterface.h"
15
16using namespace mlir;
17using namespace mlir::arith;
18
19/// Build Arith IR for the given affine map and its operands.
20static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
21 ValueRange operands) {
22 assert(map.getNumResults() == 1 && "multiple results not supported yet");
23 std::function<Value(AffineExpr)> buildExpr = [&](AffineExpr e) -> Value {
24 switch (e.getKind()) {
25 case AffineExprKind::Constant:
26 return b.create<ConstantIndexOp>(loc,
27 cast<AffineConstantExpr>(e).getValue());
28 case AffineExprKind::DimId:
29 return operands[cast<AffineDimExpr>(Val&: e).getPosition()];
30 case AffineExprKind::SymbolId:
31 return operands[cast<AffineSymbolExpr>(Val&: e).getPosition() +
32 map.getNumDims()];
33 case AffineExprKind::Add: {
34 auto binaryExpr = cast<AffineBinaryOpExpr>(Val&: e);
35 return b.create<AddIOp>(loc, buildExpr(binaryExpr.getLHS()),
36 buildExpr(binaryExpr.getRHS()));
37 }
38 case AffineExprKind::Mul: {
39 auto binaryExpr = cast<AffineBinaryOpExpr>(Val&: e);
40 return b.create<MulIOp>(loc, buildExpr(binaryExpr.getLHS()),
41 buildExpr(binaryExpr.getRHS()));
42 }
43 case AffineExprKind::FloorDiv: {
44 auto binaryExpr = cast<AffineBinaryOpExpr>(Val&: e);
45 return b.create<DivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
46 buildExpr(binaryExpr.getRHS()));
47 }
48 case AffineExprKind::CeilDiv: {
49 auto binaryExpr = cast<AffineBinaryOpExpr>(Val&: e);
50 return b.create<CeilDivSIOp>(loc, buildExpr(binaryExpr.getLHS()),
51 buildExpr(binaryExpr.getRHS()));
52 }
53 case AffineExprKind::Mod: {
54 auto binaryExpr = cast<AffineBinaryOpExpr>(Val&: e);
55 return b.create<RemSIOp>(loc, buildExpr(binaryExpr.getLHS()),
56 buildExpr(binaryExpr.getRHS()));
57 }
58 }
59 llvm_unreachable("unsupported AffineExpr kind");
60 };
61 return buildExpr(map.getResult(idx: 0));
62}
63
64FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
65 OpBuilder &b, Location loc, presburger::BoundType type,
66 const ValueBoundsConstraintSet::Variable &var,
67 ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
68 // Compute bound.
69 AffineMap boundMap;
70 ValueDimList mapOperands;
71 if (failed(result: ValueBoundsConstraintSet::computeBound(
72 resultMap&: boundMap, mapOperands, type, var, stopCondition, closedUB)))
73 return failure();
74
75 // Materialize tensor.dim/memref.dim ops.
76 SmallVector<Value> operands;
77 for (auto valueDim : mapOperands) {
78 Value value = valueDim.first;
79 std::optional<int64_t> dim = valueDim.second;
80
81 if (!dim.has_value()) {
82 // This is an index-typed value.
83 assert(value.getType().isIndex() && "expected index type");
84 operands.push_back(Elt: value);
85 continue;
86 }
87
88 assert(cast<ShapedType>(value.getType()).isDynamicDim(*dim) &&
89 "expected dynamic dim");
90 if (isa<RankedTensorType>(Val: value.getType())) {
91 // A tensor dimension is used: generate a tensor.dim.
92 operands.push_back(b.create<tensor::DimOp>(loc, value, *dim));
93 } else if (isa<MemRefType>(Val: value.getType())) {
94 // A memref dimension is used: generate a memref.dim.
95 operands.push_back(b.create<memref::DimOp>(loc, value, *dim));
96 } else {
97 llvm_unreachable("cannot generate DimOp for unsupported shaped type");
98 }
99 }
100
101 // Check for special cases where no arith ops are needed.
102 if (boundMap.isSingleConstant()) {
103 // Bound is a constant: return an IntegerAttr.
104 return static_cast<OpFoldResult>(
105 b.getIndexAttr(boundMap.getSingleConstantResult()));
106 }
107 // No arith ops are needed if the bound is a single SSA value.
108 if (auto expr = dyn_cast<AffineDimExpr>(Val: boundMap.getResult(idx: 0)))
109 return static_cast<OpFoldResult>(operands[expr.getPosition()]);
110 if (auto expr = dyn_cast<AffineSymbolExpr>(Val: boundMap.getResult(idx: 0)))
111 return static_cast<OpFoldResult>(
112 operands[expr.getPosition() + boundMap.getNumDims()]);
113 // General case: build Arith ops.
114 return static_cast<OpFoldResult>(buildArithValue(b, loc, map: boundMap, operands));
115}
116
117FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
118 OpBuilder &b, Location loc, presburger::BoundType type, Value value,
119 int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
120 bool closedUB) {
121 auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
122 ValueBoundsConstraintSet &cstr) {
123 // We are trying to reify a bound for `value` in terms of the owning op's
124 // operands. Construct a stop condition that evaluates to "true" for any SSA
125 // value expect for `value`. I.e., the bound will be computed in terms of
126 // any SSA values expect for `value`. The first such values are operands of
127 // the owner of `value`.
128 return v != value;
129 };
130 return reifyValueBound(b, loc, type, var: {value, dim},
131 stopCondition: stopCondition ? stopCondition : reifyToOperands,
132 closedUB);
133}
134
135FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
136 OpBuilder &b, Location loc, presburger::BoundType type, Value value,
137 ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
138 auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
139 ValueBoundsConstraintSet &cstr) {
140 return v != value;
141 };
142 return reifyValueBound(b, loc, type, var: value,
143 stopCondition: stopCondition ? stopCondition : reifyToOperands,
144 closedUB);
145}
146

source code of mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp