1//===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/AMX/Transforms.h"
10
11#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12#include "mlir/Conversion/LLVMCommon/Pattern.h"
13#include "mlir/Dialect/AMX/AMXDialect.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/IR/BuiltinOps.h"
16#include "mlir/IR/PatternMatch.h"
17
18using namespace mlir;
19using namespace mlir::amx;
20
21namespace {
22
23/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
24/// dimension directly translates into the number of rows of the tiles.
25/// The second dimensions needs to be scaled by the number of bytes.
26std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
27 const LLVMTypeConverter &typeConverter,
28 VectorType vType, Location loc) {
29 Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
30 unsigned width = vType.getElementType().getIntOrFloatBitWidth();
31 assert(llvm::isPowerOf2_64(width) && width >= 8);
32 unsigned bytes = width >> 3;
33 auto mattr = rewriter.getI16IntegerAttr(value: vType.getDimSize(0));
34 auto nattr = rewriter.getI16IntegerAttr(value: vType.getDimSize(1) * bytes);
35 return std::make_pair(
36 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
37 rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
38}
39
40/// Verifies if the stride matches proper tile access.
41LogicalResult verifyStride(MemRefType mType) {
42 if (mType.getRank() < 2)
43 return failure();
44 int64_t last = mType.getRank() - 1;
45 int64_t offset;
46 SmallVector<int64_t, 4> strides;
47 if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1)
48 return failure();
49 return success();
50}
51
52/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
53/// shape may "envelop" the actual tile shape, and may be dynamically sized.
54Value getStride(ConversionPatternRewriter &rewriter,
55 const LLVMTypeConverter &typeConverter, MemRefType mType,
56 Value base, Location loc) {
57 assert(mType.getRank() >= 2);
58 int64_t last = mType.getRank() - 1;
59 Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64);
60 unsigned width = mType.getElementType().getIntOrFloatBitWidth();
61 assert(llvm::isPowerOf2_64(width) && width >= 8);
62 unsigned bytes = width >> 3;
63 if (mType.isDynamicDim(last)) {
64 // Dynamic size needs code to compute the stride at runtime.
65 MemRefDescriptor memrefDescriptor(base);
66 auto attr = rewriter.getI64IntegerAttr(bytes);
67 Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
68 return rewriter.create<LLVM::MulOp>(
69 loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last));
70 }
71 // Use direct constant for static size.
72 auto attr = rewriter.getI64IntegerAttr(value: mType.getDimSize(last) * bytes);
73 return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
74}
75
76struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
77 using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern;
78 LogicalResult
79 matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
80 ConversionPatternRewriter &rewriter) const override {
81 VectorType vType = op.getVectorType();
82 // Determine m x n tile sizes.
83 std::pair<Value, Value> tsz =
84 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
85 // Replace operation with intrinsic.
86 Type resType = typeConverter->convertType(vType);
87 rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
88 tsz.second);
89 return success();
90 }
91};
92
93struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
94 using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern;
95
96 LogicalResult
97 matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
98 ConversionPatternRewriter &rewriter) const override {
99 MemRefType mType = op.getMemRefType();
100 VectorType vType = op.getVectorType();
101 // Determine m x n tile sizes.
102 std::pair<Value, Value> tsz =
103 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
104 // Determine stride.
105 if (failed(verifyStride(mType)))
106 return failure();
107 Value stride = getStride(rewriter, *getTypeConverter(), mType,
108 adaptor.getBase(), op.getLoc());
109 // Replace operation with intrinsic.
110 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
111 adaptor.getIndices(), rewriter);
112 Type resType = typeConverter->convertType(vType);
113 rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>(
114 op, resType, tsz.first, tsz.second, ptr, stride);
115 return success();
116 }
117};
118
119struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
120 using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern;
121
122 LogicalResult
123 matchAndRewrite(TileStoreOp op, OpAdaptor adaptor,
124 ConversionPatternRewriter &rewriter) const override {
125 MemRefType mType = op.getMemRefType();
126 VectorType vType = op.getVectorType();
127 // Determine m x n tile sizes.
128 std::pair<Value, Value> tsz =
129 getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
130 // Determine stride.
131 if (failed(verifyStride(mType)))
132 return failure();
133 Value stride = getStride(rewriter, *getTypeConverter(), mType,
134 adaptor.getBase(), op.getLoc());
135 // Replace operation with intrinsic.
136 Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(),
137 adaptor.getIndices(), rewriter);
138 rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>(
139 op, tsz.first, tsz.second, ptr, stride, adaptor.getVal());
140 return success();
141 }
142};
143
144struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> {
145 using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern;
146 LogicalResult
147 matchAndRewrite(TileMulFOp op, OpAdaptor adaptor,
148 ConversionPatternRewriter &rewriter) const override {
149 VectorType aType = op.getLhsVectorType();
150 VectorType bType = op.getRhsVectorType();
151 VectorType cType = op.getVectorType();
152 // Determine m x n x k tile sizes.
153 std::pair<Value, Value> tsza =
154 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
155 std::pair<Value, Value> tszb =
156 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
157 // Replace operation with intrinsic.
158 Type resType = typeConverter->convertType(cType);
159 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>(
160 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
161 adaptor.getLhs(), adaptor.getRhs());
162 return success();
163 }
164};
165
166struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
167 using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern;
168 LogicalResult
169 matchAndRewrite(TileMulIOp op, OpAdaptor adaptor,
170 ConversionPatternRewriter &rewriter) const override {
171 VectorType aType = op.getLhsVectorType();
172 VectorType bType = op.getRhsVectorType();
173 VectorType cType = op.getVectorType();
174 // Determine m x n x k tile sizes.
175 std::pair<Value, Value> tsza =
176 getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc());
177 std::pair<Value, Value> tszb =
178 getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc());
179 // Replace operation with intrinsic.
180 Type resType = typeConverter->convertType(cType);
181 bool zexta = op.getIsZextLhs();
182 bool zextb = op.getIsZextRhs();
183 if (zexta && zextb)
184 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>(
185 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
186 adaptor.getLhs(), adaptor.getRhs());
187 else if (zexta && !zextb)
188 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>(
189 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
190 adaptor.getLhs(), adaptor.getRhs());
191 else if (!zexta && zextb)
192 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>(
193 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
194 adaptor.getLhs(), adaptor.getRhs());
195 else
196 rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>(
197 op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(),
198 adaptor.getLhs(), adaptor.getRhs());
199 return success();
200 }
201};
202
203} // namespace
204
205void mlir::populateAMXLegalizeForLLVMExportPatterns(
206 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
207 patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
208 TileMulFConversion, TileMulIConversion>(arg&: converter);
209}
210
211void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
212 target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64,
213 x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud,
214 x86_amx_tdpbusd, x86_amx_tdpbuud>();
215 target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
216 TileMulFOp>();
217}
218

source code of mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp