1 | //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the AMX dialect and its operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/AMX/AMXDialect.h" |
14 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
15 | #include "mlir/IR/Builders.h" |
16 | #include "mlir/IR/OpImplementation.h" |
17 | #include "mlir/IR/TypeUtilities.h" |
18 | |
19 | using namespace mlir; |
20 | |
21 | #include "mlir/Dialect/AMX/AMXDialect.cpp.inc" |
22 | |
23 | void amx::AMXDialect::initialize() { |
24 | addOperations< |
25 | #define GET_OP_LIST |
26 | #include "mlir/Dialect/AMX/AMX.cpp.inc" |
27 | >(); |
28 | } |
29 | |
30 | /// Verify that AMX supports the implied tile shape. |
31 | static LogicalResult verifyTileSize(Operation *op, VectorType tp) { |
32 | const unsigned kMaxRows = 16; |
33 | const unsigned kBitsPerRow = 64 * 8; |
34 | unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth(); |
35 | if (tp.getDimSize(0) > kMaxRows) |
36 | return op->emitOpError(message: "bad row height: " ) << tp.getDimSize(0); |
37 | if (col > kBitsPerRow || col & 0x1f) |
38 | return op->emitOpError(message: "bad column width: " ) << (col >> 3); |
39 | return success(); |
40 | } |
41 | |
42 | /// Verify that AMX supports the multiplication. |
43 | static LogicalResult verifyMultShape(Operation *op, VectorType atp, |
44 | VectorType btp, VectorType ctp, |
45 | unsigned scale) { |
46 | unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale; |
47 | unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; |
48 | unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); |
49 | if (cm != am || cn != bn || ak != bk) |
50 | return op->emitOpError(message: "bad mult shape: " ) |
51 | << cm << " x " << cn << " x " << ak; |
52 | return success(); |
53 | } |
54 | |
55 | LogicalResult amx::TileZeroOp::verify() { |
56 | return verifyTileSize(*this, getVectorType()); |
57 | } |
58 | |
59 | LogicalResult amx::TileLoadOp::verify() { |
60 | unsigned rank = getMemRefType().getRank(); |
61 | if (getIndices().size() != rank) |
62 | return emitOpError("requires " ) << rank << " indices" ; |
63 | return verifyTileSize(*this, getVectorType()); |
64 | } |
65 | |
66 | LogicalResult amx::TileStoreOp::verify() { |
67 | unsigned rank = getMemRefType().getRank(); |
68 | if (getIndices().size() != rank) |
69 | return emitOpError("requires " ) << rank << " indices" ; |
70 | return verifyTileSize(*this, getVectorType()); |
71 | } |
72 | |
73 | LogicalResult amx::TileMulFOp::verify() { |
74 | VectorType aType = getLhsVectorType(); |
75 | VectorType bType = getRhsVectorType(); |
76 | VectorType cType = getVectorType(); |
77 | if (failed(verifyTileSize(*this, aType)) || |
78 | failed(verifyTileSize(*this, bType)) || |
79 | failed(verifyTileSize(*this, cType)) || |
80 | failed(verifyMultShape(*this, aType, bType, cType, 1))) |
81 | return failure(); |
82 | Type ta = aType.getElementType(); |
83 | Type tb = bType.getElementType(); |
84 | Type tc = cType.getElementType(); |
85 | if (!ta.isBF16() || !tb.isBF16() || !tc.isF32()) |
86 | return emitOpError("unsupported type combination" ); |
87 | return success(); |
88 | } |
89 | |
90 | LogicalResult amx::TileMulIOp::verify() { |
91 | VectorType aType = getLhsVectorType(); |
92 | VectorType bType = getRhsVectorType(); |
93 | VectorType cType = getVectorType(); |
94 | if (failed(verifyTileSize(*this, aType)) || |
95 | failed(verifyTileSize(*this, bType)) || |
96 | failed(verifyTileSize(*this, cType)) || |
97 | failed(verifyMultShape(*this, aType, bType, cType, 2))) |
98 | return failure(); |
99 | Type ta = aType.getElementType(); |
100 | Type tb = bType.getElementType(); |
101 | Type tc = cType.getElementType(); |
102 | if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) |
103 | return emitOpError("unsupported type combination" ); |
104 | return success(); |
105 | } |
106 | |
107 | #define GET_OP_CLASSES |
108 | #include "mlir/Dialect/AMX/AMX.cpp.inc" |
109 | |