1 | //===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/AMX/Transforms.h" |
10 | |
11 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
12 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
13 | #include "mlir/Dialect/AMX/AMXDialect.h" |
14 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
15 | #include "mlir/IR/BuiltinOps.h" |
16 | #include "mlir/IR/PatternMatch.h" |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::amx; |
20 | |
21 | namespace { |
22 | |
23 | /// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first |
24 | /// dimension directly translates into the number of rows of the tiles. |
25 | /// The second dimensions needs to be scaled by the number of bytes. |
26 | std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter, |
27 | const LLVMTypeConverter &typeConverter, |
28 | VectorType vType, Location loc) { |
29 | Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16); |
30 | unsigned width = vType.getElementType().getIntOrFloatBitWidth(); |
31 | assert(llvm::isPowerOf2_64(width) && width >= 8); |
32 | unsigned bytes = width >> 3; |
33 | auto mattr = rewriter.getI16IntegerAttr(value: vType.getDimSize(0)); |
34 | auto nattr = rewriter.getI16IntegerAttr(value: vType.getDimSize(1) * bytes); |
35 | return std::make_pair( |
36 | rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr), |
37 | rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)); |
38 | } |
39 | |
40 | /// Verifies if the stride matches proper tile access. |
41 | LogicalResult verifyStride(MemRefType mType) { |
42 | if (mType.getRank() < 2) |
43 | return failure(); |
44 | int64_t last = mType.getRank() - 1; |
45 | int64_t offset; |
46 | SmallVector<int64_t, 4> strides; |
47 | if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1) |
48 | return failure(); |
49 | return success(); |
50 | } |
51 | |
52 | /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer |
53 | /// shape may "envelop" the actual tile shape, and may be dynamically sized. |
54 | Value getStride(ConversionPatternRewriter &rewriter, |
55 | const LLVMTypeConverter &typeConverter, MemRefType mType, |
56 | Value base, Location loc) { |
57 | assert(mType.getRank() >= 2); |
58 | int64_t last = mType.getRank() - 1; |
59 | Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); |
60 | unsigned width = mType.getElementType().getIntOrFloatBitWidth(); |
61 | assert(llvm::isPowerOf2_64(width) && width >= 8); |
62 | unsigned bytes = width >> 3; |
63 | if (mType.isDynamicDim(last)) { |
64 | // Dynamic size needs code to compute the stride at runtime. |
65 | MemRefDescriptor memrefDescriptor(base); |
66 | auto attr = rewriter.getI64IntegerAttr(bytes); |
67 | Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); |
68 | return rewriter.create<LLVM::MulOp>( |
69 | loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last)); |
70 | } |
71 | // Use direct constant for static size. |
72 | auto attr = rewriter.getI64IntegerAttr(value: mType.getDimSize(last) * bytes); |
73 | return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); |
74 | } |
75 | |
76 | struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> { |
77 | using ConvertOpToLLVMPattern<TileZeroOp>::ConvertOpToLLVMPattern; |
78 | LogicalResult |
79 | matchAndRewrite(TileZeroOp op, OpAdaptor adaptor, |
80 | ConversionPatternRewriter &rewriter) const override { |
81 | VectorType vType = op.getVectorType(); |
82 | // Determine m x n tile sizes. |
83 | std::pair<Value, Value> tsz = |
84 | getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); |
85 | // Replace operation with intrinsic. |
86 | Type resType = typeConverter->convertType(vType); |
87 | rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first, |
88 | tsz.second); |
89 | return success(); |
90 | } |
91 | }; |
92 | |
93 | struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> { |
94 | using ConvertOpToLLVMPattern<TileLoadOp>::ConvertOpToLLVMPattern; |
95 | |
96 | LogicalResult |
97 | matchAndRewrite(TileLoadOp op, OpAdaptor adaptor, |
98 | ConversionPatternRewriter &rewriter) const override { |
99 | MemRefType mType = op.getMemRefType(); |
100 | VectorType vType = op.getVectorType(); |
101 | // Determine m x n tile sizes. |
102 | std::pair<Value, Value> tsz = |
103 | getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); |
104 | // Determine stride. |
105 | if (failed(verifyStride(mType))) |
106 | return failure(); |
107 | Value stride = getStride(rewriter, *getTypeConverter(), mType, |
108 | adaptor.getBase(), op.getLoc()); |
109 | // Replace operation with intrinsic. |
110 | Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(), |
111 | adaptor.getIndices(), rewriter); |
112 | Type resType = typeConverter->convertType(vType); |
113 | rewriter.replaceOpWithNewOp<amx::x86_amx_tileloadd64>( |
114 | op, resType, tsz.first, tsz.second, ptr, stride); |
115 | return success(); |
116 | } |
117 | }; |
118 | |
119 | struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> { |
120 | using ConvertOpToLLVMPattern<TileStoreOp>::ConvertOpToLLVMPattern; |
121 | |
122 | LogicalResult |
123 | matchAndRewrite(TileStoreOp op, OpAdaptor adaptor, |
124 | ConversionPatternRewriter &rewriter) const override { |
125 | MemRefType mType = op.getMemRefType(); |
126 | VectorType vType = op.getVectorType(); |
127 | // Determine m x n tile sizes. |
128 | std::pair<Value, Value> tsz = |
129 | getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); |
130 | // Determine stride. |
131 | if (failed(verifyStride(mType))) |
132 | return failure(); |
133 | Value stride = getStride(rewriter, *getTypeConverter(), mType, |
134 | adaptor.getBase(), op.getLoc()); |
135 | // Replace operation with intrinsic. |
136 | Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.getBase(), |
137 | adaptor.getIndices(), rewriter); |
138 | rewriter.replaceOpWithNewOp<amx::x86_amx_tilestored64>( |
139 | op, tsz.first, tsz.second, ptr, stride, adaptor.getVal()); |
140 | return success(); |
141 | } |
142 | }; |
143 | |
144 | struct TileMulFConversion : public ConvertOpToLLVMPattern<TileMulFOp> { |
145 | using ConvertOpToLLVMPattern<TileMulFOp>::ConvertOpToLLVMPattern; |
146 | LogicalResult |
147 | matchAndRewrite(TileMulFOp op, OpAdaptor adaptor, |
148 | ConversionPatternRewriter &rewriter) const override { |
149 | VectorType aType = op.getLhsVectorType(); |
150 | VectorType bType = op.getRhsVectorType(); |
151 | VectorType cType = op.getVectorType(); |
152 | // Determine m x n x k tile sizes. |
153 | std::pair<Value, Value> tsza = |
154 | getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); |
155 | std::pair<Value, Value> tszb = |
156 | getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); |
157 | // Replace operation with intrinsic. |
158 | Type resType = typeConverter->convertType(cType); |
159 | rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbf16ps>( |
160 | op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), |
161 | adaptor.getLhs(), adaptor.getRhs()); |
162 | return success(); |
163 | } |
164 | }; |
165 | |
166 | struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> { |
167 | using ConvertOpToLLVMPattern<TileMulIOp>::ConvertOpToLLVMPattern; |
168 | LogicalResult |
169 | matchAndRewrite(TileMulIOp op, OpAdaptor adaptor, |
170 | ConversionPatternRewriter &rewriter) const override { |
171 | VectorType aType = op.getLhsVectorType(); |
172 | VectorType bType = op.getRhsVectorType(); |
173 | VectorType cType = op.getVectorType(); |
174 | // Determine m x n x k tile sizes. |
175 | std::pair<Value, Value> tsza = |
176 | getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); |
177 | std::pair<Value, Value> tszb = |
178 | getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); |
179 | // Replace operation with intrinsic. |
180 | Type resType = typeConverter->convertType(cType); |
181 | bool zexta = op.getIsZextLhs(); |
182 | bool zextb = op.getIsZextRhs(); |
183 | if (zexta && zextb) |
184 | rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbuud>( |
185 | op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), |
186 | adaptor.getLhs(), adaptor.getRhs()); |
187 | else if (zexta && !zextb) |
188 | rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbusd>( |
189 | op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), |
190 | adaptor.getLhs(), adaptor.getRhs()); |
191 | else if (!zexta && zextb) |
192 | rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbsud>( |
193 | op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), |
194 | adaptor.getLhs(), adaptor.getRhs()); |
195 | else |
196 | rewriter.replaceOpWithNewOp<amx::x86_amx_tdpbssd>( |
197 | op, resType, tsza.first, tszb.second, tsza.second, adaptor.getAcc(), |
198 | adaptor.getLhs(), adaptor.getRhs()); |
199 | return success(); |
200 | } |
201 | }; |
202 | |
203 | } // namespace |
204 | |
205 | void mlir::populateAMXLegalizeForLLVMExportPatterns( |
206 | LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
207 | patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion, |
208 | TileMulFConversion, TileMulIConversion>(arg&: converter); |
209 | } |
210 | |
211 | void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { |
212 | target.addLegalOp<x86_amx_tilezero, x86_amx_tileloadd64, x86_amx_tilestored64, |
213 | x86_amx_tdpbf16ps, x86_amx_tdpbssd, x86_amx_tdpbsud, |
214 | x86_amx_tdpbusd, x86_amx_tdpbuud>(); |
215 | target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp, |
216 | TileMulFOp>(); |
217 | } |
218 | |