1 | //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements lowering of vector operations to XeGPU dialect ops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
17 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
18 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
19 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
20 | #include "mlir/Pass/Pass.h" |
21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
22 | #include "mlir/Transforms/Passes.h" |
23 | #include "llvm/ADT/TypeSwitch.h" |
24 | |
25 | #include <algorithm> |
26 | #include <optional> |
27 | |
28 | namespace mlir { |
29 | #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU |
30 | #include "mlir/Conversion/Passes.h.inc" |
31 | } // namespace mlir |
32 | |
33 | using namespace mlir; |
34 | |
35 | namespace { |
36 | |
37 | // Return true if value represents a zero constant. |
38 | static bool isZeroConstant(Value val) { |
39 | auto constant = val.getDefiningOp<arith::ConstantOp>(); |
40 | if (!constant) |
41 | return false; |
42 | |
43 | return TypeSwitch<Attribute, bool>(constant.getValue()) |
44 | .Case<FloatAttr>( |
45 | [](auto floatAttr) { return floatAttr.getValue().isZero(); }) |
46 | .Case<IntegerAttr>( |
47 | [](auto intAttr) { return intAttr.getValue().isZero(); }) |
48 | .Default([](auto) { return false; }); |
49 | } |
50 | |
51 | static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter, |
52 | Operation *op, VectorType vecTy) { |
53 | // Validate only vector as the basic vector store and load ops guarantee |
54 | // XeGPU-compatible memref source. |
55 | unsigned vecRank = vecTy.getRank(); |
56 | if (!(vecRank == 1 || vecRank == 2)) |
57 | return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector" ); |
58 | |
59 | return success(); |
60 | } |
61 | |
62 | static LogicalResult transferPreconditions(PatternRewriter &rewriter, |
63 | VectorTransferOpInterface xferOp) { |
64 | if (xferOp.getMask()) |
65 | return rewriter.notifyMatchFailure(xferOp, |
66 | "Masked transfer is not supported" ); |
67 | |
68 | auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType()); |
69 | if (!srcTy) |
70 | return rewriter.notifyMatchFailure(xferOp, "Expects memref source" ); |
71 | |
72 | // Perform common data transfer checks. |
73 | VectorType vecTy = xferOp.getVectorType(); |
74 | if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy))) |
75 | return failure(); |
76 | |
77 | // Validate further transfer op semantics. |
78 | SmallVector<int64_t> strides; |
79 | int64_t offset; |
80 | if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1) |
81 | return rewriter.notifyMatchFailure( |
82 | xferOp, "Buffer must be contiguous in the innermost dimension" ); |
83 | |
84 | unsigned vecRank = vecTy.getRank(); |
85 | if (xferOp.hasOutOfBoundsDim() && vecRank < 2) |
86 | return rewriter.notifyMatchFailure( |
87 | xferOp, "Boundary check is available only for block instructions." ); |
88 | |
89 | AffineMap map = xferOp.getPermutationMap(); |
90 | if (!map.isProjectedPermutation(/*allowZeroInResults=*/false)) |
91 | return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map" ); |
92 | unsigned numInputDims = map.getNumInputs(); |
93 | for (AffineExpr expr : map.getResults().take_back(vecRank)) { |
94 | auto dim = dyn_cast<AffineDimExpr>(expr); |
95 | if (dim.getPosition() < (numInputDims - vecRank)) |
96 | return rewriter.notifyMatchFailure( |
97 | xferOp, "Only the innermost dimensions can be accessed" ); |
98 | } |
99 | |
100 | return success(); |
101 | } |
102 | |
103 | static xegpu::CreateNdDescOp |
104 | createNdDescriptor(PatternRewriter &rewriter, Location loc, |
105 | xegpu::TensorDescType descType, TypedValue<MemRefType> src, |
106 | Operation::operand_range offsets) { |
107 | MemRefType srcTy = src.getType(); |
108 | auto [strides, offset] = srcTy.getStridesAndOffset(); |
109 | |
110 | xegpu::CreateNdDescOp ndDesc; |
111 | if (srcTy.hasStaticShape()) { |
112 | ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src, |
113 | getAsOpFoldResult(offsets)); |
114 | } else { |
115 | // In case of any dynamic shapes, source's shape and strides have to be |
116 | // explicitly provided. |
117 | SmallVector<Value> sourceDims; |
118 | unsigned srcRank = srcTy.getRank(); |
119 | for (unsigned i = 0; i < srcRank; ++i) |
120 | sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i)); |
121 | |
122 | SmallVector<int64_t> constOffsets; |
123 | SmallVector<Value> dynOffsets; |
124 | for (Value offset : offsets) { |
125 | std::optional<int64_t> staticVal = getConstantIntValue(ofr: offset); |
126 | if (!staticVal) |
127 | dynOffsets.push_back(offset); |
128 | constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic)); |
129 | } |
130 | |
131 | SmallVector<Value> dynShapes; |
132 | for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) { |
133 | if (shape == ShapedType::kDynamic) |
134 | dynShapes.push_back(sourceDims[idx]); |
135 | } |
136 | |
137 | // Compute strides in reverse order. |
138 | SmallVector<Value> dynStrides; |
139 | Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
140 | // Last stride is guaranteed to be static and unit. |
141 | for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) { |
142 | accStride = |
143 | rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]); |
144 | if (strides[i] == ShapedType::kDynamic) |
145 | dynStrides.push_back(accStride); |
146 | } |
147 | std::reverse(dynStrides.begin(), dynStrides.end()); |
148 | |
149 | ndDesc = rewriter.create<xegpu::CreateNdDescOp>( |
150 | loc, descType, src, dynOffsets, dynShapes, dynStrides, |
151 | DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets), |
152 | DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()), |
153 | DenseI64ArrayAttr::get(rewriter.getContext(), strides)); |
154 | } |
155 | |
156 | return ndDesc; |
157 | } |
158 | |
159 | struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { |
160 | using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; |
161 | |
162 | LogicalResult matchAndRewrite(vector::TransferReadOp readOp, |
163 | PatternRewriter &rewriter) const override { |
164 | Location loc = readOp.getLoc(); |
165 | |
166 | if (failed(transferPreconditions(rewriter, readOp))) |
167 | return failure(); |
168 | |
169 | bool isOutOfBounds = readOp.hasOutOfBoundsDim(); |
170 | if (isOutOfBounds && !isZeroConstant(readOp.getPadding())) |
171 | return rewriter.notifyMatchFailure( |
172 | readOp, "Unsupported non-zero padded out-of-bounds read" ); |
173 | |
174 | AffineMap readMap = readOp.getPermutationMap(); |
175 | bool isTransposeLoad = !readMap.isMinorIdentity(); |
176 | |
177 | VectorType vecTy = readOp.getVectorType(); |
178 | Type elementType = vecTy.getElementType(); |
179 | unsigned minTransposeBitWidth = 32; |
180 | if (isTransposeLoad && |
181 | elementType.getIntOrFloatBitWidth() < minTransposeBitWidth) |
182 | return rewriter.notifyMatchFailure( |
183 | readOp, "Unsupported data type for transposition" ); |
184 | |
185 | // If load is transposed, get the base shape for the tensor descriptor. |
186 | SmallVector<int64_t> descShape(vecTy.getShape()); |
187 | if (isTransposeLoad) |
188 | std::reverse(descShape.begin(), descShape.end()); |
189 | auto descType = xegpu::TensorDescType::get( |
190 | descShape, elementType, /*array_length=*/1, |
191 | /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global); |
192 | |
193 | xegpu::CreateNdDescOp ndDesc = |
194 | createNdDescriptor(rewriter, loc, descType, |
195 | dyn_cast<TypedValue<MemRefType>>(readOp.getBase()), |
196 | readOp.getIndices()); |
197 | |
198 | DenseI64ArrayAttr transposeAttr = |
199 | !isTransposeLoad ? nullptr |
200 | : DenseI64ArrayAttr::get(rewriter.getContext(), |
201 | ArrayRef<int64_t>{1, 0}); |
202 | // By default, no specific caching policy is assigned. |
203 | xegpu::CachePolicyAttr hint = nullptr; |
204 | auto loadOp = rewriter.create<xegpu::LoadNdOp>( |
205 | loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr, |
206 | /*l1_hint=*/hint, |
207 | /*l2_hint=*/hint, /*l3_hint=*/hint); |
208 | rewriter.replaceOp(readOp, loadOp); |
209 | |
210 | return success(); |
211 | } |
212 | }; |
213 | |
214 | struct TransferWriteLowering |
215 | : public OpRewritePattern<vector::TransferWriteOp> { |
216 | using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; |
217 | |
218 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
219 | PatternRewriter &rewriter) const override { |
220 | Location loc = writeOp.getLoc(); |
221 | |
222 | if (failed(transferPreconditions(rewriter, writeOp))) |
223 | return failure(); |
224 | |
225 | AffineMap map = writeOp.getPermutationMap(); |
226 | if (!map.isMinorIdentity()) |
227 | return rewriter.notifyMatchFailure(writeOp, "Expects identity map" ); |
228 | |
229 | VectorType vecTy = writeOp.getVectorType(); |
230 | auto descType = xegpu::TensorDescType::get( |
231 | vecTy.getShape(), vecTy.getElementType(), |
232 | /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), |
233 | xegpu::MemorySpace::Global); |
234 | xegpu::CreateNdDescOp ndDesc = |
235 | createNdDescriptor(rewriter, loc, descType, |
236 | dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()), |
237 | writeOp.getIndices()); |
238 | |
239 | // By default, no specific caching policy is assigned. |
240 | xegpu::CachePolicyAttr hint = nullptr; |
241 | auto storeOp = |
242 | rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc, |
243 | /*l1_hint=*/hint, |
244 | /*l2_hint=*/hint, /*l3_hint=*/hint); |
245 | rewriter.replaceOp(writeOp, storeOp); |
246 | |
247 | return success(); |
248 | } |
249 | }; |
250 | |
251 | struct LoadLowering : public OpRewritePattern<vector::LoadOp> { |
252 | using OpRewritePattern<vector::LoadOp>::OpRewritePattern; |
253 | |
254 | LogicalResult matchAndRewrite(vector::LoadOp loadOp, |
255 | PatternRewriter &rewriter) const override { |
256 | Location loc = loadOp.getLoc(); |
257 | |
258 | VectorType vecTy = loadOp.getResult().getType(); |
259 | if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy))) |
260 | return failure(); |
261 | |
262 | // Boundary check is available only for block instructions. |
263 | bool boundaryCheck = vecTy.getRank() > 1; |
264 | |
265 | auto descType = xegpu::TensorDescType::get( |
266 | vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, |
267 | boundaryCheck, xegpu::MemorySpace::Global); |
268 | xegpu::CreateNdDescOp ndDesc = createNdDescriptor( |
269 | rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices()); |
270 | |
271 | // By default, no specific caching policy is assigned. |
272 | xegpu::CachePolicyAttr hint = nullptr; |
273 | auto loadNdOp = rewriter.create<xegpu::LoadNdOp>( |
274 | loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr, |
275 | /*l1_hint=*/hint, |
276 | /*l2_hint=*/hint, /*l3_hint=*/hint); |
277 | rewriter.replaceOp(loadOp, loadNdOp); |
278 | |
279 | return success(); |
280 | } |
281 | }; |
282 | |
283 | struct StoreLowering : public OpRewritePattern<vector::StoreOp> { |
284 | using OpRewritePattern<vector::StoreOp>::OpRewritePattern; |
285 | |
286 | LogicalResult matchAndRewrite(vector::StoreOp storeOp, |
287 | PatternRewriter &rewriter) const override { |
288 | Location loc = storeOp.getLoc(); |
289 | |
290 | TypedValue<VectorType> vector = storeOp.getValueToStore(); |
291 | VectorType vecTy = vector.getType(); |
292 | if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy))) |
293 | return failure(); |
294 | |
295 | // Boundary check is available only for block instructions. |
296 | bool boundaryCheck = vecTy.getRank() > 1; |
297 | |
298 | auto descType = xegpu::TensorDescType::get( |
299 | vecTy.getShape(), vecTy.getElementType(), |
300 | /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global); |
301 | xegpu::CreateNdDescOp ndDesc = createNdDescriptor( |
302 | rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices()); |
303 | |
304 | // By default, no specific caching policy is assigned. |
305 | xegpu::CachePolicyAttr hint = nullptr; |
306 | auto storeNdOp = |
307 | rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc, |
308 | /*l1_hint=*/hint, |
309 | /*l2_hint=*/hint, /*l3_hint=*/hint); |
310 | rewriter.replaceOp(storeOp, storeNdOp); |
311 | |
312 | return success(); |
313 | } |
314 | }; |
315 | |
316 | struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> { |
317 | using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; |
318 | |
319 | LogicalResult matchAndRewrite(vector::ContractionOp contractOp, |
320 | PatternRewriter &rewriter) const override { |
321 | Location loc = contractOp.getLoc(); |
322 | |
323 | if (contractOp.getKind() != vector::CombiningKind::ADD) |
324 | return rewriter.notifyMatchFailure(contractOp, |
325 | "Expects add combining kind" ); |
326 | |
327 | TypedValue<Type> acc = contractOp.getAcc(); |
328 | VectorType accType = dyn_cast<VectorType>(acc.getType()); |
329 | if (!accType || accType.getRank() != 2) |
330 | return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector" ); |
331 | |
332 | // Accept only plain 2D data layout. |
333 | // VNNI packing is applied to DPAS as a separate lowering step. |
334 | TypedValue<VectorType> lhs = contractOp.getLhs(); |
335 | TypedValue<VectorType> rhs = contractOp.getRhs(); |
336 | if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2) |
337 | return rewriter.notifyMatchFailure(contractOp, |
338 | "Expects lhs and rhs 2D vectors" ); |
339 | |
340 | if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr())) |
341 | return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps" ); |
342 | |
343 | // TODO: Update shape validation to be target aware. |
344 | auto accShape = accType.getShape(); |
345 | int64_t dimN = accShape[1]; |
346 | if (dimN != 8 && dimN != 16) |
347 | return rewriter.notifyMatchFailure(contractOp, |
348 | "Invalid operand dimensions" ); |
349 | |
350 | auto dpasOp = rewriter.create<xegpu::DpasOp>( |
351 | loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc}); |
352 | rewriter.replaceOp(contractOp, dpasOp); |
353 | |
354 | return success(); |
355 | } |
356 | }; |
357 | |
358 | struct ConvertVectorToXeGPUPass |
359 | : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> { |
360 | void runOnOperation() override { |
361 | RewritePatternSet patterns(&getContext()); |
362 | populateVectorToXeGPUConversionPatterns(patterns); |
363 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
364 | return signalPassFailure(); |
365 | } |
366 | }; |
367 | |
368 | } // namespace |
369 | |
370 | void mlir::populateVectorToXeGPUConversionPatterns( |
371 | RewritePatternSet &patterns) { |
372 | patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering, |
373 | StoreLowering, ContractionLowering>(arg: patterns.getContext()); |
374 | } |
375 | |