| 1 | //===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the AMX dialect and its operations. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/AMX/AMXDialect.h" |
| 14 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| 15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 16 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| 17 | #include "mlir/IR/Builders.h" |
| 18 | #include "mlir/IR/DialectImplementation.h" |
| 19 | #include "mlir/IR/OpImplementation.h" |
| 20 | #include "mlir/IR/TypeUtilities.h" |
| 21 | |
| 22 | #include "llvm/ADT/TypeSwitch.h" |
| 23 | |
| 24 | using namespace mlir; |
| 25 | |
| 26 | #include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc" |
| 27 | |
| 28 | #include "mlir/Dialect/AMX/AMXDialect.cpp.inc" |
| 29 | |
| 30 | void amx::AMXDialect::initialize() { |
| 31 | addTypes< |
| 32 | #define GET_TYPEDEF_LIST |
| 33 | #include "mlir/Dialect/AMX/AMXTypes.cpp.inc" |
| 34 | >(); |
| 35 | |
| 36 | addOperations< |
| 37 | #define GET_OP_LIST |
| 38 | #include "mlir/Dialect/AMX/AMX.cpp.inc" |
| 39 | >(); |
| 40 | } |
| 41 | |
| 42 | /// Verify that AMX supports the implied tile shape. |
| 43 | static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) { |
| 44 | const unsigned kMaxRows = 16; |
| 45 | const unsigned kBitsPerRow = 64 * 8; |
| 46 | unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth(); |
| 47 | if (tp.getDimSize(0) > kMaxRows) |
| 48 | return op->emitOpError(message: "bad row height: " ) << tp.getDimSize(0); |
| 49 | if (col > kBitsPerRow || col & 0x1f) |
| 50 | return op->emitOpError(message: "bad column width: " ) << (col >> 3); |
| 51 | return success(); |
| 52 | } |
| 53 | |
| 54 | /// Verify that AMX supports the multiplication. |
| 55 | static LogicalResult verifyMultShape(Operation *op, amx::TileType atp, |
| 56 | amx::TileType btp, amx::TileType ctp, |
| 57 | unsigned scale) { |
| 58 | unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale; |
| 59 | unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; |
| 60 | unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); |
| 61 | if (cm != am || cn != bn || ak != bk) |
| 62 | return op->emitOpError(message: "bad mult shape: " ) |
| 63 | << cm << " x " << cn << " x " << ak; |
| 64 | return success(); |
| 65 | } |
| 66 | |
| 67 | /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first |
| 68 | /// dimension directly translates into the number of rows of the tiles. |
| 69 | /// The second dimensions needs to be scaled by the number of bytes. |
| 70 | static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType, |
| 71 | RewriterBase &rewriter) { |
| 72 | Type llvmInt16Type = rewriter.getIntegerType(16); |
| 73 | unsigned width = tType.getElementType().getIntOrFloatBitWidth(); |
| 74 | assert(llvm::isPowerOf2_64(width) && width >= 8); |
| 75 | unsigned bytes = width >> 3; |
| 76 | auto mattr = rewriter.getI16IntegerAttr(value: tType.getDimSize(0)); |
| 77 | auto nattr = rewriter.getI16IntegerAttr(value: tType.getDimSize(1) * bytes); |
| 78 | return SmallVector<Value>{ |
| 79 | rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr), |
| 80 | rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)}; |
| 81 | } |
| 82 | |
| 83 | /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer |
| 84 | /// shape may "envelop" the actual tile shape, and may be dynamically sized. |
| 85 | static Value getStride(Location loc, MemRefType mType, Value base, |
| 86 | RewriterBase &rewriter) { |
| 87 | assert(mType.getRank() >= 2 && "Invalid shape for AMX strides" ); |
| 88 | int64_t preLast = mType.getRank() - 2; |
| 89 | Type llvmInt64Type = rewriter.getIntegerType(64); |
| 90 | unsigned width = mType.getElementType().getIntOrFloatBitWidth(); |
| 91 | assert(llvm::isPowerOf2_64(width) && width >= 8); |
| 92 | unsigned bytes = width >> 3; |
| 93 | auto [strides, offset] = mType.getStridesAndOffset(); |
| 94 | if (strides[preLast] == ShapedType::kDynamic) { |
| 95 | // Dynamic stride needs code to compute the stride at runtime. |
| 96 | MemRefDescriptor memrefDescriptor(base); |
| 97 | auto attr = rewriter.getI64IntegerAttr(bytes); |
| 98 | Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); |
| 99 | return rewriter |
| 100 | .create<LLVM::MulOp>(loc, llvmInt64Type, scale, |
| 101 | memrefDescriptor.stride(rewriter, loc, preLast)) |
| 102 | .getResult(); |
| 103 | } |
| 104 | // Use direct constant for static stride. |
| 105 | auto attr = rewriter.getI64IntegerAttr(value: strides[preLast] * bytes); |
| 106 | return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr) |
| 107 | .getResult(); |
| 108 | } |
| 109 | |
| 110 | LogicalResult amx::TileZeroOp::verify() { |
| 111 | return verifyTileSize(*this, getTileType()); |
| 112 | } |
| 113 | |
| 114 | SmallVector<Value> |
| 115 | amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands, |
| 116 | const LLVMTypeConverter &typeConverter, |
| 117 | RewriterBase &rewriter) { |
| 118 | return getTileSizes(getLoc(), getTileType(), rewriter); |
| 119 | } |
| 120 | |
| 121 | LogicalResult amx::TileLoadOp::verify() { |
| 122 | MemRefType memrefTy = getMemRefType(); |
| 123 | unsigned rank = memrefTy.getRank(); |
| 124 | if (rank < 2) |
| 125 | return emitOpError("requires at least 2D memref" ); |
| 126 | if (getIndices().size() != rank) |
| 127 | return emitOpError("requires " ) << rank << " indices" ; |
| 128 | SmallVector<int64_t> strides; |
| 129 | int64_t offset; |
| 130 | if (failed(memrefTy.getStridesAndOffset(strides, offset)) || |
| 131 | strides.back() != 1) |
| 132 | return emitOpError("requires memref with unit innermost stride" ); |
| 133 | return verifyTileSize(*this, getTileType()); |
| 134 | } |
| 135 | |
| 136 | SmallVector<Value> |
| 137 | amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands, |
| 138 | const LLVMTypeConverter &typeConverter, |
| 139 | RewriterBase &rewriter) { |
| 140 | auto loc = getLoc(); |
| 141 | Adaptor adaptor(operands, *this); |
| 142 | |
| 143 | SmallVector<Value> intrinsicOperands; |
| 144 | intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); |
| 145 | intrinsicOperands.push_back( |
| 146 | LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), |
| 147 | adaptor.getBase(), adaptor.getIndices())); |
| 148 | intrinsicOperands.push_back( |
| 149 | getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); |
| 150 | |
| 151 | return intrinsicOperands; |
| 152 | } |
| 153 | |
| 154 | LogicalResult amx::TileStoreOp::verify() { |
| 155 | MemRefType memrefTy = getMemRefType(); |
| 156 | unsigned rank = memrefTy.getRank(); |
| 157 | if (rank < 2) |
| 158 | return emitOpError("requires at least 2D memref" ); |
| 159 | if (getIndices().size() != rank) |
| 160 | return emitOpError("requires " ) << rank << " indices" ; |
| 161 | SmallVector<int64_t> strides; |
| 162 | int64_t offset; |
| 163 | if (failed(memrefTy.getStridesAndOffset(strides, offset)) || |
| 164 | strides.back() != 1) |
| 165 | return emitOpError("requires memref with unit innermost stride" ); |
| 166 | return verifyTileSize(*this, getTileType()); |
| 167 | } |
| 168 | |
| 169 | SmallVector<Value> |
| 170 | amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands, |
| 171 | const LLVMTypeConverter &typeConverter, |
| 172 | RewriterBase &rewriter) { |
| 173 | auto loc = getLoc(); |
| 174 | Adaptor adaptor(operands, *this); |
| 175 | |
| 176 | SmallVector<Value> intrinsicOperands; |
| 177 | intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter)); |
| 178 | intrinsicOperands.push_back( |
| 179 | LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(), |
| 180 | adaptor.getBase(), adaptor.getIndices())); |
| 181 | intrinsicOperands.push_back( |
| 182 | getStride(loc, getMemRefType(), adaptor.getBase(), rewriter)); |
| 183 | intrinsicOperands.push_back(adaptor.getVal()); |
| 184 | |
| 185 | return intrinsicOperands; |
| 186 | } |
| 187 | |
| 188 | LogicalResult amx::TileMulFOp::verify() { |
| 189 | amx::TileType aType = getLhsTileType(); |
| 190 | amx::TileType bType = getRhsTileType(); |
| 191 | amx::TileType cType = getTileType(); |
| 192 | if (failed(verifyTileSize(*this, aType)) || |
| 193 | failed(verifyTileSize(*this, bType)) || |
| 194 | failed(verifyTileSize(*this, cType)) || |
| 195 | failed(verifyMultShape(*this, aType, bType, cType, 1))) |
| 196 | return failure(); |
| 197 | Type ta = aType.getElementType(); |
| 198 | Type tb = bType.getElementType(); |
| 199 | Type tc = cType.getElementType(); |
| 200 | if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32()) |
| 201 | return emitOpError("unsupported type combination" ); |
| 202 | return success(); |
| 203 | } |
| 204 | |
| 205 | SmallVector<Value> |
| 206 | amx::TileMulFOp::getIntrinsicOperands(ArrayRef<Value> operands, |
| 207 | const LLVMTypeConverter &typeConverter, |
| 208 | RewriterBase &rewriter) { |
| 209 | auto loc = getLoc(); |
| 210 | Adaptor adaptor(operands, *this); |
| 211 | |
| 212 | amx::TileType aType = getLhsTileType(); |
| 213 | amx::TileType bType = getRhsTileType(); |
| 214 | SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter); |
| 215 | SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter); |
| 216 | |
| 217 | SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1], |
| 218 | tsza[1], adaptor.getAcc(), |
| 219 | adaptor.getLhs(), adaptor.getRhs()}; |
| 220 | |
| 221 | return intrinsicOperands; |
| 222 | } |
| 223 | |
| 224 | LogicalResult amx::TileMulIOp::verify() { |
| 225 | amx::TileType aType = getLhsTileType(); |
| 226 | amx::TileType bType = getRhsTileType(); |
| 227 | amx::TileType cType = getTileType(); |
| 228 | if (failed(verifyTileSize(*this, aType)) || |
| 229 | failed(verifyTileSize(*this, bType)) || |
| 230 | failed(verifyTileSize(*this, cType)) || |
| 231 | failed(verifyMultShape(*this, aType, bType, cType, 2))) |
| 232 | return failure(); |
| 233 | Type ta = aType.getElementType(); |
| 234 | Type tb = bType.getElementType(); |
| 235 | Type tc = cType.getElementType(); |
| 236 | if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) |
| 237 | return emitOpError("unsupported type combination" ); |
| 238 | return success(); |
| 239 | } |
| 240 | |
| 241 | SmallVector<Value> |
| 242 | amx::TileMulIOp::getIntrinsicOperands(ArrayRef<Value> operands, |
| 243 | const LLVMTypeConverter &typeConverter, |
| 244 | RewriterBase &rewriter) { |
| 245 | auto loc = getLoc(); |
| 246 | Adaptor adaptor(operands, *this); |
| 247 | |
| 248 | amx::TileType aType = getLhsTileType(); |
| 249 | amx::TileType bType = getRhsTileType(); |
| 250 | SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter); |
| 251 | SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter); |
| 252 | |
| 253 | SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1], |
| 254 | tsza[1], adaptor.getAcc(), |
| 255 | adaptor.getLhs(), adaptor.getRhs()}; |
| 256 | |
| 257 | return intrinsicOperands; |
| 258 | } |
| 259 | |
| 260 | Type amx::TileType::parse(AsmParser &parser) { |
| 261 | if (parser.parseLess()) |
| 262 | return nullptr; |
| 263 | |
| 264 | SmallVector<int64_t, 2> shape; |
| 265 | if (parser.parseDimensionList(shape, false, true)) |
| 266 | return nullptr; |
| 267 | |
| 268 | Type elementType; |
| 269 | if (parser.parseType(elementType)) |
| 270 | return nullptr; |
| 271 | |
| 272 | if (parser.parseGreater()) |
| 273 | return nullptr; |
| 274 | |
| 275 | return TileType::get(shape, elementType); |
| 276 | } |
| 277 | |
| 278 | void amx::TileType::print(AsmPrinter &os) const { |
| 279 | os << "<" ; |
| 280 | os.printDimensionList(getShape()); |
| 281 | os << 'x'; |
| 282 | os.printType(getElementType()); |
| 283 | os << '>'; |
| 284 | } |
| 285 | |
| 286 | #define GET_OP_CLASSES |
| 287 | #include "mlir/Dialect/AMX/AMX.cpp.inc" |
| 288 | |
| 289 | #define GET_TYPEDEF_CLASSES |
| 290 | #include "mlir/Dialect/AMX/AMXTypes.cpp.inc" |
| 291 | |