1//===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps ------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utility methods for working with the XeGPU dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
14#include "mlir/Dialect/SCF/Transforms/Patterns.h"
15#include "mlir/Dialect/Utils/IndexingUtils.h"
16#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/Operation.h"
19#include "mlir/IR/ValueRange.h"
20#include "mlir/Interfaces/LoopLikeInterface.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/FormatVariadic.h"
24#include <cstdint>
25#include <numeric>
26
27using namespace mlir;
28
29/// convert ArrayRef<ValueRange> into SmallVector<Value>
30SmallVector<Value> xegpu::flattenValues(ArrayRef<ValueRange> values) {
31 SmallVector<Value> result;
32 for (const auto &vals : values)
33 llvm::append_range(C&: result, R: vals);
34 return result;
35}
36
37FailureOr<VectorType>
38mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
39 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(Val: tdescTy.getLayout());
40 // It only works for subgroup level layout, which only has lane_layout
41 // and lane_data, and is to distribute a SIMD code into SIMT code.
42 if (!layout || !layout.isSgLayout())
43 return failure();
44
45 SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
46 SmallVector<int64_t> laneLayout(layout.getLaneLayout().asArrayRef());
47 auto tdescShape = tdescTy.getShape();
48 auto elementType = tdescTy.getElementType();
49
50 // compute sgSize by multiply elements of laneLayout
51 // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
52 // e.g. for 1D layout, sgSize = laneLayout[0]
53 auto sgSize = std::accumulate(first: laneLayout.begin(), last: laneLayout.end(), init: 1,
54 binary_op: std::multiplies<int64_t>());
55
56 // Case 1: regular loads/stores
57 auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
58 if (scatterAttr) {
59 auto chunkSize = scatterAttr.getChunkSize().getInt();
60 // Verify if the first dimension of the tensor descriptor shape is
61 // distributable.
62 assert(tdescShape[0] == laneLayout[0] &&
63 "tensor descriptor shape is not distributable");
64 return VectorType::get(shape: {chunkSize}, elementType);
65 }
66
67 // Case 2: block loads/stores
68 // Check if the tensor descriptor shape is distributable.
69 int64_t tensorSize = 1;
70 for (auto [tdescDim, laneDim, laneDataDim] :
71 llvm::zip_equal(t&: tdescShape, u&: laneLayout, args&: laneData)) {
72 assert((tdescDim % (laneDim * laneDataDim) == 0) &&
73 "tensor descriptor shape is not distributable");
74 tensorSize *= tdescDim;
75 }
76 // tensorSize must be adjusted for array_length.
77 tensorSize *= tdescTy.getArrayLength();
78
79 return VectorType::get(shape: {tensorSize / sgSize}, elementType);
80}
81
82FailureOr<VectorType>
83mlir::xegpu::getDistributedVectorType(VectorType originalType,
84 xegpu::LayoutAttr layout) {
85 int64_t rank = originalType.getRank();
86 // Distributed vector type is only supported for 1D, 2D and 3D vectors.
87 if (rank < 1 || rank > 3)
88 return failure();
89 ArrayRef<int64_t> shape = originalType.getShape();
90 // arrayLength is 1 for 1D and 2D vectors, and equal to the first dimension
91 // of the 3D vector.
92 int arrayLength = 1;
93 if (rank == 3) {
94 arrayLength = shape[0];
95 shape = shape.drop_front();
96 }
97 auto helperTdescTy = xegpu::TensorDescType::get(
98 shape, elementType: originalType.getElementType(), array_length: arrayLength,
99 /*boundary_check=*/true,
100 /*memory_space=*/xegpu::MemorySpace::Global, layout);
101 return xegpu::getDistributedVectorType(tdescTy: helperTdescTy);
102}
103
104std::string xegpu::getLayoutName(const OpOperand &operand) {
105 const StringRef prefix("layout_operand_");
106 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
107 return llvm::formatv(Fmt: "{0}{1}", Vals: prefix, Vals&: idx).str();
108}
109
110std::string xegpu::getLayoutName(const OpResult result) {
111 const StringRef prefix = "layout_result_";
112 return llvm::formatv(Fmt: "{0}{1}", Vals: prefix, Vals: result.getResultNumber()).str();
113}
114
115xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
116 if (!value)
117 return nullptr;
118
119 if (auto tdescTy =
120 dyn_cast_if_present<xegpu::TensorDescType>(Val: value.getType()))
121 return tdescTy.getLayoutAttr();
122
123 if (auto result = dyn_cast<OpResult>(Val: value)) {
124 Operation *defOp = result.getDefiningOp();
125 assert(defOp && "result must have a defining op");
126
127 // for LoadNdOp, the layout is stored in the tensor descriptor
128 if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(Val: defOp))
129 return getLayoutAttr(value: loadNd.getTensorDesc());
130
131 std::string layoutName = getLayoutName(result);
132 if (defOp->hasAttr(name: layoutName))
133 return defOp->getAttrOfType<xegpu::LayoutAttr>(name: layoutName);
134 }
135
136 if (auto arg = dyn_cast<BlockArgument>(Val: value)) {
137 auto parentOp = arg.getOwner()->getParentOp();
138 if (auto loop = dyn_cast<LoopLikeOpInterface>(Val: parentOp)) {
139 OpOperand *tiedInit = loop.getTiedLoopInit(bbArg: arg);
140 return getLayoutAttr(value: tiedInit->get());
141 }
142 }
143
144 return nullptr;
145}
146
147xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) {
148 Operation *op = opr.getOwner();
149 std::string layoutName = xegpu::getLayoutName(operand: opr);
150 if (op->hasAttr(name: layoutName))
151 return op->getAttrOfType<xegpu::LayoutAttr>(name: layoutName);
152 return getLayoutAttr(value: opr.get());
153}
154
155template <typename T, typename>
156void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) {
157 Operation *owner = operandOrResult.getOwner();
158 std::string name = xegpu::getLayoutName(operandOrResult);
159 if (layout && !owner->hasAttrOfType<LayoutAttr>(name))
160 owner->setAttr(name, value: layout);
161}
162
163// Explicit instantiation for OpResult
164template void
165xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result,
166 const mlir::xegpu::LayoutAttr layout);
167
168// Explicit instantiation for OpOperand
169template void
170xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand,
171 const mlir::xegpu::LayoutAttr layout);
172
173void xegpu::setLayoutAttrs(Operation *op,
174 function_ref<LayoutAttr(Value)> getLayoutImpl) {
175 op->walk(callback: [&](Operation *nestOp) {
176 for (OpOperand &opr : nestOp->getOpOperands()) {
177 auto layout = getLayoutImpl(opr.get());
178 setLayoutAttr(operandOrResult: opr, layout);
179 }
180 for (OpResult result : nestOp->getOpResults()) {
181 auto layout = getLayoutImpl(result);
182 setLayoutAttr(operandOrResult: result, layout);
183 }
184 });
185}
186
187SmallVector<Value>
188xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
189 Value value, ArrayRef<int64_t> shape) {
190 auto vecTy = dyn_cast<VectorType>(Val: value.getType());
191 if (!vecTy)
192 return {value};
193
194 ArrayRef<int64_t> srcShape = vecTy.getShape();
195 if (!computeShapeRatio(shape: srcShape, subShape: shape))
196 return {value};
197
198 SmallVector<Value> result;
199 for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
200 SmallVector<int64_t> staticStrides(offsets.size(), 1);
201 result.push_back(Elt: builder.create<vector::ExtractStridedSliceOp>(
202 location: loc, args&: value, args&: offsets, args&: shape, args&: staticStrides));
203 }
204
205 return result;
206}
207
208Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
209 ValueRange values,
210 ArrayRef<int64_t> shape) {
211 VectorType inputTy = dyn_cast<VectorType>(Val: values[0].getType());
212 assert(llvm::all_of(values.getTypes(),
213 [&](Type type) { return type == inputTy; }) &&
214 "values must be of the same VectorType");
215
216 Type elemTy = inputTy.getElementType();
217 ArrayRef<int64_t> tileShape = inputTy.getShape();
218
219 VectorType resultTy = VectorType::get(shape, elementType: elemTy);
220 auto zeroAttr = builder.getZeroAttr(type: elemTy);
221 Value result = builder.create<arith::ConstantOp>(
222 location: loc, args&: resultTy, args: DenseElementsAttr::get(type: resultTy, values: zeroAttr));
223
224 for (auto [src, offsets] :
225 llvm::zip_equal(t&: values, u: StaticTileOffsetRange(shape, tileShape))) {
226 SmallVector<int64_t> staticStrides(offsets.size(), 1);
227 result = builder.create<vector::InsertStridedSliceOp>(
228 location: loc, args&: src, args&: result, args&: offsets, args&: staticStrides);
229 }
230 return result;
231}
232
233void xegpu::doSCFStructuralTypeConversionWithTensorType(
234 Operation *op, TypeConverter converter) {
235 MLIRContext *context = op->getContext();
236
237 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
238 Location loc) -> Value {
239 return builder.create<UnrealizedConversionCastOp>(location: loc, args&: type, args&: inputs)
240 .getResult(i: 0);
241 };
242
243 { // convert VectorType to RankedTensorType for SCF Structural ops
244 TypeConverter converter;
245 converter.addConversion(callback: [](Type type) -> Type { return type; });
246 converter.addConversion(callback: [](VectorType type) -> Type {
247 return RankedTensorType::get(shape: type.getShape(), elementType: type.getElementType());
248 });
249 converter.addSourceMaterialization(callback&: materializeCast);
250 converter.addTargetMaterialization(callback&: materializeCast);
251
252 mlir::ConversionTarget target(*context);
253 target.addLegalOp<UnrealizedConversionCastOp>();
254
255 mlir::RewritePatternSet patterns(context);
256 scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns,
257 target);
258 (void)mlir::applyPartialConversion(op, target, patterns: std::move(patterns));
259 }
260
261 { // propagate the layout attribute to RankedTensorType by checking
262 // BuiltInUnrealizedCastOps
263 // for VectorType to RankedTensorType cast.
264 op->walk(callback: [](UnrealizedConversionCastOp castOp) {
265 if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
266 return WalkResult::skip();
267
268 Value input = castOp.getInputs()[0];
269 Value result = castOp.getResults()[0];
270 auto inputTy = dyn_cast<VectorType>(Val: input.getType());
271 auto resultTy = dyn_cast<RankedTensorType>(Val: result.getType());
272
273 // Only look at ops casting from VectorType to RankedTensorType
274 if (!inputTy || !resultTy)
275 return WalkResult::skip();
276
277 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value: input);
278 if (!layout)
279 return WalkResult::skip();
280
281 RankedTensorType newTy = resultTy.cloneWithEncoding(encoding: layout);
282 result.setType(newTy);
283
284 // update the arguments if user is a LoopLike op.
285 for (OpOperand &use : result.getUses()) {
286 if (auto loop = dyn_cast<LoopLikeOpInterface>(Val: use.getOwner())) {
287 BlockArgument arg = loop.getTiedLoopRegionIterArg(opOperand: &use);
288 arg.setType(newTy);
289 }
290 // whileOp has two regions, the BlockArgument of the after region
291 // is not exposed by LoopLikeOpInterface
292 if (auto whileOp = dyn_cast<scf::WhileOp>(Val: use.getOwner())) {
293 unsigned idx = use.getOperandNumber();
294 BlockArgument arg = whileOp.getAfterArguments()[idx];
295 arg.setType(newTy);
296 }
297 }
298 return WalkResult::advance();
299 });
300
301 // using yieldOp as anchor to update the result type of its ParentOp
302 op->walk(callback: [](scf::YieldOp yieldOp) {
303 Operation *parentOp = yieldOp->getParentOp();
304 for (OpResult r : parentOp->getOpResults()) {
305 unsigned idx = r.getResultNumber();
306 Type resultTy = r.getType();
307 Type yieldTy = yieldOp.getResults()[idx].getType();
308 if (isa<RankedTensorType>(Val: resultTy) && yieldTy != resultTy)
309 r.setType(yieldTy);
310 }
311 });
312 }
313
314 { // perform the conversion from RankedTensorType to VectorType based on the
315 // LayoutAttr
316
317 // Handle the UnrealizedConversionCastOp introduced by the first step.
318 // For vector->RankedTensorType, it will simply forward the inputs.
319 // For RankedTensorType->vector, it will update the inputs with the
320 // one from the adaptor.
321 class UnrealizedConversionCastOpPattern
322 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
323 using OpConversionPattern<
324 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
325
326 mlir::LogicalResult
327 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
328 OneToNOpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter) const override {
330 auto inputs = op.getOperands();
331 auto outputs = op.getOutputs();
332
333 if (inputs.size() != 1 || outputs.size() != 1)
334 return failure();
335
336 auto inputTy = inputs[0].getType();
337 auto outputTy = outputs[0].getType();
338
339 if (isa<VectorType>(Val: inputTy) && isa<RankedTensorType>(Val: outputTy)) {
340 rewriter.replaceOpWithMultiple(op, newValues: adaptor.getInputs());
341 return success();
342 }
343
344 if (isa<RankedTensorType>(Val: inputTy) && isa<VectorType>(Val: outputTy)) {
345 SmallVector<Value> values = xegpu::flattenValues(values: adaptor.getInputs());
346 auto newOp = rewriter.create<UnrealizedConversionCastOp>(
347 location: op.getLoc(), args&: outputTy, args&: values);
348 rewriter.replaceOp(op, newOp);
349 return success();
350 }
351 return failure();
352 }
353 };
354
355 converter.addSourceMaterialization(callback&: materializeCast);
356 converter.addTargetMaterialization(callback: [&](OpBuilder &builder, TypeRange type,
357 ValueRange inputs, Location loc) {
358 return builder.create<UnrealizedConversionCastOp>(location: loc, args&: type, args&: inputs)
359 .getResults();
360 });
361
362 mlir::ConversionTarget target(*context);
363 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
364 callback: [](UnrealizedConversionCastOp op) {
365 auto isTensorTy = [](Type type) {
366 return isa<RankedTensorType>(Val: type);
367 };
368 return llvm::none_of(Range: op->getOperandTypes(), P: isTensorTy) &&
369 llvm::none_of(Range: op->getResultTypes(), P: isTensorTy);
370 });
371 mlir::RewritePatternSet patterns(context);
372 patterns.insert<UnrealizedConversionCastOpPattern>(arg&: context);
373 scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns,
374 target);
375 (void)mlir::applyPartialConversion(op, target, patterns: std::move(patterns));
376 }
377}
378

source code of mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp