1 | //===- InstructionCost.h ----------------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | /// \file |
9 | /// This file defines an InstructionCost class that is used when calculating |
10 | /// the cost of an instruction, or a group of instructions. In addition to a |
11 | /// numeric value representing the cost the class also contains a state that |
12 | /// can be used to encode particular properties, such as a cost being invalid. |
13 | /// Operations on InstructionCost implement saturation arithmetic, so that |
14 | /// accumulating costs on large cost-values don't overflow. |
15 | /// |
16 | //===----------------------------------------------------------------------===// |
17 | |
18 | #ifndef LLVM_SUPPORT_INSTRUCTIONCOST_H |
19 | #define LLVM_SUPPORT_INSTRUCTIONCOST_H |
20 | |
21 | #include "llvm/Support/MathExtras.h" |
22 | #include <limits> |
23 | #include <optional> |
24 | |
25 | namespace llvm { |
26 | |
27 | class raw_ostream; |
28 | |
29 | class InstructionCost { |
30 | public: |
31 | using CostType = int64_t; |
32 | |
33 | /// CostState describes the state of a cost. |
34 | enum CostState { |
35 | Valid, /// < The cost value represents a valid cost, even when the |
36 | /// cost-value is large. |
37 | Invalid /// < Invalid indicates there is no way to represent the cost as a |
38 | /// numeric value. This state exists to represent a possible issue, |
39 | /// e.g. if the cost-model knows the operation cannot be expanded |
40 | /// into a valid code-sequence by the code-generator. While some |
41 | /// passes may assert that the calculated cost must be valid, it is |
42 | /// up to individual passes how to interpret an Invalid cost. For |
43 | /// example, a transformation pass could choose not to perform a |
44 | /// transformation if the resulting cost would end up Invalid. |
45 | /// Because some passes may assert a cost is Valid, it is not |
46 | /// recommended to use Invalid costs to model 'Unknown'. |
47 | /// Note that Invalid is semantically different from a (very) high, |
48 | /// but valid cost, which intentionally indicates no issue, but |
49 | /// rather a strong preference not to select a certain operation. |
50 | }; |
51 | |
52 | private: |
53 | CostType Value = 0; |
54 | CostState State = Valid; |
55 | |
56 | void propagateState(const InstructionCost &RHS) { |
57 | if (RHS.State == Invalid) |
58 | State = Invalid; |
59 | } |
60 | |
61 | static CostType getMaxValue() { return std::numeric_limits<CostType>::max(); } |
62 | static CostType getMinValue() { return std::numeric_limits<CostType>::min(); } |
63 | |
64 | public: |
65 | // A default constructed InstructionCost is a valid zero cost |
66 | InstructionCost() = default; |
67 | |
68 | InstructionCost(CostState) = delete; |
69 | InstructionCost(CostType Val) : Value(Val), State(Valid) {} |
70 | |
71 | static InstructionCost getMax() { return getMaxValue(); } |
72 | static InstructionCost getMin() { return getMinValue(); } |
73 | static InstructionCost getInvalid(CostType Val = 0) { |
74 | InstructionCost Tmp(Val); |
75 | Tmp.setInvalid(); |
76 | return Tmp; |
77 | } |
78 | |
79 | bool isValid() const { return State == Valid; } |
80 | void setValid() { State = Valid; } |
81 | void setInvalid() { State = Invalid; } |
82 | CostState getState() const { return State; } |
83 | |
84 | /// This function is intended to be used as sparingly as possible, since the |
85 | /// class provides the full range of operator support required for arithmetic |
86 | /// and comparisons. |
87 | std::optional<CostType> getValue() const { |
88 | if (isValid()) |
89 | return Value; |
90 | return std::nullopt; |
91 | } |
92 | |
93 | /// For all of the arithmetic operators provided here any invalid state is |
94 | /// perpetuated and cannot be removed. Once a cost becomes invalid it stays |
95 | /// invalid, and it also inherits any invalid state from the RHS. |
96 | /// Arithmetic work on the actual values is implemented with saturation, |
97 | /// to avoid overflow when using more extreme cost values. |
98 | |
99 | InstructionCost &operator+=(const InstructionCost &RHS) { |
100 | propagateState(RHS); |
101 | |
102 | // Saturating addition. |
103 | InstructionCost::CostType Result; |
104 | if (AddOverflow(X: Value, Y: RHS.Value, Result)) |
105 | Result = RHS.Value > 0 ? getMaxValue() : getMinValue(); |
106 | |
107 | Value = Result; |
108 | return *this; |
109 | } |
110 | |
111 | InstructionCost &operator+=(const CostType RHS) { |
112 | InstructionCost RHS2(RHS); |
113 | *this += RHS2; |
114 | return *this; |
115 | } |
116 | |
117 | InstructionCost &operator-=(const InstructionCost &RHS) { |
118 | propagateState(RHS); |
119 | |
120 | // Saturating subtract. |
121 | InstructionCost::CostType Result; |
122 | if (SubOverflow(X: Value, Y: RHS.Value, Result)) |
123 | Result = RHS.Value > 0 ? getMinValue() : getMaxValue(); |
124 | Value = Result; |
125 | return *this; |
126 | } |
127 | |
128 | InstructionCost &operator-=(const CostType RHS) { |
129 | InstructionCost RHS2(RHS); |
130 | *this -= RHS2; |
131 | return *this; |
132 | } |
133 | |
134 | InstructionCost &operator*=(const InstructionCost &RHS) { |
135 | propagateState(RHS); |
136 | |
137 | // Saturating multiply. |
138 | InstructionCost::CostType Result; |
139 | if (MulOverflow(X: Value, Y: RHS.Value, Result)) { |
140 | if ((Value > 0 && RHS.Value > 0) || (Value < 0 && RHS.Value < 0)) |
141 | Result = getMaxValue(); |
142 | else |
143 | Result = getMinValue(); |
144 | } |
145 | |
146 | Value = Result; |
147 | return *this; |
148 | } |
149 | |
150 | InstructionCost &operator*=(const CostType RHS) { |
151 | InstructionCost RHS2(RHS); |
152 | *this *= RHS2; |
153 | return *this; |
154 | } |
155 | |
156 | InstructionCost &operator/=(const InstructionCost &RHS) { |
157 | propagateState(RHS); |
158 | Value /= RHS.Value; |
159 | return *this; |
160 | } |
161 | |
162 | InstructionCost &operator/=(const CostType RHS) { |
163 | InstructionCost RHS2(RHS); |
164 | *this /= RHS2; |
165 | return *this; |
166 | } |
167 | |
168 | InstructionCost &operator++() { |
169 | *this += 1; |
170 | return *this; |
171 | } |
172 | |
173 | InstructionCost operator++(int) { |
174 | InstructionCost Copy = *this; |
175 | ++*this; |
176 | return Copy; |
177 | } |
178 | |
179 | InstructionCost &operator--() { |
180 | *this -= 1; |
181 | return *this; |
182 | } |
183 | |
184 | InstructionCost operator--(int) { |
185 | InstructionCost Copy = *this; |
186 | --*this; |
187 | return Copy; |
188 | } |
189 | |
190 | /// For the comparison operators we have chosen to use lexicographical |
191 | /// ordering where valid costs are always considered to be less than invalid |
192 | /// costs. This avoids having to add asserts to the comparison operators that |
193 | /// the states are valid and users can test for validity of the cost |
194 | /// explicitly. |
195 | bool operator<(const InstructionCost &RHS) const { |
196 | if (State != RHS.State) |
197 | return State < RHS.State; |
198 | return Value < RHS.Value; |
199 | } |
200 | |
201 | // Implement in terms of operator< to ensure that the two comparisons stay in |
202 | // sync |
203 | bool operator==(const InstructionCost &RHS) const { |
204 | return !(*this < RHS) && !(RHS < *this); |
205 | } |
206 | |
207 | bool operator!=(const InstructionCost &RHS) const { return !(*this == RHS); } |
208 | |
209 | bool operator==(const CostType RHS) const { |
210 | InstructionCost RHS2(RHS); |
211 | return *this == RHS2; |
212 | } |
213 | |
214 | bool operator!=(const CostType RHS) const { return !(*this == RHS); } |
215 | |
216 | bool operator>(const InstructionCost &RHS) const { return RHS < *this; } |
217 | |
218 | bool operator<=(const InstructionCost &RHS) const { return !(RHS < *this); } |
219 | |
220 | bool operator>=(const InstructionCost &RHS) const { return !(*this < RHS); } |
221 | |
222 | bool operator<(const CostType RHS) const { |
223 | InstructionCost RHS2(RHS); |
224 | return *this < RHS2; |
225 | } |
226 | |
227 | bool operator>(const CostType RHS) const { |
228 | InstructionCost RHS2(RHS); |
229 | return *this > RHS2; |
230 | } |
231 | |
232 | bool operator<=(const CostType RHS) const { |
233 | InstructionCost RHS2(RHS); |
234 | return *this <= RHS2; |
235 | } |
236 | |
237 | bool operator>=(const CostType RHS) const { |
238 | InstructionCost RHS2(RHS); |
239 | return *this >= RHS2; |
240 | } |
241 | |
242 | void print(raw_ostream &OS) const; |
243 | |
244 | template <class Function> |
245 | auto map(const Function &F) const -> InstructionCost { |
246 | if (isValid()) |
247 | return F(Value); |
248 | return getInvalid(); |
249 | } |
250 | }; |
251 | |
252 | inline InstructionCost operator+(const InstructionCost &LHS, |
253 | const InstructionCost &RHS) { |
254 | InstructionCost LHS2(LHS); |
255 | LHS2 += RHS; |
256 | return LHS2; |
257 | } |
258 | |
259 | inline InstructionCost operator-(const InstructionCost &LHS, |
260 | const InstructionCost &RHS) { |
261 | InstructionCost LHS2(LHS); |
262 | LHS2 -= RHS; |
263 | return LHS2; |
264 | } |
265 | |
266 | inline InstructionCost operator*(const InstructionCost &LHS, |
267 | const InstructionCost &RHS) { |
268 | InstructionCost LHS2(LHS); |
269 | LHS2 *= RHS; |
270 | return LHS2; |
271 | } |
272 | |
273 | inline InstructionCost operator/(const InstructionCost &LHS, |
274 | const InstructionCost &RHS) { |
275 | InstructionCost LHS2(LHS); |
276 | LHS2 /= RHS; |
277 | return LHS2; |
278 | } |
279 | |
280 | inline raw_ostream &operator<<(raw_ostream &OS, const InstructionCost &V) { |
281 | V.print(OS); |
282 | return OS; |
283 | } |
284 | |
285 | } // namespace llvm |
286 | |
287 | #endif |
288 | |