1 | //===- Target/X86/X86LowerAMXType.cpp - -------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// \file Pass to transform <256 x i32> load/store |
10 | /// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only |
11 | /// provides simple operation on x86_amx. The basic elementwise operation |
12 | /// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> |
13 | /// and only AMX intrinsics can operate on the type, we need transform |
14 | /// load/store <256 x i32> instruction to AMX load/store. If the bitcast can |
15 | /// not be combined with load/store, we transform the bitcast to amx load/store |
16 | /// and <256 x i32> store/load. |
17 | /// |
18 | /// If Front End not use O0 but the Mid/Back end use O0, (e.g. "Clang -O2 -S |
19 | /// -emit-llvm t.c" + "llc t.ll") we should make sure the amx data is volatile, |
20 | /// because that is necessary for AMX fast register allocation. (In Fast |
21 | /// registera allocation, register will be allocated before spill/reload, so |
22 | /// there is no additional register for amx to identify the step in spill.) |
23 | /// The volatileTileData() will handle this case. |
24 | /// e.g. |
25 | /// ---------------------------------------------------------- |
26 | /// | def %td = ... | |
27 | /// | ... | |
28 | /// | "use %td" | |
29 | /// ---------------------------------------------------------- |
30 | /// will transfer to --> |
31 | /// ---------------------------------------------------------- |
32 | /// | def %td = ... | |
33 | /// | call void @llvm.x86.tilestored64.internal(mem, %td) | |
34 | /// | ... | |
35 | /// | %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem)| |
36 | /// | "use %td2" | |
37 | /// ---------------------------------------------------------- |
38 | // |
39 | //===----------------------------------------------------------------------===// |
40 | // |
41 | #include "X86.h" |
42 | #include "llvm/ADT/PostOrderIterator.h" |
43 | #include "llvm/ADT/SetVector.h" |
44 | #include "llvm/ADT/SmallSet.h" |
45 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" |
46 | #include "llvm/Analysis/TargetLibraryInfo.h" |
47 | #include "llvm/Analysis/TargetTransformInfo.h" |
48 | #include "llvm/CodeGen/Passes.h" |
49 | #include "llvm/CodeGen/TargetPassConfig.h" |
50 | #include "llvm/CodeGen/ValueTypes.h" |
51 | #include "llvm/IR/DataLayout.h" |
52 | #include "llvm/IR/Function.h" |
53 | #include "llvm/IR/IRBuilder.h" |
54 | #include "llvm/IR/Instructions.h" |
55 | #include "llvm/IR/IntrinsicInst.h" |
56 | #include "llvm/IR/IntrinsicsX86.h" |
57 | #include "llvm/IR/PatternMatch.h" |
58 | #include "llvm/InitializePasses.h" |
59 | #include "llvm/Pass.h" |
60 | #include "llvm/Target/TargetMachine.h" |
61 | #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" |
62 | #include "llvm/Transforms/Utils/Local.h" |
63 | |
64 | #include <map> |
65 | |
66 | using namespace llvm; |
67 | using namespace PatternMatch; |
68 | |
69 | #define DEBUG_TYPE "lower-amx-type" |
70 | |
71 | static bool isAMXCast(Instruction *II) { |
72 | return match(II, |
73 | m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value())) || |
74 | match(II, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>(m_Value())); |
75 | } |
76 | |
77 | static bool isAMXIntrinsic(Value *I) { |
78 | auto *II = dyn_cast<IntrinsicInst>(Val: I); |
79 | if (!II) |
80 | return false; |
81 | if (isAMXCast(II)) |
82 | return false; |
83 | // Check if return type or parameter is x86_amx. If it is x86_amx |
84 | // the intrinsic must be x86 amx intrinsics. |
85 | if (II->getType()->isX86_AMXTy()) |
86 | return true; |
87 | for (Value *V : II->args()) { |
88 | if (V->getType()->isX86_AMXTy()) |
89 | return true; |
90 | } |
91 | |
92 | return false; |
93 | } |
94 | |
95 | static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, |
96 | Type *Ty) { |
97 | Function &F = *BB->getParent(); |
98 | Module *M = BB->getModule(); |
99 | const DataLayout &DL = M->getDataLayout(); |
100 | |
101 | LLVMContext &Ctx = Builder.getContext(); |
102 | auto AllocaAlignment = DL.getPrefTypeAlign(Ty: Type::getX86_AMXTy(C&: Ctx)); |
103 | unsigned AllocaAS = DL.getAllocaAddrSpace(); |
104 | AllocaInst *AllocaRes = |
105 | new AllocaInst(Ty, AllocaAS, "" , F.getEntryBlock().begin()); |
106 | AllocaRes->setAlignment(AllocaAlignment); |
107 | return AllocaRes; |
108 | } |
109 | |
110 | static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { |
111 | for (Instruction &I : F.getEntryBlock()) |
112 | if (!isa<AllocaInst>(Val: &I)) |
113 | return &I; |
114 | llvm_unreachable("No terminator in the entry block!" ); |
115 | } |
116 | |
117 | static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) { |
118 | IRBuilder<> Builder(II); |
119 | Value *Row = nullptr, *Col = nullptr; |
120 | switch (II->getIntrinsicID()) { |
121 | default: |
122 | llvm_unreachable("Expect amx intrinsics" ); |
123 | case Intrinsic::x86_tileloadd64_internal: |
124 | case Intrinsic::x86_tileloaddt164_internal: |
125 | case Intrinsic::x86_tilestored64_internal: { |
126 | Row = II->getArgOperand(i: 0); |
127 | Col = II->getArgOperand(i: 1); |
128 | break; |
129 | } |
130 | // a * b + c |
131 | // The shape depends on which operand. |
132 | case Intrinsic::x86_tcmmimfp16ps_internal: |
133 | case Intrinsic::x86_tcmmrlfp16ps_internal: |
134 | case Intrinsic::x86_tdpbssd_internal: |
135 | case Intrinsic::x86_tdpbsud_internal: |
136 | case Intrinsic::x86_tdpbusd_internal: |
137 | case Intrinsic::x86_tdpbuud_internal: |
138 | case Intrinsic::x86_tdpbf16ps_internal: |
139 | case Intrinsic::x86_tdpfp16ps_internal: { |
140 | switch (OpNo) { |
141 | case 3: |
142 | Row = II->getArgOperand(i: 0); |
143 | Col = II->getArgOperand(i: 1); |
144 | break; |
145 | case 4: |
146 | Row = II->getArgOperand(i: 0); |
147 | Col = II->getArgOperand(i: 2); |
148 | break; |
149 | case 5: |
150 | if (isa<ConstantInt>(Val: II->getArgOperand(i: 2))) |
151 | Row = Builder.getInt16( |
152 | C: (cast<ConstantInt>(Val: II->getOperand(i_nocapture: 2))->getSExtValue()) / 4); |
153 | else if (isa<Instruction>(Val: II->getArgOperand(i: 2))) { |
154 | // When it is not a const value and it is not a function argument, we |
155 | // create Row after the definition of II->getOperand(2) instead of |
156 | // before II. For example, II is %118, we try to getshape for %117: |
157 | // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x |
158 | // i32> %115). |
159 | // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 |
160 | // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx |
161 | // %117). |
162 | // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its |
163 | // definition is after its user(new tileload for %117). |
164 | // So, the best choice is to create %row right after the definition of |
165 | // %106. |
166 | Builder.SetInsertPoint(cast<Instruction>(Val: II->getOperand(i_nocapture: 2))); |
167 | Row = Builder.CreateUDiv(LHS: II->getOperand(i_nocapture: 2), RHS: Builder.getInt16(C: 4)); |
168 | cast<Instruction>(Val: Row)->moveAfter(MovePos: cast<Instruction>(Val: II->getOperand(i_nocapture: 2))); |
169 | } else { |
170 | // When it is not a const value and it is a function argument, we create |
171 | // Row at the entry bb. |
172 | IRBuilder<> NewBuilder( |
173 | getFirstNonAllocaInTheEntryBlock(F&: *II->getFunction())); |
174 | Row = NewBuilder.CreateUDiv(LHS: II->getOperand(i_nocapture: 2), RHS: NewBuilder.getInt16(C: 4)); |
175 | } |
176 | Col = II->getArgOperand(i: 1); |
177 | break; |
178 | } |
179 | break; |
180 | } |
181 | } |
182 | |
183 | return std::make_pair(x&: Row, y&: Col); |
184 | } |
185 | |
186 | static std::pair<Value *, Value *> getShape(PHINode *Phi) { |
187 | Use &U = *(Phi->use_begin()); |
188 | unsigned OpNo = U.getOperandNo(); |
189 | User *V = U.getUser(); |
190 | // TODO We don't traverse all users. To make the algorithm simple, here we |
191 | // just traverse the first user. If we can find shape, then return the shape, |
192 | // otherwise just return nullptr and the optimization for undef/zero will be |
193 | // abandoned. |
194 | while (V) { |
195 | if (isAMXCast(II: dyn_cast<Instruction>(Val: V))) { |
196 | if (V->use_empty()) |
197 | break; |
198 | Use &U = *(V->use_begin()); |
199 | OpNo = U.getOperandNo(); |
200 | V = U.getUser(); |
201 | } else if (isAMXIntrinsic(I: V)) { |
202 | return getShape(II: cast<IntrinsicInst>(Val: V), OpNo); |
203 | } else if (isa<PHINode>(Val: V)) { |
204 | if (V->use_empty()) |
205 | break; |
206 | Use &U = *(V->use_begin()); |
207 | V = U.getUser(); |
208 | } else { |
209 | break; |
210 | } |
211 | } |
212 | |
213 | return std::make_pair(x: nullptr, y: nullptr); |
214 | } |
215 | |
216 | namespace { |
217 | class X86LowerAMXType { |
218 | Function &Func; |
219 | |
220 | // In AMX intrinsics we let Shape = {Row, Col}, but the |
221 | // RealCol = Col / ElementSize. We may use the RealCol |
222 | // as a new Row for other new created AMX intrinsics. |
223 | std::map<Value *, Value *> Col2Row; |
224 | |
225 | public: |
226 | X86LowerAMXType(Function &F) : Func(F) {} |
227 | bool visit(); |
228 | void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); |
229 | void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); |
230 | bool transformBitcast(BitCastInst *Bitcast); |
231 | }; |
232 | |
233 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
234 | // %2 = bitcast <256 x i32> %src to x86_amx |
235 | // --> |
236 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
237 | // i8* %addr, i64 %stride64) |
238 | void X86LowerAMXType::combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { |
239 | Value *Row = nullptr, *Col = nullptr; |
240 | Use &U = *(Bitcast->use_begin()); |
241 | unsigned OpNo = U.getOperandNo(); |
242 | auto *II = cast<IntrinsicInst>(Val: U.getUser()); |
243 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
244 | IRBuilder<> Builder(Bitcast); |
245 | // Use the maximun column as stride. |
246 | Value *Stride = Builder.getInt64(C: 64); |
247 | Value *I8Ptr = LD->getOperand(i_nocapture: 0); |
248 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; |
249 | |
250 | Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, |
251 | std::nullopt, Args); |
252 | Bitcast->replaceAllUsesWith(V: NewInst); |
253 | } |
254 | |
255 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, |
256 | // %stride); |
257 | // %13 = bitcast x86_amx %src to <256 x i32> |
258 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
259 | // --> |
260 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
261 | // %stride64, %13) |
262 | void X86LowerAMXType::combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { |
263 | |
264 | Value *Tile = Bitcast->getOperand(i_nocapture: 0); |
265 | auto *II = cast<IntrinsicInst>(Val: Tile); |
266 | // Tile is output from AMX intrinsic. The first operand of the |
267 | // intrinsic is row, the second operand of the intrinsic is column. |
268 | Value *Row = II->getOperand(i_nocapture: 0); |
269 | Value *Col = II->getOperand(i_nocapture: 1); |
270 | IRBuilder<> Builder(ST); |
271 | // Use the maximum column as stride. It must be the same with load |
272 | // stride. |
273 | Value *Stride = Builder.getInt64(C: 64); |
274 | Value *I8Ptr = ST->getOperand(i_nocapture: 1); |
275 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; |
276 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, |
277 | Args); |
278 | if (Bitcast->hasOneUse()) |
279 | return; |
280 | // %13 = bitcast x86_amx %src to <256 x i32> |
281 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
282 | // %add = <256 x i32> %13, <256 x i32> %src2 |
283 | // --> |
284 | // %13 = bitcast x86_amx %src to <256 x i32> |
285 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
286 | // %stride64, %13) |
287 | // %14 = load <256 x i32>, %addr |
288 | // %add = <256 x i32> %14, <256 x i32> %src2 |
289 | Value *Vec = Builder.CreateLoad(Ty: Bitcast->getType(), Ptr: ST->getOperand(i_nocapture: 1)); |
290 | Bitcast->replaceAllUsesWith(V: Vec); |
291 | } |
292 | |
293 | // transform bitcast to <store, load> instructions. |
294 | bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { |
295 | IRBuilder<> Builder(Bitcast); |
296 | AllocaInst *AllocaAddr; |
297 | Value *I8Ptr, *Stride; |
298 | auto *Src = Bitcast->getOperand(i_nocapture: 0); |
299 | |
300 | auto Prepare = [&](Type *MemTy) { |
301 | AllocaAddr = createAllocaInstAtEntry(Builder, BB: Bitcast->getParent(), Ty: MemTy); |
302 | I8Ptr = AllocaAddr; |
303 | Stride = Builder.getInt64(C: 64); |
304 | }; |
305 | |
306 | if (Bitcast->getType()->isX86_AMXTy()) { |
307 | // %2 = bitcast <256 x i32> %src to x86_amx |
308 | // --> |
309 | // %addr = alloca <256 x i32>, align 64 |
310 | // store <256 x i32> %src, <256 x i32>* %addr, align 64 |
311 | // %addr2 = bitcast <256 x i32>* to i8* |
312 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
313 | // i8* %addr2, |
314 | // i64 64) |
315 | Use &U = *(Bitcast->use_begin()); |
316 | unsigned OpNo = U.getOperandNo(); |
317 | auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser()); |
318 | if (!II) |
319 | return false; // May be bitcast from x86amx to <256 x i32>. |
320 | Prepare(Bitcast->getOperand(i_nocapture: 0)->getType()); |
321 | Builder.CreateStore(Val: Src, Ptr: AllocaAddr); |
322 | // TODO we can pick an constant operand for the shape. |
323 | Value *Row = nullptr, *Col = nullptr; |
324 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
325 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; |
326 | Value *NewInst = Builder.CreateIntrinsic( |
327 | Intrinsic::x86_tileloadd64_internal, std::nullopt, Args); |
328 | Bitcast->replaceAllUsesWith(V: NewInst); |
329 | } else { |
330 | // %2 = bitcast x86_amx %src to <256 x i32> |
331 | // --> |
332 | // %addr = alloca <256 x i32>, align 64 |
333 | // %addr2 = bitcast <256 x i32>* to i8* |
334 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, |
335 | // i8* %addr2, i64 %stride) |
336 | // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 |
337 | auto *II = dyn_cast<IntrinsicInst>(Val: Src); |
338 | if (!II) |
339 | return false; // May be bitcast from <256 x i32> to x86amx. |
340 | Prepare(Bitcast->getType()); |
341 | Value *Row = II->getOperand(i_nocapture: 0); |
342 | Value *Col = II->getOperand(i_nocapture: 1); |
343 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src}; |
344 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, |
345 | Args); |
346 | Value *NewInst = Builder.CreateLoad(Ty: Bitcast->getType(), Ptr: AllocaAddr); |
347 | Bitcast->replaceAllUsesWith(V: NewInst); |
348 | } |
349 | |
350 | return true; |
351 | } |
352 | |
353 | bool X86LowerAMXType::visit() { |
354 | SmallVector<Instruction *, 8> DeadInsts; |
355 | Col2Row.clear(); |
356 | |
357 | for (BasicBlock *BB : post_order(G: &Func)) { |
358 | for (Instruction &Inst : llvm::make_early_inc_range(Range: llvm::reverse(C&: *BB))) { |
359 | auto *Bitcast = dyn_cast<BitCastInst>(Val: &Inst); |
360 | if (!Bitcast) |
361 | continue; |
362 | |
363 | Value *Src = Bitcast->getOperand(i_nocapture: 0); |
364 | if (Bitcast->getType()->isX86_AMXTy()) { |
365 | if (Bitcast->user_empty()) { |
366 | DeadInsts.push_back(Elt: Bitcast); |
367 | continue; |
368 | } |
369 | LoadInst *LD = dyn_cast<LoadInst>(Val: Src); |
370 | if (!LD) { |
371 | if (transformBitcast(Bitcast)) |
372 | DeadInsts.push_back(Elt: Bitcast); |
373 | continue; |
374 | } |
375 | // If load has mutli-user, duplicate a vector load. |
376 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
377 | // %2 = bitcast <256 x i32> %src to x86_amx |
378 | // %add = add <256 x i32> %src, <256 x i32> %src2 |
379 | // --> |
380 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
381 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
382 | // i8* %addr, i64 %stride64) |
383 | // %add = add <256 x i32> %src, <256 x i32> %src2 |
384 | |
385 | // If load has one user, the load will be eliminated in DAG ISel. |
386 | // %src = load <256 x i32>, <256 x i32>* %addr, align 64 |
387 | // %2 = bitcast <256 x i32> %src to x86_amx |
388 | // --> |
389 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
390 | // i8* %addr, i64 %stride64) |
391 | combineLoadBitcast(LD, Bitcast); |
392 | DeadInsts.push_back(Elt: Bitcast); |
393 | if (LD->hasOneUse()) |
394 | DeadInsts.push_back(Elt: LD); |
395 | } else if (Src->getType()->isX86_AMXTy()) { |
396 | if (Bitcast->user_empty()) { |
397 | DeadInsts.push_back(Elt: Bitcast); |
398 | continue; |
399 | } |
400 | StoreInst *ST = nullptr; |
401 | for (Use &U : Bitcast->uses()) { |
402 | ST = dyn_cast<StoreInst>(Val: U.getUser()); |
403 | if (ST) |
404 | break; |
405 | } |
406 | if (!ST) { |
407 | if (transformBitcast(Bitcast)) |
408 | DeadInsts.push_back(Elt: Bitcast); |
409 | continue; |
410 | } |
411 | // If bitcast (%13) has one use, combine bitcast and store to amx store. |
412 | // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, |
413 | // %stride); |
414 | // %13 = bitcast x86_amx %src to <256 x i32> |
415 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
416 | // --> |
417 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
418 | // %stride64, %13) |
419 | // |
420 | // If bitcast (%13) has multi-use, transform as below. |
421 | // %13 = bitcast x86_amx %src to <256 x i32> |
422 | // store <256 x i32> %13, <256 x i32>* %addr, align 64 |
423 | // %add = <256 x i32> %13, <256 x i32> %src2 |
424 | // --> |
425 | // %13 = bitcast x86_amx %src to <256 x i32> |
426 | // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, |
427 | // %stride64, %13) |
428 | // %14 = load <256 x i32>, %addr |
429 | // %add = <256 x i32> %14, <256 x i32> %src2 |
430 | // |
431 | combineBitcastStore(Bitcast, ST); |
432 | // Delete user first. |
433 | DeadInsts.push_back(Elt: ST); |
434 | DeadInsts.push_back(Elt: Bitcast); |
435 | } |
436 | } |
437 | } |
438 | |
439 | bool C = !DeadInsts.empty(); |
440 | |
441 | for (auto *Inst : DeadInsts) |
442 | Inst->eraseFromParent(); |
443 | |
444 | return C; |
445 | } |
446 | } // anonymous namespace |
447 | |
448 | static Value *getAllocaPos(BasicBlock *BB) { |
449 | Module *M = BB->getModule(); |
450 | Function *F = BB->getParent(); |
451 | IRBuilder<> Builder(&F->getEntryBlock().front()); |
452 | const DataLayout &DL = M->getDataLayout(); |
453 | unsigned AllocaAS = DL.getAllocaAddrSpace(); |
454 | Type *V256I32Ty = VectorType::get(ElementType: Builder.getInt32Ty(), NumElements: 256, Scalable: false); |
455 | AllocaInst *AllocaRes = |
456 | new AllocaInst(V256I32Ty, AllocaAS, "" , F->getEntryBlock().begin()); |
457 | BasicBlock::iterator Iter = AllocaRes->getIterator(); |
458 | ++Iter; |
459 | Builder.SetInsertPoint(&*Iter); |
460 | Value *I8Ptr = Builder.CreateBitCast(V: AllocaRes, DestTy: Builder.getPtrTy()); |
461 | return I8Ptr; |
462 | } |
463 | |
464 | static Instruction *createTileStore(Instruction *TileDef, Value *Ptr) { |
465 | assert(TileDef->getType()->isX86_AMXTy() && "Not define tile!" ); |
466 | auto *II = cast<IntrinsicInst>(Val: TileDef); |
467 | assert(II && "Not tile intrinsic!" ); |
468 | Value *Row = II->getOperand(i_nocapture: 0); |
469 | Value *Col = II->getOperand(i_nocapture: 1); |
470 | |
471 | BasicBlock *BB = TileDef->getParent(); |
472 | BasicBlock::iterator Iter = TileDef->getIterator(); |
473 | IRBuilder<> Builder(BB, ++Iter); |
474 | Value *Stride = Builder.getInt64(C: 64); |
475 | std::array<Value *, 5> Args = {Row, Col, Ptr, Stride, TileDef}; |
476 | |
477 | Instruction *TileStore = Builder.CreateIntrinsic( |
478 | Intrinsic::x86_tilestored64_internal, std::nullopt, Args); |
479 | return TileStore; |
480 | } |
481 | |
482 | static void replaceWithTileLoad(Use &U, Value *Ptr, bool IsPHI = false) { |
483 | Value *V = U.get(); |
484 | assert(V->getType()->isX86_AMXTy() && "Not define tile!" ); |
485 | |
486 | // Get tile shape. |
487 | IntrinsicInst *II = nullptr; |
488 | if (IsPHI) { |
489 | Value *PhiOp = cast<PHINode>(Val: V)->getIncomingValue(i: 0); |
490 | II = cast<IntrinsicInst>(Val: PhiOp); |
491 | } else { |
492 | II = cast<IntrinsicInst>(Val: V); |
493 | } |
494 | Value *Row = II->getOperand(i_nocapture: 0); |
495 | Value *Col = II->getOperand(i_nocapture: 1); |
496 | |
497 | Instruction *UserI = cast<Instruction>(Val: U.getUser()); |
498 | IRBuilder<> Builder(UserI); |
499 | Value *Stride = Builder.getInt64(C: 64); |
500 | std::array<Value *, 4> Args = {Row, Col, Ptr, Stride}; |
501 | |
502 | Value *TileLoad = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, |
503 | std::nullopt, Args); |
504 | UserI->replaceUsesOfWith(From: V, To: TileLoad); |
505 | } |
506 | |
507 | static bool isIncomingOfPHI(Instruction *I) { |
508 | for (Use &U : I->uses()) { |
509 | User *V = U.getUser(); |
510 | if (isa<PHINode>(Val: V)) |
511 | return true; |
512 | } |
513 | return false; |
514 | } |
515 | |
516 | // Let all AMX tile data become volatile data, shorten the life range |
517 | // of each tile register before fast register allocation. |
518 | namespace { |
519 | class X86VolatileTileData { |
520 | Function &F; |
521 | |
522 | public: |
523 | X86VolatileTileData(Function &Func) : F(Func) {} |
524 | Value *updatePhiIncomings(BasicBlock *BB, |
525 | SmallVector<Instruction *, 2> &Incomings); |
526 | void replacePhiDefWithLoad(Instruction *PHI, Value *StorePtr); |
527 | bool volatileTileData(); |
528 | void volatileTilePHI(PHINode *PHI); |
529 | void volatileTileNonPHI(Instruction *I); |
530 | }; |
531 | |
532 | Value *X86VolatileTileData::updatePhiIncomings( |
533 | BasicBlock *BB, SmallVector<Instruction *, 2> &Incomings) { |
534 | Value *I8Ptr = getAllocaPos(BB); |
535 | |
536 | for (auto *I : Incomings) { |
537 | User *Store = createTileStore(TileDef: I, Ptr: I8Ptr); |
538 | |
539 | // All its uses (except phi) should load from stored mem. |
540 | for (Use &U : I->uses()) { |
541 | User *V = U.getUser(); |
542 | if (isa<PHINode>(Val: V) || V == Store) |
543 | continue; |
544 | replaceWithTileLoad(U, Ptr: I8Ptr); |
545 | } |
546 | } |
547 | return I8Ptr; |
548 | } |
549 | |
550 | void X86VolatileTileData::replacePhiDefWithLoad(Instruction *PHI, |
551 | Value *StorePtr) { |
552 | for (Use &U : PHI->uses()) |
553 | replaceWithTileLoad(U, Ptr: StorePtr, IsPHI: true); |
554 | PHI->eraseFromParent(); |
555 | } |
556 | |
557 | // Smilar with volatileTileNonPHI, this function only handle PHI Nodes |
558 | // and their related AMX intrinsics. |
559 | // 1) PHI Def should change to tileload. |
560 | // 2) PHI Incoming Values should tilestored in just after their def. |
561 | // 3) The mem of these tileload and tilestores should be same. |
562 | // e.g. |
563 | // ------------------------------------------------------ |
564 | // bb_dom: |
565 | // ... |
566 | // br i1 %bool.cond, label %if.else, label %if.then |
567 | // |
568 | // if.then: |
569 | // def %t0 = ... |
570 | // ... |
571 | // use %t0 |
572 | // ... |
573 | // br label %if.end |
574 | // |
575 | // if.else: |
576 | // def %t1 = ... |
577 | // br label %if.end |
578 | // |
579 | // if.end: |
580 | // %td = phi x86_amx [ %t1, %if.else ], [ %t0, %if.then ] |
581 | // ... |
582 | // use %td |
583 | // ------------------------------------------------------ |
584 | // --> |
585 | // ------------------------------------------------------ |
586 | // bb_entry: |
587 | // %mem = alloca <256 x i32>, align 1024 * |
588 | // ... |
589 | // bb_dom: |
590 | // ... |
591 | // br i1 %bool.cond, label %if.else, label %if.then |
592 | // |
593 | // if.then: |
594 | // def %t0 = ... |
595 | // call void @llvm.x86.tilestored64.internal(mem, %t0) * |
596 | // ... |
597 | // %t0` = call x86_amx @llvm.x86.tileloadd64.internal(mem)* |
598 | // use %t0` * |
599 | // ... |
600 | // br label %if.end |
601 | // |
602 | // if.else: |
603 | // def %t1 = ... |
604 | // call void @llvm.x86.tilestored64.internal(mem, %t1) * |
605 | // br label %if.end |
606 | // |
607 | // if.end: |
608 | // ... |
609 | // %td = call x86_amx @llvm.x86.tileloadd64.internal(mem) * |
610 | // use %td |
611 | // ------------------------------------------------------ |
612 | void X86VolatileTileData::volatileTilePHI(PHINode *PHI) { |
613 | BasicBlock *BB = PHI->getParent(); |
614 | SmallVector<Instruction *, 2> Incomings; |
615 | |
616 | for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { |
617 | Value *Op = PHI->getIncomingValue(i: I); |
618 | Instruction *Inst = dyn_cast<Instruction>(Val: Op); |
619 | assert(Inst && "We shouldn't fold AMX instrution!" ); |
620 | Incomings.push_back(Elt: Inst); |
621 | } |
622 | |
623 | Value *StorePtr = updatePhiIncomings(BB, Incomings); |
624 | replacePhiDefWithLoad(PHI, StorePtr); |
625 | } |
626 | |
627 | // Store the defined tile and load it before use. |
628 | // All its users are not PHI. |
629 | // e.g. |
630 | // ------------------------------------------------------ |
631 | // def %td = ... |
632 | // ... |
633 | // "use %td" |
634 | // ------------------------------------------------------ |
635 | // --> |
636 | // ------------------------------------------------------ |
637 | // def %td = ... |
638 | // call void @llvm.x86.tilestored64.internal(mem, %td) |
639 | // ... |
640 | // %td2 = call x86_amx @llvm.x86.tileloadd64.internal(mem) |
641 | // "use %td2" |
642 | // ------------------------------------------------------ |
643 | void X86VolatileTileData::volatileTileNonPHI(Instruction *I) { |
644 | BasicBlock *BB = I->getParent(); |
645 | Value *I8Ptr = getAllocaPos(BB); |
646 | User *Store = createTileStore(TileDef: I, Ptr: I8Ptr); |
647 | |
648 | // All its uses should load from stored mem. |
649 | for (Use &U : I->uses()) { |
650 | User *V = U.getUser(); |
651 | assert(!isa<PHINode>(V) && "PHI Nodes should be excluded!" ); |
652 | if (V != Store) |
653 | replaceWithTileLoad(U, Ptr: I8Ptr); |
654 | } |
655 | } |
656 | |
657 | // Volatile Tile Model: |
658 | // 1) All the uses of tile data comes from tileload in time. |
659 | // 2) All the defs of tile data tilestore into mem immediately. |
660 | // For example: |
661 | // -------------------------------------------------------------------------- |
662 | // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key |
663 | // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) |
664 | // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx |
665 | // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) |
666 | // call void @llvm.x86.tilestored64.internal(... td) area |
667 | // -------------------------------------------------------------------------- |
668 | // 3) No terminator, call or other amx instructions in the key amx area. |
669 | bool X86VolatileTileData::volatileTileData() { |
670 | bool Changed = false; |
671 | for (BasicBlock &BB : F) { |
672 | SmallVector<Instruction *, 2> PHIInsts; |
673 | SmallVector<Instruction *, 8> AMXDefInsts; |
674 | |
675 | for (Instruction &I : BB) { |
676 | if (!I.getType()->isX86_AMXTy()) |
677 | continue; |
678 | if (isa<PHINode>(Val: &I)) |
679 | PHIInsts.push_back(Elt: &I); |
680 | else |
681 | AMXDefInsts.push_back(Elt: &I); |
682 | } |
683 | |
684 | // First we "volatile" the non-phi related amx intrinsics. |
685 | for (Instruction *I : AMXDefInsts) { |
686 | if (isIncomingOfPHI(I)) |
687 | continue; |
688 | volatileTileNonPHI(I); |
689 | Changed = true; |
690 | } |
691 | |
692 | for (Instruction *I : PHIInsts) { |
693 | volatileTilePHI(PHI: dyn_cast<PHINode>(Val: I)); |
694 | Changed = true; |
695 | } |
696 | } |
697 | return Changed; |
698 | } |
699 | |
700 | } // anonymous namespace |
701 | |
702 | namespace { |
703 | |
704 | class X86LowerAMXCast { |
705 | Function &Func; |
706 | std::unique_ptr<DominatorTree> DT; |
707 | |
708 | public: |
709 | X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {} |
710 | bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST); |
711 | bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); |
712 | bool combineLdSt(SmallVectorImpl<Instruction *> &Casts); |
713 | bool combineAMXcast(TargetLibraryInfo *TLI); |
714 | bool transformAMXCast(IntrinsicInst *AMXCast); |
715 | bool transformAllAMXCast(); |
716 | bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, |
717 | SmallSetVector<Instruction *, 16> &DeadInst); |
718 | }; |
719 | |
720 | static bool DCEInstruction(Instruction *I, |
721 | SmallSetVector<Instruction *, 16> &WorkList, |
722 | const TargetLibraryInfo *TLI) { |
723 | if (isInstructionTriviallyDead(I, TLI)) { |
724 | salvageDebugInfo(I&: *I); |
725 | salvageKnowledge(I); |
726 | |
727 | // Null out all of the instruction's operands to see if any operand becomes |
728 | // dead as we go. |
729 | for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { |
730 | Value *OpV = I->getOperand(i); |
731 | I->setOperand(i, Val: nullptr); |
732 | |
733 | if (!OpV->use_empty() || I == OpV) |
734 | continue; |
735 | |
736 | // If the operand is an instruction that became dead as we nulled out the |
737 | // operand, and if it is 'trivially' dead, delete it in a future loop |
738 | // iteration. |
739 | if (Instruction *OpI = dyn_cast<Instruction>(Val: OpV)) { |
740 | if (isInstructionTriviallyDead(I: OpI, TLI)) { |
741 | WorkList.insert(X: OpI); |
742 | } |
743 | } |
744 | } |
745 | I->eraseFromParent(); |
746 | return true; |
747 | } |
748 | return false; |
749 | } |
750 | |
751 | /// This function handles following case |
752 | /// |
753 | /// A -> B amxcast |
754 | /// PHI |
755 | /// B -> A amxcast |
756 | /// |
757 | /// All the related PHI nodes can be replaced by new PHI nodes with type A. |
758 | /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. |
759 | bool X86LowerAMXCast::optimizeAMXCastFromPhi( |
760 | IntrinsicInst *CI, PHINode *PN, |
761 | SmallSetVector<Instruction *, 16> &DeadInst) { |
762 | IRBuilder<> Builder(CI); |
763 | Value *Src = CI->getOperand(i_nocapture: 0); |
764 | Type *SrcTy = Src->getType(); // Type B |
765 | Type *DestTy = CI->getType(); // Type A |
766 | |
767 | SmallVector<PHINode *, 4> PhiWorklist; |
768 | SmallSetVector<PHINode *, 4> OldPhiNodes; |
769 | |
770 | // Find all of the A->B casts and PHI nodes. |
771 | // We need to inspect all related PHI nodes, but PHIs can be cyclic, so |
772 | // OldPhiNodes is used to track all known PHI nodes, before adding a new |
773 | // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. |
774 | PhiWorklist.push_back(Elt: PN); |
775 | OldPhiNodes.insert(X: PN); |
776 | while (!PhiWorklist.empty()) { |
777 | auto *OldPN = PhiWorklist.pop_back_val(); |
778 | for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) { |
779 | Value *IncValue = OldPN->getIncomingValue(i: I); |
780 | // TODO: currently, We ignore cases where it is a const. In the future, we |
781 | // might support const. |
782 | if (isa<Constant>(Val: IncValue)) { |
783 | auto *IncConst = dyn_cast<Constant>(Val: IncValue); |
784 | if (!isa<UndefValue>(Val: IncValue) && !IncConst->isZeroValue()) |
785 | return false; |
786 | Value *Row = nullptr, *Col = nullptr; |
787 | std::tie(args&: Row, args&: Col) = getShape(Phi: OldPN); |
788 | // TODO: If it is not constant the Row and Col must domoniate tilezero |
789 | // that we are going to create. |
790 | if (!Row || !Col || !isa<Constant>(Val: Row) || !isa<Constant>(Val: Col)) |
791 | return false; |
792 | // Create tilezero at the end of incoming block. |
793 | auto *Block = OldPN->getIncomingBlock(i: I); |
794 | BasicBlock::iterator Iter = Block->getTerminator()->getIterator(); |
795 | Instruction *NewInst = Builder.CreateIntrinsic( |
796 | Intrinsic::x86_tilezero_internal, std::nullopt, {Row, Col}); |
797 | NewInst->moveBefore(MovePos: &*Iter); |
798 | NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector, |
799 | {IncValue->getType()}, {NewInst}); |
800 | NewInst->moveBefore(MovePos: &*Iter); |
801 | // Replace InValue with new Value. |
802 | OldPN->setIncomingValue(i: I, V: NewInst); |
803 | IncValue = NewInst; |
804 | } |
805 | |
806 | if (auto *PNode = dyn_cast<PHINode>(Val: IncValue)) { |
807 | if (OldPhiNodes.insert(X: PNode)) |
808 | PhiWorklist.push_back(Elt: PNode); |
809 | continue; |
810 | } |
811 | Instruction *ACI = dyn_cast<Instruction>(Val: IncValue); |
812 | if (ACI && isAMXCast(II: ACI)) { |
813 | // Verify it's a A->B cast. |
814 | Type *TyA = ACI->getOperand(i: 0)->getType(); |
815 | Type *TyB = ACI->getType(); |
816 | if (TyA != DestTy || TyB != SrcTy) |
817 | return false; |
818 | continue; |
819 | } |
820 | return false; |
821 | } |
822 | } |
823 | |
824 | // Check that each user of each old PHI node is something that we can |
825 | // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. |
826 | for (auto *OldPN : OldPhiNodes) { |
827 | for (User *V : OldPN->users()) { |
828 | Instruction *ACI = dyn_cast<Instruction>(Val: V); |
829 | if (ACI && isAMXCast(II: ACI)) { |
830 | // Verify it's a B->A cast. |
831 | Type *TyB = ACI->getOperand(i: 0)->getType(); |
832 | Type *TyA = ACI->getType(); |
833 | if (TyA != DestTy || TyB != SrcTy) |
834 | return false; |
835 | } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) { |
836 | // As long as the user is another old PHI node, then even if we don't |
837 | // rewrite it, the PHI web we're considering won't have any users |
838 | // outside itself, so it'll be dead. |
839 | // example: |
840 | // bb.0: |
841 | // %0 = amxcast ... |
842 | // bb.1: |
843 | // %1 = amxcast ... |
844 | // bb.2: |
845 | // %goodphi = phi %0, %1 |
846 | // %3 = amxcast %goodphi |
847 | // bb.3: |
848 | // %goodphi2 = phi %0, %goodphi |
849 | // %4 = amxcast %goodphi2 |
850 | // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is |
851 | // outside the phi-web, so the combination stop When |
852 | // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization |
853 | // will be done. |
854 | if (OldPhiNodes.count(key: PHI) == 0) |
855 | return false; |
856 | } else |
857 | return false; |
858 | } |
859 | } |
860 | |
861 | // For each old PHI node, create a corresponding new PHI node with a type A. |
862 | SmallDenseMap<PHINode *, PHINode *> NewPNodes; |
863 | for (auto *OldPN : OldPhiNodes) { |
864 | Builder.SetInsertPoint(OldPN); |
865 | PHINode *NewPN = Builder.CreatePHI(Ty: DestTy, NumReservedValues: OldPN->getNumOperands()); |
866 | NewPNodes[OldPN] = NewPN; |
867 | } |
868 | |
869 | // Fill in the operands of new PHI nodes. |
870 | for (auto *OldPN : OldPhiNodes) { |
871 | PHINode *NewPN = NewPNodes[OldPN]; |
872 | for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { |
873 | Value *V = OldPN->getOperand(i_nocapture: j); |
874 | Value *NewV = nullptr; |
875 | Instruction *ACI = dyn_cast<Instruction>(Val: V); |
876 | // There should not be a AMXcast from a const. |
877 | if (ACI && isAMXCast(II: ACI)) |
878 | NewV = ACI->getOperand(i: 0); |
879 | else if (auto *PrevPN = dyn_cast<PHINode>(Val: V)) |
880 | NewV = NewPNodes[PrevPN]; |
881 | assert(NewV); |
882 | NewPN->addIncoming(V: NewV, BB: OldPN->getIncomingBlock(i: j)); |
883 | } |
884 | } |
885 | |
886 | // Traverse all accumulated PHI nodes and process its users, |
887 | // which are Stores and BitcCasts. Without this processing |
888 | // NewPHI nodes could be replicated and could lead to extra |
889 | // moves generated after DeSSA. |
890 | // If there is a store with type B, change it to type A. |
891 | |
892 | // Replace users of BitCast B->A with NewPHI. These will help |
893 | // later to get rid of a closure formed by OldPHI nodes. |
894 | for (auto *OldPN : OldPhiNodes) { |
895 | PHINode *NewPN = NewPNodes[OldPN]; |
896 | for (User *V : make_early_inc_range(Range: OldPN->users())) { |
897 | Instruction *ACI = dyn_cast<Instruction>(Val: V); |
898 | if (ACI && isAMXCast(II: ACI)) { |
899 | Type *TyB = ACI->getOperand(i: 0)->getType(); |
900 | Type *TyA = ACI->getType(); |
901 | assert(TyA == DestTy && TyB == SrcTy); |
902 | (void)TyA; |
903 | (void)TyB; |
904 | ACI->replaceAllUsesWith(V: NewPN); |
905 | DeadInst.insert(X: ACI); |
906 | } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) { |
907 | // We don't need to push PHINode into DeadInst since they are operands |
908 | // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. |
909 | assert(OldPhiNodes.contains(PHI)); |
910 | (void)PHI; |
911 | } else |
912 | llvm_unreachable("all uses should be handled" ); |
913 | } |
914 | } |
915 | return true; |
916 | } |
917 | |
918 | // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42) |
919 | // store <256 x i32> %43, <256 x i32>* %p, align 64 |
920 | // --> |
921 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, |
922 | // i64 64, x86_amx %42) |
923 | bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { |
924 | Value *Tile = Cast->getOperand(i_nocapture: 0); |
925 | // TODO: If it is cast intrinsic or phi node, we can propagate the |
926 | // shape information through def-use chain. |
927 | if (!isAMXIntrinsic(I: Tile)) |
928 | return false; |
929 | auto *II = cast<IntrinsicInst>(Val: Tile); |
930 | // Tile is output from AMX intrinsic. The first operand of the |
931 | // intrinsic is row, the second operand of the intrinsic is column. |
932 | Value *Row = II->getOperand(i_nocapture: 0); |
933 | Value *Col = II->getOperand(i_nocapture: 1); |
934 | IRBuilder<> Builder(ST); |
935 | // Stride should be equal to col(measured by bytes) |
936 | Value *Stride = Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()); |
937 | Value *I8Ptr = Builder.CreateBitCast(V: ST->getOperand(i_nocapture: 1), DestTy: Builder.getPtrTy()); |
938 | std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile}; |
939 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, |
940 | Args); |
941 | return true; |
942 | } |
943 | |
944 | // %65 = load <256 x i32>, <256 x i32>* %p, align 64 |
945 | // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) |
946 | // --> |
947 | // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
948 | // i8* %p, i64 64) |
949 | bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { |
950 | bool EraseLoad = true; |
951 | Value *Row = nullptr, *Col = nullptr; |
952 | Use &U = *(Cast->use_begin()); |
953 | unsigned OpNo = U.getOperandNo(); |
954 | auto *II = cast<IntrinsicInst>(Val: U.getUser()); |
955 | // TODO: If it is cast intrinsic or phi node, we can propagate the |
956 | // shape information through def-use chain. |
957 | if (!isAMXIntrinsic(I: II)) |
958 | return false; |
959 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
960 | IRBuilder<> Builder(LD); |
961 | // Stride should be equal to col(measured by bytes) |
962 | Value *Stride = Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()); |
963 | Value *I8Ptr; |
964 | |
965 | // To save compiling time, we create doninator tree when it is really |
966 | // needed. |
967 | if (!DT) |
968 | DT.reset(p: new DominatorTree(Func)); |
969 | if (!DT->dominates(Def: Row, User: LD) || !DT->dominates(Def: Col, User: LD)) { |
970 | // store the value to stack and reload it from stack before cast. |
971 | auto *AllocaAddr = |
972 | createAllocaInstAtEntry(Builder, BB: Cast->getParent(), Ty: LD->getType()); |
973 | Builder.SetInsertPoint(&*std::next(x: LD->getIterator())); |
974 | Builder.CreateStore(Val: LD, Ptr: AllocaAddr); |
975 | |
976 | Builder.SetInsertPoint(Cast); |
977 | I8Ptr = Builder.CreateBitCast(V: AllocaAddr, DestTy: Builder.getPtrTy()); |
978 | EraseLoad = false; |
979 | } else { |
980 | I8Ptr = Builder.CreateBitCast(V: LD->getOperand(i_nocapture: 0), DestTy: Builder.getPtrTy()); |
981 | } |
982 | std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride}; |
983 | |
984 | Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, |
985 | std::nullopt, Args); |
986 | Cast->replaceAllUsesWith(V: NewInst); |
987 | |
988 | return EraseLoad; |
989 | } |
990 | |
991 | bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) { |
992 | bool Change = false; |
993 | for (auto *Cast : Casts) { |
994 | auto *II = cast<IntrinsicInst>(Val: Cast); |
995 | // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42) |
996 | // store <256 x i32> %43, <256 x i32>* %p, align 64 |
997 | // --> |
998 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, |
999 | // i64 64, x86_amx %42) |
1000 | if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) { |
1001 | SmallVector<Instruction *, 2> DeadStores; |
1002 | for (User *U : Cast->users()) { |
1003 | StoreInst *Store = dyn_cast<StoreInst>(Val: U); |
1004 | if (!Store) |
1005 | continue; |
1006 | if (combineCastStore(Cast: cast<IntrinsicInst>(Val: Cast), ST: Store)) { |
1007 | DeadStores.push_back(Elt: Store); |
1008 | Change = true; |
1009 | } |
1010 | } |
1011 | for (auto *Store : DeadStores) |
1012 | Store->eraseFromParent(); |
1013 | } else { // x86_cast_vector_to_tile |
1014 | SmallVector<Instruction *, 2> DeadLoads; |
1015 | auto *Load = dyn_cast<LoadInst>(Val: Cast->getOperand(i: 0)); |
1016 | if (!Load || !Load->hasOneUse()) |
1017 | continue; |
1018 | // %65 = load <256 x i32>, <256 x i32>* %p, align 64 |
1019 | // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) |
1020 | // --> |
1021 | // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, |
1022 | // i8* %p, i64 64) |
1023 | if (combineLoadCast(Cast: cast<IntrinsicInst>(Val: Cast), LD: Load)) { |
1024 | // Set the operand is null so that load instruction can be erased. |
1025 | Cast->setOperand(i: 0, Val: nullptr); |
1026 | Load->eraseFromParent(); |
1027 | } |
1028 | } |
1029 | } |
1030 | return Change; |
1031 | } |
1032 | |
1033 | bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { |
1034 | bool Change = false; |
1035 | // Collect tile cast instruction. |
1036 | SmallVector<Instruction *, 8> Vec2TileInsts; |
1037 | SmallVector<Instruction *, 8> Tile2VecInsts; |
1038 | SmallVector<Instruction *, 8> PhiCastWorkList; |
1039 | SmallSetVector<Instruction *, 16> DeadInst; |
1040 | for (BasicBlock &BB : Func) { |
1041 | for (Instruction &I : BB) { |
1042 | Value *Vec; |
1043 | if (match(&I, |
1044 | m_Intrinsic<Intrinsic::x86_cast_vector_to_tile>(m_Value(Vec)))) |
1045 | Vec2TileInsts.push_back(Elt: &I); |
1046 | else if (match(&I, m_Intrinsic<Intrinsic::x86_cast_tile_to_vector>( |
1047 | m_Value(Vec)))) |
1048 | Tile2VecInsts.push_back(Elt: &I); |
1049 | } |
1050 | } |
1051 | |
1052 | auto Convert = [&](SmallVectorImpl<Instruction *> &Insts, Intrinsic::ID IID) { |
1053 | for (auto *Inst : Insts) { |
1054 | for (User *U : Inst->users()) { |
1055 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: U); |
1056 | if (!II || II->getIntrinsicID() != IID) |
1057 | continue; |
1058 | // T1 = vec2tile V0 |
1059 | // V2 = tile2vec T1 |
1060 | // V3 = OP V2 |
1061 | // --> |
1062 | // T1 = vec2tile V0 |
1063 | // V2 = tile2vec T1 |
1064 | // V3 = OP V0 |
1065 | II->replaceAllUsesWith(V: Inst->getOperand(i: 0)); |
1066 | Change = true; |
1067 | } |
1068 | } |
1069 | }; |
1070 | |
1071 | Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); |
1072 | Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); |
1073 | |
1074 | SmallVector<Instruction *, 8> LiveCasts; |
1075 | auto EraseInst = [&](SmallVectorImpl<Instruction *> &Insts) { |
1076 | for (auto *Inst : Insts) { |
1077 | if (Inst->use_empty()) { |
1078 | Inst->eraseFromParent(); |
1079 | Change = true; |
1080 | } else { |
1081 | LiveCasts.push_back(Elt: Inst); |
1082 | } |
1083 | } |
1084 | }; |
1085 | |
1086 | EraseInst(Vec2TileInsts); |
1087 | EraseInst(Tile2VecInsts); |
1088 | LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " |
1089 | "Vec2Tile and Tile2Vec:\n" ; |
1090 | Func.dump()); |
1091 | Change |= combineLdSt(Casts&: LiveCasts); |
1092 | EraseInst(LiveCasts); |
1093 | LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after combine " |
1094 | "AMXCast and load/store:\n" ; |
1095 | Func.dump()); |
1096 | |
1097 | // Handle the A->B->A cast, and there is an intervening PHI node. |
1098 | for (BasicBlock &BB : Func) { |
1099 | for (Instruction &I : BB) { |
1100 | if (isAMXCast(II: &I)) { |
1101 | if (isa<PHINode>(Val: I.getOperand(i: 0))) |
1102 | PhiCastWorkList.push_back(Elt: &I); |
1103 | } |
1104 | } |
1105 | } |
1106 | for (auto *I : PhiCastWorkList) { |
1107 | // We skip the dead Amxcast. |
1108 | if (DeadInst.contains(key: I)) |
1109 | continue; |
1110 | PHINode *PN = cast<PHINode>(Val: I->getOperand(i: 0)); |
1111 | if (optimizeAMXCastFromPhi(CI: cast<IntrinsicInst>(Val: I), PN, DeadInst)) { |
1112 | DeadInst.insert(X: PN); |
1113 | Change = true; |
1114 | } |
1115 | } |
1116 | |
1117 | // Since we create new phi and merge AMXCast, some old phis and AMXCast might |
1118 | // have no uses. We do some DeadCodeElimination for them. |
1119 | while (!DeadInst.empty()) { |
1120 | Instruction *I = DeadInst.pop_back_val(); |
1121 | Change |= DCEInstruction(I, WorkList&: DeadInst, TLI); |
1122 | } |
1123 | LLVM_DEBUG(dbgs() << "[LowerAMXTYpe][combineAMXcast] IR dump after " |
1124 | "optimizeAMXCastFromPhi:\n" ; |
1125 | Func.dump()); |
1126 | return Change; |
1127 | } |
1128 | |
1129 | // There might be remaining AMXcast after combineAMXcast and they should be |
1130 | // handled elegantly. |
1131 | bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { |
1132 | IRBuilder<> Builder(AMXCast); |
1133 | AllocaInst *AllocaAddr; |
1134 | Value *I8Ptr, *Stride; |
1135 | auto *Src = AMXCast->getOperand(i_nocapture: 0); |
1136 | |
1137 | auto Prepare = [&](Type *MemTy) { |
1138 | AllocaAddr = createAllocaInstAtEntry(Builder, BB: AMXCast->getParent(), Ty: MemTy); |
1139 | I8Ptr = Builder.CreateBitCast(V: AllocaAddr, DestTy: Builder.getPtrTy()); |
1140 | Stride = Builder.getInt64(C: 64); |
1141 | }; |
1142 | |
1143 | if (AMXCast->getType()->isX86_AMXTy()) { |
1144 | // %2 = amxcast <225 x i32> %src to x86_amx |
1145 | // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, |
1146 | // i8* %addr3, i64 60, x86_amx %2) |
1147 | // --> |
1148 | // %addr = alloca <225 x i32>, align 64 |
1149 | // store <225 x i32> %src, <225 x i32>* %addr, align 64 |
1150 | // %addr2 = bitcast <225 x i32>* %addr to i8* |
1151 | // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, |
1152 | // i8* %addr2, |
1153 | // i64 60) |
1154 | // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, |
1155 | // i8* %addr3, i64 60, x86_amx %2) |
1156 | if (AMXCast->use_empty()) { |
1157 | AMXCast->eraseFromParent(); |
1158 | return true; |
1159 | } |
1160 | Use &U = *(AMXCast->use_begin()); |
1161 | unsigned OpNo = U.getOperandNo(); |
1162 | auto *II = dyn_cast<IntrinsicInst>(Val: U.getUser()); |
1163 | if (!II) |
1164 | return false; // May be bitcast from x86amx to <256 x i32>. |
1165 | Prepare(AMXCast->getOperand(i_nocapture: 0)->getType()); |
1166 | Builder.CreateStore(Val: Src, Ptr: AllocaAddr); |
1167 | // TODO we can pick an constant operand for the shape. |
1168 | Value *Row = nullptr, *Col = nullptr; |
1169 | std::tie(args&: Row, args&: Col) = getShape(II, OpNo); |
1170 | std::array<Value *, 4> Args = { |
1171 | Row, Col, I8Ptr, Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty())}; |
1172 | Value *NewInst = Builder.CreateIntrinsic( |
1173 | Intrinsic::x86_tileloadd64_internal, std::nullopt, Args); |
1174 | AMXCast->replaceAllUsesWith(V: NewInst); |
1175 | AMXCast->eraseFromParent(); |
1176 | } else { |
1177 | // %2 = amxcast x86_amx %src to <225 x i32> |
1178 | // --> |
1179 | // %addr = alloca <225 x i32>, align 64 |
1180 | // %addr2 = bitcast <225 x i32>* to i8* |
1181 | // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, |
1182 | // i8* %addr2, i64 %stride) |
1183 | // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 |
1184 | auto *II = dyn_cast<IntrinsicInst>(Val: Src); |
1185 | if (!II) |
1186 | return false; // May be bitcast from <256 x i32> to x86amx. |
1187 | Prepare(AMXCast->getType()); |
1188 | Value *Row = II->getOperand(i_nocapture: 0); |
1189 | Value *Col = II->getOperand(i_nocapture: 1); |
1190 | std::array<Value *, 5> Args = { |
1191 | Row, Col, I8Ptr, Builder.CreateSExt(V: Col, DestTy: Builder.getInt64Ty()), Src}; |
1192 | Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, |
1193 | Args); |
1194 | Value *NewInst = Builder.CreateLoad(Ty: AMXCast->getType(), Ptr: AllocaAddr); |
1195 | AMXCast->replaceAllUsesWith(V: NewInst); |
1196 | AMXCast->eraseFromParent(); |
1197 | } |
1198 | |
1199 | return true; |
1200 | } |
1201 | |
1202 | bool X86LowerAMXCast::transformAllAMXCast() { |
1203 | bool Change = false; |
1204 | // Collect tile cast instruction. |
1205 | SmallVector<Instruction *, 8> WorkLists; |
1206 | for (BasicBlock &BB : Func) { |
1207 | for (Instruction &I : BB) { |
1208 | if (isAMXCast(II: &I)) |
1209 | WorkLists.push_back(Elt: &I); |
1210 | } |
1211 | } |
1212 | |
1213 | for (auto *Inst : WorkLists) { |
1214 | Change |= transformAMXCast(AMXCast: cast<IntrinsicInst>(Val: Inst)); |
1215 | } |
1216 | |
1217 | return Change; |
1218 | } |
1219 | |
1220 | } // anonymous namespace |
1221 | |
1222 | namespace { |
1223 | |
1224 | class X86LowerAMXTypeLegacyPass : public FunctionPass { |
1225 | public: |
1226 | static char ID; |
1227 | |
1228 | X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { |
1229 | initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); |
1230 | } |
1231 | |
1232 | bool runOnFunction(Function &F) override { |
1233 | bool C = false; |
1234 | TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); |
1235 | TargetLibraryInfo *TLI = |
1236 | &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); |
1237 | |
1238 | X86LowerAMXCast LAC(F); |
1239 | C |= LAC.combineAMXcast(TLI); |
1240 | // There might be remaining AMXcast after combineAMXcast and they should be |
1241 | // handled elegantly. |
1242 | C |= LAC.transformAllAMXCast(); |
1243 | |
1244 | X86LowerAMXType LAT(F); |
1245 | C |= LAT.visit(); |
1246 | |
1247 | // Prepare for fast register allocation at O0. |
1248 | // Todo: May better check the volatile model of AMX code, not just |
1249 | // by checking Attribute::OptimizeNone and CodeGenOptLevel::None. |
1250 | if (TM->getOptLevel() == CodeGenOptLevel::None) { |
1251 | // If Front End not use O0 but the Mid/Back end use O0, (e.g. |
1252 | // "Clang -O2 -S -emit-llvm t.c" + "llc t.ll") we should make |
1253 | // sure the amx data is volatile, that is nessary for AMX fast |
1254 | // register allocation. |
1255 | if (!F.hasFnAttribute(Attribute::OptimizeNone)) { |
1256 | X86VolatileTileData VTD(F); |
1257 | C = VTD.volatileTileData() || C; |
1258 | } |
1259 | } |
1260 | |
1261 | return C; |
1262 | } |
1263 | |
1264 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
1265 | AU.setPreservesCFG(); |
1266 | AU.addRequired<TargetPassConfig>(); |
1267 | AU.addRequired<TargetLibraryInfoWrapperPass>(); |
1268 | } |
1269 | }; |
1270 | |
1271 | } // anonymous namespace |
1272 | |
1273 | static const char PassName[] = "Lower AMX type for load/store" ; |
1274 | char X86LowerAMXTypeLegacyPass::ID = 0; |
1275 | INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, |
1276 | false) |
1277 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) |
1278 | INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) |
1279 | INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, |
1280 | false) |
1281 | |
1282 | FunctionPass *llvm::createX86LowerAMXTypePass() { |
1283 | return new X86LowerAMXTypeLegacyPass(); |
1284 | } |
1285 | |