1//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Lower matrix intrinsics to vector operations.
10//
11// TODO:
12// * Improve fusion:
13// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
14// transposed.
15// * Improve cost-modeling, e.g. choose different number of rows/columns
16// columns for tiles, consider cost of copies on alias.
17//
18//===----------------------------------------------------------------------===//
19
20#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
21#include "llvm/ADT/PostOrderIterator.h"
22#include "llvm/ADT/SmallSet.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/Analysis/AliasAnalysis.h"
25#include "llvm/Analysis/DomTreeUpdater.h"
26#include "llvm/Analysis/LoopInfo.h"
27#include "llvm/Analysis/OptimizationRemarkEmitter.h"
28#include "llvm/Analysis/TargetTransformInfo.h"
29#include "llvm/Analysis/ValueTracking.h"
30#include "llvm/Analysis/VectorUtils.h"
31#include "llvm/IR/CFG.h"
32#include "llvm/IR/DataLayout.h"
33#include "llvm/IR/DebugInfoMetadata.h"
34#include "llvm/IR/Function.h"
35#include "llvm/IR/IRBuilder.h"
36#include "llvm/IR/Instructions.h"
37#include "llvm/IR/IntrinsicInst.h"
38#include "llvm/IR/MatrixBuilder.h"
39#include "llvm/IR/PatternMatch.h"
40#include "llvm/Support/Alignment.h"
41#include "llvm/Support/CommandLine.h"
42#include "llvm/Support/Debug.h"
43#include "llvm/Transforms/Utils/BasicBlockUtils.h"
44#include "llvm/Transforms/Utils/LoopUtils.h"
45#include "llvm/Transforms/Utils/MatrixUtils.h"
46
47#include <cmath>
48
49using namespace llvm;
50using namespace PatternMatch;
51
52#define DEBUG_TYPE "lower-matrix-intrinsics"
53
54static cl::opt<bool>
55 FuseMatrix("fuse-matrix", cl::init(Val: true), cl::Hidden,
56 cl::desc("Enable/disable fusing matrix instructions."));
57// TODO: Allow and use non-square tiles.
58static cl::opt<unsigned> TileSize(
59 "fuse-matrix-tile-size", cl::init(Val: 4), cl::Hidden,
60 cl::desc(
61 "Tile size for matrix instruction fusion using square-shaped tiles."));
62static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(Val: false),
63 cl::Hidden,
64 cl::desc("Generate loop nest for tiling."));
65static cl::opt<bool> ForceFusion(
66 "force-fuse-matrix", cl::init(Val: false), cl::Hidden,
67 cl::desc("Force matrix instruction fusion even if not profitable."));
68static cl::opt<bool> AllowContractEnabled(
69 "matrix-allow-contract", cl::init(Val: false), cl::Hidden,
70 cl::desc("Allow the use of FMAs if available and profitable. This may "
71 "result in different results, due to less rounding error."));
72
73static cl::opt<bool>
74 VerifyShapeInfo("verify-matrix-shapes", cl::Hidden,
75 cl::desc("Enable/disable matrix shape verification."),
76 cl::init(Val: false));
77
78enum class MatrixLayoutTy { ColumnMajor, RowMajor };
79
80static cl::opt<MatrixLayoutTy> MatrixLayout(
81 "matrix-default-layout", cl::init(Val: MatrixLayoutTy::ColumnMajor),
82 cl::desc("Sets the default matrix layout"),
83 cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
84 "Use column-major layout"),
85 clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
86 "Use row-major layout")));
87
88static cl::opt<bool> PrintAfterTransposeOpt("matrix-print-after-transpose-opt",
89 cl::init(Val: false));
90
91/// Helper function to either return Scope, if it is a subprogram or the
92/// attached subprogram for a local scope.
93static DISubprogram *getSubprogram(DIScope *Scope) {
94 if (auto *Subprogram = dyn_cast<DISubprogram>(Val: Scope))
95 return Subprogram;
96 return cast<DILocalScope>(Val: Scope)->getSubprogram();
97}
98
99/// Erase \p V from \p BB and move \II forward to avoid invalidating
100/// iterators.
101static void eraseFromParentAndMove(Value *V, BasicBlock::reverse_iterator &II,
102 BasicBlock &BB) {
103 auto *Inst = cast<Instruction>(Val: V);
104 // Still used, don't erase.
105 if (!Inst->use_empty())
106 return;
107 if (II != BB.rend() && Inst == &*II)
108 ++II;
109 Inst->eraseFromParent();
110}
111
112/// Return true if V is a splat of a value (which is used when multiplying a
113/// matrix with a scalar).
114static bool isSplat(Value *V) {
115 if (auto *SV = dyn_cast<ShuffleVectorInst>(Val: V))
116 return SV->isZeroEltSplat();
117 return false;
118}
119
120/// Match any mul operation (fp or integer).
121template <typename LTy, typename RTy>
122auto m_AnyMul(const LTy &L, const RTy &R) {
123 return m_CombineOr(m_Mul(L, R), m_FMul(L, R));
124}
125
126/// Match any add operation (fp or integer).
127template <typename LTy, typename RTy>
128auto m_AnyAdd(const LTy &L, const RTy &R) {
129 return m_CombineOr(m_Add(L, R), m_FAdd(L, R));
130}
131
132namespace {
133
134// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
135// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
136// assuming \p Stride elements between start two consecutive vectors.
137// \p Stride must be >= \p NumElements.
138// For column-major matrixes, the function computes the address of a column
139// vectors and \p NumElements must be set to the number of elements in a column
140// (= number of rows of the matrix). For row-major matrixes, the function
141// computes the address of a row vector and \p NumElements must be set to the
142// number of elements in a column (= number of columns of the matrix).
143//
144// Consider a 4x4 matrix in column-mjaor layout like below
145//
146// 0 1 2 3
147// 0 v_0_0 v_0_1 v_0_2 v_0_3
148// 1 v_1_0 v_1_1 v_1_2 v_1_3
149// 2 v_2_0 v_2_1 v_2_2 v_2_3
150// 3 v_3_0 v_3_1 v_3_2 v_3_3
151
152// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
153// we need a pointer to the first element of the submatrix as base pointer.
154// Then we can use computeVectorAddr to compute the addresses for the columns
155// of the sub-matrix.
156//
157// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
158// -> just returns Base
159// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
160// -> returns Base + (1 * 4)
161// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
162// -> returns Base + (2 * 4)
163//
164// The graphic below illustrates the number of elements in a column (marked
165// with |) and the number of skipped elements (marked with }).
166//
167// v_0_0 v_0_1 {v_0_2 {v_0_3
168// Base Col 1 Col 2
169// | | |
170// v_1_0 |v_1_1 |v_1_2 |v_1_3
171// v_2_0 |v_2_1 |v_2_2 |v_2_3
172// v_3_0 {v_3_1 {v_3_2 v_3_3
173//
174Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
175 unsigned NumElements, Type *EltType,
176 IRBuilder<> &Builder) {
177
178 assert((!isa<ConstantInt>(Stride) ||
179 cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
180 "Stride must be >= the number of elements in the result vector.");
181
182 // Compute the start of the vector with index VecIdx as VecIdx * Stride.
183 Value *VecStart = Builder.CreateMul(LHS: VecIdx, RHS: Stride, Name: "vec.start");
184
185 // Get pointer to the start of the selected vector. Skip GEP creation,
186 // if we select vector 0.
187 if (isa<ConstantInt>(Val: VecStart) && cast<ConstantInt>(Val: VecStart)->isZero())
188 VecStart = BasePtr;
189 else
190 VecStart = Builder.CreateGEP(Ty: EltType, Ptr: BasePtr, IdxList: VecStart, Name: "vec.gep");
191
192 return VecStart;
193}
194
195namespace {
196struct ShapeInfo {
197 unsigned NumRows;
198 unsigned NumColumns;
199
200 bool IsColumnMajor;
201
202 ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
203 : NumRows(NumRows), NumColumns(NumColumns),
204 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
205
206 ShapeInfo(Value *NumRows, Value *NumColumns)
207 : ShapeInfo(cast<ConstantInt>(Val: NumRows)->getZExtValue(),
208 cast<ConstantInt>(Val: NumColumns)->getZExtValue()) {}
209
210 bool operator==(const ShapeInfo &other) {
211 return NumRows == other.NumRows && NumColumns == other.NumColumns;
212 }
213 bool operator!=(const ShapeInfo &other) { return !(*this == other); }
214
215 /// Returns true if shape-information is defined, meaning both dimensions
216 /// are != 0.
217 operator bool() const {
218 assert(NumRows == 0 || NumColumns != 0);
219 return NumRows != 0;
220 }
221
222 unsigned getStride() const {
223 if (IsColumnMajor)
224 return NumRows;
225 return NumColumns;
226 }
227
228 unsigned getNumVectors() const {
229 if (IsColumnMajor)
230 return NumColumns;
231 return NumRows;
232 }
233
234 /// Returns the transposed shape.
235 ShapeInfo t() const { return ShapeInfo(NumColumns, NumRows); }
236};
237} // namespace
238
239static bool isUniformShape(Value *V) {
240 Instruction *I = dyn_cast<Instruction>(Val: V);
241 if (!I)
242 return true;
243
244 switch (I->getOpcode()) {
245 case Instruction::FAdd:
246 case Instruction::FSub:
247 case Instruction::FMul: // Scalar multiply.
248 case Instruction::FNeg:
249 case Instruction::Add:
250 case Instruction::Mul:
251 case Instruction::Sub:
252 return true;
253 default:
254 return false;
255 }
256}
257
258/// Return the ShapeInfo for the result of \p I, it it can be determined.
259static std::optional<ShapeInfo>
260computeShapeInfoForInst(Instruction *I,
261 const ValueMap<Value *, ShapeInfo> &ShapeMap) {
262 Value *M;
263 Value *N;
264 Value *K;
265 if (match(I, m_Intrinsic<Intrinsic::matrix_multiply>(
266 m_Value(), m_Value(), m_Value(V&: M), m_Value(V&: N), m_Value(V&: K))))
267 return ShapeInfo(M, K);
268 if (match(I, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(), m_Value(M),
269 m_Value(N)))) {
270 // Flip dimensions.
271 return ShapeInfo(N, M);
272 }
273 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_store>(
274 m_Value(), m_Value(), m_Value(), m_Value(), m_Value(M),
275 m_Value(N))))
276 return ShapeInfo(N, M);
277 if (match(I, m_Intrinsic<Intrinsic::matrix_column_major_load>(
278 m_Value(), m_Value(), m_Value(), m_Value(M), m_Value(N))))
279 return ShapeInfo(M, N);
280 Value *MatrixA;
281 if (match(V: I, P: m_Store(ValueOp: m_Value(V&: MatrixA), PointerOp: m_Value()))) {
282 auto OpShape = ShapeMap.find(Val: MatrixA);
283 if (OpShape != ShapeMap.end())
284 return OpShape->second;
285 }
286
287 if (isUniformShape(V: I)) {
288 // Find the first operand that has a known shape and use that.
289 for (auto &Op : I->operands()) {
290 auto OpShape = ShapeMap.find(Val: Op.get());
291 if (OpShape != ShapeMap.end())
292 return OpShape->second;
293 }
294 }
295 return std::nullopt;
296}
297
298/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
299///
300/// Currently, the lowering for each matrix intrinsic is done as follows:
301/// 1. Propagate the shape information from intrinsics to connected
302/// instructions.
303/// 2. Lower instructions with shape information (assuming column-major layout).
304/// The lowering works similarly using row-major layout.
305/// 2.1. Get column vectors for each argument. If we already lowered the
306/// definition of an argument, use the produced column vectors directly.
307/// If not, split the operand vector containing an embedded matrix into
308/// a set of column vectors,
309/// 2.2. Lower the instruction in terms of column major operations, which
310/// yields a set of column vectors containing result matrix. Note that we
311/// lower all instructions that have shape information. Besides the
312/// intrinsics, this includes stores for example.
313/// 2.3. Update uses of the lowered instruction. If we have shape information
314/// for a user, there is nothing to do, as we will look up the result
315/// column matrix when lowering the user. For other uses, we embed the
316/// result matrix in a flat vector and update the use.
317/// 2.4. Cache the result column matrix for the instruction we lowered
318/// 3. After we lowered all instructions in a function, remove the now
319/// obsolete instructions.
320///
321class LowerMatrixIntrinsics {
322 Function &Func;
323 const DataLayout &DL;
324 const TargetTransformInfo &TTI;
325 AliasAnalysis *AA;
326 DominatorTree *DT;
327 LoopInfo *LI;
328 OptimizationRemarkEmitter *ORE;
329
330 /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
331 struct OpInfoTy {
332 /// Number of stores emitted to generate this matrix.
333 unsigned NumStores = 0;
334 /// Number of loads emitted to generate this matrix.
335 unsigned NumLoads = 0;
336 /// Number of compute operations emitted to generate this matrix.
337 unsigned NumComputeOps = 0;
338 /// Most of the time transposes can be fused with matrix multiplies or can
339 /// be folded away via algebraic simplifications. This is the number of
340 /// transposes that we failed to make "free" via such optimizations.
341 unsigned NumExposedTransposes = 0;
342
343 OpInfoTy &operator+=(const OpInfoTy &RHS) {
344 NumStores += RHS.NumStores;
345 NumLoads += RHS.NumLoads;
346 NumComputeOps += RHS.NumComputeOps;
347 NumExposedTransposes += RHS.NumExposedTransposes;
348 return *this;
349 }
350 };
351
352 /// Wrapper class representing a matrix as a set of vectors, either in row or
353 /// column major layout. All vectors must have the same vector type.
354 class MatrixTy {
355 SmallVector<Value *, 16> Vectors;
356
357 OpInfoTy OpInfo;
358
359 bool IsColumnMajor = true;
360
361 public:
362 MatrixTy() : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
363 MatrixTy(ArrayRef<Value *> Vectors)
364 : Vectors(Vectors.begin(), Vectors.end()),
365 IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
366 MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
367 : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
368
369 unsigned D = isColumnMajor() ? NumColumns : NumRows;
370 for (unsigned J = 0; J < D; ++J)
371 addVector(V: PoisonValue::get(T: FixedVectorType::get(
372 ElementType: EltTy, NumElts: isColumnMajor() ? NumRows : NumColumns)));
373 }
374
375 Value *getVector(unsigned i) const { return Vectors[i]; }
376 Value *getColumn(unsigned i) const {
377 assert(isColumnMajor() && "only supported for column-major matrixes");
378 return Vectors[i];
379 }
380 Value *getRow(unsigned i) const {
381 assert(!isColumnMajor() && "only supported for row-major matrixes");
382 return Vectors[i];
383 }
384
385 void setVector(unsigned i, Value *V) { Vectors[i] = V; }
386
387 Type *getElementType() const { return getVectorTy()->getElementType(); }
388
389 unsigned getNumVectors() const {
390 if (isColumnMajor())
391 return getNumColumns();
392 return getNumRows();
393 }
394
395 unsigned getNumColumns() const {
396 if (isColumnMajor())
397 return Vectors.size();
398 else {
399 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
400 return cast<FixedVectorType>(Val: Vectors[0]->getType())->getNumElements();
401 }
402 }
403 unsigned getNumRows() const {
404 if (isColumnMajor()) {
405 assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
406 return cast<FixedVectorType>(Val: Vectors[0]->getType())->getNumElements();
407 } else
408 return Vectors.size();
409 }
410
411 void addVector(Value *V) { Vectors.push_back(Elt: V); }
412 VectorType *getColumnTy() {
413 assert(isColumnMajor() && "only supported for column-major matrixes");
414 return getVectorTy();
415 }
416
417 VectorType *getVectorTy() const {
418 return cast<VectorType>(Val: Vectors[0]->getType());
419 }
420
421 iterator_range<SmallVector<Value *, 8>::iterator> columns() {
422 assert(isColumnMajor() &&
423 "columns() only supported for column-major matrixes");
424 return make_range(x: Vectors.begin(), y: Vectors.end());
425 }
426
427 iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
428 return make_range(x: Vectors.begin(), y: Vectors.end());
429 }
430
431 /// Embed the vectors of the matrix into a flat vector by concatenating
432 /// them.
433 Value *embedInVector(IRBuilder<> &Builder) const {
434 return Vectors.size() == 1 ? Vectors[0]
435 : concatenateVectors(Builder, Vecs: Vectors);
436 }
437
438 MatrixTy &addNumLoads(unsigned N) {
439 OpInfo.NumLoads += N;
440 return *this;
441 }
442
443 void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
444
445 MatrixTy &addNumStores(unsigned N) {
446 OpInfo.NumStores += N;
447 return *this;
448 }
449
450 MatrixTy &addNumExposedTransposes(unsigned N) {
451 OpInfo.NumExposedTransposes += N;
452 return *this;
453 }
454
455 MatrixTy &addNumComputeOps(unsigned N) {
456 OpInfo.NumComputeOps += N;
457 return *this;
458 }
459
460 unsigned getNumStores() const { return OpInfo.NumStores; }
461 unsigned getNumLoads() const { return OpInfo.NumLoads; }
462 unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
463
464 const OpInfoTy &getOpInfo() const { return OpInfo; }
465
466 bool isColumnMajor() const { return IsColumnMajor; }
467
468 unsigned getStride() const {
469 if (isColumnMajor())
470 return getNumRows();
471 return getNumColumns();
472 }
473
474 /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
475 /// matrix is column-major, the result vector is extracted from a column
476 /// vector, otherwise from a row vector.
477 Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
478 IRBuilder<> &Builder) const {
479 Value *Vec = isColumnMajor() ? getColumn(i: J) : getRow(i: I);
480 assert(cast<FixedVectorType>(Vec->getType())->getNumElements() >=
481 NumElts &&
482 "Extracted vector will contain poison values");
483 return Builder.CreateShuffleVector(
484 V: Vec, Mask: createSequentialMask(Start: isColumnMajor() ? I : J, NumInts: NumElts, NumUndefs: 0),
485 Name: "block");
486 }
487 };
488
489 /// Maps instructions to their shape information. The shape information
490 /// describes the shape to be used while lowering. This matches the shape of
491 /// the result value of the instruction, with the only exceptions being store
492 /// instructions and the matrix_column_major_store intrinsics. For those, the
493 /// shape information indicates that those instructions should be lowered
494 /// using shape information as well. A ValueMap is used so that when
495 /// sub-passes like optimizeTransposes performs RAUW the map stays
496 /// up-to-date.
497 ValueMap<Value *, ShapeInfo> ShapeMap;
498
499 /// List of instructions to remove. While lowering, we are not replacing all
500 /// users of a lowered instruction, if shape information is available and
501 /// those need to be removed after we finished lowering.
502 SmallVector<Instruction *, 16> ToRemove;
503
504 /// Map from instructions to their produced column matrix.
505 MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
506
507private:
508 static FastMathFlags getFastMathFlags(Instruction *Inst) {
509 FastMathFlags FMF;
510
511 if (isa<FPMathOperator>(Val: *Inst))
512 FMF = Inst->getFastMathFlags();
513
514 FMF.setAllowContract(AllowContractEnabled || FMF.allowContract());
515
516 return FMF;
517 }
518
519public:
520 LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
521 AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI,
522 OptimizationRemarkEmitter *ORE)
523 : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
524 LI(LI), ORE(ORE) {}
525
526 unsigned getNumOps(Type *VT) {
527 assert(isa<VectorType>(VT) && "Expected vector type");
528 return getNumOps(ST: VT->getScalarType(),
529 N: cast<FixedVectorType>(Val: VT)->getNumElements());
530 }
531
532 /// Is this the minimal version executed in the backend pipelines.
533 bool isMinimal() const {
534 return !DT;
535 }
536
537 /// Return the estimated number of vector ops required for an operation on
538 /// \p VT * N.
539 unsigned getNumOps(Type *ST, unsigned N) {
540 return std::ceil(x: (ST->getPrimitiveSizeInBits() * N).getFixedValue() /
541 double(TTI.getRegisterBitWidth(
542 K: TargetTransformInfo::RGK_FixedWidthVector)
543 .getFixedValue()));
544 }
545
546 /// Return the set of vectors that a matrix value is lowered to.
547 ///
548 /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
549 /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
550 /// into vectors.
551 MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
552 IRBuilder<> &Builder) {
553 VectorType *VType = dyn_cast<VectorType>(Val: MatrixVal->getType());
554 assert(VType && "MatrixVal must be a vector type");
555 assert(cast<FixedVectorType>(VType)->getNumElements() ==
556 SI.NumRows * SI.NumColumns &&
557 "The vector size must match the number of matrix elements");
558
559 // Check if we lowered MatrixVal using shape information. In that case,
560 // return the existing matrix, if it matches the requested shape
561 // information. If there is a mis-match, embed the result in a flat
562 // vector and split it later.
563 auto Found = Inst2ColumnMatrix.find(Key: MatrixVal);
564 if (Found != Inst2ColumnMatrix.end()) {
565 MatrixTy &M = Found->second;
566 // Return the found matrix, if its shape matches the requested shape
567 // information
568 if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
569 return M;
570
571 MatrixVal = M.embedInVector(Builder);
572 }
573
574 // Otherwise split MatrixVal.
575 SmallVector<Value *, 16> SplitVecs;
576 for (unsigned MaskStart = 0;
577 MaskStart < cast<FixedVectorType>(Val: VType)->getNumElements();
578 MaskStart += SI.getStride()) {
579 Value *V = Builder.CreateShuffleVector(
580 V: MatrixVal, Mask: createSequentialMask(Start: MaskStart, NumInts: SI.getStride(), NumUndefs: 0),
581 Name: "split");
582 SplitVecs.push_back(Elt: V);
583 }
584
585 return {SplitVecs};
586 }
587
588 /// If \p V already has a known shape return false. Otherwise set the shape
589 /// for instructions that support it.
590 bool setShapeInfo(Value *V, ShapeInfo Shape) {
591 assert(Shape && "Shape not set");
592 if (isa<UndefValue>(Val: V) || !supportsShapeInfo(V))
593 return false;
594
595 auto SIter = ShapeMap.find(Val: V);
596 if (SIter != ShapeMap.end()) {
597 if (VerifyShapeInfo && (SIter->second.NumRows != Shape.NumRows ||
598 SIter->second.NumColumns != Shape.NumColumns)) {
599 errs() << "Conflicting shapes (" << SIter->second.NumRows << "x"
600 << SIter->second.NumColumns << " vs " << Shape.NumRows << "x"
601 << Shape.NumColumns << ") for " << *V << "\n";
602 report_fatal_error(
603 reason: "Matrix shape verification failed, compilation aborted!");
604 }
605
606 LLVM_DEBUG(dbgs() << " not overriding existing shape: "
607 << SIter->second.NumRows << " "
608 << SIter->second.NumColumns << " for " << *V << "\n");
609 return false;
610 }
611
612 ShapeMap.insert(KV: {V, Shape});
613 LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
614 << " for " << *V << "\n");
615 return true;
616 }
617
618 /// Returns true if shape information can be used for \p V. The supported
619 /// instructions must match the instructions that can be lowered by this pass.
620 bool supportsShapeInfo(Value *V) {
621 Instruction *Inst = dyn_cast<Instruction>(Val: V);
622 if (!Inst)
623 return false;
624
625 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: Inst);
626 if (II)
627 switch (II->getIntrinsicID()) {
628 case Intrinsic::matrix_multiply:
629 case Intrinsic::matrix_transpose:
630 case Intrinsic::matrix_column_major_load:
631 case Intrinsic::matrix_column_major_store:
632 return true;
633 default:
634 return false;
635 }
636 return isUniformShape(V) || isa<StoreInst>(Val: V) || isa<LoadInst>(Val: V);
637 }
638
639 /// Propagate the shape information of instructions to their users.
640 /// The work list contains instructions for which we can compute the shape,
641 /// either based on the information provided by matrix intrinsics or known
642 /// shapes of operands.
643 SmallVector<Instruction *, 32>
644 propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
645 SmallVector<Instruction *, 32> NewWorkList;
646 // Pop an element for which we guaranteed to have at least one of the
647 // operand shapes. Add the shape for this and then add users to the work
648 // list.
649 LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
650 while (!WorkList.empty()) {
651 Instruction *Inst = WorkList.pop_back_val();
652
653 // New entry, set the value and insert operands
654 bool Propagate = false;
655 if (auto SI = computeShapeInfoForInst(I: Inst, ShapeMap))
656 Propagate = setShapeInfo(V: Inst, Shape: *SI);
657
658 if (Propagate) {
659 NewWorkList.push_back(Elt: Inst);
660 for (auto *User : Inst->users())
661 if (ShapeMap.count(Val: User) == 0)
662 WorkList.push_back(Elt: cast<Instruction>(Val: User));
663 }
664 }
665
666 return NewWorkList;
667 }
668
669 /// Propagate the shape to operands of instructions with shape information.
670 /// \p Worklist contains the instruction for which we already know the shape.
671 SmallVector<Instruction *, 32>
672 propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
673 SmallVector<Instruction *, 32> NewWorkList;
674
675 auto pushInstruction = [](Value *V,
676 SmallVectorImpl<Instruction *> &WorkList) {
677 Instruction *I = dyn_cast<Instruction>(Val: V);
678 if (I)
679 WorkList.push_back(Elt: I);
680 };
681 // Pop an element with known shape. Traverse the operands, if their shape
682 // derives from the result shape and is unknown, add it and add them to the
683 // worklist.
684 LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
685 while (!WorkList.empty()) {
686 Value *V = WorkList.pop_back_val();
687
688 size_t BeforeProcessingV = WorkList.size();
689 if (!isa<Instruction>(Val: V))
690 continue;
691
692 Value *MatrixA;
693 Value *MatrixB;
694 Value *M;
695 Value *N;
696 Value *K;
697 if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
698 m_Value(V&: MatrixA), m_Value(V&: MatrixB), m_Value(V&: M),
699 m_Value(V&: N), m_Value(V&: K)))) {
700 if (setShapeInfo(V: MatrixA, Shape: {M, N}))
701 pushInstruction(MatrixA, WorkList);
702
703 if (setShapeInfo(V: MatrixB, Shape: {N, K}))
704 pushInstruction(MatrixB, WorkList);
705
706 } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
707 m_Value(MatrixA), m_Value(M), m_Value(N)))) {
708 // Flip dimensions.
709 if (setShapeInfo(V: MatrixA, Shape: {M, N}))
710 pushInstruction(MatrixA, WorkList);
711 } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
712 m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
713 m_Value(M), m_Value(N)))) {
714 if (setShapeInfo(V: MatrixA, Shape: {M, N})) {
715 pushInstruction(MatrixA, WorkList);
716 }
717 } else if (isa<LoadInst>(V) ||
718 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
719 // Nothing to do, no matrix input.
720 } else if (isa<StoreInst>(Val: V)) {
721 // Nothing to do. We forward-propagated to this so we would just
722 // backward propagate to an instruction with an already known shape.
723 } else if (isUniformShape(V)) {
724 // Propagate to all operands.
725 ShapeInfo Shape = ShapeMap[V];
726 for (Use &U : cast<Instruction>(Val: V)->operands()) {
727 if (setShapeInfo(V: U.get(), Shape))
728 pushInstruction(U.get(), WorkList);
729 }
730 }
731 // After we discovered new shape info for new instructions in the
732 // worklist, we use their users as seeds for the next round of forward
733 // propagation.
734 for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
735 for (User *U : WorkList[I]->users())
736 if (isa<Instruction>(Val: U) && V != U)
737 NewWorkList.push_back(Elt: cast<Instruction>(Val: U));
738 }
739 return NewWorkList;
740 }
741
742 /// (Op0 op Op1)^T -> Op0^T op Op1^T
743 /// Transpose \p Op0 and \p Op1 of shape \p Shape0 and \p Shape1, then use
744 /// them on both sides of \p Operation.
745 Instruction *distributeTransposes(
746 Value *Op0, ShapeInfo Shape0, Value *Op1, ShapeInfo Shape1,
747 MatrixBuilder &Builder,
748 function_ref<Instruction *(Value *, ShapeInfo, Value *, ShapeInfo)>
749 Operation) {
750 Value *T0 = Builder.CreateMatrixTranspose(
751 Matrix: Op0, Rows: Shape0.NumRows, Columns: Shape0.NumColumns, Name: Op0->getName() + "_t");
752 // We are being run after shape prop, add shape for newly created
753 // instructions so that we lower them later.
754 setShapeInfo(V: T0, Shape: Shape0.t());
755 Value *T1 = Builder.CreateMatrixTranspose(
756 Matrix: Op1, Rows: Shape1.NumRows, Columns: Shape1.NumColumns, Name: Op1->getName() + "_t");
757 setShapeInfo(V: T1, Shape: Shape1.t());
758 return Operation(T0, Shape0.t(), T1, Shape1.t());
759 }
760
761 void updateShapeAndReplaceAllUsesWith(Instruction &Old, Value *New) {
762 // We need to remove Old from the ShapeMap otherwise RAUW will replace it
763 // with New. We should only add New it it supportsShapeInfo so we insert
764 // it conditionally instead.
765 auto S = ShapeMap.find(Val: &Old);
766 if (S != ShapeMap.end()) {
767 ShapeMap.erase(I: S);
768 if (supportsShapeInfo(V: New))
769 ShapeMap.insert(KV: {New, S->second});
770 }
771 Old.replaceAllUsesWith(V: New);
772 }
773
774 /// Sink a top-level transpose inside matmuls and adds.
775 /// This creates and erases instructions as needed, and returns the newly
776 /// created instruction while updating the iterator to avoid invalidation. If
777 /// this returns nullptr, no new instruction was created.
778 Instruction *sinkTranspose(Instruction &I, BasicBlock::reverse_iterator &II) {
779 BasicBlock &BB = *I.getParent();
780 IRBuilder<> IB(&I);
781 MatrixBuilder Builder(IB);
782
783 Value *TA, *TAMA, *TAMB;
784 ConstantInt *R, *K, *C;
785 if (!match(&I, m_Intrinsic<Intrinsic::matrix_transpose>(
786 m_Value(TA), m_ConstantInt(R), m_ConstantInt(C))))
787 return nullptr;
788
789 // Transpose of a transpose is a nop
790 Value *TATA;
791 if (match(TA, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(TATA)))) {
792 updateShapeAndReplaceAllUsesWith(Old&: I, New: TATA);
793 eraseFromParentAndMove(V: &I, II, BB);
794 eraseFromParentAndMove(V: TA, II, BB);
795 return nullptr;
796 }
797
798 // k^T -> k
799 if (isSplat(V: TA)) {
800 updateShapeAndReplaceAllUsesWith(Old&: I, New: TA);
801 eraseFromParentAndMove(V: &I, II, BB);
802 return nullptr;
803 }
804
805 // (A * B)^t -> B^t * A^t
806 // RxK KxC CxK KxR
807 if (match(TA, m_Intrinsic<Intrinsic::matrix_multiply>(
808 m_Value(V&: TAMA), m_Value(V&: TAMB), m_ConstantInt(CI&: R),
809 m_ConstantInt(CI&: K), m_ConstantInt(CI&: C)))) {
810 auto NewInst = distributeTransposes(
811 Op0: TAMB, Shape0: {K, C}, Op1: TAMA, Shape1: {R, K}, Builder,
812 Operation: [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
813 return Builder.CreateMatrixMultiply(LHS: T0, RHS: T1, LHSRows: Shape0.NumRows,
814 LHSColumns: Shape0.NumColumns,
815 RHSColumns: Shape1.NumColumns, Name: "mmul");
816 });
817 updateShapeAndReplaceAllUsesWith(Old&: I, New: NewInst);
818 eraseFromParentAndMove(V: &I, II, BB);
819 eraseFromParentAndMove(V: TA, II, BB);
820 return NewInst;
821 }
822
823 // Same as above, but with a mul, which occurs when multiplied
824 // with a scalar.
825 // (A * k)^t -> A^t * k
826 // R x C RxC
827 if (match(V: TA, P: m_AnyMul(L: m_Value(V&: TAMA), R: m_Value(V&: TAMB))) &&
828 (isSplat(V: TAMA) || isSplat(V: TAMB))) {
829 IRBuilder<> LocalBuilder(&I);
830 // We know that the transposed operand is of shape RxC.
831 // An when multiplied with a scalar, the shape is preserved.
832 auto NewInst = distributeTransposes(
833 Op0: TAMA, Shape0: {R, C}, Op1: TAMB, Shape1: {R, C}, Builder,
834 Operation: [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
835 bool IsFP = I.getType()->isFPOrFPVectorTy();
836 auto *Mul = IsFP ? LocalBuilder.CreateFMul(L: T0, R: T1, Name: "mmul")
837 : LocalBuilder.CreateMul(LHS: T0, RHS: T1, Name: "mmul");
838 auto *Result = cast<Instruction>(Val: Mul);
839 setShapeInfo(V: Result, Shape: Shape0);
840 return Result;
841 });
842 updateShapeAndReplaceAllUsesWith(Old&: I, New: NewInst);
843 eraseFromParentAndMove(V: &I, II, BB);
844 eraseFromParentAndMove(V: TA, II, BB);
845 return NewInst;
846 }
847
848 // (A + B)^t -> A^t + B^t
849 // RxC RxC CxR CxR
850 if (match(V: TA, P: m_AnyAdd(L: m_Value(V&: TAMA), R: m_Value(V&: TAMB)))) {
851 IRBuilder<> LocalBuilder(&I);
852 auto NewInst = distributeTransposes(
853 Op0: TAMA, Shape0: {R, C}, Op1: TAMB, Shape1: {R, C}, Builder,
854 Operation: [&](Value *T0, ShapeInfo Shape0, Value *T1, ShapeInfo Shape1) {
855 bool IsFP = I.getType()->isFPOrFPVectorTy();
856 auto *Add = IsFP ? LocalBuilder.CreateFAdd(L: T0, R: T1, Name: "madd")
857 : LocalBuilder.CreateAdd(LHS: T0, RHS: T1, Name: "madd");
858
859 auto *Result = cast<Instruction>(Val: Add);
860 setShapeInfo(V: Result, Shape: Shape0);
861 return Result;
862 });
863 updateShapeAndReplaceAllUsesWith(Old&: I, New: NewInst);
864 eraseFromParentAndMove(V: &I, II, BB);
865 eraseFromParentAndMove(V: TA, II, BB);
866 return NewInst;
867 }
868
869 return nullptr;
870 }
871
872 void liftTranspose(Instruction &I) {
873 // Erase dead Instructions after lifting transposes from binops.
874 auto CleanupBinOp = [](Instruction &T, Value *A, Value *B) {
875 if (T.use_empty())
876 T.eraseFromParent();
877 if (A->use_empty())
878 cast<Instruction>(Val: A)->eraseFromParent();
879 if (A != B && B->use_empty())
880 cast<Instruction>(Val: B)->eraseFromParent();
881 };
882
883 Value *A, *B, *AT, *BT;
884 ConstantInt *R, *K, *C;
885 // A^t * B ^t -> (B * A)^t
886 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>(
887 m_Value(A), m_Value(B), m_ConstantInt(R),
888 m_ConstantInt(K), m_ConstantInt(C))) &&
889 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(AT))) &&
890 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value((BT))))) {
891 IRBuilder<> IB(&I);
892 MatrixBuilder Builder(IB);
893 Value *M = Builder.CreateMatrixMultiply(
894 LHS: BT, RHS: AT, LHSRows: C->getZExtValue(), LHSColumns: K->getZExtValue(), RHSColumns: R->getZExtValue());
895 setShapeInfo(V: M, Shape: {C, R});
896 Instruction *NewInst = Builder.CreateMatrixTranspose(Matrix: M, Rows: C->getZExtValue(),
897 Columns: R->getZExtValue());
898 updateShapeAndReplaceAllUsesWith(Old&: I, New: NewInst);
899 CleanupBinOp(I, A, B);
900 }
901 // A^t + B ^t -> (A + B)^t. Pick rows and columns from first transpose. If
902 // the shape of the second transpose is different, there's a shape conflict
903 // which gets resolved by picking the shape of the first operand.
904 else if (match(&I, m_FAdd(m_Value(A), m_Value(B))) &&
905 match(A, m_Intrinsic<Intrinsic::matrix_transpose>(
906 m_Value(AT), m_ConstantInt(R), m_ConstantInt(C))) &&
907 match(B, m_Intrinsic<Intrinsic::matrix_transpose>(
908 m_Value(BT), m_ConstantInt(), m_ConstantInt()))) {
909 IRBuilder<> Builder(&I);
910 auto *Add = cast<Instruction>(Val: Builder.CreateFAdd(L: AT, R: BT, Name: "mfadd"));
911 setShapeInfo(V: Add, Shape: {R, C});
912 MatrixBuilder MBuilder(Builder);
913 Instruction *NewInst = MBuilder.CreateMatrixTranspose(
914 Matrix: Add, Rows: R->getZExtValue(), Columns: C->getZExtValue(), Name: "mfadd_t");
915 updateShapeAndReplaceAllUsesWith(Old&: I, New: NewInst);
916 assert(computeShapeInfoForInst(NewInst, ShapeMap) ==
917 computeShapeInfoForInst(&I, ShapeMap) &&
918 "Shape of new instruction doesn't match original shape.");
919 CleanupBinOp(I, A, B);
920 assert(computeShapeInfoForInst(Add, ShapeMap).value_or(ShapeMap[Add]) ==
921 ShapeMap[Add] &&
922 "Shape of updated addition doesn't match cached shape.");
923 }
924 }
925
926 /// Try moving transposes in order to fold them away or into multiplies.
927 void optimizeTransposes() {
928 // First sink all transposes inside matmuls and adds, hoping that we end up
929 // with NN, NT or TN variants.
930 for (BasicBlock &BB : reverse(C&: Func)) {
931 for (auto II = BB.rbegin(); II != BB.rend();) {
932 Instruction &I = *II;
933 // We may remove II. By default continue on the next/prev instruction.
934 ++II;
935 if (Instruction *NewInst = sinkTranspose(I, II))
936 II = std::next(x: BasicBlock::reverse_iterator(NewInst));
937 }
938 }
939
940 // If we have a TT matmul or a TT add, lift the transpose. We may be able
941 // to fold into consuming multiply or add.
942 for (BasicBlock &BB : Func) {
943 for (Instruction &I : llvm::make_early_inc_range(Range&: BB)) {
944 liftTranspose(I);
945 }
946 }
947 }
948
949 bool Visit() {
950 SmallVector<Instruction *, 32> WorkList;
951
952 // Initially only the shape of matrix intrinsics is known.
953 // Initialize the work list with ops carrying shape information.
954 for (BasicBlock &BB : Func)
955 for (Instruction &Inst : BB) {
956 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: &Inst);
957 if (!II)
958 continue;
959
960 switch (II->getIntrinsicID()) {
961 case Intrinsic::matrix_multiply:
962 case Intrinsic::matrix_transpose:
963 case Intrinsic::matrix_column_major_load:
964 case Intrinsic::matrix_column_major_store:
965 WorkList.push_back(Elt: &Inst);
966 break;
967 default:
968 break;
969 }
970 }
971
972 // Avoid unnecessary work if there are no matrix intrinsics in the function.
973 if (WorkList.empty())
974 return false;
975
976 // Propagate shapes until nothing changes any longer.
977 while (!WorkList.empty()) {
978 WorkList = propagateShapeForward(WorkList);
979 WorkList = propagateShapeBackward(WorkList);
980 }
981
982 if (!isMinimal()) {
983 optimizeTransposes();
984 if (PrintAfterTransposeOpt) {
985 dbgs() << "Dump after matrix transpose optimization:\n";
986 Func.print(OS&: dbgs());
987 }
988 }
989
990 bool Changed = false;
991 SmallVector<CallInst *, 16> MaybeFusableInsts;
992 SmallVector<Instruction *, 16> MatrixInsts;
993
994 // First, collect all instructions with shape information and candidates for
995 // fusion (currently only matrix multiplies).
996 ReversePostOrderTraversal<Function *> RPOT(&Func);
997 for (auto *BB : RPOT)
998 for (Instruction &I : *BB) {
999 if (ShapeMap.find(Val: &I) == ShapeMap.end())
1000 continue;
1001 if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
1002 MaybeFusableInsts.push_back(Elt: cast<CallInst>(Val: &I));
1003 MatrixInsts.push_back(Elt: &I);
1004 }
1005
1006 // Second, try to lower any dot products
1007 SmallPtrSet<Instruction *, 16> FusedInsts;
1008 for (CallInst *CI : MaybeFusableInsts)
1009 lowerDotProduct(MatMul: CI, FusedInsts, FMF: getFastMathFlags(Inst: CI));
1010
1011 // Third, try to fuse candidates.
1012 for (CallInst *CI : MaybeFusableInsts)
1013 LowerMatrixMultiplyFused(MatMul: CI, FusedInsts);
1014
1015 Changed = !FusedInsts.empty();
1016
1017 // Fourth, lower remaining instructions with shape information.
1018 for (Instruction *Inst : MatrixInsts) {
1019 if (FusedInsts.count(Ptr: Inst))
1020 continue;
1021
1022 IRBuilder<> Builder(Inst);
1023
1024 if (CallInst *CInst = dyn_cast<CallInst>(Val: Inst))
1025 Changed |= VisitCallInst(Inst: CInst);
1026
1027 Value *Op1;
1028 Value *Op2;
1029 if (auto *BinOp = dyn_cast<BinaryOperator>(Val: Inst))
1030 Changed |= VisitBinaryOperator(Inst: BinOp);
1031 if (auto *UnOp = dyn_cast<UnaryOperator>(Val: Inst))
1032 Changed |= VisitUnaryOperator(Inst: UnOp);
1033 if (match(V: Inst, P: m_Load(Op: m_Value(V&: Op1))))
1034 Changed |= VisitLoad(Inst: cast<LoadInst>(Val: Inst), Ptr: Op1, Builder);
1035 else if (match(V: Inst, P: m_Store(ValueOp: m_Value(V&: Op1), PointerOp: m_Value(V&: Op2))))
1036 Changed |= VisitStore(Inst: cast<StoreInst>(Val: Inst), StoredVal: Op1, Ptr: Op2, Builder);
1037 }
1038
1039 if (ORE) {
1040 RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
1041 RemarkGen.emitRemarks();
1042 }
1043
1044 // Delete the instructions backwards, as it has a reduced likelihood of
1045 // having to update as many def-use and use-def chains.
1046 //
1047 // Because we add to ToRemove during fusion we can't guarantee that defs
1048 // are before uses. Change uses to poison temporarily as these should get
1049 // removed as well.
1050 //
1051 // For verification, we keep track of where we changed uses to poison in
1052 // PoisonedInsts and then check that we in fact remove them.
1053 SmallSet<Instruction *, 16> PoisonedInsts;
1054 for (auto *Inst : reverse(C&: ToRemove)) {
1055 for (Use &U : llvm::make_early_inc_range(Range: Inst->uses())) {
1056 if (auto *Poisoned = dyn_cast<Instruction>(Val: U.getUser()))
1057 PoisonedInsts.insert(Ptr: Poisoned);
1058 U.set(PoisonValue::get(T: Inst->getType()));
1059 }
1060 Inst->eraseFromParent();
1061 PoisonedInsts.erase(Ptr: Inst);
1062 }
1063 if (!PoisonedInsts.empty()) {
1064 // If we didn't remove all poisoned instructions, it's a hard error.
1065 dbgs() << "Poisoned but present instructions:\n";
1066 for (auto *I : PoisonedInsts)
1067 dbgs() << *I << "\n";
1068 llvm_unreachable("Poisoned but instruction not removed");
1069 }
1070
1071 return Changed;
1072 }
1073
1074 /// Replace intrinsic calls
1075 bool VisitCallInst(CallInst *Inst) {
1076 if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
1077 return false;
1078
1079 switch (Inst->getCalledFunction()->getIntrinsicID()) {
1080 case Intrinsic::matrix_multiply:
1081 LowerMultiply(MatMul: Inst);
1082 break;
1083 case Intrinsic::matrix_transpose:
1084 LowerTranspose(Inst);
1085 break;
1086 case Intrinsic::matrix_column_major_load:
1087 LowerColumnMajorLoad(Inst);
1088 break;
1089 case Intrinsic::matrix_column_major_store:
1090 LowerColumnMajorStore(Inst);
1091 break;
1092 default:
1093 return false;
1094 }
1095 return true;
1096 }
1097
1098 /// Compute the alignment for a column/row \p Idx with \p Stride between them.
1099 /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
1100 /// ConstantInt, reduce the initial alignment based on the byte offset. For
1101 /// non-ConstantInt strides, return the common alignment of the initial
1102 /// alignment and the element size in bytes.
1103 Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
1104 MaybeAlign A) const {
1105 Align InitialAlign = DL.getValueOrABITypeAlignment(Alignment: A, Ty: ElementTy);
1106 if (Idx == 0)
1107 return InitialAlign;
1108
1109 TypeSize ElementSizeInBits = DL.getTypeSizeInBits(Ty: ElementTy);
1110 if (auto *ConstStride = dyn_cast<ConstantInt>(Val: Stride)) {
1111 uint64_t StrideInBytes =
1112 ConstStride->getZExtValue() * ElementSizeInBits / 8;
1113 return commonAlignment(A: InitialAlign, Offset: Idx * StrideInBytes);
1114 }
1115 return commonAlignment(A: InitialAlign, Offset: ElementSizeInBits / 8);
1116 }
1117
1118 /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
1119 /// vectors.
1120 MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
1121 bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
1122 auto *VType = cast<VectorType>(Val: Ty);
1123 Type *EltTy = VType->getElementType();
1124 Type *VecTy = FixedVectorType::get(ElementType: EltTy, NumElts: Shape.getStride());
1125 Value *EltPtr = Ptr;
1126 MatrixTy Result;
1127 for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
1128 Value *GEP = computeVectorAddr(
1129 BasePtr: EltPtr, VecIdx: Builder.getIntN(N: Stride->getType()->getScalarSizeInBits(), C: I),
1130 Stride, NumElements: Shape.getStride(), EltType: EltTy, Builder);
1131 Value *Vector = Builder.CreateAlignedLoad(
1132 Ty: VecTy, Ptr: GEP, Align: getAlignForIndex(Idx: I, Stride, ElementTy: EltTy, A: MAlign),
1133 isVolatile: IsVolatile, Name: "col.load");
1134
1135 Result.addVector(V: Vector);
1136 }
1137 return Result.addNumLoads(N: getNumOps(VT: Result.getVectorTy()) *
1138 Result.getNumVectors());
1139 }
1140
1141 /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
1142 /// starting at \p MatrixPtr[I][J].
1143 MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
1144 ShapeInfo MatrixShape, Value *I, Value *J,
1145 ShapeInfo ResultShape, Type *EltTy,
1146 IRBuilder<> &Builder) {
1147
1148 Value *Offset = Builder.CreateAdd(
1149 LHS: Builder.CreateMul(LHS: J, RHS: Builder.getInt64(C: MatrixShape.getStride())), RHS: I);
1150
1151 Value *TileStart = Builder.CreateGEP(Ty: EltTy, Ptr: MatrixPtr, IdxList: Offset);
1152 auto *TileTy = FixedVectorType::get(ElementType: EltTy, NumElts: ResultShape.NumRows *
1153 ResultShape.NumColumns);
1154
1155 return loadMatrix(Ty: TileTy, Ptr: TileStart, MAlign: Align,
1156 Stride: Builder.getInt64(C: MatrixShape.getStride()), IsVolatile,
1157 Shape: ResultShape, Builder);
1158 }
1159
1160 /// Lower a load instruction with shape information.
1161 void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
1162 bool IsVolatile, ShapeInfo Shape) {
1163 IRBuilder<> Builder(Inst);
1164 finalizeLowering(Inst,
1165 Matrix: loadMatrix(Ty: Inst->getType(), Ptr, MAlign: Align, Stride, IsVolatile,
1166 Shape, Builder),
1167 Builder);
1168 }
1169
1170 /// Lowers llvm.matrix.column.major.load.
1171 ///
1172 /// The intrinsic loads a matrix from memory using a stride between columns.
1173 void LowerColumnMajorLoad(CallInst *Inst) {
1174 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1175 "Intrinsic only supports column-major layout!");
1176 Value *Ptr = Inst->getArgOperand(i: 0);
1177 Value *Stride = Inst->getArgOperand(i: 1);
1178 LowerLoad(Inst, Ptr, Align: Inst->getParamAlign(ArgNo: 0), Stride,
1179 IsVolatile: cast<ConstantInt>(Val: Inst->getArgOperand(i: 2))->isOne(),
1180 Shape: {Inst->getArgOperand(i: 3), Inst->getArgOperand(i: 4)});
1181 }
1182
1183 /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
1184 /// MatrixPtr[I][J].
1185 void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
1186 MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
1187 Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
1188 Value *Offset = Builder.CreateAdd(
1189 LHS: Builder.CreateMul(LHS: J, RHS: Builder.getInt64(C: MatrixShape.getStride())), RHS: I);
1190
1191 Value *TileStart = Builder.CreateGEP(Ty: EltTy, Ptr: MatrixPtr, IdxList: Offset);
1192 auto *TileTy = FixedVectorType::get(ElementType: EltTy, NumElts: StoreVal.getNumRows() *
1193 StoreVal.getNumColumns());
1194
1195 storeMatrix(Ty: TileTy, StoreVal, Ptr: TileStart, MAlign,
1196 Stride: Builder.getInt64(C: MatrixShape.getStride()), IsVolatile, Builder);
1197 }
1198
1199 /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
1200 /// vectors.
1201 MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
1202 MaybeAlign MAlign, Value *Stride, bool IsVolatile,
1203 IRBuilder<> &Builder) {
1204 auto VType = cast<VectorType>(Val: Ty);
1205 Value *EltPtr = Ptr;
1206 for (auto Vec : enumerate(First: StoreVal.vectors())) {
1207 Value *GEP = computeVectorAddr(
1208 BasePtr: EltPtr,
1209 VecIdx: Builder.getIntN(N: Stride->getType()->getScalarSizeInBits(),
1210 C: Vec.index()),
1211 Stride, NumElements: StoreVal.getStride(), EltType: VType->getElementType(), Builder);
1212 Builder.CreateAlignedStore(Val: Vec.value(), Ptr: GEP,
1213 Align: getAlignForIndex(Idx: Vec.index(), Stride,
1214 ElementTy: VType->getElementType(),
1215 A: MAlign),
1216 isVolatile: IsVolatile);
1217 }
1218 return MatrixTy().addNumStores(N: getNumOps(VT: StoreVal.getVectorTy()) *
1219 StoreVal.getNumVectors());
1220 }
1221
1222 /// Lower a store instruction with shape information.
1223 void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
1224 Value *Stride, bool IsVolatile, ShapeInfo Shape) {
1225 IRBuilder<> Builder(Inst);
1226 auto StoreVal = getMatrix(MatrixVal: Matrix, SI: Shape, Builder);
1227 finalizeLowering(Inst,
1228 Matrix: storeMatrix(Ty: Matrix->getType(), StoreVal, Ptr, MAlign: A, Stride,
1229 IsVolatile, Builder),
1230 Builder);
1231 }
1232
1233 /// Lowers llvm.matrix.column.major.store.
1234 ///
1235 /// The intrinsic store a matrix back memory using a stride between columns.
1236 void LowerColumnMajorStore(CallInst *Inst) {
1237 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1238 "Intrinsic only supports column-major layout!");
1239 Value *Matrix = Inst->getArgOperand(i: 0);
1240 Value *Ptr = Inst->getArgOperand(i: 1);
1241 Value *Stride = Inst->getArgOperand(i: 2);
1242 LowerStore(Inst, Matrix, Ptr, A: Inst->getParamAlign(ArgNo: 1), Stride,
1243 IsVolatile: cast<ConstantInt>(Val: Inst->getArgOperand(i: 3))->isOne(),
1244 Shape: {Inst->getArgOperand(i: 4), Inst->getArgOperand(i: 5)});
1245 }
1246
1247 // Set elements I..I+NumElts-1 to Block
1248 Value *insertVector(Value *Col, unsigned I, Value *Block,
1249 IRBuilder<> &Builder) {
1250
1251 // First, bring Block to the same size as Col
1252 unsigned BlockNumElts =
1253 cast<FixedVectorType>(Val: Block->getType())->getNumElements();
1254 unsigned NumElts = cast<FixedVectorType>(Val: Col->getType())->getNumElements();
1255 assert(NumElts >= BlockNumElts && "Too few elements for current block");
1256
1257 Block = Builder.CreateShuffleVector(
1258 V: Block, Mask: createSequentialMask(Start: 0, NumInts: BlockNumElts, NumUndefs: NumElts - BlockNumElts));
1259
1260 // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
1261 // 8, 4, 5, 6
1262 SmallVector<int, 16> Mask;
1263 unsigned i;
1264 for (i = 0; i < I; i++)
1265 Mask.push_back(Elt: i);
1266
1267 unsigned VecNumElts =
1268 cast<FixedVectorType>(Val: Col->getType())->getNumElements();
1269 for (; i < I + BlockNumElts; i++)
1270 Mask.push_back(Elt: i - I + VecNumElts);
1271
1272 for (; i < VecNumElts; i++)
1273 Mask.push_back(Elt: i);
1274
1275 return Builder.CreateShuffleVector(V1: Col, V2: Block, Mask);
1276 }
1277
1278 Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
1279 IRBuilder<> &Builder, bool AllowContraction,
1280 unsigned &NumComputeOps) {
1281 NumComputeOps += getNumOps(VT: A->getType());
1282 if (!Sum)
1283 return UseFPOp ? Builder.CreateFMul(L: A, R: B) : Builder.CreateMul(LHS: A, RHS: B);
1284
1285 if (UseFPOp) {
1286 if (AllowContraction) {
1287 // Use fmuladd for floating point operations and let the backend decide
1288 // if that's profitable.
1289 Function *FMulAdd = Intrinsic::getDeclaration(
1290 Func.getParent(), Intrinsic::fmuladd, A->getType());
1291 return Builder.CreateCall(Callee: FMulAdd, Args: {A, B, Sum});
1292 }
1293 NumComputeOps += getNumOps(VT: A->getType());
1294 Value *Mul = Builder.CreateFMul(L: A, R: B);
1295 return Builder.CreateFAdd(L: Sum, R: Mul);
1296 }
1297
1298 NumComputeOps += getNumOps(VT: A->getType());
1299 Value *Mul = Builder.CreateMul(LHS: A, RHS: B);
1300 return Builder.CreateAdd(LHS: Sum, RHS: Mul);
1301 }
1302
1303 /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
1304 /// users with shape information, there's nothing to do: they will use the
1305 /// cached value when they are lowered. For other users, \p Matrix is
1306 /// flattened and the uses are updated to use it. Also marks \p Inst for
1307 /// deletion.
1308 void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
1309 IRBuilder<> &Builder) {
1310 auto inserted = Inst2ColumnMatrix.insert(KV: std::make_pair(x&: Inst, y&: Matrix));
1311 (void)inserted;
1312 assert(inserted.second && "multiple matrix lowering mapping");
1313
1314 ToRemove.push_back(Elt: Inst);
1315 Value *Flattened = nullptr;
1316 for (Use &U : llvm::make_early_inc_range(Range: Inst->uses())) {
1317 if (ShapeMap.find(Val: U.getUser()) == ShapeMap.end()) {
1318 if (!Flattened)
1319 Flattened = Matrix.embedInVector(Builder);
1320 U.set(Flattened);
1321 }
1322 }
1323 }
1324
1325 /// Special case for MatMul lowering. Prevents scalar loads of row-major
1326 /// vectors Lowers to vector reduction add instead of sequential add if
1327 /// reassocation is enabled.
1328 void lowerDotProduct(CallInst *MatMul,
1329 SmallPtrSet<Instruction *, 16> &FusedInsts,
1330 FastMathFlags FMF) {
1331 if (FusedInsts.contains(Ptr: MatMul) ||
1332 MatrixLayout != MatrixLayoutTy::ColumnMajor)
1333 return;
1334 ShapeInfo LShape(MatMul->getArgOperand(i: 2), MatMul->getArgOperand(i: 3));
1335 ShapeInfo RShape(MatMul->getArgOperand(i: 3), MatMul->getArgOperand(i: 4));
1336
1337 if (LShape.NumRows != 1 || RShape.NumColumns != 1) // not a dot product
1338 return;
1339
1340 Value *LHS = MatMul->getArgOperand(i: 0);
1341 Value *RHS = MatMul->getArgOperand(i: 1);
1342
1343 Type *ElementType = cast<VectorType>(Val: LHS->getType())->getElementType();
1344 bool IsIntVec = ElementType->isIntegerTy();
1345
1346 // Floating point reductions require reassocation.
1347 if (!IsIntVec && !FMF.allowReassoc())
1348 return;
1349
1350 auto CanBeFlattened = [](Value *Op) {
1351 if (match(V: Op, P: m_BinOp()))
1352 return true;
1353 return match(
1354 Op, m_OneUse(m_CombineOr(
1355 m_Load(m_Value()),
1356 m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1357 m_Intrinsic<Intrinsic::matrix_column_major_load>(
1358 m_Value(), m_SpecificInt(1))))));
1359 };
1360 // Returns the cost benefit of using \p Op with the dot product lowering. If
1361 // the returned cost is < 0, the argument is cheaper to use in the
1362 // dot-product lowering.
1363 auto GetCostForArg = [this, &CanBeFlattened](Value *Op, unsigned N) {
1364 if (ShapeMap.find(Val: Op) == ShapeMap.end())
1365 return InstructionCost::getInvalid();
1366
1367 if (!isa<Instruction>(Val: Op))
1368 return InstructionCost(0);
1369
1370 FixedVectorType *VecTy = cast<FixedVectorType>(Val: Op->getType());
1371 Type *EltTy = VecTy->getElementType();
1372
1373 if (!CanBeFlattened(Op)) {
1374 InstructionCost EmbedCost(0);
1375 // Roughly estimate the cost for embedding the columns into a vector.
1376 for (unsigned I = 1; I < N; ++I)
1377 EmbedCost +=
1378 TTI.getShuffleCost(Kind: TTI::SK_Splice, Tp: FixedVectorType::get(ElementType: EltTy, NumElts: 1),
1379 Mask: std::nullopt, CostKind: TTI::TCK_RecipThroughput);
1380 return EmbedCost;
1381 }
1382
1383 if (match(V: Op, P: m_BinOp()) && ShapeMap.find(Val: Op) != ShapeMap.end()) {
1384 InstructionCost OriginalCost =
1385 TTI.getArithmeticInstrCost(Opcode: cast<Instruction>(Val: Op)->getOpcode(),
1386 Ty: EltTy) *
1387 N;
1388 InstructionCost NewCost = TTI.getArithmeticInstrCost(
1389 Opcode: cast<Instruction>(Val: Op)->getOpcode(), Ty: VecTy);
1390 return NewCost - OriginalCost;
1391 }
1392
1393 if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1394 // The transpose can be skipped for the dot product lowering, roughly
1395 // estimate the savings as the cost of embedding the columns in a
1396 // vector.
1397 InstructionCost EmbedCost(0);
1398 for (unsigned I = 1; I < N; ++I)
1399 EmbedCost -=
1400 TTI.getShuffleCost(Kind: TTI::SK_Splice, Tp: FixedVectorType::get(ElementType: EltTy, NumElts: 1),
1401 Mask: std::nullopt, CostKind: TTI::TCK_RecipThroughput);
1402 return EmbedCost;
1403 }
1404
1405 // Costs for loads.
1406 if (N == 1)
1407 return InstructionCost(0);
1408
1409 return TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy, Alignment: Align(1), AddressSpace: 0) -
1410 N * TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: EltTy, Alignment: Align(1), AddressSpace: 0);
1411 };
1412
1413 // Iterate over LHS and operations feeding LHS and check if it is profitable
1414 // to flatten the visited ops. For each op, we compute the difference
1415 // between the flattened and matrix versions.
1416 SmallPtrSet<Value *, 4> Seen;
1417 SmallVector<Value *> WorkList;
1418 SmallVector<Value *> ToFlatten;
1419 WorkList.push_back(Elt: LHS);
1420 InstructionCost LHSCost(0);
1421 while (!WorkList.empty()) {
1422 Value *Op = WorkList.pop_back_val();
1423 if (!Seen.insert(Ptr: Op).second)
1424 continue;
1425
1426 InstructionCost OpCost = GetCostForArg(Op, LShape.NumColumns);
1427 if (OpCost + LHSCost >= LHSCost)
1428 continue;
1429
1430 LHSCost += OpCost;
1431 ToFlatten.push_back(Elt: Op);
1432 if (auto *I = dyn_cast<Instruction>(Val: Op))
1433 WorkList.append(in_start: I->op_begin(), in_end: I->op_end());
1434 }
1435
1436 // We compare the costs of a vector.reduce.add to sequential add.
1437 int AddOpCode = IsIntVec ? Instruction::Add : Instruction::FAdd;
1438 int MulOpCode = IsIntVec ? Instruction::Mul : Instruction::FMul;
1439 InstructionCost ReductionCost =
1440 TTI.getArithmeticReductionCost(
1441 Opcode: AddOpCode, Ty: cast<VectorType>(Val: LHS->getType()),
1442 FMF: IsIntVec ? std::nullopt : std::optional(FMF)) +
1443 TTI.getArithmeticInstrCost(Opcode: MulOpCode, Ty: LHS->getType());
1444 InstructionCost SequentialAddCost =
1445 TTI.getArithmeticInstrCost(Opcode: AddOpCode, Ty: ElementType) *
1446 (LShape.NumColumns - 1) +
1447 TTI.getArithmeticInstrCost(Opcode: MulOpCode, Ty: ElementType) *
1448 (LShape.NumColumns);
1449 if ((LHSCost + ReductionCost - SequentialAddCost) > InstructionCost(0))
1450 return;
1451
1452 FusedInsts.insert(Ptr: MatMul);
1453 IRBuilder<> Builder(MatMul);
1454 auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1455 this](Value *Op) {
1456 // Matmul must be the only user of loads because we don't use LowerLoad
1457 // for row vectors (LowerLoad results in scalar loads and shufflevectors
1458 // instead of single vector load).
1459 if (!CanBeFlattened(Op))
1460 return;
1461
1462 if (match(V: Op, P: m_BinOp()) && ShapeMap.find(Val: Op) != ShapeMap.end()) {
1463 ShapeMap[Op] = ShapeMap[Op].t();
1464 return;
1465 }
1466
1467 FusedInsts.insert(Ptr: cast<Instruction>(Val: Op));
1468 // If vector uses the builtin load, lower to a LoadInst
1469 Value *Arg;
1470 if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1471 m_Value(Arg)))) {
1472 auto *NewLoad = Builder.CreateLoad(Ty: Op->getType(), Ptr: Arg);
1473 Op->replaceAllUsesWith(V: NewLoad);
1474 cast<Instruction>(Val: Op)->eraseFromParent();
1475 return;
1476 } else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1477 m_Value(Arg)))) {
1478 ToRemove.push_back(Elt: cast<Instruction>(Val: Op));
1479 Op->replaceAllUsesWith(V: Arg);
1480 return;
1481 }
1482 };
1483
1484 for (auto *V : ToFlatten)
1485 FlattenArg(V);
1486
1487 LHS = MatMul->getArgOperand(i: 0);
1488
1489 // Insert mul/fmul and llvm.vector.reduce.fadd
1490 Value *Mul =
1491 IsIntVec ? Builder.CreateMul(LHS, RHS) : Builder.CreateFMul(L: LHS, R: RHS);
1492
1493 Value *Result;
1494 if (IsIntVec)
1495 Result = Builder.CreateAddReduce(Src: Mul);
1496 else {
1497 Result = Builder.CreateFAddReduce(
1498 Acc: ConstantFP::get(Ty: cast<VectorType>(Val: LHS->getType())->getElementType(),
1499 V: 0.0),
1500 Src: Mul);
1501 cast<Instruction>(Val: Result)->setFastMathFlags(FMF);
1502 }
1503
1504 // pack scalar back into a matrix and then replace matmul inst
1505 Result = Builder.CreateInsertElement(Vec: PoisonValue::get(T: MatMul->getType()),
1506 NewElt: Result, Idx: uint64_t(0));
1507 MatMul->replaceAllUsesWith(V: Result);
1508 FusedInsts.insert(Ptr: MatMul);
1509 ToRemove.push_back(Elt: MatMul);
1510 }
1511
1512 /// Compute \p Result += \p A * \p B for input matrices with left-associating
1513 /// addition.
1514 ///
1515 /// We can fold a transpose into the operand that is used to extract scalars.
1516 /// This is the first operands with row-major and the second with
1517 /// column-major. If \p IsScalarMatrixTransposed we assume the appropriate
1518 /// operand is transposed.
1519 void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
1520 const MatrixTy &B, IRBuilder<> &Builder, bool IsTiled,
1521 bool IsScalarMatrixTransposed, FastMathFlags FMF) {
1522 const unsigned VF = std::max<unsigned>(
1523 a: TTI.getRegisterBitWidth(K: TargetTransformInfo::RGK_FixedWidthVector)
1524 .getFixedValue() /
1525 Result.getElementType()->getPrimitiveSizeInBits().getFixedValue(),
1526 b: 1U);
1527 unsigned R = Result.getNumRows();
1528 unsigned C = Result.getNumColumns();
1529 unsigned M = A.getNumColumns();
1530
1531 bool IsFP = Result.getElementType()->isFloatingPointTy();
1532 assert(A.isColumnMajor() == B.isColumnMajor() &&
1533 Result.isColumnMajor() == A.isColumnMajor() &&
1534 "operands must agree on matrix layout");
1535 unsigned NumComputeOps = 0;
1536
1537 Builder.setFastMathFlags(FMF);
1538
1539 if (A.isColumnMajor()) {
1540 // Multiply columns from the first operand with scalars from the second
1541 // operand. Then move along the K axes and accumulate the columns. With
1542 // this the adds can be vectorized without reassociation.
1543 for (unsigned J = 0; J < C; ++J) {
1544 unsigned BlockSize = VF;
1545 // If Result is zero, we don't need to accumulate in the K==0 iteration.
1546 bool isSumZero = isa<ConstantAggregateZero>(Val: Result.getColumn(i: J));
1547
1548 for (unsigned I = 0; I < R; I += BlockSize) {
1549 // Gradually lower the vectorization factor to cover the remainder.
1550 while (I + BlockSize > R)
1551 BlockSize /= 2;
1552
1553 Value *Sum = IsTiled ? Result.extractVector(I, J, NumElts: BlockSize, Builder)
1554 : nullptr;
1555 for (unsigned K = 0; K < M; ++K) {
1556 Value *L = A.extractVector(I, J: K, NumElts: BlockSize, Builder);
1557 Value *RH = Builder.CreateExtractElement(
1558 Vec: B.getColumn(i: IsScalarMatrixTransposed ? K : J),
1559 Idx: IsScalarMatrixTransposed ? J : K);
1560 Value *Splat = Builder.CreateVectorSplat(NumElts: BlockSize, V: RH, Name: "splat");
1561 Sum =
1562 createMulAdd(Sum: isSumZero && K == 0 ? nullptr : Sum, A: L, B: Splat,
1563 UseFPOp: IsFP, Builder, AllowContraction: FMF.allowContract(), NumComputeOps);
1564 }
1565 Result.setVector(i: J,
1566 V: insertVector(Col: Result.getVector(i: J), I, Block: Sum, Builder));
1567 }
1568 }
1569 } else {
1570 // Multiply rows from the second operand with scalars from the first
1571 // operand. Then move along the K axes and accumulate the rows. With this
1572 // the adds can be vectorized without reassociation.
1573 for (unsigned I = 0; I < R; ++I) {
1574 unsigned BlockSize = VF;
1575 bool isSumZero = isa<ConstantAggregateZero>(Val: Result.getRow(i: I));
1576 for (unsigned J = 0; J < C; J += BlockSize) {
1577 // Gradually lower the vectorization factor to cover the remainder.
1578 while (J + BlockSize > C)
1579 BlockSize /= 2;
1580
1581 Value *Sum = nullptr;
1582 for (unsigned K = 0; K < M; ++K) {
1583 Value *R = B.extractVector(I: K, J, NumElts: BlockSize, Builder);
1584 Value *LH = Builder.CreateExtractElement(
1585 Vec: A.getVector(i: IsScalarMatrixTransposed ? K : I),
1586 Idx: IsScalarMatrixTransposed ? I : K);
1587 Value *Splat = Builder.CreateVectorSplat(NumElts: BlockSize, V: LH, Name: "splat");
1588 Sum =
1589 createMulAdd(Sum: isSumZero && K == 0 ? nullptr : Sum, A: Splat, B: R,
1590 UseFPOp: IsFP, Builder, AllowContraction: FMF.allowContract(), NumComputeOps);
1591 }
1592 Result.setVector(i: I,
1593 V: insertVector(Col: Result.getVector(i: I), I: J, Block: Sum, Builder));
1594 }
1595 }
1596 }
1597 Result.addNumComputeOps(N: NumComputeOps);
1598 }
1599
1600 /// Ensure that the memory in \p Load does not alias \p Store by potentially
1601 /// copying it to a new location. This new or otherwise the original location
1602 /// is returned.
1603 Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
1604 CallInst *MatMul) {
1605 MemoryLocation StoreLoc = MemoryLocation::get(SI: Store);
1606 MemoryLocation LoadLoc = MemoryLocation::get(LI: Load);
1607
1608 // If we can statically determine noalias we're good.
1609 if (AA->isNoAlias(LocA: LoadLoc, LocB: StoreLoc))
1610 return Load->getPointerOperand();
1611
1612 // Create code to check if the memory locations of the Load and Store
1613 // overlap and if they do, copy Load's operand to a new buffer.
1614
1615 // First, create new blocks for 2n part of the check and the copy.
1616 BasicBlock *Check0 = MatMul->getParent();
1617 // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
1618 // DT. Manually collect dominator tree updates, to avoid unnecessary work,
1619 // as we adjust Check0 and Check1's branches.
1620 SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
1621 for (BasicBlock *Succ : successors(BB: Check0))
1622 DTUpdates.push_back(Elt: {DT->Delete, Check0, Succ});
1623
1624 BasicBlock *Check1 =
1625 SplitBlock(Old: MatMul->getParent(), SplitPt: MatMul, DTU: (DomTreeUpdater *)nullptr, LI,
1626 MSSAU: nullptr, BBName: "alias_cont");
1627 BasicBlock *Copy =
1628 SplitBlock(Old: MatMul->getParent(), SplitPt: MatMul, DTU: (DomTreeUpdater *)nullptr, LI,
1629 MSSAU: nullptr, BBName: "copy");
1630 BasicBlock *Fusion =
1631 SplitBlock(Old: MatMul->getParent(), SplitPt: MatMul, DTU: (DomTreeUpdater *)nullptr, LI,
1632 MSSAU: nullptr, BBName: "no_alias");
1633
1634 // Check if the loaded memory location begins before the end of the store
1635 // location. If the condition holds, they might overlap, otherwise they are
1636 // guaranteed to not overlap.
1637 IRBuilder<> Builder(MatMul);
1638 Check0->getTerminator()->eraseFromParent();
1639 Builder.SetInsertPoint(Check0);
1640 Type *IntPtrTy = Builder.getIntPtrTy(DL: Load->getModule()->getDataLayout());
1641 Value *StoreBegin = Builder.CreatePtrToInt(
1642 V: const_cast<Value *>(StoreLoc.Ptr), DestTy: IntPtrTy, Name: "store.begin");
1643 Value *StoreEnd = Builder.CreateAdd(
1644 LHS: StoreBegin, RHS: ConstantInt::get(Ty: IntPtrTy, V: StoreLoc.Size.getValue()),
1645 Name: "store.end", HasNUW: true, HasNSW: true);
1646 Value *LoadBegin = Builder.CreatePtrToInt(V: const_cast<Value *>(LoadLoc.Ptr),
1647 DestTy: IntPtrTy, Name: "load.begin");
1648 Builder.CreateCondBr(Cond: Builder.CreateICmpULT(LHS: LoadBegin, RHS: StoreEnd), True: Check1,
1649 False: Fusion);
1650
1651 // Check if the store begins before the end of the load location. If the
1652 // condition holds, they alias, otherwise they are guaranteed to not
1653 // overlap.
1654 Check1->getTerminator()->eraseFromParent();
1655 Builder.SetInsertPoint(TheBB: Check1, IP: Check1->begin());
1656 Value *LoadEnd = Builder.CreateAdd(
1657 LHS: LoadBegin, RHS: ConstantInt::get(Ty: IntPtrTy, V: LoadLoc.Size.getValue()),
1658 Name: "load.end", HasNUW: true, HasNSW: true);
1659 Builder.CreateCondBr(Cond: Builder.CreateICmpULT(LHS: StoreBegin, RHS: LoadEnd), True: Copy,
1660 False: Fusion);
1661
1662 // Copy load operand to new alloca.
1663 Builder.SetInsertPoint(TheBB: Copy, IP: Copy->begin());
1664 auto *VT = cast<FixedVectorType>(Val: Load->getType());
1665 // Use an array type for the alloca, to avoid potentially huge alignment
1666 // requirements for large vector types.
1667 auto *ArrayTy = ArrayType::get(ElementType: VT->getElementType(), NumElements: VT->getNumElements());
1668 AllocaInst *Alloca =
1669 Builder.CreateAlloca(Ty: ArrayTy, AddrSpace: Load->getPointerAddressSpace());
1670
1671 Builder.CreateMemCpy(Dst: Alloca, DstAlign: Alloca->getAlign(), Src: Load->getPointerOperand(),
1672 SrcAlign: Load->getAlign(), Size: LoadLoc.Size.getValue());
1673 Builder.SetInsertPoint(TheBB: Fusion, IP: Fusion->begin());
1674 PHINode *PHI = Builder.CreatePHI(Ty: Load->getPointerOperandType(), NumReservedValues: 3);
1675 PHI->addIncoming(V: Load->getPointerOperand(), BB: Check0);
1676 PHI->addIncoming(V: Load->getPointerOperand(), BB: Check1);
1677 PHI->addIncoming(V: Alloca, BB: Copy);
1678
1679 // Adjust DT.
1680 DTUpdates.push_back(Elt: {DT->Insert, Check0, Check1});
1681 DTUpdates.push_back(Elt: {DT->Insert, Check0, Fusion});
1682 DTUpdates.push_back(Elt: {DT->Insert, Check1, Copy});
1683 DTUpdates.push_back(Elt: {DT->Insert, Check1, Fusion});
1684 DT->applyUpdates(Updates: DTUpdates);
1685 return PHI;
1686 }
1687
1688 bool isFusionProfitable(CallInst *MatMul) {
1689 if (ForceFusion)
1690 return true;
1691
1692 ShapeInfo LShape(MatMul->getArgOperand(i: 2), MatMul->getArgOperand(i: 3));
1693 ShapeInfo RShape(MatMul->getArgOperand(i: 3), MatMul->getArgOperand(i: 4));
1694
1695 const unsigned R = LShape.NumRows;
1696 const unsigned C = RShape.NumColumns;
1697 const unsigned M = LShape.NumColumns;
1698 auto *EltType = cast<VectorType>(Val: MatMul->getType())->getElementType();
1699
1700 const unsigned VF = std::max<unsigned>(
1701 a: TTI.getRegisterBitWidth(K: TargetTransformInfo::RGK_FixedWidthVector)
1702 .getFixedValue() /
1703 EltType->getPrimitiveSizeInBits().getFixedValue(),
1704 b: 1U);
1705
1706 // Cost model for tiling
1707 //
1708 // For tiling to be beneficial, we need reuse either along the R or
1709 // the C axis. We vectorize along the R axis so that means at least
1710 // 3 elements.
1711 // TODO: Also consider cost of copying if operands alias.
1712 if (R <= VF && C == 1)
1713 return false;
1714 // Then we need enough elements to exceed the number of vector
1715 // registers we have. Note that this is an oversimplification since
1716 // fusing also takes some extra loads which may exceed the number of
1717 // reloads necessary.
1718 unsigned Op0Regs = (R + VF - 1) / VF * M;
1719 unsigned Op1Regs = (M + VF - 1) / VF * C;
1720 return Op0Regs + Op1Regs >
1721 TTI.getNumberOfRegisters(ClassID: TTI.getRegisterClassForType(Vector: true));
1722 }
1723
1724 MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
1725 MatrixTy Res;
1726 auto *ColumType = FixedVectorType::get(ElementType: EltType, NumElts: R);
1727 for (unsigned I = 0; I < C; ++I)
1728 Res.addVector(V: ConstantAggregateZero::get(Ty: ColumType));
1729 return Res;
1730 }
1731
1732 void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
1733 Value *RPtr, ShapeInfo RShape, StoreInst *Store) {
1734 auto *EltType = cast<VectorType>(Val: MatMul->getType())->getElementType();
1735
1736 // Create the main tiling loop nest.
1737 TileInfo TI(LShape.NumRows, RShape.NumColumns, LShape.NumColumns, TileSize);
1738 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1739 Instruction *InsertI = cast<Instruction>(Val: MatMul);
1740 BasicBlock *Start = InsertI->getParent();
1741 BasicBlock *End =
1742 SplitBlock(Old: InsertI->getParent(), SplitPt: InsertI, DT, LI, MSSAU: nullptr, BBName: "continue");
1743 IRBuilder<> Builder(MatMul);
1744 BasicBlock *InnerBody = TI.CreateTiledLoops(Start, End, B&: Builder, DTU, LI&: *LI);
1745
1746 Type *TileVecTy =
1747 FixedVectorType::get(ElementType: MatMul->getType()->getScalarType(), NumElts: TileSize);
1748 MatrixTy TileResult;
1749 // Insert in the inner loop header.
1750 Builder.SetInsertPoint(TI.KLoop.Header->getTerminator());
1751 // Create PHI nodes for the result columns to accumulate across iterations.
1752 SmallVector<PHINode *, 4> ColumnPhis;
1753 for (unsigned I = 0; I < TileSize; I++) {
1754 auto *Phi = Builder.CreatePHI(Ty: TileVecTy, NumReservedValues: 2, Name: "result.vec." + Twine(I));
1755 Phi->addIncoming(V: ConstantAggregateZero::get(Ty: TileVecTy),
1756 BB: TI.RowLoop.Header->getSingleSuccessor());
1757 TileResult.addVector(V: Phi);
1758 ColumnPhis.push_back(Elt: Phi);
1759 }
1760
1761 // Insert in the inner loop body, which computes
1762 // Res += Load(CurrentRow, K) * Load(K, CurrentColumn)
1763 Builder.SetInsertPoint(InnerBody->getTerminator());
1764 // Load tiles of the operands.
1765 MatrixTy A =
1766 loadMatrix(MatrixPtr: LPtr, Align: {}, IsVolatile: false, MatrixShape: LShape, I: TI.RowLoop.Index, J: TI.KLoop.Index,
1767 ResultShape: {TileSize, TileSize}, EltTy: EltType, Builder);
1768 MatrixTy B =
1769 loadMatrix(MatrixPtr: RPtr, Align: {}, IsVolatile: false, MatrixShape: RShape, I: TI.KLoop.Index, J: TI.ColumnLoop.Index,
1770 ResultShape: {TileSize, TileSize}, EltTy: EltType, Builder);
1771 emitMatrixMultiply(Result&: TileResult, A, B, Builder, IsTiled: true, IsScalarMatrixTransposed: false,
1772 FMF: getFastMathFlags(Inst: MatMul));
1773 // Store result after the inner loop is done.
1774 Builder.SetInsertPoint(TI.RowLoop.Latch->getTerminator());
1775 storeMatrix(StoreVal: TileResult, MatrixPtr: Store->getPointerOperand(), MAlign: Store->getAlign(),
1776 IsVolatile: Store->isVolatile(), MatrixShape: {LShape.NumRows, RShape.NumColumns},
1777 I: TI.RowLoop.Index, J: TI.ColumnLoop.Index, EltTy: EltType, Builder);
1778
1779 for (unsigned I = 0; I < TileResult.getNumVectors(); I++)
1780 ColumnPhis[I]->addIncoming(V: TileResult.getVector(i: I), BB: TI.KLoop.Latch);
1781
1782 // Force unrolling of a few iterations of the inner loop, to make sure there
1783 // is enough work per iteration.
1784 // FIXME: The unroller should make this decision directly instead, but
1785 // currently the cost-model is not up to the task.
1786 unsigned InnerLoopUnrollCount = std::min(a: 10u, b: LShape.NumColumns / TileSize);
1787 addStringMetadataToLoop(TheLoop: LI->getLoopFor(BB: TI.KLoop.Header),
1788 MDString: "llvm.loop.unroll.count", V: InnerLoopUnrollCount);
1789 }
1790
1791 void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
1792 StoreInst *Store,
1793 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1794 assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
1795 "Tiling only supported for column-major matrixes at the moment!");
1796 if (!isFusionProfitable(MatMul))
1797 return;
1798
1799 ShapeInfo LShape(MatMul->getArgOperand(i: 2), MatMul->getArgOperand(i: 3));
1800 ShapeInfo RShape(MatMul->getArgOperand(i: 3), MatMul->getArgOperand(i: 4));
1801
1802 const unsigned R = LShape.NumRows;
1803 const unsigned C = RShape.NumColumns;
1804 const unsigned M = LShape.NumColumns;
1805 auto *EltType = cast<VectorType>(Val: MatMul->getType())->getElementType();
1806
1807 Value *APtr = getNonAliasingPointer(Load: LoadOp0, Store, MatMul);
1808 Value *BPtr = getNonAliasingPointer(Load: LoadOp1, Store, MatMul);
1809 Value *CPtr = Store->getPointerOperand();
1810
1811 if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
1812 createTiledLoops(MatMul, LPtr: APtr, LShape, RPtr: BPtr, RShape, Store);
1813 else {
1814 IRBuilder<> Builder(Store);
1815 for (unsigned J = 0; J < C; J += TileSize)
1816 for (unsigned I = 0; I < R; I += TileSize) {
1817 const unsigned TileR = std::min(a: R - I, b: unsigned(TileSize));
1818 const unsigned TileC = std::min(a: C - J, b: unsigned(TileSize));
1819 MatrixTy Res = getZeroMatrix(EltType, R: TileR, C: TileC);
1820
1821 for (unsigned K = 0; K < M; K += TileSize) {
1822 const unsigned TileM = std::min(a: M - K, b: unsigned(TileSize));
1823 MatrixTy A =
1824 loadMatrix(MatrixPtr: APtr, Align: LoadOp0->getAlign(), IsVolatile: LoadOp0->isVolatile(),
1825 MatrixShape: LShape, I: Builder.getInt64(C: I), J: Builder.getInt64(C: K),
1826 ResultShape: {TileR, TileM}, EltTy: EltType, Builder);
1827 MatrixTy B =
1828 loadMatrix(MatrixPtr: BPtr, Align: LoadOp1->getAlign(), IsVolatile: LoadOp1->isVolatile(),
1829 MatrixShape: RShape, I: Builder.getInt64(C: K), J: Builder.getInt64(C: J),
1830 ResultShape: {TileM, TileC}, EltTy: EltType, Builder);
1831 emitMatrixMultiply(Result&: Res, A, B, Builder, IsTiled: true, IsScalarMatrixTransposed: false,
1832 FMF: getFastMathFlags(Inst: MatMul));
1833 }
1834 storeMatrix(StoreVal: Res, MatrixPtr: CPtr, MAlign: Store->getAlign(), IsVolatile: Store->isVolatile(), MatrixShape: {R, M},
1835 I: Builder.getInt64(C: I), J: Builder.getInt64(C: J), EltTy: EltType,
1836 Builder);
1837 }
1838 }
1839
1840 // Mark eliminated instructions as fused and remove them.
1841 FusedInsts.insert(Ptr: Store);
1842 FusedInsts.insert(Ptr: MatMul);
1843 Store->eraseFromParent();
1844 MatMul->eraseFromParent();
1845 if (LoadOp0->hasNUses(N: 0)) {
1846 FusedInsts.insert(Ptr: LoadOp0);
1847 LoadOp0->eraseFromParent();
1848 }
1849 if (LoadOp1 != LoadOp0 && LoadOp1->hasNUses(N: 0)) {
1850 FusedInsts.insert(Ptr: LoadOp1);
1851 LoadOp1->eraseFromParent();
1852 }
1853 }
1854
1855 /// Try to lower matrix multiply chains by fusing operations.
1856 ///
1857 /// Call finalizeLowering on lowered instructions. Instructions that are
1858 /// completely eliminated by fusion are added to \p FusedInsts.
1859 void LowerMatrixMultiplyFused(CallInst *MatMul,
1860 SmallPtrSetImpl<Instruction *> &FusedInsts) {
1861 if (!FuseMatrix || !DT)
1862 return;
1863
1864 assert(AA && LI && "Analyses should be available");
1865
1866 Value *A = MatMul->getArgOperand(i: 0);
1867 Value *B = MatMul->getArgOperand(i: 1);
1868
1869 // We can fold the transpose into the operand that is used to fetch scalars.
1870 Value *T;
1871 if (MatrixLayout == MatrixLayoutTy::ColumnMajor
1872 ? match(B, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))
1873 : match(A, m_Intrinsic<Intrinsic::matrix_transpose>(m_Value(T)))) {
1874 IRBuilder<> Builder(MatMul);
1875 auto *EltType = cast<VectorType>(Val: MatMul->getType())->getElementType();
1876 ShapeInfo LShape(MatMul->getArgOperand(i: 2), MatMul->getArgOperand(i: 3));
1877 ShapeInfo RShape(MatMul->getArgOperand(i: 3), MatMul->getArgOperand(i: 4));
1878 const unsigned R = LShape.NumRows;
1879 const unsigned M = LShape.NumColumns;
1880 const unsigned C = RShape.NumColumns;
1881
1882 MatrixTy MA;
1883 MatrixTy MB;
1884
1885 Value *Transpose;
1886 if (MatrixLayout == MatrixLayoutTy::ColumnMajor) {
1887 MA = getMatrix(MatrixVal: A, SI: ShapeInfo(R, M), Builder);
1888 MB = getMatrix(MatrixVal: T, SI: ShapeInfo(C, M), Builder);
1889 Transpose = B;
1890 } else {
1891 MA = getMatrix(MatrixVal: T, SI: ShapeInfo(R, M), Builder);
1892 MB = getMatrix(MatrixVal: B, SI: ShapeInfo(C, M), Builder);
1893 Transpose = A;
1894 }
1895
1896 // Initialize the output
1897 MatrixTy Result(R, C, EltType);
1898
1899 emitMatrixMultiply(Result, A: MA, B: MB, Builder, IsTiled: false, IsScalarMatrixTransposed: true,
1900 FMF: getFastMathFlags(Inst: MatMul));
1901
1902 FusedInsts.insert(Ptr: MatMul);
1903 if (Transpose->hasOneUse()) {
1904 FusedInsts.insert(Ptr: cast<Instruction>(Val: Transpose));
1905 ToRemove.push_back(Elt: cast<Instruction>(Val: Transpose));
1906 // TODO: add a fake entry for the folded instruction so that this is
1907 // included in the expression in the remark.
1908 Inst2ColumnMatrix[Transpose] = MatrixTy(M, C, EltType);
1909 }
1910 finalizeLowering(Inst: MatMul, Matrix: Result, Builder);
1911 return;
1912 }
1913
1914 if (!MatMul->hasOneUse() || MatrixLayout != MatrixLayoutTy::ColumnMajor)
1915 return;
1916
1917 // Lower {ld, ld} -> matmul -> st chains. No need to call finalizeLowering
1918 // since the single store user will be lowered as part of this.
1919 auto *LoadOp0 = dyn_cast<LoadInst>(Val: A);
1920 auto *LoadOp1 = dyn_cast<LoadInst>(Val: B);
1921 auto *Store = dyn_cast<StoreInst>(Val: *MatMul->user_begin());
1922 if (LoadOp0 && LoadOp1 && Store) {
1923 // The store address must dominate the MatMul instruction, otherwise
1924 // we create invalid IR.
1925 SetVector<Value *> WorkList;
1926 WorkList.insert(X: Store->getOperand(i_nocapture: 1));
1927 SmallVector<Instruction *> ToHoist;
1928 for (unsigned I = 0; I != WorkList.size(); ++I) {
1929 Value *Current = WorkList[I];
1930 auto *CurrI = dyn_cast<Instruction>(Val: Current);
1931 if (!CurrI)
1932 continue;
1933 if (isa<PHINode>(Val: CurrI))
1934 return;
1935 if (DT->dominates(Def: CurrI, User: MatMul))
1936 continue;
1937 if (CurrI->mayHaveSideEffects() || CurrI->mayReadFromMemory())
1938 return;
1939 ToHoist.push_back(Elt: CurrI);
1940 WorkList.insert(Start: CurrI->op_begin(), End: CurrI->op_end());
1941 }
1942
1943 sort(C&: ToHoist, Comp: [this](Instruction *A, Instruction *B) {
1944 return DT->dominates(Def: A, User: B);
1945 });
1946 for (Instruction *I : ToHoist)
1947 I->moveBefore(MovePos: MatMul);
1948
1949 emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
1950 return;
1951 }
1952 }
1953
1954 /// Lowers llvm.matrix.multiply.
1955 void LowerMultiply(CallInst *MatMul) {
1956 IRBuilder<> Builder(MatMul);
1957 auto *EltType = cast<VectorType>(Val: MatMul->getType())->getElementType();
1958 ShapeInfo LShape(MatMul->getArgOperand(i: 2), MatMul->getArgOperand(i: 3));
1959 ShapeInfo RShape(MatMul->getArgOperand(i: 3), MatMul->getArgOperand(i: 4));
1960
1961 const MatrixTy &Lhs = getMatrix(MatrixVal: MatMul->getArgOperand(i: 0), SI: LShape, Builder);
1962 const MatrixTy &Rhs = getMatrix(MatrixVal: MatMul->getArgOperand(i: 1), SI: RShape, Builder);
1963 assert(Lhs.getElementType() == Rhs.getElementType() &&
1964 "Matrix multiply argument element types do not match.");
1965
1966 const unsigned R = LShape.NumRows;
1967 const unsigned C = RShape.NumColumns;
1968 assert(LShape.NumColumns == RShape.NumRows);
1969
1970 // Initialize the output
1971 MatrixTy Result(R, C, EltType);
1972 assert(Lhs.getElementType() == Result.getElementType() &&
1973 "Matrix multiply result element type does not match arguments.");
1974
1975 emitMatrixMultiply(Result, A: Lhs, B: Rhs, Builder, IsTiled: false, IsScalarMatrixTransposed: false,
1976 FMF: getFastMathFlags(Inst: MatMul));
1977 finalizeLowering(Inst: MatMul, Matrix: Result, Builder);
1978 }
1979
1980 /// Lowers llvm.matrix.transpose.
1981 void LowerTranspose(CallInst *Inst) {
1982 MatrixTy Result;
1983 IRBuilder<> Builder(Inst);
1984 Value *InputVal = Inst->getArgOperand(i: 0);
1985 VectorType *VectorTy = cast<VectorType>(Val: InputVal->getType());
1986 ShapeInfo ArgShape(Inst->getArgOperand(i: 1), Inst->getArgOperand(i: 2));
1987 MatrixTy InputMatrix = getMatrix(MatrixVal: InputVal, SI: ArgShape, Builder);
1988
1989 const unsigned NewNumVecs =
1990 InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
1991 const unsigned NewNumElts =
1992 InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
1993
1994 for (unsigned I = 0; I < NewNumVecs; ++I) {
1995 // Build a single result vector. First initialize it.
1996 Value *ResultVector = PoisonValue::get(
1997 T: FixedVectorType::get(ElementType: VectorTy->getElementType(), NumElts: NewNumElts));
1998 // Go through the old elements and insert it into the resulting vector.
1999 for (auto J : enumerate(First: InputMatrix.vectors())) {
2000 Value *Elt = Builder.CreateExtractElement(Vec: J.value(), Idx: I);
2001 // Row and column indices are transposed.
2002 ResultVector =
2003 Builder.CreateInsertElement(Vec: ResultVector, NewElt: Elt, Idx: J.index());
2004 }
2005 Result.addVector(V: ResultVector);
2006 }
2007
2008 // TODO: Improve estimate of operations needed for transposes. Currently we
2009 // just count the insertelement/extractelement instructions, but do not
2010 // account for later simplifications/combines.
2011 finalizeLowering(
2012 Inst,
2013 Matrix: Result.addNumComputeOps(N: 2 * ArgShape.NumRows * ArgShape.NumColumns)
2014 .addNumExposedTransposes(N: 1),
2015 Builder);
2016 }
2017
2018 /// Lower load instructions, if shape information is available.
2019 bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
2020 auto I = ShapeMap.find(Val: Inst);
2021 if (I == ShapeMap.end())
2022 return false;
2023
2024 LowerLoad(Inst, Ptr, Align: Inst->getAlign(),
2025 Stride: Builder.getInt64(C: I->second.getStride()), IsVolatile: Inst->isVolatile(),
2026 Shape: I->second);
2027 return true;
2028 }
2029
2030 bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
2031 IRBuilder<> &Builder) {
2032 auto I = ShapeMap.find(Val: StoredVal);
2033 if (I == ShapeMap.end())
2034 return false;
2035
2036 LowerStore(Inst, Matrix: StoredVal, Ptr, A: Inst->getAlign(),
2037 Stride: Builder.getInt64(C: I->second.getStride()), IsVolatile: Inst->isVolatile(),
2038 Shape: I->second);
2039 return true;
2040 }
2041
2042 /// Lower binary operators, if shape information is available.
2043 bool VisitBinaryOperator(BinaryOperator *Inst) {
2044 auto I = ShapeMap.find(Val: Inst);
2045 if (I == ShapeMap.end())
2046 return false;
2047
2048 Value *Lhs = Inst->getOperand(i_nocapture: 0);
2049 Value *Rhs = Inst->getOperand(i_nocapture: 1);
2050
2051 IRBuilder<> Builder(Inst);
2052 ShapeInfo &Shape = I->second;
2053
2054 MatrixTy Result;
2055 MatrixTy A = getMatrix(MatrixVal: Lhs, SI: Shape, Builder);
2056 MatrixTy B = getMatrix(MatrixVal: Rhs, SI: Shape, Builder);
2057 assert(A.isColumnMajor() == B.isColumnMajor() &&
2058 Result.isColumnMajor() == A.isColumnMajor() &&
2059 "operands must agree on matrix layout");
2060
2061 Builder.setFastMathFlags(getFastMathFlags(Inst));
2062
2063 // Helper to perform binary op on vectors.
2064 auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
2065 switch (Inst->getOpcode()) {
2066 case Instruction::Add:
2067 return Builder.CreateAdd(LHS, RHS);
2068 case Instruction::Mul:
2069 return Builder.CreateMul(LHS, RHS);
2070 case Instruction::Sub:
2071 return Builder.CreateSub(LHS, RHS);
2072 case Instruction::FAdd:
2073 return Builder.CreateFAdd(L: LHS, R: RHS);
2074 case Instruction::FMul:
2075 return Builder.CreateFMul(L: LHS, R: RHS);
2076 case Instruction::FSub:
2077 return Builder.CreateFSub(L: LHS, R: RHS);
2078 default:
2079 llvm_unreachable("Unsupported binary operator for matrix");
2080 }
2081 };
2082
2083 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2084 Result.addVector(V: BuildVectorOp(A.getVector(i: I), B.getVector(i: I)));
2085
2086 finalizeLowering(Inst,
2087 Matrix: Result.addNumComputeOps(N: getNumOps(VT: Result.getVectorTy()) *
2088 Result.getNumVectors()),
2089 Builder);
2090 return true;
2091 }
2092
2093 /// Lower unary operators, if shape information is available.
2094 bool VisitUnaryOperator(UnaryOperator *Inst) {
2095 auto I = ShapeMap.find(Val: Inst);
2096 if (I == ShapeMap.end())
2097 return false;
2098
2099 Value *Op = Inst->getOperand(i_nocapture: 0);
2100
2101 IRBuilder<> Builder(Inst);
2102 ShapeInfo &Shape = I->second;
2103
2104 MatrixTy Result;
2105 MatrixTy M = getMatrix(MatrixVal: Op, SI: Shape, Builder);
2106
2107 Builder.setFastMathFlags(getFastMathFlags(Inst));
2108
2109 // Helper to perform unary op on vectors.
2110 auto BuildVectorOp = [&Builder, Inst](Value *Op) {
2111 switch (Inst->getOpcode()) {
2112 case Instruction::FNeg:
2113 return Builder.CreateFNeg(V: Op);
2114 default:
2115 llvm_unreachable("Unsupported unary operator for matrix");
2116 }
2117 };
2118
2119 for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
2120 Result.addVector(V: BuildVectorOp(M.getVector(i: I)));
2121
2122 finalizeLowering(Inst,
2123 Matrix: Result.addNumComputeOps(N: getNumOps(VT: Result.getVectorTy()) *
2124 Result.getNumVectors()),
2125 Builder);
2126 return true;
2127 }
2128
2129 /// Helper to linearize a matrix expression tree into a string. Currently
2130 /// matrix expressions are linarized by starting at an expression leaf and
2131 /// linearizing bottom up.
2132 struct ExprLinearizer {
2133 unsigned LengthToBreak = 100;
2134 std::string Str;
2135 raw_string_ostream Stream;
2136 unsigned LineLength = 0;
2137 const DataLayout &DL;
2138
2139 /// Mapping from instructions to matrixes. It is used to identify
2140 /// matrix instructions.
2141 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2142
2143 /// Mapping from values to the leaves of all expressions that the value is
2144 /// part of.
2145 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
2146
2147 /// Set of matrix expressions in the scope of a given DISubprogram.
2148 const SmallSetVector<Value *, 32> &ExprsInSubprogram;
2149
2150 /// Leaf node of the expression to linearize.
2151 Value *Leaf;
2152
2153 /// Used to keep track of sub-expressions that get reused while linearizing
2154 /// the expression. Re-used sub-expressions are marked as (reused).
2155 SmallPtrSet<Value *, 8> ReusedExprs;
2156
2157 ExprLinearizer(const DataLayout &DL,
2158 const MapVector<Value *, MatrixTy> &Inst2Matrix,
2159 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2160 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2161 Value *Leaf)
2162 : Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
2163 ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
2164
2165 void indent(unsigned N) {
2166 LineLength += N;
2167 for (unsigned i = 0; i < N; i++)
2168 Stream << " ";
2169 }
2170
2171 void lineBreak() {
2172 Stream << "\n";
2173 LineLength = 0;
2174 }
2175
2176 void maybeIndent(unsigned Indent) {
2177 if (LineLength >= LengthToBreak)
2178 lineBreak();
2179
2180 if (LineLength == 0)
2181 indent(N: Indent);
2182 }
2183
2184 void write(StringRef S) {
2185 LineLength += S.size();
2186 Stream << S;
2187 }
2188
2189 Value *getUnderlyingObjectThroughLoads(Value *V) {
2190 if (Value *Ptr = getPointerOperand(V))
2191 return getUnderlyingObjectThroughLoads(V: Ptr);
2192 else if (V->getType()->isPointerTy())
2193 return getUnderlyingObject(V);
2194 return V;
2195 }
2196
2197 /// Returns true if \p V is a matrix value in the given subprogram.
2198 bool isMatrix(Value *V) const { return ExprsInSubprogram.count(key: V); }
2199
2200 /// If \p V is a matrix value, print its shape as NumRows x NumColumns to
2201 /// \p SS.
2202 void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
2203 auto M = Inst2Matrix.find(Key: V);
2204 if (M == Inst2Matrix.end())
2205 SS << "unknown";
2206 else {
2207 SS << M->second.getNumRows();
2208 SS << "x";
2209 SS << M->second.getNumColumns();
2210 }
2211 }
2212
2213 /// Write the called function name. Handles calls to llvm.matrix.*
2214 /// specially: we write the name, followed by the dimensions of the input
2215 /// matrixes, followed by the scalar type name.
2216 void writeFnName(CallInst *CI) {
2217 if (!CI->getCalledFunction())
2218 write(S: "<no called fn>");
2219 else {
2220 StringRef Name = CI->getCalledFunction()->getName();
2221 if (!Name.starts_with(Prefix: "llvm.matrix")) {
2222 write(S: Name);
2223 return;
2224 }
2225 auto *II = cast<IntrinsicInst>(Val: CI);
2226 write(S: Intrinsic::getBaseName(id: II->getIntrinsicID())
2227 .drop_front(N: StringRef("llvm.matrix.").size()));
2228 write(S: ".");
2229 std::string Tmp;
2230 raw_string_ostream SS(Tmp);
2231
2232 switch (II->getIntrinsicID()) {
2233 case Intrinsic::matrix_multiply:
2234 prettyPrintMatrixType(V: II->getOperand(i_nocapture: 0), SS);
2235 SS << ".";
2236 prettyPrintMatrixType(V: II->getOperand(i_nocapture: 1), SS);
2237 SS << "." << *II->getType()->getScalarType();
2238 break;
2239 case Intrinsic::matrix_transpose:
2240 prettyPrintMatrixType(V: II->getOperand(i_nocapture: 0), SS);
2241 SS << "." << *II->getType()->getScalarType();
2242 break;
2243 case Intrinsic::matrix_column_major_load:
2244 prettyPrintMatrixType(V: II, SS);
2245 SS << "." << *II->getType()->getScalarType();
2246 break;
2247 case Intrinsic::matrix_column_major_store:
2248 prettyPrintMatrixType(V: II->getOperand(i_nocapture: 0), SS);
2249 SS << "." << *II->getOperand(i_nocapture: 0)->getType()->getScalarType();
2250 break;
2251 default:
2252 llvm_unreachable("Unhandled case");
2253 }
2254 SS.flush();
2255 write(S: Tmp);
2256 }
2257 }
2258
2259 unsigned getNumShapeArgs(CallInst *CI) const {
2260 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: CI)) {
2261 switch (II->getIntrinsicID()) {
2262 case Intrinsic::matrix_multiply:
2263 return 3;
2264 case Intrinsic::matrix_transpose:
2265 return 2;
2266 case Intrinsic::matrix_column_major_load:
2267 case Intrinsic::matrix_column_major_store:
2268 return 3;
2269 default:
2270 return 0;
2271 }
2272 }
2273 return 0;
2274 }
2275
2276 /// Special printing for values: for pointers, we print if they refer to an
2277 /// (function) external address or a stack address, for other values we
2278 /// either print the constant or "scalar"/"matrix" for other values.
2279 void write(Value *V) {
2280 V = getUnderlyingObjectThroughLoads(V);
2281 if (V->getType()->isPointerTy()) {
2282 if (isa<AllocaInst>(Val: V)) {
2283 Stream << "stack addr";
2284 LineLength += StringRef("stack addr").size();
2285 } else {
2286 Stream << "addr";
2287 LineLength += StringRef("addr").size();
2288 }
2289 if (!V->getName().empty()) {
2290 Stream << " %" << V->getName() << "";
2291 LineLength += V->getName().size() + 2;
2292 }
2293 return;
2294 }
2295
2296 std::string Tmp;
2297 raw_string_ostream TmpStream(Tmp);
2298
2299 if (auto *CI = dyn_cast<ConstantInt>(Val: V))
2300 TmpStream << CI->getValue();
2301 else if (isa<Constant>(Val: V))
2302 TmpStream << "constant";
2303 else {
2304 if (isMatrix(V))
2305 TmpStream << "matrix";
2306 else
2307 TmpStream << "scalar";
2308 }
2309 TmpStream.flush();
2310 Tmp = std::string(StringRef(Tmp).trim());
2311 LineLength += Tmp.size();
2312 Stream << Tmp;
2313 }
2314
2315 /// Linearize expression \p Expr starting at an indentation of \p Indent.
2316 /// Expressions that are re-used multiple times are prefixed with (reused)
2317 /// at the re-used root instruction.
2318 void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
2319 bool ParentShared) {
2320 auto *I = cast<Instruction>(Val: Expr);
2321 maybeIndent(Indent);
2322 SmallVector<Value *, 8> Ops;
2323
2324 // Is Expr shared with other expression leaves?
2325 bool ExprShared = false;
2326
2327 // Deal with shared subtrees. Mark them as shared, if required.
2328 if (!ParentShared) {
2329 auto SI = Shared.find(Val: Expr);
2330 assert(SI != Shared.end() && SI->second.count(Leaf));
2331
2332 for (Value *S : SI->second) {
2333 if (S == Leaf)
2334 continue;
2335 DebugLoc DL = cast<Instruction>(Val: S)->getDebugLoc();
2336 write(S: "shared with remark at line " + std::to_string(val: DL.getLine()) +
2337 " column " + std::to_string(val: DL.getCol()) + " (");
2338 }
2339 ExprShared = SI->second.size() > 1;
2340 }
2341
2342 bool Reused = !ReusedExprs.insert(Ptr: Expr).second;
2343 if (Reused && !ParentReused)
2344 write(S: "(reused) ");
2345
2346 if (auto *CI = dyn_cast<CallInst>(Val: I)) {
2347 writeFnName(CI);
2348
2349 Ops.append(in_start: CI->arg_begin(), in_end: CI->arg_end() - getNumShapeArgs(CI));
2350 } else if (isa<BitCastInst>(Val: Expr)) {
2351 // Special case bitcasts, which are used to materialize matrixes from
2352 // non-matrix ops.
2353 write(S: "matrix");
2354 return;
2355 } else {
2356 Ops.append(in_start: I->value_op_begin(), in_end: I->value_op_end());
2357 write(S: std::string(I->getOpcodeName()));
2358 }
2359
2360 write(S: std::string("("));
2361
2362 unsigned NumOpsToBreak = 1;
2363 if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
2364 NumOpsToBreak = 2;
2365
2366 for (Value *Op : Ops) {
2367 if (Ops.size() > NumOpsToBreak)
2368 lineBreak();
2369
2370 maybeIndent(Indent: Indent + 1);
2371 if (isMatrix(V: Op))
2372 linearizeExpr(Expr: Op, Indent: Indent + 1, ParentReused: Reused, ParentShared: ExprShared);
2373 else
2374 write(V: Op);
2375 if (Op != Ops.back())
2376 write(S: ", ");
2377 }
2378
2379 write(S: ")");
2380 }
2381
2382 const std::string &getResult() {
2383 Stream.flush();
2384 return Str;
2385 }
2386 };
2387
2388 /// Generate remarks for matrix operations in a function. To generate remarks
2389 /// for matrix expressions, the following approach is used:
2390 /// 1. Use the inlined-at debug information to group matrix operations to the
2391 /// DISubprograms they are contained in.
2392 /// 2. Collect leaves of matrix expressions (done in
2393 /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
2394 // mapping. Leaves are lowered matrix instructions without other matrix
2395 // users (like stores) in the current subprogram.
2396 /// 3. For each leaf, create a remark containing a linearizied version of the
2397 /// matrix expression. The expression is linearized by a recursive
2398 /// bottom-up traversal of the matrix operands, starting at a leaf. Note
2399 /// that multiple leaves can share sub-expressions. Shared subexpressions
2400 /// are explicitly marked as shared().
2401 struct RemarkGenerator {
2402 const MapVector<Value *, MatrixTy> &Inst2Matrix;
2403 OptimizationRemarkEmitter &ORE;
2404 Function &Func;
2405 const DataLayout &DL;
2406
2407 RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
2408 OptimizationRemarkEmitter &ORE, Function &Func)
2409 : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
2410 DL(Func.getParent()->getDataLayout()) {}
2411
2412 /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
2413 /// instructions in Inst2Matrix returning void or without any users in
2414 /// \p ExprsInSubprogram. Currently that should only include stores.
2415 SmallVector<Value *, 4>
2416 getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
2417 SmallVector<Value *, 4> Leaves;
2418 for (auto *Expr : ExprsInSubprogram)
2419 if (Expr->getType()->isVoidTy() ||
2420 !any_of(Range: Expr->users(), P: [&ExprsInSubprogram](User *U) {
2421 return ExprsInSubprogram.count(key: U);
2422 }))
2423 Leaves.push_back(Elt: Expr);
2424 return Leaves;
2425 }
2426
2427 /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
2428 /// to all visited expressions in \p Shared. Limit the matrix operations to
2429 /// the ones in \p ExprsInSubprogram.
2430 void collectSharedInfo(Value *Leaf, Value *V,
2431 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2432 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
2433
2434 if (!ExprsInSubprogram.count(key: V))
2435 return;
2436
2437 auto I = Shared.insert(KV: {V, {}});
2438 I.first->second.insert(Ptr: Leaf);
2439
2440 for (Value *Op : cast<Instruction>(Val: V)->operand_values())
2441 collectSharedInfo(Leaf, V: Op, ExprsInSubprogram, Shared);
2442 }
2443
2444 /// Calculate the number of exclusive and shared op counts for expression
2445 /// starting at \p V. Expressions used multiple times are counted once.
2446 /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
2447 std::pair<OpInfoTy, OpInfoTy>
2448 sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
2449 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2450 DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
2451 if (!ExprsInSubprogram.count(key: Root))
2452 return {};
2453
2454 // Already counted this expression. Stop.
2455 if (!ReusedExprs.insert(Ptr: Root).second)
2456 return {};
2457
2458 OpInfoTy SharedCount;
2459 OpInfoTy Count;
2460
2461 auto I = Shared.find(Val: Root);
2462 auto CM = Inst2Matrix.find(Key: Root);
2463 if (I->second.size() == 1)
2464 Count = CM->second.getOpInfo();
2465 else
2466 SharedCount = CM->second.getOpInfo();
2467
2468 for (Value *Op : cast<Instruction>(Val: Root)->operand_values()) {
2469 auto C = sumOpInfos(Root: Op, ReusedExprs, ExprsInSubprogram, Shared);
2470 Count += C.first;
2471 SharedCount += C.second;
2472 }
2473 return {Count, SharedCount};
2474 }
2475
2476 void emitRemarks() {
2477 if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
2478 return;
2479
2480 // Map matrix operations to their containting subprograms, by traversing
2481 // the inlinedAt chain. If the function does not have a DISubprogram, we
2482 // only map them to the containing function.
2483 MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
2484 for (const auto &KV : Inst2Matrix) {
2485 if (Func.getSubprogram()) {
2486 auto *I = cast<Instruction>(Val: KV.first);
2487 DILocation *Context = I->getDebugLoc();
2488 while (Context) {
2489 auto I =
2490 Subprog2Exprs.insert(KV: {getSubprogram(Scope: Context->getScope()), {}});
2491 I.first->second.push_back(Elt: KV.first);
2492 Context = DebugLoc(Context).getInlinedAt();
2493 }
2494 } else {
2495 auto I = Subprog2Exprs.insert(KV: {nullptr, {}});
2496 I.first->second.push_back(Elt: KV.first);
2497 }
2498 }
2499 for (auto &KV : Subprog2Exprs) {
2500 SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
2501 KV.second.end());
2502 auto Leaves = getExpressionLeaves(ExprsInSubprogram);
2503
2504 DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
2505 for (Value *Leaf : Leaves)
2506 collectSharedInfo(Leaf, V: Leaf, ExprsInSubprogram, Shared);
2507
2508 // Generate remarks for each leaf.
2509 for (auto *L : Leaves) {
2510
2511 DebugLoc Loc = cast<Instruction>(Val: L)->getDebugLoc();
2512 DILocation *Context = cast<Instruction>(Val: L)->getDebugLoc();
2513 while (Context) {
2514 if (getSubprogram(Scope: Context->getScope()) == KV.first) {
2515 Loc = Context;
2516 break;
2517 }
2518 Context = DebugLoc(Context).getInlinedAt();
2519 }
2520
2521 SmallPtrSet<Value *, 8> ReusedExprs;
2522 OpInfoTy Counts, SharedCounts;
2523 std::tie(args&: Counts, args&: SharedCounts) =
2524 sumOpInfos(Root: L, ReusedExprs, ExprsInSubprogram, Shared);
2525
2526 OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
2527 cast<Instruction>(Val: L)->getParent());
2528
2529 Rem << "Lowered with ";
2530 Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
2531 << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
2532 << ore::NV("NumComputeOps", Counts.NumComputeOps)
2533 << " compute ops, "
2534 << ore::NV("NumExposedTransposes", Counts.NumExposedTransposes)
2535 << " exposed transposes";
2536
2537 if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
2538 SharedCounts.NumComputeOps > 0) {
2539 Rem << ",\nadditionally "
2540 << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
2541 << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
2542 << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
2543 << " compute ops"
2544 << " are shared with other expressions";
2545 }
2546
2547 Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
2548 ORE.emit(OptDiag&: Rem);
2549 }
2550 }
2551 }
2552
2553 std::string
2554 linearize(Value *L,
2555 const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
2556 const SmallSetVector<Value *, 32> &ExprsInSubprogram,
2557 const DataLayout &DL) {
2558 ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
2559 Lin.linearizeExpr(Expr: L, Indent: 0, ParentReused: false, ParentShared: false);
2560 return Lin.getResult();
2561 }
2562 };
2563};
2564} // namespace
2565
2566PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
2567 FunctionAnalysisManager &AM) {
2568 auto &TTI = AM.getResult<TargetIRAnalysis>(IR&: F);
2569 OptimizationRemarkEmitter *ORE = nullptr;
2570 AAResults *AA = nullptr;
2571 DominatorTree *DT = nullptr;
2572 LoopInfo *LI = nullptr;
2573
2574 if (!Minimal) {
2575 ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: F);
2576 AA = &AM.getResult<AAManager>(IR&: F);
2577 DT = &AM.getResult<DominatorTreeAnalysis>(IR&: F);
2578 LI = &AM.getResult<LoopAnalysis>(IR&: F);
2579 }
2580
2581 LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
2582 if (LMT.Visit()) {
2583 PreservedAnalyses PA;
2584 if (!Minimal) {
2585 PA.preserve<LoopAnalysis>();
2586 PA.preserve<DominatorTreeAnalysis>();
2587 }
2588 return PA;
2589 }
2590 return PreservedAnalyses::all();
2591}
2592
2593void LowerMatrixIntrinsicsPass::printPipeline(
2594 raw_ostream &OS, function_ref<StringRef(StringRef)> MapClassName2PassName) {
2595 static_cast<PassInfoMixin<LowerMatrixIntrinsicsPass> *>(this)->printPipeline(
2596 OS, MapClassName2PassName);
2597 OS << '<';
2598 if (Minimal)
2599 OS << "minimal";
2600 OS << '>';
2601}
2602

source code of llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp