1//===- RISCVFoldMasks.cpp - MI Vector Pseudo Mask Peepholes ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8//
9// This pass performs various peephole optimisations that fold masks into vector
10// pseudo instructions after instruction selection.
11//
12// Currently it converts
13// PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14// ->
15// PseudoVMV_V_V %false, %true, %vl, %sew
16//
17//===---------------------------------------------------------------------===//
18
19#include "RISCV.h"
20#include "RISCVISelDAGToDAG.h"
21#include "RISCVSubtarget.h"
22#include "llvm/CodeGen/MachineFunctionPass.h"
23#include "llvm/CodeGen/MachineRegisterInfo.h"
24#include "llvm/CodeGen/TargetInstrInfo.h"
25#include "llvm/CodeGen/TargetRegisterInfo.h"
26
27using namespace llvm;
28
29#define DEBUG_TYPE "riscv-fold-masks"
30
31namespace {
32
33class RISCVFoldMasks : public MachineFunctionPass {
34public:
35 static char ID;
36 const TargetInstrInfo *TII;
37 MachineRegisterInfo *MRI;
38 const TargetRegisterInfo *TRI;
39 RISCVFoldMasks() : MachineFunctionPass(ID) {}
40
41 bool runOnMachineFunction(MachineFunction &MF) override;
42 MachineFunctionProperties getRequiredProperties() const override {
43 return MachineFunctionProperties().set(
44 MachineFunctionProperties::Property::IsSSA);
45 }
46
47 StringRef getPassName() const override { return "RISC-V Fold Masks"; }
48
49private:
50 bool convertToUnmasked(MachineInstr &MI) const;
51 bool convertVMergeToVMv(MachineInstr &MI) const;
52
53 bool isAllOnesMask(const MachineInstr *MaskDef) const;
54
55 /// Maps uses of V0 to the corresponding def of V0.
56 DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
57};
58
59} // namespace
60
61char RISCVFoldMasks::ID = 0;
62
63INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false)
64
65bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
66 assert(MaskDef && MaskDef->isCopy() &&
67 MaskDef->getOperand(0).getReg() == RISCV::V0);
68 Register SrcReg = TRI->lookThruCopyLike(SrcReg: MaskDef->getOperand(i: 1).getReg(), MRI);
69 if (!SrcReg.isVirtual())
70 return false;
71 MaskDef = MRI->getVRegDef(Reg: SrcReg);
72 if (!MaskDef)
73 return false;
74
75 // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
76 // undefined behaviour if it's the wrong bitwidth, so we could choose to
77 // assume that it's all-ones? Same applies to its VL.
78 switch (MaskDef->getOpcode()) {
79 case RISCV::PseudoVMSET_M_B1:
80 case RISCV::PseudoVMSET_M_B2:
81 case RISCV::PseudoVMSET_M_B4:
82 case RISCV::PseudoVMSET_M_B8:
83 case RISCV::PseudoVMSET_M_B16:
84 case RISCV::PseudoVMSET_M_B32:
85 case RISCV::PseudoVMSET_M_B64:
86 return true;
87 default:
88 return false;
89 }
90}
91
92// Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
93// (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
94bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI) const {
95#define CASE_VMERGE_TO_VMV(lmul) \
96 case RISCV::PseudoVMERGE_VVM_##lmul: \
97 NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
98 break;
99 unsigned NewOpc;
100 switch (MI.getOpcode()) {
101 default:
102 return false;
103 CASE_VMERGE_TO_VMV(MF8)
104 CASE_VMERGE_TO_VMV(MF4)
105 CASE_VMERGE_TO_VMV(MF2)
106 CASE_VMERGE_TO_VMV(M1)
107 CASE_VMERGE_TO_VMV(M2)
108 CASE_VMERGE_TO_VMV(M4)
109 CASE_VMERGE_TO_VMV(M8)
110 }
111
112 Register MergeReg = MI.getOperand(i: 1).getReg();
113 Register FalseReg = MI.getOperand(i: 2).getReg();
114 // Check merge == false (or merge == undef)
115 if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(SrcReg: MergeReg, MRI) !=
116 TRI->lookThruCopyLike(SrcReg: FalseReg, MRI))
117 return false;
118
119 assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
120 if (!isAllOnesMask(MaskDef: V0Defs.lookup(Val: &MI)))
121 return false;
122
123 MI.setDesc(TII->get(Opcode: NewOpc));
124 MI.removeOperand(OpNo: 1); // Merge operand
125 MI.tieOperands(DefIdx: 0, UseIdx: 1); // Tie false to dest
126 MI.removeOperand(OpNo: 3); // Mask operand
127 MI.addOperand(
128 Op: MachineOperand::CreateImm(Val: RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED));
129
130 // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
131 // register class for the destination and merge operands e.g. VRNoV0 -> VR
132 MRI->recomputeRegClass(Reg: MI.getOperand(i: 0).getReg());
133 MRI->recomputeRegClass(Reg: MI.getOperand(i: 1).getReg());
134 return true;
135}
136
137bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const {
138 const RISCV::RISCVMaskedPseudoInfo *I =
139 RISCV::getMaskedPseudoInfo(MI.getOpcode());
140 if (!I)
141 return false;
142
143 if (!isAllOnesMask(MaskDef: V0Defs.lookup(Val: &MI)))
144 return false;
145
146 // There are two classes of pseudos in the table - compares and
147 // everything else. See the comment on RISCVMaskedPseudo for details.
148 const unsigned Opc = I->UnmaskedPseudo;
149 const MCInstrDesc &MCID = TII->get(Opcode: Opc);
150 [[maybe_unused]] const bool HasPolicyOp =
151 RISCVII::hasVecPolicyOp(TSFlags: MCID.TSFlags);
152 const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc: MCID);
153#ifndef NDEBUG
154 const MCInstrDesc &MaskedMCID = TII->get(Opcode: MI.getOpcode());
155 assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) ==
156 RISCVII::hasVecPolicyOp(MCID.TSFlags) &&
157 "Masked and unmasked pseudos are inconsistent");
158 assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure");
159#endif
160 (void)HasPolicyOp;
161
162 MI.setDesc(MCID);
163
164 // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
165 unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
166 MI.removeOperand(OpNo: MaskOpIdx);
167
168 // The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
169 // so try and relax it to vr.
170 MRI->recomputeRegClass(Reg: MI.getOperand(i: 0).getReg());
171 unsigned PassthruOpIdx = MI.getNumExplicitDefs();
172 if (HasPassthru) {
173 if (MI.getOperand(i: PassthruOpIdx).getReg() != RISCV::NoRegister)
174 MRI->recomputeRegClass(Reg: MI.getOperand(i: PassthruOpIdx).getReg());
175 } else
176 MI.removeOperand(OpNo: PassthruOpIdx);
177
178 return true;
179}
180
181bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
182 if (skipFunction(F: MF.getFunction()))
183 return false;
184
185 // Skip if the vector extension is not enabled.
186 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
187 if (!ST.hasVInstructions())
188 return false;
189
190 TII = ST.getInstrInfo();
191 MRI = &MF.getRegInfo();
192 TRI = MRI->getTargetRegisterInfo();
193
194 bool Changed = false;
195
196 // Masked pseudos coming out of isel will have their mask operand in the form:
197 //
198 // $v0:vr = COPY %mask:vr
199 // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
200 //
201 // Because $v0 isn't in SSA, keep track of its definition at each use so we
202 // can check mask operands.
203 for (const MachineBasicBlock &MBB : MF) {
204 const MachineInstr *CurrentV0Def = nullptr;
205 for (const MachineInstr &MI : MBB) {
206 if (MI.readsRegister(RISCV::Reg: V0, TRI))
207 V0Defs[&MI] = CurrentV0Def;
208
209 if (MI.definesRegister(RISCV::Reg: V0, TRI))
210 CurrentV0Def = &MI;
211 }
212 }
213
214 for (MachineBasicBlock &MBB : MF) {
215 for (MachineInstr &MI : MBB) {
216 Changed |= convertToUnmasked(MI);
217 Changed |= convertVMergeToVMv(MI);
218 }
219 }
220
221 return Changed;
222}
223
224FunctionPass *llvm::createRISCVFoldMasksPass() { return new RISCVFoldMasks(); }
225

source code of llvm/lib/Target/RISCV/RISCVFoldMasks.cpp