1//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to XeGPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22#include "mlir/Transforms/Passes.h"
23#include "llvm/ADT/TypeSwitch.h"
24
25#include <algorithm>
26#include <optional>
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34
35namespace {
36
37// Return true if value represents a zero constant.
38static bool isZeroConstant(Value val) {
39 auto constant = val.getDefiningOp<arith::ConstantOp>();
40 if (!constant)
41 return false;
42
43 return TypeSwitch<Attribute, bool>(constant.getValue())
44 .Case<FloatAttr>(
45 [](auto floatAttr) { return floatAttr.getValue().isZero(); })
46 .Case<IntegerAttr>(
47 [](auto intAttr) { return intAttr.getValue().isZero(); })
48 .Default([](auto) { return false; });
49}
50
51static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
52 Operation *op, VectorType vecTy) {
53 // Validate only vector as the basic vector store and load ops guarantee
54 // XeGPU-compatible memref source.
55 unsigned vecRank = vecTy.getRank();
56 if (!(vecRank == 1 || vecRank == 2))
57 return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
58
59 return success();
60}
61
62static LogicalResult transferPreconditions(PatternRewriter &rewriter,
63 VectorTransferOpInterface xferOp) {
64 if (xferOp.getMask())
65 return rewriter.notifyMatchFailure(xferOp,
66 "Masked transfer is not supported");
67
68 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
69 if (!srcTy)
70 return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
71
72 // Perform common data transfer checks.
73 VectorType vecTy = xferOp.getVectorType();
74 if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
75 return failure();
76
77 // Validate further transfer op semantics.
78 SmallVector<int64_t> strides;
79 int64_t offset;
80 if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
81 return rewriter.notifyMatchFailure(
82 xferOp, "Buffer must be contiguous in the innermost dimension");
83
84 unsigned vecRank = vecTy.getRank();
85 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
86 return rewriter.notifyMatchFailure(
87 xferOp, "Boundary check is available only for block instructions.");
88
89 AffineMap map = xferOp.getPermutationMap();
90 if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
91 return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
92 unsigned numInputDims = map.getNumInputs();
93 for (AffineExpr expr : map.getResults().take_back(vecRank)) {
94 auto dim = dyn_cast<AffineDimExpr>(expr);
95 if (dim.getPosition() < (numInputDims - vecRank))
96 return rewriter.notifyMatchFailure(
97 xferOp, "Only the innermost dimensions can be accessed");
98 }
99
100 return success();
101}
102
103static xegpu::CreateNdDescOp
104createNdDescriptor(PatternRewriter &rewriter, Location loc,
105 xegpu::TensorDescType descType, TypedValue<MemRefType> src,
106 Operation::operand_range offsets) {
107 MemRefType srcTy = src.getType();
108 auto [strides, offset] = srcTy.getStridesAndOffset();
109
110 xegpu::CreateNdDescOp ndDesc;
111 if (srcTy.hasStaticShape()) {
112 ndDesc = rewriter.create<xegpu::CreateNdDescOp>(loc, descType, src,
113 getAsOpFoldResult(offsets));
114 } else {
115 // In case of any dynamic shapes, source's shape and strides have to be
116 // explicitly provided.
117 SmallVector<Value> sourceDims;
118 unsigned srcRank = srcTy.getRank();
119 for (unsigned i = 0; i < srcRank; ++i)
120 sourceDims.push_back(rewriter.create<memref::DimOp>(loc, src, i));
121
122 SmallVector<int64_t> constOffsets;
123 SmallVector<Value> dynOffsets;
124 for (Value offset : offsets) {
125 std::optional<int64_t> staticVal = getConstantIntValue(ofr: offset);
126 if (!staticVal)
127 dynOffsets.push_back(offset);
128 constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
129 }
130
131 SmallVector<Value> dynShapes;
132 for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
133 if (shape == ShapedType::kDynamic)
134 dynShapes.push_back(sourceDims[idx]);
135 }
136
137 // Compute strides in reverse order.
138 SmallVector<Value> dynStrides;
139 Value accStride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
140 // Last stride is guaranteed to be static and unit.
141 for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
142 accStride =
143 rewriter.create<arith::MulIOp>(loc, accStride, sourceDims[i + 1]);
144 if (strides[i] == ShapedType::kDynamic)
145 dynStrides.push_back(accStride);
146 }
147 std::reverse(dynStrides.begin(), dynStrides.end());
148
149 ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
150 loc, descType, src, dynOffsets, dynShapes, dynStrides,
151 DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
152 DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
153 DenseI64ArrayAttr::get(rewriter.getContext(), strides));
154 }
155
156 return ndDesc;
157}
158
159struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
160 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
161
162 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
163 PatternRewriter &rewriter) const override {
164 Location loc = readOp.getLoc();
165
166 if (failed(transferPreconditions(rewriter, readOp)))
167 return failure();
168
169 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
170 if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
171 return rewriter.notifyMatchFailure(
172 readOp, "Unsupported non-zero padded out-of-bounds read");
173
174 AffineMap readMap = readOp.getPermutationMap();
175 bool isTransposeLoad = !readMap.isMinorIdentity();
176
177 VectorType vecTy = readOp.getVectorType();
178 Type elementType = vecTy.getElementType();
179 unsigned minTransposeBitWidth = 32;
180 if (isTransposeLoad &&
181 elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
182 return rewriter.notifyMatchFailure(
183 readOp, "Unsupported data type for transposition");
184
185 // If load is transposed, get the base shape for the tensor descriptor.
186 SmallVector<int64_t> descShape(vecTy.getShape());
187 if (isTransposeLoad)
188 std::reverse(descShape.begin(), descShape.end());
189 auto descType = xegpu::TensorDescType::get(
190 descShape, elementType, /*array_length=*/1,
191 /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
192
193 xegpu::CreateNdDescOp ndDesc =
194 createNdDescriptor(rewriter, loc, descType,
195 dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
196 readOp.getIndices());
197
198 DenseI64ArrayAttr transposeAttr =
199 !isTransposeLoad ? nullptr
200 : DenseI64ArrayAttr::get(rewriter.getContext(),
201 ArrayRef<int64_t>{1, 0});
202 // By default, no specific caching policy is assigned.
203 xegpu::CachePolicyAttr hint = nullptr;
204 auto loadOp = rewriter.create<xegpu::LoadNdOp>(
205 loc, vecTy, ndDesc, /*packed=*/nullptr, transposeAttr,
206 /*l1_hint=*/hint,
207 /*l2_hint=*/hint, /*l3_hint=*/hint);
208 rewriter.replaceOp(readOp, loadOp);
209
210 return success();
211 }
212};
213
214struct TransferWriteLowering
215 : public OpRewritePattern<vector::TransferWriteOp> {
216 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
217
218 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
219 PatternRewriter &rewriter) const override {
220 Location loc = writeOp.getLoc();
221
222 if (failed(transferPreconditions(rewriter, writeOp)))
223 return failure();
224
225 AffineMap map = writeOp.getPermutationMap();
226 if (!map.isMinorIdentity())
227 return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
228
229 VectorType vecTy = writeOp.getVectorType();
230 auto descType = xegpu::TensorDescType::get(
231 vecTy.getShape(), vecTy.getElementType(),
232 /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
233 xegpu::MemorySpace::Global);
234 xegpu::CreateNdDescOp ndDesc =
235 createNdDescriptor(rewriter, loc, descType,
236 dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
237 writeOp.getIndices());
238
239 // By default, no specific caching policy is assigned.
240 xegpu::CachePolicyAttr hint = nullptr;
241 auto storeOp =
242 rewriter.create<xegpu::StoreNdOp>(loc, writeOp.getVector(), ndDesc,
243 /*l1_hint=*/hint,
244 /*l2_hint=*/hint, /*l3_hint=*/hint);
245 rewriter.replaceOp(writeOp, storeOp);
246
247 return success();
248 }
249};
250
251struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
252 using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
253
254 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
255 PatternRewriter &rewriter) const override {
256 Location loc = loadOp.getLoc();
257
258 VectorType vecTy = loadOp.getResult().getType();
259 if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
260 return failure();
261
262 // Boundary check is available only for block instructions.
263 bool boundaryCheck = vecTy.getRank() > 1;
264
265 auto descType = xegpu::TensorDescType::get(
266 vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
267 boundaryCheck, xegpu::MemorySpace::Global);
268 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
269 rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
270
271 // By default, no specific caching policy is assigned.
272 xegpu::CachePolicyAttr hint = nullptr;
273 auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
274 loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
275 /*l1_hint=*/hint,
276 /*l2_hint=*/hint, /*l3_hint=*/hint);
277 rewriter.replaceOp(loadOp, loadNdOp);
278
279 return success();
280 }
281};
282
283struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
284 using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
285
286 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
287 PatternRewriter &rewriter) const override {
288 Location loc = storeOp.getLoc();
289
290 TypedValue<VectorType> vector = storeOp.getValueToStore();
291 VectorType vecTy = vector.getType();
292 if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
293 return failure();
294
295 // Boundary check is available only for block instructions.
296 bool boundaryCheck = vecTy.getRank() > 1;
297
298 auto descType = xegpu::TensorDescType::get(
299 vecTy.getShape(), vecTy.getElementType(),
300 /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
301 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
302 rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
303
304 // By default, no specific caching policy is assigned.
305 xegpu::CachePolicyAttr hint = nullptr;
306 auto storeNdOp =
307 rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
308 /*l1_hint=*/hint,
309 /*l2_hint=*/hint, /*l3_hint=*/hint);
310 rewriter.replaceOp(storeOp, storeNdOp);
311
312 return success();
313 }
314};
315
316struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
317 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
318
319 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
320 PatternRewriter &rewriter) const override {
321 Location loc = contractOp.getLoc();
322
323 if (contractOp.getKind() != vector::CombiningKind::ADD)
324 return rewriter.notifyMatchFailure(contractOp,
325 "Expects add combining kind");
326
327 TypedValue<Type> acc = contractOp.getAcc();
328 VectorType accType = dyn_cast<VectorType>(acc.getType());
329 if (!accType || accType.getRank() != 2)
330 return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
331
332 // Accept only plain 2D data layout.
333 // VNNI packing is applied to DPAS as a separate lowering step.
334 TypedValue<VectorType> lhs = contractOp.getLhs();
335 TypedValue<VectorType> rhs = contractOp.getRhs();
336 if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
337 return rewriter.notifyMatchFailure(contractOp,
338 "Expects lhs and rhs 2D vectors");
339
340 if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
341 return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
342
343 // TODO: Update shape validation to be target aware.
344 auto accShape = accType.getShape();
345 int64_t dimN = accShape[1];
346 if (dimN != 8 && dimN != 16)
347 return rewriter.notifyMatchFailure(contractOp,
348 "Invalid operand dimensions");
349
350 auto dpasOp = rewriter.create<xegpu::DpasOp>(
351 loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
352 rewriter.replaceOp(contractOp, dpasOp);
353
354 return success();
355 }
356};
357
358struct ConvertVectorToXeGPUPass
359 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
360 void runOnOperation() override {
361 RewritePatternSet patterns(&getContext());
362 populateVectorToXeGPUConversionPatterns(patterns);
363 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
364 return signalPassFailure();
365 }
366};
367
368} // namespace
369
370void mlir::populateVectorToXeGPUConversionPatterns(
371 RewritePatternSet &patterns) {
372 patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
373 StoreLowering, ContractionLowering>(arg: patterns.getContext());
374}
375

source code of mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp