1 | //===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" |
10 | |
11 | #include "mlir/Dialect/Arith/IR/Arith.h" |
12 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | namespace mlir { |
17 | namespace arith { |
18 | namespace { |
19 | |
20 | struct AddIOpInterface |
21 | : public ValueBoundsOpInterface::ExternalModel<AddIOpInterface, AddIOp> { |
22 | void populateBoundsForIndexValue(Operation *op, Value value, |
23 | ValueBoundsConstraintSet &cstr) const { |
24 | auto addIOp = cast<AddIOp>(op); |
25 | assert(value == addIOp.getResult() && "invalid value" ); |
26 | |
27 | // Note: `getExpr` has a side effect: it may add a new column to the |
28 | // constraint system. The evaluation order of addition operands is |
29 | // unspecified in C++. To make sure that all compilers produce the exact |
30 | // same results (that can be FileCheck'd), it is important that `getExpr` |
31 | // is called first and assigned to temporary variables, and the addition |
32 | // is performed afterwards. |
33 | AffineExpr lhs = cstr.getExpr(addIOp.getLhs()); |
34 | AffineExpr rhs = cstr.getExpr(addIOp.getRhs()); |
35 | cstr.bound(value) == lhs + rhs; |
36 | } |
37 | }; |
38 | |
39 | struct ConstantOpInterface |
40 | : public ValueBoundsOpInterface::ExternalModel<ConstantOpInterface, |
41 | ConstantOp> { |
42 | void populateBoundsForIndexValue(Operation *op, Value value, |
43 | ValueBoundsConstraintSet &cstr) const { |
44 | auto constantOp = cast<ConstantOp>(op); |
45 | assert(value == constantOp.getResult() && "invalid value" ); |
46 | |
47 | if (auto attr = llvm::dyn_cast<IntegerAttr>(constantOp.getValue())) |
48 | cstr.bound(value) == attr.getInt(); |
49 | } |
50 | }; |
51 | |
52 | struct SubIOpInterface |
53 | : public ValueBoundsOpInterface::ExternalModel<SubIOpInterface, SubIOp> { |
54 | void populateBoundsForIndexValue(Operation *op, Value value, |
55 | ValueBoundsConstraintSet &cstr) const { |
56 | auto subIOp = cast<SubIOp>(op); |
57 | assert(value == subIOp.getResult() && "invalid value" ); |
58 | |
59 | AffineExpr lhs = cstr.getExpr(subIOp.getLhs()); |
60 | AffineExpr rhs = cstr.getExpr(subIOp.getRhs()); |
61 | cstr.bound(value) == lhs - rhs; |
62 | } |
63 | }; |
64 | |
65 | struct MulIOpInterface |
66 | : public ValueBoundsOpInterface::ExternalModel<MulIOpInterface, MulIOp> { |
67 | void populateBoundsForIndexValue(Operation *op, Value value, |
68 | ValueBoundsConstraintSet &cstr) const { |
69 | auto mulIOp = cast<MulIOp>(op); |
70 | assert(value == mulIOp.getResult() && "invalid value" ); |
71 | |
72 | AffineExpr lhs = cstr.getExpr(mulIOp.getLhs()); |
73 | AffineExpr rhs = cstr.getExpr(mulIOp.getRhs()); |
74 | cstr.bound(value) == lhs *rhs; |
75 | } |
76 | }; |
77 | |
78 | struct SelectOpInterface |
79 | : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface, |
80 | SelectOp> { |
81 | |
82 | static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim, |
83 | ValueBoundsConstraintSet &cstr) { |
84 | Value value = selectOp.getResult(); |
85 | Value condition = selectOp.getCondition(); |
86 | Value trueValue = selectOp.getTrueValue(); |
87 | Value falseValue = selectOp.getFalseValue(); |
88 | |
89 | if (isa<ShapedType>(condition.getType())) { |
90 | // If the condition is a shaped type, the condition is applied |
91 | // element-wise. All three operands must have the same shape. |
92 | cstr.bound(value)[*dim] == cstr.getExpr(value: trueValue, dim); |
93 | cstr.bound(value)[*dim] == cstr.getExpr(value: falseValue, dim); |
94 | cstr.bound(value)[*dim] == cstr.getExpr(value: condition, dim); |
95 | return; |
96 | } |
97 | |
98 | // Populate constraints for the true/false values (and all values on the |
99 | // backward slice, as long as the current stop condition is not satisfied). |
100 | cstr.populateConstraints(value: trueValue, dim); |
101 | cstr.populateConstraints(value: falseValue, dim); |
102 | auto boundsBuilder = cstr.bound(value); |
103 | if (dim) |
104 | boundsBuilder[*dim]; |
105 | |
106 | // Compare yielded values. |
107 | // If trueValue <= falseValue: |
108 | // * result <= falseValue |
109 | // * result >= trueValue |
110 | if (cstr.compare(/*lhs=*/{trueValue, dim}, |
111 | cmp: ValueBoundsConstraintSet::ComparisonOperator::LE, |
112 | /*rhs=*/{falseValue, dim})) { |
113 | if (dim) { |
114 | cstr.bound(value)[*dim] >= cstr.getExpr(value: trueValue, dim); |
115 | cstr.bound(value)[*dim] <= cstr.getExpr(value: falseValue, dim); |
116 | } else { |
117 | cstr.bound(value) >= trueValue; |
118 | cstr.bound(value) <= falseValue; |
119 | } |
120 | } |
121 | // If falseValue <= trueValue: |
122 | // * result <= trueValue |
123 | // * result >= falseValue |
124 | if (cstr.compare(/*lhs=*/{falseValue, dim}, |
125 | cmp: ValueBoundsConstraintSet::ComparisonOperator::LE, |
126 | /*rhs=*/{trueValue, dim})) { |
127 | if (dim) { |
128 | cstr.bound(value)[*dim] >= cstr.getExpr(value: falseValue, dim); |
129 | cstr.bound(value)[*dim] <= cstr.getExpr(value: trueValue, dim); |
130 | } else { |
131 | cstr.bound(value) >= falseValue; |
132 | cstr.bound(value) <= trueValue; |
133 | } |
134 | } |
135 | } |
136 | |
137 | void populateBoundsForIndexValue(Operation *op, Value value, |
138 | ValueBoundsConstraintSet &cstr) const { |
139 | populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr); |
140 | } |
141 | |
142 | void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, |
143 | ValueBoundsConstraintSet &cstr) const { |
144 | populateBounds(cast<SelectOp>(op), dim, cstr); |
145 | } |
146 | }; |
147 | } // namespace |
148 | } // namespace arith |
149 | } // namespace mlir |
150 | |
151 | void mlir::arith::registerValueBoundsOpInterfaceExternalModels( |
152 | DialectRegistry ®istry) { |
153 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, arith::ArithDialect *dialect) { |
154 | arith::AddIOp::attachInterface<arith::AddIOpInterface>(*ctx); |
155 | arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx); |
156 | arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx); |
157 | arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx); |
158 | arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx); |
159 | }); |
160 | } |
161 | |