1//===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMX dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/AMX/AMXDialect.h"
14#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/OpImplementation.h"
17#include "mlir/IR/TypeUtilities.h"
18
19using namespace mlir;
20
21#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
22
23void amx::AMXDialect::initialize() {
24 addOperations<
25#define GET_OP_LIST
26#include "mlir/Dialect/AMX/AMX.cpp.inc"
27 >();
28}
29
30/// Verify that AMX supports the implied tile shape.
31static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
32 const unsigned kMaxRows = 16;
33 const unsigned kBitsPerRow = 64 * 8;
34 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
35 if (tp.getDimSize(0) > kMaxRows)
36 return op->emitOpError(message: "bad row height: ") << tp.getDimSize(0);
37 if (col > kBitsPerRow || col & 0x1f)
38 return op->emitOpError(message: "bad column width: ") << (col >> 3);
39 return success();
40}
41
42/// Verify that AMX supports the multiplication.
43static LogicalResult verifyMultShape(Operation *op, VectorType atp,
44 VectorType btp, VectorType ctp,
45 unsigned scale) {
46 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
47 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
48 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
49 if (cm != am || cn != bn || ak != bk)
50 return op->emitOpError(message: "bad mult shape: ")
51 << cm << " x " << cn << " x " << ak;
52 return success();
53}
54
55LogicalResult amx::TileZeroOp::verify() {
56 return verifyTileSize(*this, getVectorType());
57}
58
59LogicalResult amx::TileLoadOp::verify() {
60 unsigned rank = getMemRefType().getRank();
61 if (getIndices().size() != rank)
62 return emitOpError("requires ") << rank << " indices";
63 return verifyTileSize(*this, getVectorType());
64}
65
66LogicalResult amx::TileStoreOp::verify() {
67 unsigned rank = getMemRefType().getRank();
68 if (getIndices().size() != rank)
69 return emitOpError("requires ") << rank << " indices";
70 return verifyTileSize(*this, getVectorType());
71}
72
73LogicalResult amx::TileMulFOp::verify() {
74 VectorType aType = getLhsVectorType();
75 VectorType bType = getRhsVectorType();
76 VectorType cType = getVectorType();
77 if (failed(verifyTileSize(*this, aType)) ||
78 failed(verifyTileSize(*this, bType)) ||
79 failed(verifyTileSize(*this, cType)) ||
80 failed(verifyMultShape(*this, aType, bType, cType, 1)))
81 return failure();
82 Type ta = aType.getElementType();
83 Type tb = bType.getElementType();
84 Type tc = cType.getElementType();
85 if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
86 return emitOpError("unsupported type combination");
87 return success();
88}
89
90LogicalResult amx::TileMulIOp::verify() {
91 VectorType aType = getLhsVectorType();
92 VectorType bType = getRhsVectorType();
93 VectorType cType = getVectorType();
94 if (failed(verifyTileSize(*this, aType)) ||
95 failed(verifyTileSize(*this, bType)) ||
96 failed(verifyTileSize(*this, cType)) ||
97 failed(verifyMultShape(*this, aType, bType, cType, 2)))
98 return failure();
99 Type ta = aType.getElementType();
100 Type tb = bType.getElementType();
101 Type tc = cType.getElementType();
102 if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
103 return emitOpError("unsupported type combination");
104 return success();
105}
106
107#define GET_OP_CLASSES
108#include "mlir/Dialect/AMX/AMX.cpp.inc"
109

source code of mlir/lib/Dialect/AMX/IR/AMXDialect.cpp