1//===---- llvm/unittest/CodeGen/SelectionDAGPatternMatchTest.cpp ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Analysis/OptimizationRemarkEmitter.h"
10#include "llvm/AsmParser/Parser.h"
11#include "llvm/CodeGen/MachineModuleInfo.h"
12#include "llvm/CodeGen/SDPatternMatch.h"
13#include "llvm/CodeGen/TargetLowering.h"
14#include "llvm/MC/TargetRegistry.h"
15#include "llvm/Support/SourceMgr.h"
16#include "llvm/Support/TargetSelect.h"
17#include "llvm/Target/TargetMachine.h"
18#include "gtest/gtest.h"
19
20using namespace llvm;
21
22class SelectionDAGPatternMatchTest : public testing::Test {
23protected:
24 static void SetUpTestCase() {
25 InitializeAllTargets();
26 InitializeAllTargetMCs();
27 }
28
29 void SetUp() override {
30 StringRef Assembly = "@g = global i32 0\n"
31 "@g_alias = alias i32, i32* @g\n"
32 "define i32 @f() {\n"
33 " %1 = load i32, i32* @g\n"
34 " ret i32 %1\n"
35 "}";
36
37 Triple TargetTriple("riscv64--");
38 std::string Error;
39 const Target *T = TargetRegistry::lookupTarget(ArchName: "", TheTriple&: TargetTriple, Error);
40 // FIXME: These tests do not depend on RISCV specifically, but we have to
41 // initialize a target. A skeleton Target for unittests would allow us to
42 // always run these tests.
43 if (!T)
44 GTEST_SKIP();
45
46 TargetOptions Options;
47 TM = std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine *>(
48 T->createTargetMachine(TT: "riscv64", CPU: "", Features: "+m,+f,+d,+v", Options,
49 RM: std::nullopt, CM: std::nullopt,
50 OL: CodeGenOptLevel::Aggressive)));
51 if (!TM)
52 GTEST_SKIP();
53
54 SMDiagnostic SMError;
55 M = parseAssemblyString(AsmString: Assembly, Err&: SMError, Context);
56 if (!M)
57 report_fatal_error(reason: SMError.getMessage());
58 M->setDataLayout(TM->createDataLayout());
59
60 F = M->getFunction(Name: "f");
61 if (!F)
62 report_fatal_error(reason: "F?");
63 G = M->getGlobalVariable(Name: "g");
64 if (!G)
65 report_fatal_error(reason: "G?");
66 AliasedG = M->getNamedAlias(Name: "g_alias");
67 if (!AliasedG)
68 report_fatal_error(reason: "AliasedG?");
69
70 MachineModuleInfo MMI(TM.get());
71
72 MF = std::make_unique<MachineFunction>(args&: *F, args&: *TM, args: *TM->getSubtargetImpl(*F),
73 args: 0, args&: MMI);
74
75 DAG = std::make_unique<SelectionDAG>(args&: *TM, args: CodeGenOptLevel::None);
76 if (!DAG)
77 report_fatal_error(reason: "DAG?");
78 OptimizationRemarkEmitter ORE(F);
79 DAG->init(NewMF&: *MF, NewORE&: ORE, PassPtr: nullptr, LibraryInfo: nullptr, UA: nullptr, PSIin: nullptr, BFIin: nullptr, FnVarLocs: nullptr);
80 }
81
82 TargetLoweringBase::LegalizeTypeAction getTypeAction(EVT VT) {
83 return DAG->getTargetLoweringInfo().getTypeAction(Context, VT);
84 }
85
86 EVT getTypeToTransformTo(EVT VT) {
87 return DAG->getTargetLoweringInfo().getTypeToTransformTo(Context, VT);
88 }
89
90 LLVMContext Context;
91 std::unique_ptr<LLVMTargetMachine> TM;
92 std::unique_ptr<Module> M;
93 Function *F;
94 GlobalVariable *G;
95 GlobalAlias *AliasedG;
96 std::unique_ptr<MachineFunction> MF;
97 std::unique_ptr<SelectionDAG> DAG;
98};
99
100TEST_F(SelectionDAGPatternMatchTest, matchValueType) {
101 SDLoc DL;
102 auto Int32VT = EVT::getIntegerVT(Context, BitWidth: 32);
103 auto Float32VT = EVT::getFloatingPointVT(BitWidth: 32);
104 auto VInt32VT = EVT::getVectorVT(Context, VT: Int32VT, NumElements: 4);
105
106 SDValue Op0 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 1, VT: Int32VT);
107 SDValue Op1 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 2, VT: Float32VT);
108 SDValue Op2 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 2, VT: VInt32VT);
109
110 using namespace SDPatternMatch;
111 EXPECT_TRUE(sd_match(Op0, m_SpecificVT(Int32VT)));
112 EVT BindVT;
113 EXPECT_TRUE(sd_match(Op1, m_VT(BindVT)));
114 EXPECT_EQ(BindVT, Float32VT);
115 EXPECT_TRUE(sd_match(Op0, m_IntegerVT()));
116 EXPECT_TRUE(sd_match(Op1, m_FloatingPointVT()));
117 EXPECT_TRUE(sd_match(Op2, m_VectorVT()));
118 EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT()));
119}
120
121TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
122 SDLoc DL;
123 auto Int32VT = EVT::getIntegerVT(Context, BitWidth: 32);
124 auto Float32VT = EVT::getFloatingPointVT(BitWidth: 32);
125
126 SDValue Op0 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 1, VT: Int32VT);
127 SDValue Op1 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 2, VT: Int32VT);
128 SDValue Op2 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 3, VT: Float32VT);
129
130 SDValue Add = DAG->getNode(Opcode: ISD::ADD, DL, VT: Int32VT, N1: Op0, N2: Op1);
131 SDValue Sub = DAG->getNode(Opcode: ISD::SUB, DL, VT: Int32VT, N1: Add, N2: Op0);
132 SDValue Mul = DAG->getNode(Opcode: ISD::MUL, DL, VT: Int32VT, N1: Add, N2: Sub);
133 SDValue And = DAG->getNode(Opcode: ISD::AND, DL, VT: Int32VT, N1: Op0, N2: Op1);
134 SDValue Xor = DAG->getNode(Opcode: ISD::XOR, DL, VT: Int32VT, N1: Op1, N2: Op0);
135 SDValue Or = DAG->getNode(Opcode: ISD::OR, DL, VT: Int32VT, N1: Op0, N2: Op1);
136 SDValue SMax = DAG->getNode(Opcode: ISD::SMAX, DL, VT: Int32VT, N1: Op0, N2: Op1);
137 SDValue SMin = DAG->getNode(Opcode: ISD::SMIN, DL, VT: Int32VT, N1: Op1, N2: Op0);
138 SDValue UMax = DAG->getNode(Opcode: ISD::UMAX, DL, VT: Int32VT, N1: Op0, N2: Op1);
139 SDValue UMin = DAG->getNode(Opcode: ISD::UMIN, DL, VT: Int32VT, N1: Op1, N2: Op0);
140
141 SDValue SFAdd = DAG->getNode(ISD::STRICT_FADD, DL, {Float32VT, MVT::Other},
142 {DAG->getEntryNode(), Op2, Op2});
143
144 using namespace SDPatternMatch;
145 EXPECT_TRUE(sd_match(Sub, m_BinOp(ISD::SUB, m_Value(), m_Value())));
146 EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value())));
147 EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value())));
148 EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value())));
149 EXPECT_TRUE(sd_match(
150 Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add)))));
151 EXPECT_TRUE(
152 sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_SpecificVT(Float32VT),
153 m_SpecificVT(Float32VT))));
154
155 EXPECT_TRUE(sd_match(And, m_c_BinOp(ISD::AND, m_Value(), m_Value())));
156 EXPECT_TRUE(sd_match(And, m_And(m_Value(), m_Value())));
157 EXPECT_TRUE(sd_match(Xor, m_c_BinOp(ISD::XOR, m_Value(), m_Value())));
158 EXPECT_TRUE(sd_match(Xor, m_Xor(m_Value(), m_Value())));
159 EXPECT_TRUE(sd_match(Or, m_c_BinOp(ISD::OR, m_Value(), m_Value())));
160 EXPECT_TRUE(sd_match(Or, m_Or(m_Value(), m_Value())));
161
162 EXPECT_TRUE(sd_match(SMax, m_c_BinOp(ISD::SMAX, m_Value(), m_Value())));
163 EXPECT_TRUE(sd_match(SMax, m_SMax(m_Value(), m_Value())));
164 EXPECT_TRUE(sd_match(SMin, m_c_BinOp(ISD::SMIN, m_Value(), m_Value())));
165 EXPECT_TRUE(sd_match(SMin, m_SMin(m_Value(), m_Value())));
166 EXPECT_TRUE(sd_match(UMax, m_c_BinOp(ISD::UMAX, m_Value(), m_Value())));
167 EXPECT_TRUE(sd_match(UMax, m_UMax(m_Value(), m_Value())));
168 EXPECT_TRUE(sd_match(UMin, m_c_BinOp(ISD::UMIN, m_Value(), m_Value())));
169 EXPECT_TRUE(sd_match(UMin, m_UMin(m_Value(), m_Value())));
170
171 SDValue BindVal;
172 EXPECT_TRUE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_Value(BindVal),
173 m_Deferred(BindVal))));
174 EXPECT_FALSE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_OtherVT(),
175 m_SpecificVT(Float32VT))));
176}
177
178TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
179 SDLoc DL;
180 auto Int32VT = EVT::getIntegerVT(Context, BitWidth: 32);
181 auto Int64VT = EVT::getIntegerVT(Context, BitWidth: 64);
182
183 SDValue Op0 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 1, VT: Int32VT);
184 SDValue Op1 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 1, VT: Int64VT);
185
186 SDValue ZExt = DAG->getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: Int64VT, Operand: Op0);
187 SDValue SExt = DAG->getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: Int64VT, Operand: Op0);
188 SDValue Trunc = DAG->getNode(Opcode: ISD::TRUNCATE, DL, VT: Int32VT, Operand: Op1);
189
190 SDValue Sub = DAG->getNode(Opcode: ISD::SUB, DL, VT: Int32VT, N1: Trunc, N2: Op0);
191 SDValue Neg = DAG->getNegative(Val: Op0, DL, VT: Int32VT);
192 SDValue Not = DAG->getNOT(DL, Val: Op0, VT: Int32VT);
193
194 using namespace SDPatternMatch;
195 EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value())));
196 EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value())));
197 EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1))));
198
199 EXPECT_TRUE(sd_match(Neg, m_Neg(m_Value())));
200 EXPECT_TRUE(sd_match(Not, m_Not(m_Value())));
201 EXPECT_FALSE(sd_match(ZExt, m_Neg(m_Value())));
202 EXPECT_FALSE(sd_match(Sub, m_Neg(m_Value())));
203 EXPECT_FALSE(sd_match(Neg, m_Not(m_Value())));
204}
205
206TEST_F(SelectionDAGPatternMatchTest, matchConstants) {
207 SDLoc DL;
208 auto Int32VT = EVT::getIntegerVT(Context, BitWidth: 32);
209 auto VInt32VT = EVT::getVectorVT(Context, VT: Int32VT, NumElements: 4);
210
211 SDValue Arg0 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 1, VT: Int32VT);
212
213 SDValue Const3 = DAG->getConstant(Val: 3, DL, VT: Int32VT);
214 SDValue Const87 = DAG->getConstant(Val: 87, DL, VT: Int32VT);
215 SDValue Splat = DAG->getSplat(VT: VInt32VT, DL, Op: Arg0);
216 SDValue ConstSplat = DAG->getSplat(VT: VInt32VT, DL, Op: Const3);
217 SDValue Zero = DAG->getConstant(Val: 0, DL, VT: Int32VT);
218 SDValue One = DAG->getConstant(Val: 1, DL, VT: Int32VT);
219 SDValue AllOnes = DAG->getConstant(Val: APInt::getAllOnes(numBits: 32), DL, VT: Int32VT);
220
221 using namespace SDPatternMatch;
222 EXPECT_TRUE(sd_match(Const87, m_ConstInt()));
223 EXPECT_FALSE(sd_match(Arg0, m_ConstInt()));
224 APInt ConstVal;
225 EXPECT_TRUE(sd_match(ConstSplat, m_ConstInt(ConstVal)));
226 EXPECT_EQ(ConstVal, 3);
227 EXPECT_FALSE(sd_match(Splat, m_ConstInt()));
228
229 EXPECT_TRUE(sd_match(Const87, m_SpecificInt(87)));
230 EXPECT_TRUE(sd_match(Const3, m_SpecificInt(ConstVal)));
231 EXPECT_TRUE(sd_match(AllOnes, m_AllOnes()));
232
233 EXPECT_TRUE(sd_match(Zero, DAG.get(), m_False()));
234 EXPECT_TRUE(sd_match(One, DAG.get(), m_True()));
235 EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_True()));
236}
237
238TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
239 SDLoc DL;
240 auto Int32VT = EVT::getIntegerVT(Context, BitWidth: 32);
241
242 SDValue Op0 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 1, VT: Int32VT);
243 SDValue Op1 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 2, VT: Int32VT);
244
245 SDValue Add = DAG->getNode(Opcode: ISD::ADD, DL, VT: Int32VT, N1: Op0, N2: Op1);
246 SDValue Sub = DAG->getNode(Opcode: ISD::SUB, DL, VT: Int32VT, N1: Add, N2: Op0);
247
248 using namespace SDPatternMatch;
249 EXPECT_TRUE(sd_match(
250 Sub, m_AnyOf(m_Opc(ISD::ADD), m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
251 EXPECT_TRUE(sd_match(Add, m_AllOf(m_Opc(ISD::ADD), m_OneUse())));
252}
253
254TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
255 SDLoc DL;
256 auto Int32VT = EVT::getIntegerVT(Context, BitWidth: 32);
257 auto Int64VT = EVT::getIntegerVT(Context, BitWidth: 64);
258
259 SDValue Op32 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 1, VT: Int32VT);
260 SDValue Op64 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 1, VT: Int64VT);
261 SDValue ZExt = DAG->getNode(Opcode: ISD::ZERO_EXTEND, DL, VT: Int64VT, Operand: Op32);
262 SDValue SExt = DAG->getNode(Opcode: ISD::SIGN_EXTEND, DL, VT: Int64VT, Operand: Op32);
263 SDValue AExt = DAG->getNode(Opcode: ISD::ANY_EXTEND, DL, VT: Int64VT, Operand: Op32);
264 SDValue Trunc = DAG->getNode(Opcode: ISD::TRUNCATE, DL, VT: Int32VT, Operand: Op64);
265
266 using namespace SDPatternMatch;
267 SDValue A;
268 EXPECT_TRUE(sd_match(Op32, m_ZExtOrSelf(m_Value(A))));
269 EXPECT_TRUE(A == Op32);
270 EXPECT_TRUE(sd_match(ZExt, m_ZExtOrSelf(m_Value(A))));
271 EXPECT_TRUE(A == Op32);
272 EXPECT_TRUE(sd_match(Op64, m_SExtOrSelf(m_Value(A))));
273 EXPECT_TRUE(A == Op64);
274 EXPECT_TRUE(sd_match(SExt, m_SExtOrSelf(m_Value(A))));
275 EXPECT_TRUE(A == Op32);
276 EXPECT_TRUE(sd_match(Op32, m_AExtOrSelf(m_Value(A))));
277 EXPECT_TRUE(A == Op32);
278 EXPECT_TRUE(sd_match(AExt, m_AExtOrSelf(m_Value(A))));
279 EXPECT_TRUE(A == Op32);
280 EXPECT_TRUE(sd_match(Op64, m_TruncOrSelf(m_Value(A))));
281 EXPECT_TRUE(A == Op64);
282 EXPECT_TRUE(sd_match(Trunc, m_TruncOrSelf(m_Value(A))));
283 EXPECT_TRUE(A == Op64);
284}
285
286TEST_F(SelectionDAGPatternMatchTest, matchNode) {
287 SDLoc DL;
288 auto Int32VT = EVT::getIntegerVT(Context, BitWidth: 32);
289
290 SDValue Op0 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 1, VT: Int32VT);
291 SDValue Op1 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 2, VT: Int32VT);
292
293 SDValue Add = DAG->getNode(Opcode: ISD::ADD, DL, VT: Int32VT, N1: Op0, N2: Op1);
294
295 using namespace SDPatternMatch;
296 EXPECT_TRUE(sd_match(Add, m_Node(ISD::ADD, m_Value(), m_Value())));
297 EXPECT_FALSE(sd_match(Add, m_Node(ISD::SUB, m_Value(), m_Value())));
298 EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_Value())));
299 EXPECT_FALSE(
300 sd_match(Add, m_Node(ISD::ADD, m_Value(), m_Value(), m_Value())));
301 EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_ConstInt(), m_Value())));
302}
303
304namespace {
305struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
306 using SDPatternMatch::BasicMatchContext::BasicMatchContext;
307
308 bool match(SDValue OpVal, unsigned Opc) const {
309 if (!OpVal->isVPOpcode())
310 return OpVal->getOpcode() == Opc;
311
312 auto BaseOpc = ISD::getBaseOpcodeForVP(Opcode: OpVal->getOpcode(), hasFPExcept: false);
313 return BaseOpc.has_value() && *BaseOpc == Opc;
314 }
315};
316} // anonymous namespace
317TEST_F(SelectionDAGPatternMatchTest, matchContext) {
318 SDLoc DL;
319 auto BoolVT = EVT::getIntegerVT(Context, BitWidth: 1);
320 auto Int32VT = EVT::getIntegerVT(Context, BitWidth: 32);
321 auto VInt32VT = EVT::getVectorVT(Context, VT: Int32VT, NumElements: 4);
322 auto MaskVT = EVT::getVectorVT(Context, VT: BoolVT, NumElements: 4);
323
324 SDValue Scalar0 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 1, VT: Int32VT);
325 SDValue Vector0 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 2, VT: VInt32VT);
326 SDValue Mask0 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 3, VT: MaskVT);
327
328 SDValue VPAdd = DAG->getNode(Opcode: ISD::VP_ADD, DL, VT: VInt32VT,
329 Ops: {Vector0, Vector0, Mask0, Scalar0});
330 SDValue VPReduceAdd = DAG->getNode(Opcode: ISD::VP_REDUCE_ADD, DL, VT: Int32VT,
331 Ops: {Scalar0, VPAdd, Mask0, Scalar0});
332
333 using namespace SDPatternMatch;
334 VPMatchContext VPCtx(DAG.get());
335 EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
336 // VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
337 // sd_match before switching to VPMatchContext when checking VPAdd.
338 EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
339 m_Context(VPCtx, m_Opc(ISD::ADD)),
340 m_Value(), m_Value())));
341}
342
343TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
344 SDLoc DL;
345 auto Int16VT = EVT::getIntegerVT(Context, BitWidth: 16);
346 auto Int64VT = EVT::getIntegerVT(Context, BitWidth: 64);
347
348 SDValue Op0 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 1, VT: Int64VT);
349 SDValue Op1 = DAG->getCopyFromReg(Chain: DAG->getEntryNode(), dl: DL, Reg: 2, VT: Int16VT);
350
351 SDValue Add = DAG->getNode(Opcode: ISD::ADD, DL, VT: Int64VT, N1: Op0, N2: Op0);
352
353 using namespace SDPatternMatch;
354 EXPECT_TRUE(sd_match(Op0, DAG.get(), m_LegalType(m_Value())));
355 EXPECT_FALSE(sd_match(Op1, DAG.get(), m_LegalType(m_Value())));
356 EXPECT_TRUE(sd_match(Add, DAG.get(),
357 m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value())))));
358}
359

source code of llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp