1//===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10///
11/// The header file for the GVN pass that contains expression handling
12/// classes
13//
14//===----------------------------------------------------------------------===//
15
16#ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
17#define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
18
19#include "llvm/ADT/Hashing.h"
20#include "llvm/ADT/iterator_range.h"
21#include "llvm/Analysis/MemorySSA.h"
22#include "llvm/IR/Constant.h"
23#include "llvm/IR/Instructions.h"
24#include "llvm/IR/Value.h"
25#include "llvm/Support/Allocator.h"
26#include "llvm/Support/ArrayRecycler.h"
27#include "llvm/Support/Casting.h"
28#include "llvm/Support/Compiler.h"
29#include "llvm/Support/raw_ostream.h"
30#include <algorithm>
31#include <cassert>
32#include <iterator>
33#include <utility>
34
35namespace llvm {
36
37class BasicBlock;
38class Type;
39
40namespace GVNExpression {
41
42enum ExpressionType {
43 ET_Base,
44 ET_Constant,
45 ET_Variable,
46 ET_Dead,
47 ET_Unknown,
48 ET_BasicStart,
49 ET_Basic,
50 ET_AggregateValue,
51 ET_Phi,
52 ET_MemoryStart,
53 ET_Call,
54 ET_Load,
55 ET_Store,
56 ET_MemoryEnd,
57 ET_BasicEnd
58};
59
60class Expression {
61private:
62 ExpressionType EType;
63 unsigned Opcode;
64 mutable hash_code HashVal = 0;
65
66public:
67 Expression(ExpressionType ET = ET_Base, unsigned O = ~2U)
68 : EType(ET), Opcode(O) {}
69 Expression(const Expression &) = delete;
70 Expression &operator=(const Expression &) = delete;
71 virtual ~Expression();
72
73 static unsigned getEmptyKey() { return ~0U; }
74 static unsigned getTombstoneKey() { return ~1U; }
75
76 bool operator!=(const Expression &Other) const { return !(*this == Other); }
77 bool operator==(const Expression &Other) const {
78 if (getOpcode() != Other.getOpcode())
79 return false;
80 if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
81 return true;
82 // Compare the expression type for anything but load and store.
83 // For load and store we set the opcode to zero to make them equal.
84 if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
85 getExpressionType() != Other.getExpressionType())
86 return false;
87
88 return equals(Other);
89 }
90
91 hash_code getComputedHash() const {
92 // It's theoretically possible for a thing to hash to zero. In that case,
93 // we will just compute the hash a few extra times, which is no worse that
94 // we did before, which was to compute it always.
95 if (static_cast<unsigned>(HashVal) == 0)
96 HashVal = getHashValue();
97 return HashVal;
98 }
99
100 virtual bool equals(const Expression &Other) const { return true; }
101
102 // Return true if the two expressions are exactly the same, including the
103 // normally ignored fields.
104 virtual bool exactlyEquals(const Expression &Other) const {
105 return getExpressionType() == Other.getExpressionType() && equals(Other);
106 }
107
108 unsigned getOpcode() const { return Opcode; }
109 void setOpcode(unsigned opcode) { Opcode = opcode; }
110 ExpressionType getExpressionType() const { return EType; }
111
112 // We deliberately leave the expression type out of the hash value.
113 virtual hash_code getHashValue() const { return getOpcode(); }
114
115 // Debugging support
116 virtual void printInternal(raw_ostream &OS, bool PrintEType) const {
117 if (PrintEType)
118 OS << "etype = " << getExpressionType() << ",";
119 OS << "opcode = " << getOpcode() << ", ";
120 }
121
122 void print(raw_ostream &OS) const {
123 OS << "{ ";
124 printInternal(OS, PrintEType: true);
125 OS << "}";
126 }
127
128 LLVM_DUMP_METHOD void dump() const;
129};
130
131inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
132 E.print(OS);
133 return OS;
134}
135
136class BasicExpression : public Expression {
137private:
138 using RecyclerType = ArrayRecycler<Value *>;
139 using RecyclerCapacity = RecyclerType::Capacity;
140
141 Value **Operands = nullptr;
142 unsigned MaxOperands;
143 unsigned NumOperands = 0;
144 Type *ValueType = nullptr;
145
146public:
147 BasicExpression(unsigned NumOperands)
148 : BasicExpression(NumOperands, ET_Basic) {}
149 BasicExpression(unsigned NumOperands, ExpressionType ET)
150 : Expression(ET), MaxOperands(NumOperands) {}
151 BasicExpression() = delete;
152 BasicExpression(const BasicExpression &) = delete;
153 BasicExpression &operator=(const BasicExpression &) = delete;
154 ~BasicExpression() override;
155
156 static bool classof(const Expression *EB) {
157 ExpressionType ET = EB->getExpressionType();
158 return ET > ET_BasicStart && ET < ET_BasicEnd;
159 }
160
161 /// Swap two operands. Used during GVN to put commutative operands in
162 /// order.
163 void swapOperands(unsigned First, unsigned Second) {
164 std::swap(a&: Operands[First], b&: Operands[Second]);
165 }
166
167 Value *getOperand(unsigned N) const {
168 assert(Operands && "Operands not allocated");
169 assert(N < NumOperands && "Operand out of range");
170 return Operands[N];
171 }
172
173 void setOperand(unsigned N, Value *V) {
174 assert(Operands && "Operands not allocated before setting");
175 assert(N < NumOperands && "Operand out of range");
176 Operands[N] = V;
177 }
178
179 unsigned getNumOperands() const { return NumOperands; }
180
181 using op_iterator = Value **;
182 using const_op_iterator = Value *const *;
183
184 op_iterator op_begin() { return Operands; }
185 op_iterator op_end() { return Operands + NumOperands; }
186 const_op_iterator op_begin() const { return Operands; }
187 const_op_iterator op_end() const { return Operands + NumOperands; }
188 iterator_range<op_iterator> operands() {
189 return iterator_range<op_iterator>(op_begin(), op_end());
190 }
191 iterator_range<const_op_iterator> operands() const {
192 return iterator_range<const_op_iterator>(op_begin(), op_end());
193 }
194
195 void op_push_back(Value *Arg) {
196 assert(NumOperands < MaxOperands && "Tried to add too many operands");
197 assert(Operands && "Operandss not allocated before pushing");
198 Operands[NumOperands++] = Arg;
199 }
200 bool op_empty() const { return getNumOperands() == 0; }
201
202 void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) {
203 assert(!Operands && "Operands already allocated");
204 Operands = Recycler.allocate(Cap: RecyclerCapacity::get(N: MaxOperands), Allocator);
205 }
206 void deallocateOperands(RecyclerType &Recycler) {
207 Recycler.deallocate(Cap: RecyclerCapacity::get(N: MaxOperands), Ptr: Operands);
208 }
209
210 void setType(Type *T) { ValueType = T; }
211 Type *getType() const { return ValueType; }
212
213 bool equals(const Expression &Other) const override {
214 if (getOpcode() != Other.getOpcode())
215 return false;
216
217 const auto &OE = cast<BasicExpression>(Val: Other);
218 return getType() == OE.getType() && NumOperands == OE.NumOperands &&
219 std::equal(first1: op_begin(), last1: op_end(), first2: OE.op_begin());
220 }
221
222 hash_code getHashValue() const override {
223 return hash_combine(args: this->Expression::getHashValue(), args: ValueType,
224 args: hash_combine_range(first: op_begin(), last: op_end()));
225 }
226
227 // Debugging support
228 void printInternal(raw_ostream &OS, bool PrintEType) const override {
229 if (PrintEType)
230 OS << "ExpressionTypeBasic, ";
231
232 this->Expression::printInternal(OS, PrintEType: false);
233 OS << "operands = {";
234 for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
235 OS << "[" << i << "] = ";
236 Operands[i]->printAsOperand(O&: OS);
237 OS << " ";
238 }
239 OS << "} ";
240 }
241};
242
243class op_inserter {
244private:
245 using Container = BasicExpression;
246
247 Container *BE;
248
249public:
250 using iterator_category = std::output_iterator_tag;
251 using value_type = void;
252 using difference_type = void;
253 using pointer = void;
254 using reference = void;
255
256 explicit op_inserter(BasicExpression &E) : BE(&E) {}
257 explicit op_inserter(BasicExpression *E) : BE(E) {}
258
259 op_inserter &operator=(Value *val) {
260 BE->op_push_back(Arg: val);
261 return *this;
262 }
263 op_inserter &operator*() { return *this; }
264 op_inserter &operator++() { return *this; }
265 op_inserter &operator++(int) { return *this; }
266};
267
268class MemoryExpression : public BasicExpression {
269private:
270 const MemoryAccess *MemoryLeader;
271
272public:
273 MemoryExpression(unsigned NumOperands, enum ExpressionType EType,
274 const MemoryAccess *MemoryLeader)
275 : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader) {}
276 MemoryExpression() = delete;
277 MemoryExpression(const MemoryExpression &) = delete;
278 MemoryExpression &operator=(const MemoryExpression &) = delete;
279
280 static bool classof(const Expression *EB) {
281 return EB->getExpressionType() > ET_MemoryStart &&
282 EB->getExpressionType() < ET_MemoryEnd;
283 }
284
285 hash_code getHashValue() const override {
286 return hash_combine(args: this->BasicExpression::getHashValue(), args: MemoryLeader);
287 }
288
289 bool equals(const Expression &Other) const override {
290 if (!this->BasicExpression::equals(Other))
291 return false;
292 const MemoryExpression &OtherMCE = cast<MemoryExpression>(Val: Other);
293
294 return MemoryLeader == OtherMCE.MemoryLeader;
295 }
296
297 const MemoryAccess *getMemoryLeader() const { return MemoryLeader; }
298 void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; }
299};
300
301class CallExpression final : public MemoryExpression {
302private:
303 CallInst *Call;
304
305public:
306 CallExpression(unsigned NumOperands, CallInst *C,
307 const MemoryAccess *MemoryLeader)
308 : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {}
309 CallExpression() = delete;
310 CallExpression(const CallExpression &) = delete;
311 CallExpression &operator=(const CallExpression &) = delete;
312 ~CallExpression() override;
313
314 static bool classof(const Expression *EB) {
315 return EB->getExpressionType() == ET_Call;
316 }
317
318 // Debugging support
319 void printInternal(raw_ostream &OS, bool PrintEType) const override {
320 if (PrintEType)
321 OS << "ExpressionTypeCall, ";
322 this->BasicExpression::printInternal(OS, PrintEType: false);
323 OS << " represents call at ";
324 Call->printAsOperand(O&: OS);
325 }
326};
327
328class LoadExpression final : public MemoryExpression {
329private:
330 LoadInst *Load;
331
332public:
333 LoadExpression(unsigned NumOperands, LoadInst *L,
334 const MemoryAccess *MemoryLeader)
335 : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {}
336
337 LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
338 const MemoryAccess *MemoryLeader)
339 : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {}
340
341 LoadExpression() = delete;
342 LoadExpression(const LoadExpression &) = delete;
343 LoadExpression &operator=(const LoadExpression &) = delete;
344 ~LoadExpression() override;
345
346 static bool classof(const Expression *EB) {
347 return EB->getExpressionType() == ET_Load;
348 }
349
350 LoadInst *getLoadInst() const { return Load; }
351 void setLoadInst(LoadInst *L) { Load = L; }
352
353 bool equals(const Expression &Other) const override;
354 bool exactlyEquals(const Expression &Other) const override {
355 return Expression::exactlyEquals(Other) &&
356 cast<LoadExpression>(Val: Other).getLoadInst() == getLoadInst();
357 }
358
359 // Debugging support
360 void printInternal(raw_ostream &OS, bool PrintEType) const override {
361 if (PrintEType)
362 OS << "ExpressionTypeLoad, ";
363 this->BasicExpression::printInternal(OS, PrintEType: false);
364 OS << " represents Load at ";
365 Load->printAsOperand(O&: OS);
366 OS << " with MemoryLeader " << *getMemoryLeader();
367 }
368};
369
370class StoreExpression final : public MemoryExpression {
371private:
372 StoreInst *Store;
373 Value *StoredValue;
374
375public:
376 StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue,
377 const MemoryAccess *MemoryLeader)
378 : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S),
379 StoredValue(StoredValue) {}
380 StoreExpression() = delete;
381 StoreExpression(const StoreExpression &) = delete;
382 StoreExpression &operator=(const StoreExpression &) = delete;
383 ~StoreExpression() override;
384
385 static bool classof(const Expression *EB) {
386 return EB->getExpressionType() == ET_Store;
387 }
388
389 StoreInst *getStoreInst() const { return Store; }
390 Value *getStoredValue() const { return StoredValue; }
391
392 bool equals(const Expression &Other) const override;
393
394 bool exactlyEquals(const Expression &Other) const override {
395 return Expression::exactlyEquals(Other) &&
396 cast<StoreExpression>(Val: Other).getStoreInst() == getStoreInst();
397 }
398
399 // Debugging support
400 void printInternal(raw_ostream &OS, bool PrintEType) const override {
401 if (PrintEType)
402 OS << "ExpressionTypeStore, ";
403 this->BasicExpression::printInternal(OS, PrintEType: false);
404 OS << " represents Store " << *Store;
405 OS << " with StoredValue ";
406 StoredValue->printAsOperand(O&: OS);
407 OS << " and MemoryLeader " << *getMemoryLeader();
408 }
409};
410
411class AggregateValueExpression final : public BasicExpression {
412private:
413 unsigned MaxIntOperands;
414 unsigned NumIntOperands = 0;
415 unsigned *IntOperands = nullptr;
416
417public:
418 AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands)
419 : BasicExpression(NumOperands, ET_AggregateValue),
420 MaxIntOperands(NumIntOperands) {}
421 AggregateValueExpression() = delete;
422 AggregateValueExpression(const AggregateValueExpression &) = delete;
423 AggregateValueExpression &
424 operator=(const AggregateValueExpression &) = delete;
425 ~AggregateValueExpression() override;
426
427 static bool classof(const Expression *EB) {
428 return EB->getExpressionType() == ET_AggregateValue;
429 }
430
431 using int_arg_iterator = unsigned *;
432 using const_int_arg_iterator = const unsigned *;
433
434 int_arg_iterator int_op_begin() { return IntOperands; }
435 int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; }
436 const_int_arg_iterator int_op_begin() const { return IntOperands; }
437 const_int_arg_iterator int_op_end() const {
438 return IntOperands + NumIntOperands;
439 }
440 unsigned int_op_size() const { return NumIntOperands; }
441 bool int_op_empty() const { return NumIntOperands == 0; }
442 void int_op_push_back(unsigned IntOperand) {
443 assert(NumIntOperands < MaxIntOperands &&
444 "Tried to add too many int operands");
445 assert(IntOperands && "Operands not allocated before pushing");
446 IntOperands[NumIntOperands++] = IntOperand;
447 }
448
449 virtual void allocateIntOperands(BumpPtrAllocator &Allocator) {
450 assert(!IntOperands && "Operands already allocated");
451 IntOperands = Allocator.Allocate<unsigned>(Num: MaxIntOperands);
452 }
453
454 bool equals(const Expression &Other) const override {
455 if (!this->BasicExpression::equals(Other))
456 return false;
457 const AggregateValueExpression &OE = cast<AggregateValueExpression>(Val: Other);
458 return NumIntOperands == OE.NumIntOperands &&
459 std::equal(first1: int_op_begin(), last1: int_op_end(), first2: OE.int_op_begin());
460 }
461
462 hash_code getHashValue() const override {
463 return hash_combine(args: this->BasicExpression::getHashValue(),
464 args: hash_combine_range(first: int_op_begin(), last: int_op_end()));
465 }
466
467 // Debugging support
468 void printInternal(raw_ostream &OS, bool PrintEType) const override {
469 if (PrintEType)
470 OS << "ExpressionTypeAggregateValue, ";
471 this->BasicExpression::printInternal(OS, PrintEType: false);
472 OS << ", intoperands = {";
473 for (unsigned i = 0, e = int_op_size(); i != e; ++i) {
474 OS << "[" << i << "] = " << IntOperands[i] << " ";
475 }
476 OS << "}";
477 }
478};
479
480class int_op_inserter {
481private:
482 using Container = AggregateValueExpression;
483
484 Container *AVE;
485
486public:
487 using iterator_category = std::output_iterator_tag;
488 using value_type = void;
489 using difference_type = void;
490 using pointer = void;
491 using reference = void;
492
493 explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {}
494 explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {}
495
496 int_op_inserter &operator=(unsigned int val) {
497 AVE->int_op_push_back(IntOperand: val);
498 return *this;
499 }
500 int_op_inserter &operator*() { return *this; }
501 int_op_inserter &operator++() { return *this; }
502 int_op_inserter &operator++(int) { return *this; }
503};
504
505class PHIExpression final : public BasicExpression {
506private:
507 BasicBlock *BB;
508
509public:
510 PHIExpression(unsigned NumOperands, BasicBlock *B)
511 : BasicExpression(NumOperands, ET_Phi), BB(B) {}
512 PHIExpression() = delete;
513 PHIExpression(const PHIExpression &) = delete;
514 PHIExpression &operator=(const PHIExpression &) = delete;
515 ~PHIExpression() override;
516
517 static bool classof(const Expression *EB) {
518 return EB->getExpressionType() == ET_Phi;
519 }
520
521 bool equals(const Expression &Other) const override {
522 if (!this->BasicExpression::equals(Other))
523 return false;
524 const PHIExpression &OE = cast<PHIExpression>(Val: Other);
525 return BB == OE.BB;
526 }
527
528 hash_code getHashValue() const override {
529 return hash_combine(args: this->BasicExpression::getHashValue(), args: BB);
530 }
531
532 // Debugging support
533 void printInternal(raw_ostream &OS, bool PrintEType) const override {
534 if (PrintEType)
535 OS << "ExpressionTypePhi, ";
536 this->BasicExpression::printInternal(OS, PrintEType: false);
537 OS << "bb = " << BB;
538 }
539};
540
541class DeadExpression final : public Expression {
542public:
543 DeadExpression() : Expression(ET_Dead) {}
544 DeadExpression(const DeadExpression &) = delete;
545 DeadExpression &operator=(const DeadExpression &) = delete;
546
547 static bool classof(const Expression *E) {
548 return E->getExpressionType() == ET_Dead;
549 }
550};
551
552class VariableExpression final : public Expression {
553private:
554 Value *VariableValue;
555
556public:
557 VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {}
558 VariableExpression() = delete;
559 VariableExpression(const VariableExpression &) = delete;
560 VariableExpression &operator=(const VariableExpression &) = delete;
561
562 static bool classof(const Expression *EB) {
563 return EB->getExpressionType() == ET_Variable;
564 }
565
566 Value *getVariableValue() const { return VariableValue; }
567 void setVariableValue(Value *V) { VariableValue = V; }
568
569 bool equals(const Expression &Other) const override {
570 const VariableExpression &OC = cast<VariableExpression>(Val: Other);
571 return VariableValue == OC.VariableValue;
572 }
573
574 hash_code getHashValue() const override {
575 return hash_combine(args: this->Expression::getHashValue(),
576 args: VariableValue->getType(), args: VariableValue);
577 }
578
579 // Debugging support
580 void printInternal(raw_ostream &OS, bool PrintEType) const override {
581 if (PrintEType)
582 OS << "ExpressionTypeVariable, ";
583 this->Expression::printInternal(OS, PrintEType: false);
584 OS << " variable = " << *VariableValue;
585 }
586};
587
588class ConstantExpression final : public Expression {
589private:
590 Constant *ConstantValue = nullptr;
591
592public:
593 ConstantExpression() : Expression(ET_Constant) {}
594 ConstantExpression(Constant *constantValue)
595 : Expression(ET_Constant), ConstantValue(constantValue) {}
596 ConstantExpression(const ConstantExpression &) = delete;
597 ConstantExpression &operator=(const ConstantExpression &) = delete;
598
599 static bool classof(const Expression *EB) {
600 return EB->getExpressionType() == ET_Constant;
601 }
602
603 Constant *getConstantValue() const { return ConstantValue; }
604 void setConstantValue(Constant *V) { ConstantValue = V; }
605
606 bool equals(const Expression &Other) const override {
607 const ConstantExpression &OC = cast<ConstantExpression>(Val: Other);
608 return ConstantValue == OC.ConstantValue;
609 }
610
611 hash_code getHashValue() const override {
612 return hash_combine(args: this->Expression::getHashValue(),
613 args: ConstantValue->getType(), args: ConstantValue);
614 }
615
616 // Debugging support
617 void printInternal(raw_ostream &OS, bool PrintEType) const override {
618 if (PrintEType)
619 OS << "ExpressionTypeConstant, ";
620 this->Expression::printInternal(OS, PrintEType: false);
621 OS << " constant = " << *ConstantValue;
622 }
623};
624
625class UnknownExpression final : public Expression {
626private:
627 Instruction *Inst;
628
629public:
630 UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {}
631 UnknownExpression() = delete;
632 UnknownExpression(const UnknownExpression &) = delete;
633 UnknownExpression &operator=(const UnknownExpression &) = delete;
634
635 static bool classof(const Expression *EB) {
636 return EB->getExpressionType() == ET_Unknown;
637 }
638
639 Instruction *getInstruction() const { return Inst; }
640 void setInstruction(Instruction *I) { Inst = I; }
641
642 bool equals(const Expression &Other) const override {
643 const auto &OU = cast<UnknownExpression>(Val: Other);
644 return Inst == OU.Inst;
645 }
646
647 hash_code getHashValue() const override {
648 return hash_combine(args: this->Expression::getHashValue(), args: Inst);
649 }
650
651 // Debugging support
652 void printInternal(raw_ostream &OS, bool PrintEType) const override {
653 if (PrintEType)
654 OS << "ExpressionTypeUnknown, ";
655 this->Expression::printInternal(OS, PrintEType: false);
656 OS << " inst = " << *Inst;
657 }
658};
659
660} // end namespace GVNExpression
661
662} // end namespace llvm
663
664#endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
665

source code of llvm/include/llvm/Transforms/Scalar/GVNExpression.h