1//===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements lowering of vector operations to XeGPU dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/MemRef/IR/MemRef.h"
17#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
18#include "mlir/Dialect/Vector/IR/VectorOps.h"
19#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22#include "llvm/ADT/TypeSwitch.h"
23
24#include <algorithm>
25#include <optional>
26
27namespace mlir {
28#define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33
34namespace {
35
36// Return true if value represents a zero constant.
37static bool isZeroConstant(Value val) {
38 auto constant = val.getDefiningOp<arith::ConstantOp>();
39 if (!constant)
40 return false;
41
42 return TypeSwitch<Attribute, bool>(constant.getValue())
43 .Case<FloatAttr>(
44 caseFn: [](auto floatAttr) { return floatAttr.getValue().isZero(); })
45 .Case<IntegerAttr>(
46 caseFn: [](auto intAttr) { return intAttr.getValue().isZero(); })
47 .Default(defaultFn: [](auto) { return false; });
48}
49
50static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
51 Operation *op, VectorType vecTy) {
52 // Validate only vector as the basic vector store and load ops guarantee
53 // XeGPU-compatible memref source.
54 unsigned vecRank = vecTy.getRank();
55 if (!(vecRank == 1 || vecRank == 2))
56 return rewriter.notifyMatchFailure(arg&: op, msg: "Expects 1D or 2D vector");
57
58 return success();
59}
60
61static LogicalResult transferPreconditions(PatternRewriter &rewriter,
62 VectorTransferOpInterface xferOp) {
63 if (xferOp.getMask())
64 return rewriter.notifyMatchFailure(arg&: xferOp,
65 msg: "Masked transfer is not supported");
66
67 auto srcTy = dyn_cast<MemRefType>(Val: xferOp.getShapedType());
68 if (!srcTy)
69 return rewriter.notifyMatchFailure(arg&: xferOp, msg: "Expects memref source");
70
71 // Perform common data transfer checks.
72 VectorType vecTy = xferOp.getVectorType();
73 if (failed(Result: storeLoadPreconditions(rewriter, op: xferOp, vecTy)))
74 return failure();
75
76 // Validate further transfer op semantics.
77 SmallVector<int64_t> strides;
78 int64_t offset;
79 if (failed(Result: srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
80 return rewriter.notifyMatchFailure(
81 arg&: xferOp, msg: "Buffer must be contiguous in the innermost dimension");
82
83 unsigned vecRank = vecTy.getRank();
84 if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
85 return rewriter.notifyMatchFailure(
86 arg&: xferOp, msg: "Boundary check is available only for block instructions.");
87
88 AffineMap map = xferOp.getPermutationMap();
89 if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
90 return rewriter.notifyMatchFailure(arg&: xferOp, msg: "Unsupported permutation map");
91 unsigned numInputDims = map.getNumInputs();
92 for (AffineExpr expr : map.getResults().take_back(N: vecRank)) {
93 auto dim = dyn_cast<AffineDimExpr>(Val&: expr);
94 if (dim.getPosition() < (numInputDims - vecRank))
95 return rewriter.notifyMatchFailure(
96 arg&: xferOp, msg: "Only the innermost dimensions can be accessed");
97 }
98
99 return success();
100}
101
102static xegpu::CreateNdDescOp
103createNdDescriptor(PatternRewriter &rewriter, Location loc,
104 xegpu::TensorDescType descType, TypedValue<MemRefType> src,
105 Operation::operand_range offsets) {
106 MemRefType srcTy = src.getType();
107 auto [strides, offset] = srcTy.getStridesAndOffset();
108
109 xegpu::CreateNdDescOp ndDesc;
110 if (srcTy.hasStaticShape()) {
111 ndDesc = rewriter.create<xegpu::CreateNdDescOp>(location: loc, args&: descType, args&: src,
112 args: getAsOpFoldResult(values: offsets));
113 } else {
114 // In case of any dynamic shapes, source's shape and strides have to be
115 // explicitly provided.
116 SmallVector<Value> sourceDims;
117 unsigned srcRank = srcTy.getRank();
118 for (unsigned i = 0; i < srcRank; ++i)
119 sourceDims.push_back(Elt: rewriter.create<memref::DimOp>(location: loc, args&: src, args&: i));
120
121 SmallVector<int64_t> constOffsets;
122 SmallVector<Value> dynOffsets;
123 for (Value offset : offsets) {
124 std::optional<int64_t> staticVal = getConstantIntValue(ofr: offset);
125 if (!staticVal)
126 dynOffsets.push_back(Elt: offset);
127 constOffsets.push_back(Elt: staticVal.value_or(u: ShapedType::kDynamic));
128 }
129
130 SmallVector<Value> dynShapes;
131 for (auto [idx, shape] : llvm::enumerate(First: srcTy.getShape())) {
132 if (shape == ShapedType::kDynamic)
133 dynShapes.push_back(Elt: sourceDims[idx]);
134 }
135
136 // Compute strides in reverse order.
137 SmallVector<Value> dynStrides;
138 Value accStride = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
139 // Last stride is guaranteed to be static and unit.
140 for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
141 accStride =
142 rewriter.create<arith::MulIOp>(location: loc, args&: accStride, args&: sourceDims[i + 1]);
143 if (strides[i] == ShapedType::kDynamic)
144 dynStrides.push_back(Elt: accStride);
145 }
146 std::reverse(first: dynStrides.begin(), last: dynStrides.end());
147
148 ndDesc = rewriter.create<xegpu::CreateNdDescOp>(
149 location: loc, args&: descType, args&: src, args&: dynOffsets, args&: dynShapes, args&: dynStrides,
150 args: DenseI64ArrayAttr::get(context: rewriter.getContext(), content: constOffsets),
151 args: DenseI64ArrayAttr::get(context: rewriter.getContext(), content: srcTy.getShape()),
152 args: DenseI64ArrayAttr::get(context: rewriter.getContext(), content: strides));
153 }
154
155 return ndDesc;
156}
157
158struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
159 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
160
161 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
162 PatternRewriter &rewriter) const override {
163 Location loc = readOp.getLoc();
164
165 if (failed(Result: transferPreconditions(rewriter, xferOp: readOp)))
166 return failure();
167
168 bool isOutOfBounds = readOp.hasOutOfBoundsDim();
169 if (isOutOfBounds && !isZeroConstant(val: readOp.getPadding()))
170 return rewriter.notifyMatchFailure(
171 arg&: readOp, msg: "Unsupported non-zero padded out-of-bounds read");
172
173 AffineMap readMap = readOp.getPermutationMap();
174 bool isTransposeLoad = !readMap.isMinorIdentity();
175
176 VectorType vecTy = readOp.getVectorType();
177 Type elementType = vecTy.getElementType();
178 unsigned minTransposeBitWidth = 32;
179 if (isTransposeLoad &&
180 elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
181 return rewriter.notifyMatchFailure(
182 arg&: readOp, msg: "Unsupported data type for transposition");
183
184 // If load is transposed, get the base shape for the tensor descriptor.
185 SmallVector<int64_t> descShape(vecTy.getShape());
186 if (isTransposeLoad)
187 std::reverse(first: descShape.begin(), last: descShape.end());
188 auto descType = xegpu::TensorDescType::get(
189 shape: descShape, elementType, /*array_length=*/1,
190 /*boundary_check=*/isOutOfBounds, memory_space: xegpu::MemorySpace::Global);
191
192 xegpu::CreateNdDescOp ndDesc =
193 createNdDescriptor(rewriter, loc, descType,
194 src: dyn_cast<TypedValue<MemRefType>>(Val: readOp.getBase()),
195 offsets: readOp.getIndices());
196
197 DenseI64ArrayAttr transposeAttr =
198 !isTransposeLoad ? nullptr
199 : DenseI64ArrayAttr::get(context: rewriter.getContext(),
200 content: ArrayRef<int64_t>{1, 0});
201 // By default, no specific caching policy is assigned.
202 xegpu::CachePolicyAttr hint = nullptr;
203 auto loadOp = rewriter.create<xegpu::LoadNdOp>(
204 location: loc, args&: vecTy, args&: ndDesc, /*packed=*/args: nullptr, args&: transposeAttr,
205 /*l1_hint=*/args&: hint,
206 /*l2_hint=*/args&: hint, /*l3_hint=*/args&: hint);
207 rewriter.replaceOp(op: readOp, newOp: loadOp);
208
209 return success();
210 }
211};
212
213struct TransferWriteLowering
214 : public OpRewritePattern<vector::TransferWriteOp> {
215 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
216
217 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
218 PatternRewriter &rewriter) const override {
219 Location loc = writeOp.getLoc();
220
221 if (failed(Result: transferPreconditions(rewriter, xferOp: writeOp)))
222 return failure();
223
224 AffineMap map = writeOp.getPermutationMap();
225 if (!map.isMinorIdentity())
226 return rewriter.notifyMatchFailure(arg&: writeOp, msg: "Expects identity map");
227
228 VectorType vecTy = writeOp.getVectorType();
229 auto descType = xegpu::TensorDescType::get(
230 shape: vecTy.getShape(), elementType: vecTy.getElementType(),
231 /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
232 memory_space: xegpu::MemorySpace::Global);
233 xegpu::CreateNdDescOp ndDesc =
234 createNdDescriptor(rewriter, loc, descType,
235 src: dyn_cast<TypedValue<MemRefType>>(Val: writeOp.getBase()),
236 offsets: writeOp.getIndices());
237
238 // By default, no specific caching policy is assigned.
239 xegpu::CachePolicyAttr hint = nullptr;
240 auto storeOp =
241 rewriter.create<xegpu::StoreNdOp>(location: loc, args: writeOp.getVector(), args&: ndDesc,
242 /*l1_hint=*/args&: hint,
243 /*l2_hint=*/args&: hint, /*l3_hint=*/args&: hint);
244 rewriter.replaceOp(op: writeOp, newOp: storeOp);
245
246 return success();
247 }
248};
249
250struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
251 using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
252
253 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
254 PatternRewriter &rewriter) const override {
255 Location loc = loadOp.getLoc();
256
257 VectorType vecTy = loadOp.getResult().getType();
258 if (failed(Result: storeLoadPreconditions(rewriter, op: loadOp, vecTy)))
259 return failure();
260
261 // Boundary check is available only for block instructions.
262 bool boundaryCheck = vecTy.getRank() > 1;
263
264 auto descType = xegpu::TensorDescType::get(
265 shape: vecTy.getShape(), elementType: vecTy.getElementType(), /*array_length=*/1,
266 boundary_check: boundaryCheck, memory_space: xegpu::MemorySpace::Global);
267 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
268 rewriter, loc, descType, src: loadOp.getBase(), offsets: loadOp.getIndices());
269
270 // By default, no specific caching policy is assigned.
271 xegpu::CachePolicyAttr hint = nullptr;
272 auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
273 location: loc, args&: vecTy, args&: ndDesc, /*packed=*/args: nullptr, /*transpose=*/args: nullptr,
274 /*l1_hint=*/args&: hint,
275 /*l2_hint=*/args&: hint, /*l3_hint=*/args&: hint);
276 rewriter.replaceOp(op: loadOp, newOp: loadNdOp);
277
278 return success();
279 }
280};
281
282struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
283 using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
284
285 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
286 PatternRewriter &rewriter) const override {
287 Location loc = storeOp.getLoc();
288
289 TypedValue<VectorType> vector = storeOp.getValueToStore();
290 VectorType vecTy = vector.getType();
291 if (failed(Result: storeLoadPreconditions(rewriter, op: storeOp, vecTy)))
292 return failure();
293
294 // Boundary check is available only for block instructions.
295 bool boundaryCheck = vecTy.getRank() > 1;
296
297 auto descType = xegpu::TensorDescType::get(
298 shape: vecTy.getShape(), elementType: vecTy.getElementType(),
299 /*array_length=*/1, boundary_check: boundaryCheck, memory_space: xegpu::MemorySpace::Global);
300 xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
301 rewriter, loc, descType, src: storeOp.getBase(), offsets: storeOp.getIndices());
302
303 // By default, no specific caching policy is assigned.
304 xegpu::CachePolicyAttr hint = nullptr;
305 auto storeNdOp =
306 rewriter.create<xegpu::StoreNdOp>(location: loc, args&: vector, args&: ndDesc,
307 /*l1_hint=*/args&: hint,
308 /*l2_hint=*/args&: hint, /*l3_hint=*/args&: hint);
309 rewriter.replaceOp(op: storeOp, newOp: storeNdOp);
310
311 return success();
312 }
313};
314
315struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
316 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
317
318 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
319 PatternRewriter &rewriter) const override {
320 Location loc = contractOp.getLoc();
321
322 if (contractOp.getKind() != vector::CombiningKind::ADD)
323 return rewriter.notifyMatchFailure(arg&: contractOp,
324 msg: "Expects add combining kind");
325
326 TypedValue<Type> acc = contractOp.getAcc();
327 VectorType accType = dyn_cast<VectorType>(Val: acc.getType());
328 if (!accType || accType.getRank() != 2)
329 return rewriter.notifyMatchFailure(arg&: contractOp, msg: "Expects acc 2D vector");
330
331 // Accept only plain 2D data layout.
332 // VNNI packing is applied to DPAS as a separate lowering step.
333 TypedValue<VectorType> lhs = contractOp.getLhs();
334 TypedValue<VectorType> rhs = contractOp.getRhs();
335 if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
336 return rewriter.notifyMatchFailure(arg&: contractOp,
337 msg: "Expects lhs and rhs 2D vectors");
338
339 if (!isRowMajorMatmul(indexingMaps: contractOp.getIndexingMapsAttr()))
340 return rewriter.notifyMatchFailure(arg&: contractOp, msg: "Invalid indexing maps");
341
342 auto dpasOp = rewriter.create<xegpu::DpasOp>(
343 location: loc, args: TypeRange{contractOp.getResultType()}, args: ValueRange{lhs, rhs, acc});
344 rewriter.replaceOp(op: contractOp, newOp: dpasOp);
345
346 return success();
347 }
348};
349
350struct ConvertVectorToXeGPUPass
351 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
352 void runOnOperation() override {
353 RewritePatternSet patterns(&getContext());
354 populateVectorToXeGPUConversionPatterns(patterns);
355 if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns))))
356 return signalPassFailure();
357 }
358};
359
360} // namespace
361
362void mlir::populateVectorToXeGPUConversionPatterns(
363 RewritePatternSet &patterns) {
364 patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
365 StoreLowering, ContractionLowering>(arg: patterns.getContext());
366}
367

source code of mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp