1//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Interfaces/ValueBoundsOpInterface.h"
13
14using namespace mlir;
15
16namespace mlir {
17namespace arith {
18namespace {
19
20struct AddIOpInterface
21 : public ValueBoundsOpInterface::ExternalModel<AddIOpInterface, AddIOp> {
22 void populateBoundsForIndexValue(Operation *op, Value value,
23 ValueBoundsConstraintSet &cstr) const {
24 auto addIOp = cast<AddIOp>(op);
25 assert(value == addIOp.getResult() && "invalid value");
26
27 // Note: `getExpr` has a side effect: it may add a new column to the
28 // constraint system. The evaluation order of addition operands is
29 // unspecified in C++. To make sure that all compilers produce the exact
30 // same results (that can be FileCheck'd), it is important that `getExpr`
31 // is called first and assigned to temporary variables, and the addition
32 // is performed afterwards.
33 AffineExpr lhs = cstr.getExpr(addIOp.getLhs());
34 AffineExpr rhs = cstr.getExpr(addIOp.getRhs());
35 cstr.bound(value) == lhs + rhs;
36 }
37};
38
39struct ConstantOpInterface
40 : public ValueBoundsOpInterface::ExternalModel<ConstantOpInterface,
41 ConstantOp> {
42 void populateBoundsForIndexValue(Operation *op, Value value,
43 ValueBoundsConstraintSet &cstr) const {
44 auto constantOp = cast<ConstantOp>(op);
45 assert(value == constantOp.getResult() && "invalid value");
46
47 if (auto attr = llvm::dyn_cast<IntegerAttr>(constantOp.getValue()))
48 cstr.bound(value) == attr.getInt();
49 }
50};
51
52struct SubIOpInterface
53 : public ValueBoundsOpInterface::ExternalModel<SubIOpInterface, SubIOp> {
54 void populateBoundsForIndexValue(Operation *op, Value value,
55 ValueBoundsConstraintSet &cstr) const {
56 auto subIOp = cast<SubIOp>(op);
57 assert(value == subIOp.getResult() && "invalid value");
58
59 AffineExpr lhs = cstr.getExpr(subIOp.getLhs());
60 AffineExpr rhs = cstr.getExpr(subIOp.getRhs());
61 cstr.bound(value) == lhs - rhs;
62 }
63};
64
65struct MulIOpInterface
66 : public ValueBoundsOpInterface::ExternalModel<MulIOpInterface, MulIOp> {
67 void populateBoundsForIndexValue(Operation *op, Value value,
68 ValueBoundsConstraintSet &cstr) const {
69 auto mulIOp = cast<MulIOp>(op);
70 assert(value == mulIOp.getResult() && "invalid value");
71
72 AffineExpr lhs = cstr.getExpr(mulIOp.getLhs());
73 AffineExpr rhs = cstr.getExpr(mulIOp.getRhs());
74 cstr.bound(value) == lhs *rhs;
75 }
76};
77
78struct SelectOpInterface
79 : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
80 SelectOp> {
81
82 static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
83 ValueBoundsConstraintSet &cstr) {
84 Value value = selectOp.getResult();
85 Value condition = selectOp.getCondition();
86 Value trueValue = selectOp.getTrueValue();
87 Value falseValue = selectOp.getFalseValue();
88
89 if (isa<ShapedType>(condition.getType())) {
90 // If the condition is a shaped type, the condition is applied
91 // element-wise. All three operands must have the same shape.
92 cstr.bound(value)[*dim] == cstr.getExpr(value: trueValue, dim);
93 cstr.bound(value)[*dim] == cstr.getExpr(value: falseValue, dim);
94 cstr.bound(value)[*dim] == cstr.getExpr(value: condition, dim);
95 return;
96 }
97
98 // Populate constraints for the true/false values (and all values on the
99 // backward slice, as long as the current stop condition is not satisfied).
100 cstr.populateConstraints(value: trueValue, dim);
101 cstr.populateConstraints(value: falseValue, dim);
102 auto boundsBuilder = cstr.bound(value);
103 if (dim)
104 boundsBuilder[*dim];
105
106 // Compare yielded values.
107 // If trueValue <= falseValue:
108 // * result <= falseValue
109 // * result >= trueValue
110 if (cstr.compare(/*lhs=*/{trueValue, dim},
111 cmp: ValueBoundsConstraintSet::ComparisonOperator::LE,
112 /*rhs=*/{falseValue, dim})) {
113 if (dim) {
114 cstr.bound(value)[*dim] >= cstr.getExpr(value: trueValue, dim);
115 cstr.bound(value)[*dim] <= cstr.getExpr(value: falseValue, dim);
116 } else {
117 cstr.bound(value) >= trueValue;
118 cstr.bound(value) <= falseValue;
119 }
120 }
121 // If falseValue <= trueValue:
122 // * result <= trueValue
123 // * result >= falseValue
124 if (cstr.compare(/*lhs=*/{falseValue, dim},
125 cmp: ValueBoundsConstraintSet::ComparisonOperator::LE,
126 /*rhs=*/{trueValue, dim})) {
127 if (dim) {
128 cstr.bound(value)[*dim] >= cstr.getExpr(value: falseValue, dim);
129 cstr.bound(value)[*dim] <= cstr.getExpr(value: trueValue, dim);
130 } else {
131 cstr.bound(value) >= falseValue;
132 cstr.bound(value) <= trueValue;
133 }
134 }
135 }
136
137 void populateBoundsForIndexValue(Operation *op, Value value,
138 ValueBoundsConstraintSet &cstr) const {
139 populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
140 }
141
142 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
143 ValueBoundsConstraintSet &cstr) const {
144 populateBounds(cast<SelectOp>(op), dim, cstr);
145 }
146};
147} // namespace
148} // namespace arith
149} // namespace mlir
150
151void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
152 DialectRegistry &registry) {
153 registry.addExtension(extensionFn: +[](MLIRContext *ctx, arith::ArithDialect *dialect) {
154 arith::AddIOp::attachInterface<arith::AddIOpInterface>(*ctx);
155 arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
156 arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
157 arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
158 arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
159 });
160}
161

source code of mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp