1 | //===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the classes used to represent and build scalar expressions. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H |
14 | #define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H |
15 | |
16 | #include "llvm/ADT/DenseMap.h" |
17 | #include "llvm/ADT/SmallPtrSet.h" |
18 | #include "llvm/ADT/SmallVector.h" |
19 | #include "llvm/Analysis/ScalarEvolution.h" |
20 | #include "llvm/IR/Constants.h" |
21 | #include "llvm/IR/ValueHandle.h" |
22 | #include "llvm/Support/Casting.h" |
23 | #include "llvm/Support/ErrorHandling.h" |
24 | #include <cassert> |
25 | #include <cstddef> |
26 | |
27 | namespace llvm { |
28 | |
29 | class APInt; |
30 | class Constant; |
31 | class ConstantInt; |
32 | class ConstantRange; |
33 | class Loop; |
34 | class Type; |
35 | class Value; |
36 | |
37 | enum SCEVTypes : unsigned short { |
38 | // These should be ordered in terms of increasing complexity to make the |
39 | // folders simpler. |
40 | scConstant, |
41 | scVScale, |
42 | scTruncate, |
43 | scZeroExtend, |
44 | scSignExtend, |
45 | scAddExpr, |
46 | scMulExpr, |
47 | scUDivExpr, |
48 | scAddRecExpr, |
49 | scUMaxExpr, |
50 | scSMaxExpr, |
51 | scUMinExpr, |
52 | scSMinExpr, |
53 | scSequentialUMinExpr, |
54 | scPtrToInt, |
55 | scUnknown, |
56 | scCouldNotCompute |
57 | }; |
58 | |
59 | /// This class represents a constant integer value. |
60 | class SCEVConstant : public SCEV { |
61 | friend class ScalarEvolution; |
62 | |
63 | ConstantInt *V; |
64 | |
65 | SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v) |
66 | : SCEV(ID, scConstant, 1), V(v) {} |
67 | |
68 | public: |
69 | ConstantInt *getValue() const { return V; } |
70 | const APInt &getAPInt() const { return getValue()->getValue(); } |
71 | |
72 | Type *getType() const { return V->getType(); } |
73 | |
74 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
75 | static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; } |
76 | }; |
77 | |
78 | /// This class represents the value of vscale, as used when defining the length |
79 | /// of a scalable vector or returned by the llvm.vscale() intrinsic. |
80 | class SCEVVScale : public SCEV { |
81 | friend class ScalarEvolution; |
82 | |
83 | SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty) |
84 | : SCEV(ID, scVScale, 0), Ty(ty) {} |
85 | |
86 | Type *Ty; |
87 | |
88 | public: |
89 | Type *getType() const { return Ty; } |
90 | |
91 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
92 | static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; } |
93 | }; |
94 | |
95 | inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) { |
96 | APInt Size(16, 1); |
97 | for (const auto *Arg : Args) |
98 | Size = Size.uadd_sat(RHS: APInt(16, Arg->getExpressionSize())); |
99 | return (unsigned short)Size.getZExtValue(); |
100 | } |
101 | |
102 | /// This is the base class for unary cast operator classes. |
103 | class SCEVCastExpr : public SCEV { |
104 | protected: |
105 | const SCEV *Op; |
106 | Type *Ty; |
107 | |
108 | SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, |
109 | Type *ty); |
110 | |
111 | public: |
112 | const SCEV *getOperand() const { return Op; } |
113 | const SCEV *getOperand(unsigned i) const { |
114 | assert(i == 0 && "Operand index out of range!" ); |
115 | return Op; |
116 | } |
117 | ArrayRef<const SCEV *> operands() const { return Op; } |
118 | size_t getNumOperands() const { return 1; } |
119 | Type *getType() const { return Ty; } |
120 | |
121 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
122 | static bool classof(const SCEV *S) { |
123 | return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate || |
124 | S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend; |
125 | } |
126 | }; |
127 | |
128 | /// This class represents a cast from a pointer to a pointer-sized integer |
129 | /// value. |
130 | class SCEVPtrToIntExpr : public SCEVCastExpr { |
131 | friend class ScalarEvolution; |
132 | |
133 | SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy); |
134 | |
135 | public: |
136 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
137 | static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; } |
138 | }; |
139 | |
140 | /// This is the base class for unary integral cast operator classes. |
141 | class SCEVIntegralCastExpr : public SCEVCastExpr { |
142 | protected: |
143 | SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, |
144 | const SCEV *op, Type *ty); |
145 | |
146 | public: |
147 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
148 | static bool classof(const SCEV *S) { |
149 | return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend || |
150 | S->getSCEVType() == scSignExtend; |
151 | } |
152 | }; |
153 | |
154 | /// This class represents a truncation of an integer value to a |
155 | /// smaller integer value. |
156 | class SCEVTruncateExpr : public SCEVIntegralCastExpr { |
157 | friend class ScalarEvolution; |
158 | |
159 | SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); |
160 | |
161 | public: |
162 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
163 | static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; } |
164 | }; |
165 | |
166 | /// This class represents a zero extension of a small integer value |
167 | /// to a larger integer value. |
168 | class SCEVZeroExtendExpr : public SCEVIntegralCastExpr { |
169 | friend class ScalarEvolution; |
170 | |
171 | SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); |
172 | |
173 | public: |
174 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
175 | static bool classof(const SCEV *S) { |
176 | return S->getSCEVType() == scZeroExtend; |
177 | } |
178 | }; |
179 | |
180 | /// This class represents a sign extension of a small integer value |
181 | /// to a larger integer value. |
182 | class SCEVSignExtendExpr : public SCEVIntegralCastExpr { |
183 | friend class ScalarEvolution; |
184 | |
185 | SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty); |
186 | |
187 | public: |
188 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
189 | static bool classof(const SCEV *S) { |
190 | return S->getSCEVType() == scSignExtend; |
191 | } |
192 | }; |
193 | |
194 | /// This node is a base class providing common functionality for |
195 | /// n'ary operators. |
196 | class SCEVNAryExpr : public SCEV { |
197 | protected: |
198 | // Since SCEVs are immutable, ScalarEvolution allocates operand |
199 | // arrays with its SCEVAllocator, so this class just needs a simple |
200 | // pointer rather than a more elaborate vector-like data structure. |
201 | // This also avoids the need for a non-trivial destructor. |
202 | const SCEV *const *Operands; |
203 | size_t NumOperands; |
204 | |
205 | SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, |
206 | const SCEV *const *O, size_t N) |
207 | : SCEV(ID, T, computeExpressionSize(Args: ArrayRef(O, N))), Operands(O), |
208 | NumOperands(N) {} |
209 | |
210 | public: |
211 | size_t getNumOperands() const { return NumOperands; } |
212 | |
213 | const SCEV *getOperand(unsigned i) const { |
214 | assert(i < NumOperands && "Operand index out of range!" ); |
215 | return Operands[i]; |
216 | } |
217 | |
218 | ArrayRef<const SCEV *> operands() const { |
219 | return ArrayRef(Operands, NumOperands); |
220 | } |
221 | |
222 | NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const { |
223 | return (NoWrapFlags)(SubclassData & Mask); |
224 | } |
225 | |
226 | bool hasNoUnsignedWrap() const { |
227 | return getNoWrapFlags(Mask: FlagNUW) != FlagAnyWrap; |
228 | } |
229 | |
230 | bool hasNoSignedWrap() const { |
231 | return getNoWrapFlags(Mask: FlagNSW) != FlagAnyWrap; |
232 | } |
233 | |
234 | bool hasNoSelfWrap() const { return getNoWrapFlags(Mask: FlagNW) != FlagAnyWrap; } |
235 | |
236 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
237 | static bool classof(const SCEV *S) { |
238 | return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || |
239 | S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || |
240 | S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr || |
241 | S->getSCEVType() == scSequentialUMinExpr || |
242 | S->getSCEVType() == scAddRecExpr; |
243 | } |
244 | }; |
245 | |
246 | /// This node is the base class for n'ary commutative operators. |
247 | class SCEVCommutativeExpr : public SCEVNAryExpr { |
248 | protected: |
249 | SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, |
250 | const SCEV *const *O, size_t N) |
251 | : SCEVNAryExpr(ID, T, O, N) {} |
252 | |
253 | public: |
254 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
255 | static bool classof(const SCEV *S) { |
256 | return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr || |
257 | S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr || |
258 | S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr; |
259 | } |
260 | |
261 | /// Set flags for a non-recurrence without clearing previously set flags. |
262 | void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } |
263 | }; |
264 | |
265 | /// This node represents an addition of some number of SCEVs. |
266 | class SCEVAddExpr : public SCEVCommutativeExpr { |
267 | friend class ScalarEvolution; |
268 | |
269 | Type *Ty; |
270 | |
271 | SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) |
272 | : SCEVCommutativeExpr(ID, scAddExpr, O, N) { |
273 | auto *FirstPointerTypedOp = find_if(Range: operands(), P: [](const SCEV *Op) { |
274 | return Op->getType()->isPointerTy(); |
275 | }); |
276 | if (FirstPointerTypedOp != operands().end()) |
277 | Ty = (*FirstPointerTypedOp)->getType(); |
278 | else |
279 | Ty = getOperand(i: 0)->getType(); |
280 | } |
281 | |
282 | public: |
283 | Type *getType() const { return Ty; } |
284 | |
285 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
286 | static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; } |
287 | }; |
288 | |
289 | /// This node represents multiplication of some number of SCEVs. |
290 | class SCEVMulExpr : public SCEVCommutativeExpr { |
291 | friend class ScalarEvolution; |
292 | |
293 | SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) |
294 | : SCEVCommutativeExpr(ID, scMulExpr, O, N) {} |
295 | |
296 | public: |
297 | Type *getType() const { return getOperand(i: 0)->getType(); } |
298 | |
299 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
300 | static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; } |
301 | }; |
302 | |
303 | /// This class represents a binary unsigned division operation. |
304 | class SCEVUDivExpr : public SCEV { |
305 | friend class ScalarEvolution; |
306 | |
307 | std::array<const SCEV *, 2> Operands; |
308 | |
309 | SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs) |
310 | : SCEV(ID, scUDivExpr, computeExpressionSize(Args: {lhs, rhs})) { |
311 | Operands[0] = lhs; |
312 | Operands[1] = rhs; |
313 | } |
314 | |
315 | public: |
316 | const SCEV *getLHS() const { return Operands[0]; } |
317 | const SCEV *getRHS() const { return Operands[1]; } |
318 | size_t getNumOperands() const { return 2; } |
319 | const SCEV *getOperand(unsigned i) const { |
320 | assert((i == 0 || i == 1) && "Operand index out of range!" ); |
321 | return i == 0 ? getLHS() : getRHS(); |
322 | } |
323 | |
324 | ArrayRef<const SCEV *> operands() const { return Operands; } |
325 | |
326 | Type *getType() const { |
327 | // In most cases the types of LHS and RHS will be the same, but in some |
328 | // crazy cases one or the other may be a pointer. ScalarEvolution doesn't |
329 | // depend on the type for correctness, but handling types carefully can |
330 | // avoid extra casts in the SCEVExpander. The LHS is more likely to be |
331 | // a pointer type than the RHS, so use the RHS' type here. |
332 | return getRHS()->getType(); |
333 | } |
334 | |
335 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
336 | static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; } |
337 | }; |
338 | |
339 | /// This node represents a polynomial recurrence on the trip count |
340 | /// of the specified loop. This is the primary focus of the |
341 | /// ScalarEvolution framework; all the other SCEV subclasses are |
342 | /// mostly just supporting infrastructure to allow SCEVAddRecExpr |
343 | /// expressions to be created and analyzed. |
344 | /// |
345 | /// All operands of an AddRec are required to be loop invariant. |
346 | /// |
347 | class SCEVAddRecExpr : public SCEVNAryExpr { |
348 | friend class ScalarEvolution; |
349 | |
350 | const Loop *L; |
351 | |
352 | SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N, |
353 | const Loop *l) |
354 | : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {} |
355 | |
356 | public: |
357 | Type *getType() const { return getStart()->getType(); } |
358 | const SCEV *getStart() const { return Operands[0]; } |
359 | const Loop *getLoop() const { return L; } |
360 | |
361 | /// Constructs and returns the recurrence indicating how much this |
362 | /// expression steps by. If this is a polynomial of degree N, it |
363 | /// returns a chrec of degree N-1. We cannot determine whether |
364 | /// the step recurrence has self-wraparound. |
365 | const SCEV *getStepRecurrence(ScalarEvolution &SE) const { |
366 | if (isAffine()) |
367 | return getOperand(i: 1); |
368 | return SE.getAddRecExpr( |
369 | Operands: SmallVector<const SCEV *, 3>(operands().drop_front()), L: getLoop(), |
370 | Flags: FlagAnyWrap); |
371 | } |
372 | |
373 | /// Return true if this represents an expression A + B*x where A |
374 | /// and B are loop invariant values. |
375 | bool isAffine() const { |
376 | // We know that the start value is invariant. This expression is thus |
377 | // affine iff the step is also invariant. |
378 | return getNumOperands() == 2; |
379 | } |
380 | |
381 | /// Return true if this represents an expression A + B*x + C*x^2 |
382 | /// where A, B and C are loop invariant values. This corresponds |
383 | /// to an addrec of the form {L,+,M,+,N} |
384 | bool isQuadratic() const { return getNumOperands() == 3; } |
385 | |
386 | /// Set flags for a recurrence without clearing any previously set flags. |
387 | /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here |
388 | /// to make it easier to propagate flags. |
389 | void setNoWrapFlags(NoWrapFlags Flags) { |
390 | if (Flags & (FlagNUW | FlagNSW)) |
391 | Flags = ScalarEvolution::setFlags(Flags, OnFlags: FlagNW); |
392 | SubclassData |= Flags; |
393 | } |
394 | |
395 | /// Return the value of this chain of recurrences at the specified |
396 | /// iteration number. |
397 | const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const; |
398 | |
399 | /// Return the value of this chain of recurrences at the specified iteration |
400 | /// number. Takes an explicit list of operands to represent an AddRec. |
401 | static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands, |
402 | const SCEV *It, ScalarEvolution &SE); |
403 | |
404 | /// Return the number of iterations of this loop that produce |
405 | /// values in the specified constant range. Another way of |
406 | /// looking at this is that it returns the first iteration number |
407 | /// where the value is not in the condition, thus computing the |
408 | /// exit count. If the iteration count can't be computed, an |
409 | /// instance of SCEVCouldNotCompute is returned. |
410 | const SCEV *getNumIterationsInRange(const ConstantRange &Range, |
411 | ScalarEvolution &SE) const; |
412 | |
413 | /// Return an expression representing the value of this expression |
414 | /// one iteration of the loop ahead. |
415 | const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const; |
416 | |
417 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
418 | static bool classof(const SCEV *S) { |
419 | return S->getSCEVType() == scAddRecExpr; |
420 | } |
421 | }; |
422 | |
423 | /// This node is the base class min/max selections. |
424 | class SCEVMinMaxExpr : public SCEVCommutativeExpr { |
425 | friend class ScalarEvolution; |
426 | |
427 | static bool isMinMaxType(enum SCEVTypes T) { |
428 | return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr || |
429 | T == scUMinExpr; |
430 | } |
431 | |
432 | protected: |
433 | /// Note: Constructing subclasses via this constructor is allowed |
434 | SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, |
435 | const SCEV *const *O, size_t N) |
436 | : SCEVCommutativeExpr(ID, T, O, N) { |
437 | assert(isMinMaxType(T)); |
438 | // Min and max never overflow |
439 | setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); |
440 | } |
441 | |
442 | public: |
443 | Type *getType() const { return getOperand(i: 0)->getType(); } |
444 | |
445 | static bool classof(const SCEV *S) { return isMinMaxType(T: S->getSCEVType()); } |
446 | |
447 | static enum SCEVTypes negate(enum SCEVTypes T) { |
448 | switch (T) { |
449 | case scSMaxExpr: |
450 | return scSMinExpr; |
451 | case scSMinExpr: |
452 | return scSMaxExpr; |
453 | case scUMaxExpr: |
454 | return scUMinExpr; |
455 | case scUMinExpr: |
456 | return scUMaxExpr; |
457 | default: |
458 | llvm_unreachable("Not a min or max SCEV type!" ); |
459 | } |
460 | } |
461 | }; |
462 | |
463 | /// This class represents a signed maximum selection. |
464 | class SCEVSMaxExpr : public SCEVMinMaxExpr { |
465 | friend class ScalarEvolution; |
466 | |
467 | SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) |
468 | : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {} |
469 | |
470 | public: |
471 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
472 | static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; } |
473 | }; |
474 | |
475 | /// This class represents an unsigned maximum selection. |
476 | class SCEVUMaxExpr : public SCEVMinMaxExpr { |
477 | friend class ScalarEvolution; |
478 | |
479 | SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) |
480 | : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {} |
481 | |
482 | public: |
483 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
484 | static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; } |
485 | }; |
486 | |
487 | /// This class represents a signed minimum selection. |
488 | class SCEVSMinExpr : public SCEVMinMaxExpr { |
489 | friend class ScalarEvolution; |
490 | |
491 | SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) |
492 | : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {} |
493 | |
494 | public: |
495 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
496 | static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; } |
497 | }; |
498 | |
499 | /// This class represents an unsigned minimum selection. |
500 | class SCEVUMinExpr : public SCEVMinMaxExpr { |
501 | friend class ScalarEvolution; |
502 | |
503 | SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N) |
504 | : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {} |
505 | |
506 | public: |
507 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
508 | static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; } |
509 | }; |
510 | |
511 | /// This node is the base class for sequential/in-order min/max selections. |
512 | /// Note that their fundamental difference from SCEVMinMaxExpr's is that they |
513 | /// are early-returning upon reaching saturation point. |
514 | /// I.e. given `0 umin_seq poison`, the result will be `0`, |
515 | /// while the result of `0 umin poison` is `poison`. |
516 | class SCEVSequentialMinMaxExpr : public SCEVNAryExpr { |
517 | friend class ScalarEvolution; |
518 | |
519 | static bool isSequentialMinMaxType(enum SCEVTypes T) { |
520 | return T == scSequentialUMinExpr; |
521 | } |
522 | |
523 | /// Set flags for a non-recurrence without clearing previously set flags. |
524 | void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; } |
525 | |
526 | protected: |
527 | /// Note: Constructing subclasses via this constructor is allowed |
528 | SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T, |
529 | const SCEV *const *O, size_t N) |
530 | : SCEVNAryExpr(ID, T, O, N) { |
531 | assert(isSequentialMinMaxType(T)); |
532 | // Min and max never overflow |
533 | setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)); |
534 | } |
535 | |
536 | public: |
537 | Type *getType() const { return getOperand(i: 0)->getType(); } |
538 | |
539 | static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) { |
540 | assert(isSequentialMinMaxType(Ty)); |
541 | switch (Ty) { |
542 | case scSequentialUMinExpr: |
543 | return scUMinExpr; |
544 | default: |
545 | llvm_unreachable("Not a sequential min/max type." ); |
546 | } |
547 | } |
548 | |
549 | SCEVTypes getEquivalentNonSequentialSCEVType() const { |
550 | return getEquivalentNonSequentialSCEVType(Ty: getSCEVType()); |
551 | } |
552 | |
553 | static bool classof(const SCEV *S) { |
554 | return isSequentialMinMaxType(T: S->getSCEVType()); |
555 | } |
556 | }; |
557 | |
558 | /// This class represents a sequential/in-order unsigned minimum selection. |
559 | class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr { |
560 | friend class ScalarEvolution; |
561 | |
562 | SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, |
563 | size_t N) |
564 | : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {} |
565 | |
566 | public: |
567 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
568 | static bool classof(const SCEV *S) { |
569 | return S->getSCEVType() == scSequentialUMinExpr; |
570 | } |
571 | }; |
572 | |
573 | /// This means that we are dealing with an entirely unknown SCEV |
574 | /// value, and only represent it as its LLVM Value. This is the |
575 | /// "bottom" value for the analysis. |
576 | class SCEVUnknown final : public SCEV, private CallbackVH { |
577 | friend class ScalarEvolution; |
578 | |
579 | /// The parent ScalarEvolution value. This is used to update the |
580 | /// parent's maps when the value associated with a SCEVUnknown is |
581 | /// deleted or RAUW'd. |
582 | ScalarEvolution *SE; |
583 | |
584 | /// The next pointer in the linked list of all SCEVUnknown |
585 | /// instances owned by a ScalarEvolution. |
586 | SCEVUnknown *Next; |
587 | |
588 | SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se, |
589 | SCEVUnknown *next) |
590 | : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {} |
591 | |
592 | // Implement CallbackVH. |
593 | void deleted() override; |
594 | void allUsesReplacedWith(Value *New) override; |
595 | |
596 | public: |
597 | Value *getValue() const { return getValPtr(); } |
598 | |
599 | Type *getType() const { return getValPtr()->getType(); } |
600 | |
601 | /// Methods for support type inquiry through isa, cast, and dyn_cast: |
602 | static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; } |
603 | }; |
604 | |
605 | /// This class defines a simple visitor class that may be used for |
606 | /// various SCEV analysis purposes. |
607 | template <typename SC, typename RetVal = void> struct SCEVVisitor { |
608 | RetVal visit(const SCEV *S) { |
609 | switch (S->getSCEVType()) { |
610 | case scConstant: |
611 | return ((SC *)this)->visitConstant((const SCEVConstant *)S); |
612 | case scVScale: |
613 | return ((SC *)this)->visitVScale((const SCEVVScale *)S); |
614 | case scPtrToInt: |
615 | return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S); |
616 | case scTruncate: |
617 | return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S); |
618 | case scZeroExtend: |
619 | return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S); |
620 | case scSignExtend: |
621 | return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S); |
622 | case scAddExpr: |
623 | return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S); |
624 | case scMulExpr: |
625 | return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S); |
626 | case scUDivExpr: |
627 | return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S); |
628 | case scAddRecExpr: |
629 | return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S); |
630 | case scSMaxExpr: |
631 | return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S); |
632 | case scUMaxExpr: |
633 | return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S); |
634 | case scSMinExpr: |
635 | return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S); |
636 | case scUMinExpr: |
637 | return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S); |
638 | case scSequentialUMinExpr: |
639 | return ((SC *)this) |
640 | ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S); |
641 | case scUnknown: |
642 | return ((SC *)this)->visitUnknown((const SCEVUnknown *)S); |
643 | case scCouldNotCompute: |
644 | return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S); |
645 | } |
646 | llvm_unreachable("Unknown SCEV kind!" ); |
647 | } |
648 | |
649 | RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) { |
650 | llvm_unreachable("Invalid use of SCEVCouldNotCompute!" ); |
651 | } |
652 | }; |
653 | |
654 | /// Visit all nodes in the expression tree using worklist traversal. |
655 | /// |
656 | /// Visitor implements: |
657 | /// // return true to follow this node. |
658 | /// bool follow(const SCEV *S); |
659 | /// // return true to terminate the search. |
660 | /// bool isDone(); |
661 | template <typename SV> class SCEVTraversal { |
662 | SV &Visitor; |
663 | SmallVector<const SCEV *, 8> Worklist; |
664 | SmallPtrSet<const SCEV *, 8> Visited; |
665 | |
666 | void push(const SCEV *S) { |
667 | if (Visited.insert(Ptr: S).second && Visitor.follow(S)) |
668 | Worklist.push_back(Elt: S); |
669 | } |
670 | |
671 | public: |
672 | SCEVTraversal(SV &V) : Visitor(V) {} |
673 | |
674 | void visitAll(const SCEV *Root) { |
675 | push(S: Root); |
676 | while (!Worklist.empty() && !Visitor.isDone()) { |
677 | const SCEV *S = Worklist.pop_back_val(); |
678 | |
679 | switch (S->getSCEVType()) { |
680 | case scConstant: |
681 | case scVScale: |
682 | case scUnknown: |
683 | continue; |
684 | case scPtrToInt: |
685 | case scTruncate: |
686 | case scZeroExtend: |
687 | case scSignExtend: |
688 | case scAddExpr: |
689 | case scMulExpr: |
690 | case scUDivExpr: |
691 | case scSMaxExpr: |
692 | case scUMaxExpr: |
693 | case scSMinExpr: |
694 | case scUMinExpr: |
695 | case scSequentialUMinExpr: |
696 | case scAddRecExpr: |
697 | for (const auto *Op : S->operands()) { |
698 | push(S: Op); |
699 | if (Visitor.isDone()) |
700 | break; |
701 | } |
702 | continue; |
703 | case scCouldNotCompute: |
704 | llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!" ); |
705 | } |
706 | llvm_unreachable("Unknown SCEV kind!" ); |
707 | } |
708 | } |
709 | }; |
710 | |
711 | /// Use SCEVTraversal to visit all nodes in the given expression tree. |
712 | template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) { |
713 | SCEVTraversal<SV> T(Visitor); |
714 | T.visitAll(Root); |
715 | } |
716 | |
717 | /// Return true if any node in \p Root satisfies the predicate \p Pred. |
718 | template <typename PredTy> |
719 | bool SCEVExprContains(const SCEV *Root, PredTy Pred) { |
720 | struct FindClosure { |
721 | bool Found = false; |
722 | PredTy Pred; |
723 | |
724 | FindClosure(PredTy Pred) : Pred(Pred) {} |
725 | |
726 | bool follow(const SCEV *S) { |
727 | if (!Pred(S)) |
728 | return true; |
729 | |
730 | Found = true; |
731 | return false; |
732 | } |
733 | |
734 | bool isDone() const { return Found; } |
735 | }; |
736 | |
737 | FindClosure FC(Pred); |
738 | visitAll(Root, FC); |
739 | return FC.Found; |
740 | } |
741 | |
742 | /// This visitor recursively visits a SCEV expression and re-writes it. |
743 | /// The result from each visit is cached, so it will return the same |
744 | /// SCEV for the same input. |
745 | template <typename SC> |
746 | class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> { |
747 | protected: |
748 | ScalarEvolution &SE; |
749 | // Memoize the result of each visit so that we only compute once for |
750 | // the same input SCEV. This is to avoid redundant computations when |
751 | // a SCEV is referenced by multiple SCEVs. Without memoization, this |
752 | // visit algorithm would have exponential time complexity in the worst |
753 | // case, causing the compiler to hang on certain tests. |
754 | SmallDenseMap<const SCEV *, const SCEV *> RewriteResults; |
755 | |
756 | public: |
757 | SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {} |
758 | |
759 | const SCEV *visit(const SCEV *S) { |
760 | auto It = RewriteResults.find(Val: S); |
761 | if (It != RewriteResults.end()) |
762 | return It->second; |
763 | auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S); |
764 | auto Result = RewriteResults.try_emplace(S, Visited); |
765 | assert(Result.second && "Should insert a new entry" ); |
766 | return Result.first->second; |
767 | } |
768 | |
769 | const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } |
770 | |
771 | const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; } |
772 | |
773 | const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { |
774 | const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); |
775 | return Operand == Expr->getOperand() |
776 | ? Expr |
777 | : SE.getPtrToIntExpr(Op: Operand, Ty: Expr->getType()); |
778 | } |
779 | |
780 | const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) { |
781 | const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); |
782 | return Operand == Expr->getOperand() |
783 | ? Expr |
784 | : SE.getTruncateExpr(Op: Operand, Ty: Expr->getType()); |
785 | } |
786 | |
787 | const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { |
788 | const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); |
789 | return Operand == Expr->getOperand() |
790 | ? Expr |
791 | : SE.getZeroExtendExpr(Op: Operand, Ty: Expr->getType()); |
792 | } |
793 | |
794 | const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { |
795 | const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); |
796 | return Operand == Expr->getOperand() |
797 | ? Expr |
798 | : SE.getSignExtendExpr(Op: Operand, Ty: Expr->getType()); |
799 | } |
800 | |
801 | const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { |
802 | SmallVector<const SCEV *, 2> Operands; |
803 | bool Changed = false; |
804 | for (const auto *Op : Expr->operands()) { |
805 | Operands.push_back(Elt: ((SC *)this)->visit(Op)); |
806 | Changed |= Op != Operands.back(); |
807 | } |
808 | return !Changed ? Expr : SE.getAddExpr(Ops&: Operands); |
809 | } |
810 | |
811 | const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { |
812 | SmallVector<const SCEV *, 2> Operands; |
813 | bool Changed = false; |
814 | for (const auto *Op : Expr->operands()) { |
815 | Operands.push_back(Elt: ((SC *)this)->visit(Op)); |
816 | Changed |= Op != Operands.back(); |
817 | } |
818 | return !Changed ? Expr : SE.getMulExpr(Ops&: Operands); |
819 | } |
820 | |
821 | const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) { |
822 | auto *LHS = ((SC *)this)->visit(Expr->getLHS()); |
823 | auto *RHS = ((SC *)this)->visit(Expr->getRHS()); |
824 | bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS(); |
825 | return !Changed ? Expr : SE.getUDivExpr(LHS, RHS); |
826 | } |
827 | |
828 | const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { |
829 | SmallVector<const SCEV *, 2> Operands; |
830 | bool Changed = false; |
831 | for (const auto *Op : Expr->operands()) { |
832 | Operands.push_back(Elt: ((SC *)this)->visit(Op)); |
833 | Changed |= Op != Operands.back(); |
834 | } |
835 | return !Changed ? Expr |
836 | : SE.getAddRecExpr(Operands, L: Expr->getLoop(), |
837 | Flags: Expr->getNoWrapFlags()); |
838 | } |
839 | |
840 | const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) { |
841 | SmallVector<const SCEV *, 2> Operands; |
842 | bool Changed = false; |
843 | for (const auto *Op : Expr->operands()) { |
844 | Operands.push_back(Elt: ((SC *)this)->visit(Op)); |
845 | Changed |= Op != Operands.back(); |
846 | } |
847 | return !Changed ? Expr : SE.getSMaxExpr(Operands); |
848 | } |
849 | |
850 | const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) { |
851 | SmallVector<const SCEV *, 2> Operands; |
852 | bool Changed = false; |
853 | for (const auto *Op : Expr->operands()) { |
854 | Operands.push_back(Elt: ((SC *)this)->visit(Op)); |
855 | Changed |= Op != Operands.back(); |
856 | } |
857 | return !Changed ? Expr : SE.getUMaxExpr(Operands); |
858 | } |
859 | |
860 | const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) { |
861 | SmallVector<const SCEV *, 2> Operands; |
862 | bool Changed = false; |
863 | for (const auto *Op : Expr->operands()) { |
864 | Operands.push_back(Elt: ((SC *)this)->visit(Op)); |
865 | Changed |= Op != Operands.back(); |
866 | } |
867 | return !Changed ? Expr : SE.getSMinExpr(Operands); |
868 | } |
869 | |
870 | const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) { |
871 | SmallVector<const SCEV *, 2> Operands; |
872 | bool Changed = false; |
873 | for (const auto *Op : Expr->operands()) { |
874 | Operands.push_back(Elt: ((SC *)this)->visit(Op)); |
875 | Changed |= Op != Operands.back(); |
876 | } |
877 | return !Changed ? Expr : SE.getUMinExpr(Operands); |
878 | } |
879 | |
880 | const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { |
881 | SmallVector<const SCEV *, 2> Operands; |
882 | bool Changed = false; |
883 | for (const auto *Op : Expr->operands()) { |
884 | Operands.push_back(Elt: ((SC *)this)->visit(Op)); |
885 | Changed |= Op != Operands.back(); |
886 | } |
887 | return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/Sequential: true); |
888 | } |
889 | |
890 | const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; } |
891 | |
892 | const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { |
893 | return Expr; |
894 | } |
895 | }; |
896 | |
897 | using ValueToValueMap = DenseMap<const Value *, Value *>; |
898 | using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>; |
899 | |
900 | /// The SCEVParameterRewriter takes a scalar evolution expression and updates |
901 | /// the SCEVUnknown components following the Map (Value -> SCEV). |
902 | class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> { |
903 | public: |
904 | static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, |
905 | ValueToSCEVMapTy &Map) { |
906 | SCEVParameterRewriter Rewriter(SE, Map); |
907 | return Rewriter.visit(S: Scev); |
908 | } |
909 | |
910 | SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) |
911 | : SCEVRewriteVisitor(SE), Map(M) {} |
912 | |
913 | const SCEV *visitUnknown(const SCEVUnknown *Expr) { |
914 | auto I = Map.find(Val: Expr->getValue()); |
915 | if (I == Map.end()) |
916 | return Expr; |
917 | return I->second; |
918 | } |
919 | |
920 | private: |
921 | ValueToSCEVMapTy ⤅ |
922 | }; |
923 | |
924 | using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>; |
925 | |
926 | /// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies |
927 | /// the Map (Loop -> SCEV) to all AddRecExprs. |
928 | class SCEVLoopAddRecRewriter |
929 | : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> { |
930 | public: |
931 | SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M) |
932 | : SCEVRewriteVisitor(SE), Map(M) {} |
933 | |
934 | static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map, |
935 | ScalarEvolution &SE) { |
936 | SCEVLoopAddRecRewriter Rewriter(SE, Map); |
937 | return Rewriter.visit(S: Scev); |
938 | } |
939 | |
940 | const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { |
941 | SmallVector<const SCEV *, 2> Operands; |
942 | for (const SCEV *Op : Expr->operands()) |
943 | Operands.push_back(Elt: visit(S: Op)); |
944 | |
945 | const Loop *L = Expr->getLoop(); |
946 | if (0 == Map.count(Val: L)) |
947 | return SE.getAddRecExpr(Operands, L, Flags: Expr->getNoWrapFlags()); |
948 | |
949 | return SCEVAddRecExpr::evaluateAtIteration(Operands, It: Map[L], SE); |
950 | } |
951 | |
952 | private: |
953 | LoopToScevMapT ⤅ |
954 | }; |
955 | |
956 | } // end namespace llvm |
957 | |
958 | #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H |
959 | |