1 | //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the AMX dialect and its operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/AMX/AMXDialect.h" |
14 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
16 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/DialectImplementation.h" |
19 | #include "mlir/IR/OpImplementation.h" |
20 | #include "mlir/IR/TypeUtilities.h" |
21 | |
22 | #include "llvm/ADT/TypeSwitch.h" |
23 | |
24 | using namespace mlir; |
25 | |
26 | #include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc" |
27 | |
28 | #include "mlir/Dialect/AMX/AMXDialect.cpp.inc" |
29 | |
30 | void amx::AMXDialect::initialize() { |
31 | addTypes< |
32 | #define GET_TYPEDEF_LIST |
33 | #include "mlir/Dialect/AMX/AMXTypes.cpp.inc" |
34 | >(); |
35 | |
36 | addOperations< |
37 | #define GET_OP_LIST |
38 | #include "mlir/Dialect/AMX/AMX.cpp.inc" |
39 | >(); |
40 | } |
41 | |
42 | /// Verify that AMX supports the implied tile shape. |
43 | static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) { |
44 | const unsigned kMaxRows = 16; |
45 | const unsigned kBitsPerRow = 64 * 8; |
46 | unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth(); |
47 | if (tp.getDimSize(0) > kMaxRows) |
48 | return op->emitOpError(message: "bad row height: " ) << tp.getDimSize(0); |
49 | if (col > kBitsPerRow || col & 0x1f) |
50 | return op->emitOpError(message: "bad column width: " ) << (col >> 3); |
51 | return success(); |
52 | } |
53 | |
54 | /// Verify that AMX supports the multiplication. |
55 | static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, |
56 | amx::TileType btp, amx::TileType ctp, |
57 | unsigned scale) { |
58 | unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale; |
59 | unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; |
60 | unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); |
61 | if (cm != am || cn != bn || ak != bk) |
62 | return op->emitOpError(message: "bad mult shape: " ) |
63 | << cm << " x " << cn << " x " << ak; |
64 | return success(); |
65 | } |
66 | |
67 | /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first |
68 | /// dimension directly translates into the number of rows of the tiles. |
69 | /// The second dimensions needs to be scaled by the number of bytes. |
70 | static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType, |
71 | RewriterBase &rewriter) { |
72 | Type llvmInt16Type = rewriter.getIntegerType(16); |
73 | unsigned width = tType.getElementType().getIntOrFloatBitWidth(); |
74 | assert(llvm::isPowerOf2_64(width) && width >= 8); |
75 | unsigned bytes = width >> 3; |
76 | auto mattr = rewriter.getI16IntegerAttr(value: tType.getDimSize(0)); |
77 | auto nattr = rewriter.getI16IntegerAttr(value: tType.getDimSize(1) * bytes); |
78 | return SmallVector<Value>{ |
79 | rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr), |
80 | rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)}; |
81 | } |
82 | |
83 | /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer |
84 | /// shape may "envelop" the actual tile shape, and may be dynamically sized. |
85 | static Value getStride(Location loc, MemRefType mType, Value base, |
86 | RewriterBase &rewriter) { |
87 | assert(mType.getRank() >= 2 && "Invalid shape for AMX strides" ); |
88 | int64_t preLast = mType.getRank() - 2; |
89 | Type llvmInt64Type = rewriter.getIntegerType(64); |
90 | unsigned width = mType.getElementType().getIntOrFloatBitWidth(); |
91 | assert(llvm::isPowerOf2_64(width) && width >= 8); |
92 | unsigned bytes = width >> 3; |
93 | auto [strides, offset] = mType.getStridesAndOffset(); |
94 | if (strides[preLast] == ShapedType::kDynamic) { |
95 | // Dynamic stride needs code to compute the stride at runtime. |
96 | MemRefDescriptor memrefDescriptor(base); |
97 | auto attr = rewriter.getI64IntegerAttr(bytes); |
98 | Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); |
99 | return rewriter |
100 | .create<LLVM::MulOp>(loc, llvmInt64Type, scale, |
101 | memrefDescriptor.stride(rewriter, loc, preLast)) |
102 | .getResult(); |
103 | } |
104 | // Use direct constant for static stride. |
105 | auto attr = rewriter.getI64IntegerAttr(value: strides[preLast] * bytes); |
106 | return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr) |
107 | .getResult(); |
108 | } |
109 | |
110 | LogicalResult amx::TileZeroOp::verify() { |
111 | return verifyTileSize(*this, getTileType()); |
112 | } |
113 | |
114 | SmallVector<Value> |
115 | amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands, |
116 | const LLVMTypeConverter &typeConverter, |
117 | RewriterBase &rewriter) { |
118 | return getTileSizes(getLoc(), getTileType(), rewriter); |
119 | } |
120 | |
121 | LogicalResult amx::TileLoadOp::verify() { |
122 | MemRefType memrefTy = getMemRefType(); |
123 | unsigned rank = memrefTy.getRank(); |
124 | if (rank < 2) |
125 | return emitOpError("requires at least 2D memref" ); |
126 | if (getIndices().size() != rank) |
127 | return emitOpError("requires " ) << rank << " indices" ; |
128 | SmallVector<int64_t> strides; |
129 | int64_t offset; |
130 | if (failed(memrefTy.getStridesAndOffset(strides, offset)) || |
131 | strides.back() != 1) |
132 | return emitOpError("requires memref with unit innermost stride" ); |
133 | return verifyTileSize(*this, getTileType()); |
134 | } |
135 | |
136 | SmallVector<Value> |
137 | amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands, |
138 | const LLVMTypeConverter &typeConverter, |
139 | RewriterBase &rewriter) { |
140 | auto loc = getLoc(); |
141 | Adaptor adaptor(operands, *this); |
142 | |
143 | SmallVector<Value> intrinsicOperands; |
144 | intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); |
145 | intrinsicOperands.push_back( |
146 | LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), |
147 | adaptor.getBase(), adaptor.getIndices())); |
148 | intrinsicOperands.push_back( |
149 | getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); |
150 | |
151 | return intrinsicOperands; |
152 | } |
153 | |
154 | LogicalResult amx::TileStoreOp::verify() { |
155 | MemRefType memrefTy = getMemRefType(); |
156 | unsigned rank = memrefTy.getRank(); |
157 | if (rank < 2) |
158 | return emitOpError("requires at least 2D memref" ); |
159 | if (getIndices().size() != rank) |
160 | return emitOpError("requires " ) << rank << " indices" ; |
161 | SmallVector<int64_t> strides; |
162 | int64_t offset; |
163 | if (failed(memrefTy.getStridesAndOffset(strides, offset)) || |
164 | strides.back() != 1) |
165 | return emitOpError("requires memref with unit innermost stride" ); |
166 | return verifyTileSize(*this, getTileType()); |
167 | } |
168 | |
169 | SmallVector<Value> |
170 | amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands, |
171 | const LLVMTypeConverter &typeConverter, |
172 | RewriterBase &rewriter) { |
173 | auto loc = getLoc(); |
174 | Adaptor adaptor(operands, *this); |
175 | |
176 | SmallVector<Value> intrinsicOperands; |
177 | intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); |
178 | intrinsicOperands.push_back( |
179 | LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), |
180 | adaptor.getBase(), adaptor.getIndices())); |
181 | intrinsicOperands.push_back( |
182 | getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); |
183 | intrinsicOperands.push_back(adaptor.getVal()); |
184 | |
185 | return intrinsicOperands; |
186 | } |
187 | |
188 | LogicalResult amx::TileMulFOp::verify() { |
189 | amx::TileType aType = getLhsTileType(); |
190 | amx::TileType bType = getRhsTileType(); |
191 | amx::TileType cType = getTileType(); |
192 | if (failed(verifyTileSize(*this, aType)) || |
193 | failed(verifyTileSize(*this, bType)) || |
194 | failed(verifyTileSize(*this, cType)) || |
195 | failed(verifyMultShape(*this, aType, bType, cType, 1))) |
196 | return failure(); |
197 | Type ta = aType.getElementType(); |
198 | Type tb = bType.getElementType(); |
199 | Type tc = cType.getElementType(); |
200 | if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32()) |
201 | return emitOpError("unsupported type combination" ); |
202 | return success(); |
203 | } |
204 | |
205 | SmallVector<Value> |
206 | amx::TileMulFOp::getIntrinsicOperands(ArrayRef<Value> operands, |
207 | const LLVMTypeConverter &typeConverter, |
208 | RewriterBase &rewriter) { |
209 | auto loc = getLoc(); |
210 | Adaptor adaptor(operands, *this); |
211 | |
212 | amx::TileType aType = getLhsTileType(); |
213 | amx::TileType bType = getRhsTileType(); |
214 | SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter); |
215 | SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter); |
216 | |
217 | SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1], |
218 | tsza[1], adaptor.getAcc(), |
219 | adaptor.getLhs(), adaptor.getRhs()}; |
220 | |
221 | return intrinsicOperands; |
222 | } |
223 | |
224 | LogicalResult amx::TileMulIOp::verify() { |
225 | amx::TileType aType = getLhsTileType(); |
226 | amx::TileType bType = getRhsTileType(); |
227 | amx::TileType cType = getTileType(); |
228 | if (failed(verifyTileSize(*this, aType)) || |
229 | failed(verifyTileSize(*this, bType)) || |
230 | failed(verifyTileSize(*this, cType)) || |
231 | failed(verifyMultShape(*this, aType, bType, cType, 2))) |
232 | return failure(); |
233 | Type ta = aType.getElementType(); |
234 | Type tb = bType.getElementType(); |
235 | Type tc = cType.getElementType(); |
236 | if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) |
237 | return emitOpError("unsupported type combination" ); |
238 | return success(); |
239 | } |
240 | |
241 | SmallVector<Value> |
242 | amx::TileMulIOp::getIntrinsicOperands(ArrayRef<Value> operands, |
243 | const LLVMTypeConverter &typeConverter, |
244 | RewriterBase &rewriter) { |
245 | auto loc = getLoc(); |
246 | Adaptor adaptor(operands, *this); |
247 | |
248 | amx::TileType aType = getLhsTileType(); |
249 | amx::TileType bType = getRhsTileType(); |
250 | SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter); |
251 | SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter); |
252 | |
253 | SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1], |
254 | tsza[1], adaptor.getAcc(), |
255 | adaptor.getLhs(), adaptor.getRhs()}; |
256 | |
257 | return intrinsicOperands; |
258 | } |
259 | |
260 | Type amx::TileType::parse(AsmParser &parser) { |
261 | if (parser.parseLess()) |
262 | return nullptr; |
263 | |
264 | SmallVector<int64_t, 2> shape; |
265 | if (parser.parseDimensionList(shape, false, true)) |
266 | return nullptr; |
267 | |
268 | Type elementType; |
269 | if (parser.parseType(elementType)) |
270 | return nullptr; |
271 | |
272 | if (parser.parseGreater()) |
273 | return nullptr; |
274 | |
275 | return TileType::get(shape, elementType); |
276 | } |
277 | |
278 | void amx::TileType::print(AsmPrinter &os) const { |
279 | os << "<" ; |
280 | os.printDimensionList(getShape()); |
281 | os << 'x'; |
282 | os.printType(getElementType()); |
283 | os << '>'; |
284 | } |
285 | |
286 | #define GET_OP_CLASSES |
287 | #include "mlir/Dialect/AMX/AMX.cpp.inc" |
288 | |
289 | #define GET_TYPEDEF_CLASSES |
290 | #include "mlir/Dialect/AMX/AMXTypes.cpp.inc" |
291 | |