1//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Interfaces/ValueBoundsOpInterface.h"
13
14using namespace mlir;
15
16namespace mlir {
17namespace arith {
18namespace {
19
20struct AddIOpInterface
21 : public ValueBoundsOpInterface::ExternalModel<AddIOpInterface, AddIOp> {
22 void populateBoundsForIndexValue(Operation *op, Value value,
23 ValueBoundsConstraintSet &cstr) const {
24 auto addIOp = cast<AddIOp>(op);
25 assert(value == addIOp.getResult() && "invalid value");
26
27 // Note: `getExpr` has a side effect: it may add a new column to the
28 // constraint system. The evaluation order of addition operands is
29 // unspecified in C++. To make sure that all compilers produce the exact
30 // same results (that can be FileCheck'd), it is important that `getExpr`
31 // is called first and assigned to temporary variables, and the addition
32 // is performed afterwards.
33 AffineExpr lhs = cstr.getExpr(addIOp.getLhs());
34 AffineExpr rhs = cstr.getExpr(addIOp.getRhs());
35 cstr.bound(value) == lhs + rhs;
36 }
37};
38
39struct ConstantOpInterface
40 : public ValueBoundsOpInterface::ExternalModel<ConstantOpInterface,
41 ConstantOp> {
42 void populateBoundsForIndexValue(Operation *op, Value value,
43 ValueBoundsConstraintSet &cstr) const {
44 auto constantOp = cast<ConstantOp>(op);
45 assert(value == constantOp.getResult() && "invalid value");
46
47 if (auto attr = llvm::dyn_cast<IntegerAttr>(constantOp.getValue()))
48 cstr.bound(value) == attr.getInt();
49 }
50};
51
52struct SubIOpInterface
53 : public ValueBoundsOpInterface::ExternalModel<SubIOpInterface, SubIOp> {
54 void populateBoundsForIndexValue(Operation *op, Value value,
55 ValueBoundsConstraintSet &cstr) const {
56 auto subIOp = cast<SubIOp>(op);
57 assert(value == subIOp.getResult() && "invalid value");
58
59 AffineExpr lhs = cstr.getExpr(subIOp.getLhs());
60 AffineExpr rhs = cstr.getExpr(subIOp.getRhs());
61 cstr.bound(value) == lhs - rhs;
62 }
63};
64
65struct MulIOpInterface
66 : public ValueBoundsOpInterface::ExternalModel<MulIOpInterface, MulIOp> {
67 void populateBoundsForIndexValue(Operation *op, Value value,
68 ValueBoundsConstraintSet &cstr) const {
69 auto mulIOp = cast<MulIOp>(op);
70 assert(value == mulIOp.getResult() && "invalid value");
71
72 AffineExpr lhs = cstr.getExpr(mulIOp.getLhs());
73 AffineExpr rhs = cstr.getExpr(mulIOp.getRhs());
74 cstr.bound(value) == lhs *rhs;
75 }
76};
77
78struct FloorDivSIOpInterface
79 : public ValueBoundsOpInterface::ExternalModel<FloorDivSIOpInterface,
80 FloorDivSIOp> {
81 void populateBoundsForIndexValue(Operation *op, Value value,
82 ValueBoundsConstraintSet &cstr) const {
83 auto divSIOp = cast<FloorDivSIOp>(op);
84 assert(value == divSIOp.getResult() && "invalid value");
85
86 AffineExpr lhs = cstr.getExpr(divSIOp.getLhs());
87 AffineExpr rhs = cstr.getExpr(divSIOp.getRhs());
88 cstr.bound(value) == lhs.floorDiv(other: rhs);
89 }
90};
91
92struct SelectOpInterface
93 : public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
94 SelectOp> {
95
96 static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
97 ValueBoundsConstraintSet &cstr) {
98 Value value = selectOp.getResult();
99 Value condition = selectOp.getCondition();
100 Value trueValue = selectOp.getTrueValue();
101 Value falseValue = selectOp.getFalseValue();
102
103 if (isa<ShapedType>(condition.getType())) {
104 // If the condition is a shaped type, the condition is applied
105 // element-wise. All three operands must have the same shape.
106 cstr.bound(value)[*dim] == cstr.getExpr(value: trueValue, dim);
107 cstr.bound(value)[*dim] == cstr.getExpr(value: falseValue, dim);
108 cstr.bound(value)[*dim] == cstr.getExpr(value: condition, dim);
109 return;
110 }
111
112 // Populate constraints for the true/false values (and all values on the
113 // backward slice, as long as the current stop condition is not satisfied).
114 cstr.populateConstraints(value: trueValue, dim);
115 cstr.populateConstraints(value: falseValue, dim);
116 auto boundsBuilder = cstr.bound(value);
117 if (dim)
118 boundsBuilder[*dim];
119
120 // Compare yielded values.
121 // If trueValue <= falseValue:
122 // * result <= falseValue
123 // * result >= trueValue
124 if (cstr.populateAndCompare(
125 /*lhs=*/{trueValue, dim},
126 cmp: ValueBoundsConstraintSet::ComparisonOperator::LE,
127 /*rhs=*/{falseValue, dim})) {
128 if (dim) {
129 cstr.bound(value)[*dim] >= cstr.getExpr(value: trueValue, dim);
130 cstr.bound(value)[*dim] <= cstr.getExpr(value: falseValue, dim);
131 } else {
132 cstr.bound(value) >= trueValue;
133 cstr.bound(value) <= falseValue;
134 }
135 }
136 // If falseValue <= trueValue:
137 // * result <= trueValue
138 // * result >= falseValue
139 if (cstr.populateAndCompare(
140 /*lhs=*/{falseValue, dim},
141 cmp: ValueBoundsConstraintSet::ComparisonOperator::LE,
142 /*rhs=*/{trueValue, dim})) {
143 if (dim) {
144 cstr.bound(value)[*dim] >= cstr.getExpr(value: falseValue, dim);
145 cstr.bound(value)[*dim] <= cstr.getExpr(value: trueValue, dim);
146 } else {
147 cstr.bound(value) >= falseValue;
148 cstr.bound(value) <= trueValue;
149 }
150 }
151 }
152
153 void populateBoundsForIndexValue(Operation *op, Value value,
154 ValueBoundsConstraintSet &cstr) const {
155 populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
156 }
157
158 void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
159 ValueBoundsConstraintSet &cstr) const {
160 populateBounds(cast<SelectOp>(op), dim, cstr);
161 }
162};
163} // namespace
164} // namespace arith
165} // namespace mlir
166
167void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
168 DialectRegistry &registry) {
169 registry.addExtension(extensionFn: +[](MLIRContext *ctx, arith::ArithDialect *dialect) {
170 arith::AddIOp::attachInterface<arith::AddIOpInterface>(*ctx);
171 arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
172 arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
173 arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
174 arith::FloorDivSIOp::attachInterface<arith::FloorDivSIOpInterface>(*ctx);
175 arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
176 });
177}
178

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp