1 | //===-- SPIRVStripConvergentIntrinsics.cpp ----------------------*- C++ -*-===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass trims convergence intrinsics as those were only useful when |
10 | // modifying the CFG during IR passes. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "SPIRV.h" |
15 | #include "SPIRVSubtarget.h" |
16 | #include "SPIRVTargetMachine.h" |
17 | #include "SPIRVUtils.h" |
18 | #include "llvm/CodeGen/IntrinsicLowering.h" |
19 | #include "llvm/IR/IRBuilder.h" |
20 | #include "llvm/IR/IntrinsicInst.h" |
21 | #include "llvm/IR/Intrinsics.h" |
22 | #include "llvm/IR/IntrinsicsSPIRV.h" |
23 | #include "llvm/Transforms/Utils/Cloning.h" |
24 | #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" |
25 | |
26 | using namespace llvm; |
27 | |
28 | namespace llvm { |
29 | void initializeSPIRVStripConvergentIntrinsicsPass(PassRegistry &); |
30 | } |
31 | |
32 | class SPIRVStripConvergentIntrinsics : public FunctionPass { |
33 | public: |
34 | static char ID; |
35 | |
36 | SPIRVStripConvergentIntrinsics() : FunctionPass(ID) { |
37 | initializeSPIRVStripConvergentIntrinsicsPass( |
38 | *PassRegistry::getPassRegistry()); |
39 | }; |
40 | |
41 | virtual bool runOnFunction(Function &F) override { |
42 | DenseSet<Instruction *> ToRemove; |
43 | |
44 | for (BasicBlock &BB : F) { |
45 | for (Instruction &I : BB) { |
46 | if (auto *II = dyn_cast<IntrinsicInst>(Val: &I)) { |
47 | if (II->getIntrinsicID() != |
48 | Intrinsic::experimental_convergence_entry && |
49 | II->getIntrinsicID() != |
50 | Intrinsic::experimental_convergence_loop && |
51 | II->getIntrinsicID() != |
52 | Intrinsic::experimental_convergence_anchor) { |
53 | continue; |
54 | } |
55 | |
56 | II->replaceAllUsesWith(V: UndefValue::get(T: II->getType())); |
57 | ToRemove.insert(V: II); |
58 | } else if (auto *CI = dyn_cast<CallInst>(Val: &I)) { |
59 | auto OB = CI->getOperandBundle(ID: LLVMContext::OB_convergencectrl); |
60 | if (!OB.has_value()) |
61 | continue; |
62 | |
63 | auto *NewCall = CallBase::removeOperandBundle( |
64 | CB: CI, ID: LLVMContext::OB_convergencectrl, InsertPt: CI); |
65 | NewCall->copyMetadata(SrcInst: *CI); |
66 | CI->replaceAllUsesWith(V: NewCall); |
67 | ToRemove.insert(V: CI); |
68 | } |
69 | } |
70 | } |
71 | |
72 | // All usages must be removed before their definition is removed. |
73 | for (Instruction *I : ToRemove) |
74 | I->eraseFromParent(); |
75 | |
76 | return ToRemove.size() != 0; |
77 | } |
78 | }; |
79 | |
80 | char SPIRVStripConvergentIntrinsics::ID = 0; |
81 | INITIALIZE_PASS(SPIRVStripConvergentIntrinsics, "strip-convergent-intrinsics", |
82 | "SPIRV strip convergent intrinsics", false, false) |
83 | |
84 | FunctionPass *llvm::createSPIRVStripConvergenceIntrinsicsPass() { |
85 | return new SPIRVStripConvergentIntrinsics(); |
86 | } |
87 |