1//===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the classes used to represent and build scalar expressions.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14#define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/SmallPtrSet.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/Analysis/ScalarEvolution.h"
20#include "llvm/IR/Constants.h"
21#include "llvm/IR/ValueHandle.h"
22#include "llvm/Support/Casting.h"
23#include "llvm/Support/ErrorHandling.h"
24#include <cassert>
25#include <cstddef>
26
27namespace llvm {
28
29class APInt;
30class Constant;
31class ConstantInt;
32class ConstantRange;
33class Loop;
34class Type;
35class Value;
36
37enum SCEVTypes : unsigned short {
38 // These should be ordered in terms of increasing complexity to make the
39 // folders simpler.
40 scConstant,
41 scVScale,
42 scTruncate,
43 scZeroExtend,
44 scSignExtend,
45 scAddExpr,
46 scMulExpr,
47 scUDivExpr,
48 scAddRecExpr,
49 scUMaxExpr,
50 scSMaxExpr,
51 scUMinExpr,
52 scSMinExpr,
53 scSequentialUMinExpr,
54 scPtrToInt,
55 scUnknown,
56 scCouldNotCompute
57};
58
59/// This class represents a constant integer value.
60class SCEVConstant : public SCEV {
61 friend class ScalarEvolution;
62
63 ConstantInt *V;
64
65 SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
66 : SCEV(ID, scConstant, 1), V(v) {}
67
68public:
69 ConstantInt *getValue() const { return V; }
70 const APInt &getAPInt() const { return getValue()->getValue(); }
71
72 Type *getType() const { return V->getType(); }
73
74 /// Methods for support type inquiry through isa, cast, and dyn_cast:
75 static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
76};
77
78/// This class represents the value of vscale, as used when defining the length
79/// of a scalable vector or returned by the llvm.vscale() intrinsic.
80class SCEVVScale : public SCEV {
81 friend class ScalarEvolution;
82
83 SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty)
84 : SCEV(ID, scVScale, 0), Ty(ty) {}
85
86 Type *Ty;
87
88public:
89 Type *getType() const { return Ty; }
90
91 /// Methods for support type inquiry through isa, cast, and dyn_cast:
92 static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; }
93};
94
95inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
96 APInt Size(16, 1);
97 for (const auto *Arg : Args)
98 Size = Size.uadd_sat(RHS: APInt(16, Arg->getExpressionSize()));
99 return (unsigned short)Size.getZExtValue();
100}
101
102/// This is the base class for unary cast operator classes.
103class SCEVCastExpr : public SCEV {
104protected:
105 const SCEV *Op;
106 Type *Ty;
107
108 SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
109 Type *ty);
110
111public:
112 const SCEV *getOperand() const { return Op; }
113 const SCEV *getOperand(unsigned i) const {
114 assert(i == 0 && "Operand index out of range!");
115 return Op;
116 }
117 ArrayRef<const SCEV *> operands() const { return Op; }
118 size_t getNumOperands() const { return 1; }
119 Type *getType() const { return Ty; }
120
121 /// Methods for support type inquiry through isa, cast, and dyn_cast:
122 static bool classof(const SCEV *S) {
123 return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
124 S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
125 }
126};
127
128/// This class represents a cast from a pointer to a pointer-sized integer
129/// value.
130class SCEVPtrToIntExpr : public SCEVCastExpr {
131 friend class ScalarEvolution;
132
133 SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
134
135public:
136 /// Methods for support type inquiry through isa, cast, and dyn_cast:
137 static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
138};
139
140/// This is the base class for unary integral cast operator classes.
141class SCEVIntegralCastExpr : public SCEVCastExpr {
142protected:
143 SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
144 const SCEV *op, Type *ty);
145
146public:
147 /// Methods for support type inquiry through isa, cast, and dyn_cast:
148 static bool classof(const SCEV *S) {
149 return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
150 S->getSCEVType() == scSignExtend;
151 }
152};
153
154/// This class represents a truncation of an integer value to a
155/// smaller integer value.
156class SCEVTruncateExpr : public SCEVIntegralCastExpr {
157 friend class ScalarEvolution;
158
159 SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
160
161public:
162 /// Methods for support type inquiry through isa, cast, and dyn_cast:
163 static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
164};
165
166/// This class represents a zero extension of a small integer value
167/// to a larger integer value.
168class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
169 friend class ScalarEvolution;
170
171 SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
172
173public:
174 /// Methods for support type inquiry through isa, cast, and dyn_cast:
175 static bool classof(const SCEV *S) {
176 return S->getSCEVType() == scZeroExtend;
177 }
178};
179
180/// This class represents a sign extension of a small integer value
181/// to a larger integer value.
182class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
183 friend class ScalarEvolution;
184
185 SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
186
187public:
188 /// Methods for support type inquiry through isa, cast, and dyn_cast:
189 static bool classof(const SCEV *S) {
190 return S->getSCEVType() == scSignExtend;
191 }
192};
193
194/// This node is a base class providing common functionality for
195/// n'ary operators.
196class SCEVNAryExpr : public SCEV {
197protected:
198 // Since SCEVs are immutable, ScalarEvolution allocates operand
199 // arrays with its SCEVAllocator, so this class just needs a simple
200 // pointer rather than a more elaborate vector-like data structure.
201 // This also avoids the need for a non-trivial destructor.
202 const SCEV *const *Operands;
203 size_t NumOperands;
204
205 SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
206 const SCEV *const *O, size_t N)
207 : SCEV(ID, T, computeExpressionSize(Args: ArrayRef(O, N))), Operands(O),
208 NumOperands(N) {}
209
210public:
211 size_t getNumOperands() const { return NumOperands; }
212
213 const SCEV *getOperand(unsigned i) const {
214 assert(i < NumOperands && "Operand index out of range!");
215 return Operands[i];
216 }
217
218 ArrayRef<const SCEV *> operands() const {
219 return ArrayRef(Operands, NumOperands);
220 }
221
222 NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
223 return (NoWrapFlags)(SubclassData & Mask);
224 }
225
226 bool hasNoUnsignedWrap() const {
227 return getNoWrapFlags(Mask: FlagNUW) != FlagAnyWrap;
228 }
229
230 bool hasNoSignedWrap() const {
231 return getNoWrapFlags(Mask: FlagNSW) != FlagAnyWrap;
232 }
233
234 bool hasNoSelfWrap() const { return getNoWrapFlags(Mask: FlagNW) != FlagAnyWrap; }
235
236 /// Methods for support type inquiry through isa, cast, and dyn_cast:
237 static bool classof(const SCEV *S) {
238 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
239 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
240 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
241 S->getSCEVType() == scSequentialUMinExpr ||
242 S->getSCEVType() == scAddRecExpr;
243 }
244};
245
246/// This node is the base class for n'ary commutative operators.
247class SCEVCommutativeExpr : public SCEVNAryExpr {
248protected:
249 SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
250 const SCEV *const *O, size_t N)
251 : SCEVNAryExpr(ID, T, O, N) {}
252
253public:
254 /// Methods for support type inquiry through isa, cast, and dyn_cast:
255 static bool classof(const SCEV *S) {
256 return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
257 S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
258 S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
259 }
260
261 /// Set flags for a non-recurrence without clearing previously set flags.
262 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
263};
264
265/// This node represents an addition of some number of SCEVs.
266class SCEVAddExpr : public SCEVCommutativeExpr {
267 friend class ScalarEvolution;
268
269 Type *Ty;
270
271 SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
272 : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
273 auto *FirstPointerTypedOp = find_if(Range: operands(), P: [](const SCEV *Op) {
274 return Op->getType()->isPointerTy();
275 });
276 if (FirstPointerTypedOp != operands().end())
277 Ty = (*FirstPointerTypedOp)->getType();
278 else
279 Ty = getOperand(i: 0)->getType();
280 }
281
282public:
283 Type *getType() const { return Ty; }
284
285 /// Methods for support type inquiry through isa, cast, and dyn_cast:
286 static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
287};
288
289/// This node represents multiplication of some number of SCEVs.
290class SCEVMulExpr : public SCEVCommutativeExpr {
291 friend class ScalarEvolution;
292
293 SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
294 : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
295
296public:
297 Type *getType() const { return getOperand(i: 0)->getType(); }
298
299 /// Methods for support type inquiry through isa, cast, and dyn_cast:
300 static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
301};
302
303/// This class represents a binary unsigned division operation.
304class SCEVUDivExpr : public SCEV {
305 friend class ScalarEvolution;
306
307 std::array<const SCEV *, 2> Operands;
308
309 SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
310 : SCEV(ID, scUDivExpr, computeExpressionSize(Args: {lhs, rhs})) {
311 Operands[0] = lhs;
312 Operands[1] = rhs;
313 }
314
315public:
316 const SCEV *getLHS() const { return Operands[0]; }
317 const SCEV *getRHS() const { return Operands[1]; }
318 size_t getNumOperands() const { return 2; }
319 const SCEV *getOperand(unsigned i) const {
320 assert((i == 0 || i == 1) && "Operand index out of range!");
321 return i == 0 ? getLHS() : getRHS();
322 }
323
324 ArrayRef<const SCEV *> operands() const { return Operands; }
325
326 Type *getType() const {
327 // In most cases the types of LHS and RHS will be the same, but in some
328 // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
329 // depend on the type for correctness, but handling types carefully can
330 // avoid extra casts in the SCEVExpander. The LHS is more likely to be
331 // a pointer type than the RHS, so use the RHS' type here.
332 return getRHS()->getType();
333 }
334
335 /// Methods for support type inquiry through isa, cast, and dyn_cast:
336 static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
337};
338
339/// This node represents a polynomial recurrence on the trip count
340/// of the specified loop. This is the primary focus of the
341/// ScalarEvolution framework; all the other SCEV subclasses are
342/// mostly just supporting infrastructure to allow SCEVAddRecExpr
343/// expressions to be created and analyzed.
344///
345/// All operands of an AddRec are required to be loop invariant.
346///
347class SCEVAddRecExpr : public SCEVNAryExpr {
348 friend class ScalarEvolution;
349
350 const Loop *L;
351
352 SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
353 const Loop *l)
354 : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
355
356public:
357 Type *getType() const { return getStart()->getType(); }
358 const SCEV *getStart() const { return Operands[0]; }
359 const Loop *getLoop() const { return L; }
360
361 /// Constructs and returns the recurrence indicating how much this
362 /// expression steps by. If this is a polynomial of degree N, it
363 /// returns a chrec of degree N-1. We cannot determine whether
364 /// the step recurrence has self-wraparound.
365 const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
366 if (isAffine())
367 return getOperand(i: 1);
368 return SE.getAddRecExpr(
369 Operands: SmallVector<const SCEV *, 3>(operands().drop_front()), L: getLoop(),
370 Flags: FlagAnyWrap);
371 }
372
373 /// Return true if this represents an expression A + B*x where A
374 /// and B are loop invariant values.
375 bool isAffine() const {
376 // We know that the start value is invariant. This expression is thus
377 // affine iff the step is also invariant.
378 return getNumOperands() == 2;
379 }
380
381 /// Return true if this represents an expression A + B*x + C*x^2
382 /// where A, B and C are loop invariant values. This corresponds
383 /// to an addrec of the form {L,+,M,+,N}
384 bool isQuadratic() const { return getNumOperands() == 3; }
385
386 /// Set flags for a recurrence without clearing any previously set flags.
387 /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
388 /// to make it easier to propagate flags.
389 void setNoWrapFlags(NoWrapFlags Flags) {
390 if (Flags & (FlagNUW | FlagNSW))
391 Flags = ScalarEvolution::setFlags(Flags, OnFlags: FlagNW);
392 SubclassData |= Flags;
393 }
394
395 /// Return the value of this chain of recurrences at the specified
396 /// iteration number.
397 const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
398
399 /// Return the value of this chain of recurrences at the specified iteration
400 /// number. Takes an explicit list of operands to represent an AddRec.
401 static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
402 const SCEV *It, ScalarEvolution &SE);
403
404 /// Return the number of iterations of this loop that produce
405 /// values in the specified constant range. Another way of
406 /// looking at this is that it returns the first iteration number
407 /// where the value is not in the condition, thus computing the
408 /// exit count. If the iteration count can't be computed, an
409 /// instance of SCEVCouldNotCompute is returned.
410 const SCEV *getNumIterationsInRange(const ConstantRange &Range,
411 ScalarEvolution &SE) const;
412
413 /// Return an expression representing the value of this expression
414 /// one iteration of the loop ahead.
415 const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
416
417 /// Methods for support type inquiry through isa, cast, and dyn_cast:
418 static bool classof(const SCEV *S) {
419 return S->getSCEVType() == scAddRecExpr;
420 }
421};
422
423/// This node is the base class min/max selections.
424class SCEVMinMaxExpr : public SCEVCommutativeExpr {
425 friend class ScalarEvolution;
426
427 static bool isMinMaxType(enum SCEVTypes T) {
428 return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
429 T == scUMinExpr;
430 }
431
432protected:
433 /// Note: Constructing subclasses via this constructor is allowed
434 SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
435 const SCEV *const *O, size_t N)
436 : SCEVCommutativeExpr(ID, T, O, N) {
437 assert(isMinMaxType(T));
438 // Min and max never overflow
439 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
440 }
441
442public:
443 Type *getType() const { return getOperand(i: 0)->getType(); }
444
445 static bool classof(const SCEV *S) { return isMinMaxType(T: S->getSCEVType()); }
446
447 static enum SCEVTypes negate(enum SCEVTypes T) {
448 switch (T) {
449 case scSMaxExpr:
450 return scSMinExpr;
451 case scSMinExpr:
452 return scSMaxExpr;
453 case scUMaxExpr:
454 return scUMinExpr;
455 case scUMinExpr:
456 return scUMaxExpr;
457 default:
458 llvm_unreachable("Not a min or max SCEV type!");
459 }
460 }
461};
462
463/// This class represents a signed maximum selection.
464class SCEVSMaxExpr : public SCEVMinMaxExpr {
465 friend class ScalarEvolution;
466
467 SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
468 : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
469
470public:
471 /// Methods for support type inquiry through isa, cast, and dyn_cast:
472 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
473};
474
475/// This class represents an unsigned maximum selection.
476class SCEVUMaxExpr : public SCEVMinMaxExpr {
477 friend class ScalarEvolution;
478
479 SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
480 : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
481
482public:
483 /// Methods for support type inquiry through isa, cast, and dyn_cast:
484 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
485};
486
487/// This class represents a signed minimum selection.
488class SCEVSMinExpr : public SCEVMinMaxExpr {
489 friend class ScalarEvolution;
490
491 SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
492 : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
493
494public:
495 /// Methods for support type inquiry through isa, cast, and dyn_cast:
496 static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
497};
498
499/// This class represents an unsigned minimum selection.
500class SCEVUMinExpr : public SCEVMinMaxExpr {
501 friend class ScalarEvolution;
502
503 SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
504 : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
505
506public:
507 /// Methods for support type inquiry through isa, cast, and dyn_cast:
508 static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
509};
510
511/// This node is the base class for sequential/in-order min/max selections.
512/// Note that their fundamental difference from SCEVMinMaxExpr's is that they
513/// are early-returning upon reaching saturation point.
514/// I.e. given `0 umin_seq poison`, the result will be `0`,
515/// while the result of `0 umin poison` is `poison`.
516class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
517 friend class ScalarEvolution;
518
519 static bool isSequentialMinMaxType(enum SCEVTypes T) {
520 return T == scSequentialUMinExpr;
521 }
522
523 /// Set flags for a non-recurrence without clearing previously set flags.
524 void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
525
526protected:
527 /// Note: Constructing subclasses via this constructor is allowed
528 SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
529 const SCEV *const *O, size_t N)
530 : SCEVNAryExpr(ID, T, O, N) {
531 assert(isSequentialMinMaxType(T));
532 // Min and max never overflow
533 setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
534 }
535
536public:
537 Type *getType() const { return getOperand(i: 0)->getType(); }
538
539 static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
540 assert(isSequentialMinMaxType(Ty));
541 switch (Ty) {
542 case scSequentialUMinExpr:
543 return scUMinExpr;
544 default:
545 llvm_unreachable("Not a sequential min/max type.");
546 }
547 }
548
549 SCEVTypes getEquivalentNonSequentialSCEVType() const {
550 return getEquivalentNonSequentialSCEVType(Ty: getSCEVType());
551 }
552
553 static bool classof(const SCEV *S) {
554 return isSequentialMinMaxType(T: S->getSCEVType());
555 }
556};
557
558/// This class represents a sequential/in-order unsigned minimum selection.
559class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
560 friend class ScalarEvolution;
561
562 SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
563 size_t N)
564 : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
565
566public:
567 /// Methods for support type inquiry through isa, cast, and dyn_cast:
568 static bool classof(const SCEV *S) {
569 return S->getSCEVType() == scSequentialUMinExpr;
570 }
571};
572
573/// This means that we are dealing with an entirely unknown SCEV
574/// value, and only represent it as its LLVM Value. This is the
575/// "bottom" value for the analysis.
576class SCEVUnknown final : public SCEV, private CallbackVH {
577 friend class ScalarEvolution;
578
579 /// The parent ScalarEvolution value. This is used to update the
580 /// parent's maps when the value associated with a SCEVUnknown is
581 /// deleted or RAUW'd.
582 ScalarEvolution *SE;
583
584 /// The next pointer in the linked list of all SCEVUnknown
585 /// instances owned by a ScalarEvolution.
586 SCEVUnknown *Next;
587
588 SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
589 SCEVUnknown *next)
590 : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
591
592 // Implement CallbackVH.
593 void deleted() override;
594 void allUsesReplacedWith(Value *New) override;
595
596public:
597 Value *getValue() const { return getValPtr(); }
598
599 Type *getType() const { return getValPtr()->getType(); }
600
601 /// Methods for support type inquiry through isa, cast, and dyn_cast:
602 static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
603};
604
605/// This class defines a simple visitor class that may be used for
606/// various SCEV analysis purposes.
607template <typename SC, typename RetVal = void> struct SCEVVisitor {
608 RetVal visit(const SCEV *S) {
609 switch (S->getSCEVType()) {
610 case scConstant:
611 return ((SC *)this)->visitConstant((const SCEVConstant *)S);
612 case scVScale:
613 return ((SC *)this)->visitVScale((const SCEVVScale *)S);
614 case scPtrToInt:
615 return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
616 case scTruncate:
617 return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
618 case scZeroExtend:
619 return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
620 case scSignExtend:
621 return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
622 case scAddExpr:
623 return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
624 case scMulExpr:
625 return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
626 case scUDivExpr:
627 return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
628 case scAddRecExpr:
629 return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
630 case scSMaxExpr:
631 return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
632 case scUMaxExpr:
633 return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
634 case scSMinExpr:
635 return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
636 case scUMinExpr:
637 return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
638 case scSequentialUMinExpr:
639 return ((SC *)this)
640 ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
641 case scUnknown:
642 return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
643 case scCouldNotCompute:
644 return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
645 }
646 llvm_unreachable("Unknown SCEV kind!");
647 }
648
649 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
650 llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
651 }
652};
653
654/// Visit all nodes in the expression tree using worklist traversal.
655///
656/// Visitor implements:
657/// // return true to follow this node.
658/// bool follow(const SCEV *S);
659/// // return true to terminate the search.
660/// bool isDone();
661template <typename SV> class SCEVTraversal {
662 SV &Visitor;
663 SmallVector<const SCEV *, 8> Worklist;
664 SmallPtrSet<const SCEV *, 8> Visited;
665
666 void push(const SCEV *S) {
667 if (Visited.insert(Ptr: S).second && Visitor.follow(S))
668 Worklist.push_back(Elt: S);
669 }
670
671public:
672 SCEVTraversal(SV &V) : Visitor(V) {}
673
674 void visitAll(const SCEV *Root) {
675 push(S: Root);
676 while (!Worklist.empty() && !Visitor.isDone()) {
677 const SCEV *S = Worklist.pop_back_val();
678
679 switch (S->getSCEVType()) {
680 case scConstant:
681 case scVScale:
682 case scUnknown:
683 continue;
684 case scPtrToInt:
685 case scTruncate:
686 case scZeroExtend:
687 case scSignExtend:
688 case scAddExpr:
689 case scMulExpr:
690 case scUDivExpr:
691 case scSMaxExpr:
692 case scUMaxExpr:
693 case scSMinExpr:
694 case scUMinExpr:
695 case scSequentialUMinExpr:
696 case scAddRecExpr:
697 for (const auto *Op : S->operands()) {
698 push(S: Op);
699 if (Visitor.isDone())
700 break;
701 }
702 continue;
703 case scCouldNotCompute:
704 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
705 }
706 llvm_unreachable("Unknown SCEV kind!");
707 }
708 }
709};
710
711/// Use SCEVTraversal to visit all nodes in the given expression tree.
712template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
713 SCEVTraversal<SV> T(Visitor);
714 T.visitAll(Root);
715}
716
717/// Return true if any node in \p Root satisfies the predicate \p Pred.
718template <typename PredTy>
719bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
720 struct FindClosure {
721 bool Found = false;
722 PredTy Pred;
723
724 FindClosure(PredTy Pred) : Pred(Pred) {}
725
726 bool follow(const SCEV *S) {
727 if (!Pred(S))
728 return true;
729
730 Found = true;
731 return false;
732 }
733
734 bool isDone() const { return Found; }
735 };
736
737 FindClosure FC(Pred);
738 visitAll(Root, FC);
739 return FC.Found;
740}
741
742/// This visitor recursively visits a SCEV expression and re-writes it.
743/// The result from each visit is cached, so it will return the same
744/// SCEV for the same input.
745template <typename SC>
746class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
747protected:
748 ScalarEvolution &SE;
749 // Memoize the result of each visit so that we only compute once for
750 // the same input SCEV. This is to avoid redundant computations when
751 // a SCEV is referenced by multiple SCEVs. Without memoization, this
752 // visit algorithm would have exponential time complexity in the worst
753 // case, causing the compiler to hang on certain tests.
754 SmallDenseMap<const SCEV *, const SCEV *> RewriteResults;
755
756public:
757 SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
758
759 const SCEV *visit(const SCEV *S) {
760 auto It = RewriteResults.find(Val: S);
761 if (It != RewriteResults.end())
762 return It->second;
763 auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
764 auto Result = RewriteResults.try_emplace(S, Visited);
765 assert(Result.second && "Should insert a new entry");
766 return Result.first->second;
767 }
768
769 const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
770
771 const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; }
772
773 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
774 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
775 return Operand == Expr->getOperand()
776 ? Expr
777 : SE.getPtrToIntExpr(Op: Operand, Ty: Expr->getType());
778 }
779
780 const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
781 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
782 return Operand == Expr->getOperand()
783 ? Expr
784 : SE.getTruncateExpr(Op: Operand, Ty: Expr->getType());
785 }
786
787 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
788 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
789 return Operand == Expr->getOperand()
790 ? Expr
791 : SE.getZeroExtendExpr(Op: Operand, Ty: Expr->getType());
792 }
793
794 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
795 const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
796 return Operand == Expr->getOperand()
797 ? Expr
798 : SE.getSignExtendExpr(Op: Operand, Ty: Expr->getType());
799 }
800
801 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
802 SmallVector<const SCEV *, 2> Operands;
803 bool Changed = false;
804 for (const auto *Op : Expr->operands()) {
805 Operands.push_back(Elt: ((SC *)this)->visit(Op));
806 Changed |= Op != Operands.back();
807 }
808 return !Changed ? Expr : SE.getAddExpr(Ops&: Operands);
809 }
810
811 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
812 SmallVector<const SCEV *, 2> Operands;
813 bool Changed = false;
814 for (const auto *Op : Expr->operands()) {
815 Operands.push_back(Elt: ((SC *)this)->visit(Op));
816 Changed |= Op != Operands.back();
817 }
818 return !Changed ? Expr : SE.getMulExpr(Ops&: Operands);
819 }
820
821 const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
822 auto *LHS = ((SC *)this)->visit(Expr->getLHS());
823 auto *RHS = ((SC *)this)->visit(Expr->getRHS());
824 bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
825 return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
826 }
827
828 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
829 SmallVector<const SCEV *, 2> Operands;
830 bool Changed = false;
831 for (const auto *Op : Expr->operands()) {
832 Operands.push_back(Elt: ((SC *)this)->visit(Op));
833 Changed |= Op != Operands.back();
834 }
835 return !Changed ? Expr
836 : SE.getAddRecExpr(Operands, L: Expr->getLoop(),
837 Flags: Expr->getNoWrapFlags());
838 }
839
840 const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
841 SmallVector<const SCEV *, 2> Operands;
842 bool Changed = false;
843 for (const auto *Op : Expr->operands()) {
844 Operands.push_back(Elt: ((SC *)this)->visit(Op));
845 Changed |= Op != Operands.back();
846 }
847 return !Changed ? Expr : SE.getSMaxExpr(Operands);
848 }
849
850 const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
851 SmallVector<const SCEV *, 2> Operands;
852 bool Changed = false;
853 for (const auto *Op : Expr->operands()) {
854 Operands.push_back(Elt: ((SC *)this)->visit(Op));
855 Changed |= Op != Operands.back();
856 }
857 return !Changed ? Expr : SE.getUMaxExpr(Operands);
858 }
859
860 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
861 SmallVector<const SCEV *, 2> Operands;
862 bool Changed = false;
863 for (const auto *Op : Expr->operands()) {
864 Operands.push_back(Elt: ((SC *)this)->visit(Op));
865 Changed |= Op != Operands.back();
866 }
867 return !Changed ? Expr : SE.getSMinExpr(Operands);
868 }
869
870 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
871 SmallVector<const SCEV *, 2> Operands;
872 bool Changed = false;
873 for (const auto *Op : Expr->operands()) {
874 Operands.push_back(Elt: ((SC *)this)->visit(Op));
875 Changed |= Op != Operands.back();
876 }
877 return !Changed ? Expr : SE.getUMinExpr(Operands);
878 }
879
880 const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
881 SmallVector<const SCEV *, 2> Operands;
882 bool Changed = false;
883 for (const auto *Op : Expr->operands()) {
884 Operands.push_back(Elt: ((SC *)this)->visit(Op));
885 Changed |= Op != Operands.back();
886 }
887 return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/Sequential: true);
888 }
889
890 const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
891
892 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
893 return Expr;
894 }
895};
896
897using ValueToValueMap = DenseMap<const Value *, Value *>;
898using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
899
900/// The SCEVParameterRewriter takes a scalar evolution expression and updates
901/// the SCEVUnknown components following the Map (Value -> SCEV).
902class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
903public:
904 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
905 ValueToSCEVMapTy &Map) {
906 SCEVParameterRewriter Rewriter(SE, Map);
907 return Rewriter.visit(S: Scev);
908 }
909
910 SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
911 : SCEVRewriteVisitor(SE), Map(M) {}
912
913 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
914 auto I = Map.find(Val: Expr->getValue());
915 if (I == Map.end())
916 return Expr;
917 return I->second;
918 }
919
920private:
921 ValueToSCEVMapTy &Map;
922};
923
924using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
925
926/// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
927/// the Map (Loop -> SCEV) to all AddRecExprs.
928class SCEVLoopAddRecRewriter
929 : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
930public:
931 SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
932 : SCEVRewriteVisitor(SE), Map(M) {}
933
934 static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
935 ScalarEvolution &SE) {
936 SCEVLoopAddRecRewriter Rewriter(SE, Map);
937 return Rewriter.visit(S: Scev);
938 }
939
940 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
941 SmallVector<const SCEV *, 2> Operands;
942 for (const SCEV *Op : Expr->operands())
943 Operands.push_back(Elt: visit(S: Op));
944
945 const Loop *L = Expr->getLoop();
946 if (0 == Map.count(Val: L))
947 return SE.getAddRecExpr(Operands, L, Flags: Expr->getNoWrapFlags());
948
949 return SCEVAddRecExpr::evaluateAtIteration(Operands, It: Map[L], SE);
950 }
951
952private:
953 LoopToScevMapT &Map;
954};
955
956} // end namespace llvm
957
958#endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
959

source code of llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h