1//===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the AMX dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/AMX/AMXDialect.h"
14#include "mlir/Conversion/LLVMCommon/Pattern.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/DialectImplementation.h"
19#include "mlir/IR/OpImplementation.h"
20#include "mlir/IR/TypeUtilities.h"
21
22#include "llvm/ADT/TypeSwitch.h"
23
24using namespace mlir;
25
26#include "mlir/Dialect/AMX/AMXInterfaces.cpp.inc"
27
28#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
29
30void amx::AMXDialect::initialize() {
31 addTypes<
32#define GET_TYPEDEF_LIST
33#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
34 >();
35
36 addOperations<
37#define GET_OP_LIST
38#include "mlir/Dialect/AMX/AMX.cpp.inc"
39 >();
40}
41
42/// Verify that AMX supports the implied tile shape.
43static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
44 const unsigned kMaxRows = 16;
45 const unsigned kBitsPerRow = 64 * 8;
46 unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
47 if (tp.getDimSize(0) > kMaxRows)
48 return op->emitOpError(message: "bad row height: ") << tp.getDimSize(0);
49 if (col > kBitsPerRow || col & 0x1f)
50 return op->emitOpError(message: "bad column width: ") << (col >> 3);
51 return success();
52}
53
54/// Verify that AMX supports the multiplication.
55static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
56 amx::TileType btp, amx::TileType ctp,
57 unsigned scale) {
58 unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
59 unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
60 unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
61 if (cm != am || cn != bn || ak != bk)
62 return op->emitOpError(message: "bad mult shape: ")
63 << cm << " x " << cn << " x " << ak;
64 return success();
65}
66
67/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
68/// dimension directly translates into the number of rows of the tiles.
69/// The second dimensions needs to be scaled by the number of bytes.
70static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
71 RewriterBase &rewriter) {
72 Type llvmInt16Type = rewriter.getIntegerType(16);
73 unsigned width = tType.getElementType().getIntOrFloatBitWidth();
74 assert(llvm::isPowerOf2_64(width) && width >= 8);
75 unsigned bytes = width >> 3;
76 auto mattr = rewriter.getI16IntegerAttr(value: tType.getDimSize(0));
77 auto nattr = rewriter.getI16IntegerAttr(value: tType.getDimSize(1) * bytes);
78 return SmallVector<Value>{
79 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
80 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)};
81}
82
83/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
84/// shape may "envelop" the actual tile shape, and may be dynamically sized.
85static Value getStride(Location loc, MemRefType mType, Value base,
86 RewriterBase &rewriter) {
87 assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
88 int64_t preLast = mType.getRank() - 2;
89 Type llvmInt64Type = rewriter.getIntegerType(64);
90 unsigned width = mType.getElementType().getIntOrFloatBitWidth();
91 assert(llvm::isPowerOf2_64(width) && width >= 8);
92 unsigned bytes = width >> 3;
93 auto [strides, offset] = mType.getStridesAndOffset();
94 if (strides[preLast] == ShapedType::kDynamic) {
95 // Dynamic stride needs code to compute the stride at runtime.
96 MemRefDescriptor memrefDescriptor(base);
97 auto attr = rewriter.getI64IntegerAttr(bytes);
98 Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
99 return rewriter
100 .create<LLVM::MulOp>(loc, llvmInt64Type, scale,
101 memrefDescriptor.stride(rewriter, loc, preLast))
102 .getResult();
103 }
104 // Use direct constant for static stride.
105 auto attr = rewriter.getI64IntegerAttr(value: strides[preLast] * bytes);
106 return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
107 .getResult();
108}
109
110LogicalResult amx::TileZeroOp::verify() {
111 return verifyTileSize(*this, getTileType());
112}
113
114SmallVector<Value>
115amx::TileZeroOp::getIntrinsicOperands(ArrayRef<Value> operands,
116 const LLVMTypeConverter &typeConverter,
117 RewriterBase &rewriter) {
118 return getTileSizes(getLoc(), getTileType(), rewriter);
119}
120
121LogicalResult amx::TileLoadOp::verify() {
122 MemRefType memrefTy = getMemRefType();
123 unsigned rank = memrefTy.getRank();
124 if (rank < 2)
125 return emitOpError("requires at least 2D memref");
126 if (getIndices().size() != rank)
127 return emitOpError("requires ") << rank << " indices";
128 SmallVector<int64_t> strides;
129 int64_t offset;
130 if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
131 strides.back() != 1)
132 return emitOpError("requires memref with unit innermost stride");
133 return verifyTileSize(*this, getTileType());
134}
135
136SmallVector<Value>
137amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
138 const LLVMTypeConverter &typeConverter,
139 RewriterBase &rewriter) {
140 auto loc = getLoc();
141 Adaptor adaptor(operands, *this);
142
143 SmallVector<Value> intrinsicOperands;
144 intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
145 intrinsicOperands.push_back(
146 LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
147 adaptor.getBase(), adaptor.getIndices()));
148 intrinsicOperands.push_back(
149 getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
150
151 return intrinsicOperands;
152}
153
154LogicalResult amx::TileStoreOp::verify() {
155 MemRefType memrefTy = getMemRefType();
156 unsigned rank = memrefTy.getRank();
157 if (rank < 2)
158 return emitOpError("requires at least 2D memref");
159 if (getIndices().size() != rank)
160 return emitOpError("requires ") << rank << " indices";
161 SmallVector<int64_t> strides;
162 int64_t offset;
163 if (failed(memrefTy.getStridesAndOffset(strides, offset)) ||
164 strides.back() != 1)
165 return emitOpError("requires memref with unit innermost stride");
166 return verifyTileSize(*this, getTileType());
167}
168
169SmallVector<Value>
170amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
171 const LLVMTypeConverter &typeConverter,
172 RewriterBase &rewriter) {
173 auto loc = getLoc();
174 Adaptor adaptor(operands, *this);
175
176 SmallVector<Value> intrinsicOperands;
177 intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
178 intrinsicOperands.push_back(
179 LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
180 adaptor.getBase(), adaptor.getIndices()));
181 intrinsicOperands.push_back(
182 getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
183 intrinsicOperands.push_back(adaptor.getVal());
184
185 return intrinsicOperands;
186}
187
188LogicalResult amx::TileMulFOp::verify() {
189 amx::TileType aType = getLhsTileType();
190 amx::TileType bType = getRhsTileType();
191 amx::TileType cType = getTileType();
192 if (failed(verifyTileSize(*this, aType)) ||
193 failed(verifyTileSize(*this, bType)) ||
194 failed(verifyTileSize(*this, cType)) ||
195 failed(verifyMultShape(*this, aType, bType, cType, 1)))
196 return failure();
197 Type ta = aType.getElementType();
198 Type tb = bType.getElementType();
199 Type tc = cType.getElementType();
200 if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
201 return emitOpError("unsupported type combination");
202 return success();
203}
204
205SmallVector<Value>
206amx::TileMulFOp::getIntrinsicOperands(ArrayRef<Value> operands,
207 const LLVMTypeConverter &typeConverter,
208 RewriterBase &rewriter) {
209 auto loc = getLoc();
210 Adaptor adaptor(operands, *this);
211
212 amx::TileType aType = getLhsTileType();
213 amx::TileType bType = getRhsTileType();
214 SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
215 SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
216
217 SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
218 tsza[1], adaptor.getAcc(),
219 adaptor.getLhs(), adaptor.getRhs()};
220
221 return intrinsicOperands;
222}
223
224LogicalResult amx::TileMulIOp::verify() {
225 amx::TileType aType = getLhsTileType();
226 amx::TileType bType = getRhsTileType();
227 amx::TileType cType = getTileType();
228 if (failed(verifyTileSize(*this, aType)) ||
229 failed(verifyTileSize(*this, bType)) ||
230 failed(verifyTileSize(*this, cType)) ||
231 failed(verifyMultShape(*this, aType, bType, cType, 2)))
232 return failure();
233 Type ta = aType.getElementType();
234 Type tb = bType.getElementType();
235 Type tc = cType.getElementType();
236 if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
237 return emitOpError("unsupported type combination");
238 return success();
239}
240
241SmallVector<Value>
242amx::TileMulIOp::getIntrinsicOperands(ArrayRef<Value> operands,
243 const LLVMTypeConverter &typeConverter,
244 RewriterBase &rewriter) {
245 auto loc = getLoc();
246 Adaptor adaptor(operands, *this);
247
248 amx::TileType aType = getLhsTileType();
249 amx::TileType bType = getRhsTileType();
250 SmallVector<Value> tsza = getTileSizes(loc, aType, rewriter);
251 SmallVector<Value> tszb = getTileSizes(loc, bType, rewriter);
252
253 SmallVector<Value> intrinsicOperands = {tsza[0], tszb[1],
254 tsza[1], adaptor.getAcc(),
255 adaptor.getLhs(), adaptor.getRhs()};
256
257 return intrinsicOperands;
258}
259
260Type amx::TileType::parse(AsmParser &parser) {
261 if (parser.parseLess())
262 return nullptr;
263
264 SmallVector<int64_t, 2> shape;
265 if (parser.parseDimensionList(shape, false, true))
266 return nullptr;
267
268 Type elementType;
269 if (parser.parseType(elementType))
270 return nullptr;
271
272 if (parser.parseGreater())
273 return nullptr;
274
275 return TileType::get(shape, elementType);
276}
277
278void amx::TileType::print(AsmPrinter &os) const {
279 os << "<";
280 os.printDimensionList(getShape());
281 os << 'x';
282 os.printType(getElementType());
283 os << '>';
284}
285
286#define GET_OP_CLASSES
287#include "mlir/Dialect/AMX/AMX.cpp.inc"
288
289#define GET_TYPEDEF_CLASSES
290#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
291

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/AMX/IR/AMXDialect.cpp