1//===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps ------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utility methods for working with the XeGPU dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
14#include "mlir/Dialect/SCF/Transforms/Patterns.h"
15#include "mlir/Dialect/Utils/IndexingUtils.h"
16#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
17#include "mlir/IR/Builders.h"
18#include "mlir/IR/Operation.h"
19#include "mlir/IR/ValueRange.h"
20#include "mlir/Interfaces/LoopLikeInterface.h"
21#include "mlir/Transforms/DialectConversion.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/FormatVariadic.h"
24#include <cstdint>
25#include <numeric>
26
27using namespace mlir;
28
29/// convert ArrayRef<ValueRange> into SmallVector<Value>
30static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
31 SmallVector<Value> result;
32 for (const auto &vals : values)
33 llvm::append_range(C&: result, R: vals);
34 return result;
35}
36
37FailureOr<VectorType>
38mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
39 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
40 // It only works for subgroup level layout, which only has lane_layout
41 // and lane_data, and is to distribute a SIMD code into SIMT code.
42 if (!layout || !layout.isSgLayout())
43 return failure();
44
45 SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
46 SmallVector<int64_t> laneLayout(layout.getLaneLayout().asArrayRef());
47 auto tdescShape = tdescTy.getShape();
48 auto elementType = tdescTy.getElementType();
49
50 // compute sgSize by multiply elements of laneLayout
51 // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
52 // e.g. for 1D layout, sgSize = laneLayout[0]
53 auto sgSize = std::accumulate(first: laneLayout.begin(), last: laneLayout.end(), init: 1,
54 binary_op: std::multiplies<int64_t>());
55
56 // Case 1: regular loads/stores
57 auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
58 if (scatterAttr) {
59 auto chunkSize = scatterAttr.getChunkSize().getInt();
60 // Verify if the first dimension of the tensor descriptor shape is
61 // distributable.
62 assert(tdescShape[0] == laneLayout[0] &&
63 "tensor descriptor shape is not distributable");
64 return VectorType::get({chunkSize}, elementType);
65 }
66
67 // Case 2: block loads/stores
68 // Check if the tensor descriptor shape is distributable.
69 int64_t tensorSize = 1;
70 for (auto [tdescDim, laneDim, laneDataDim] :
71 llvm::zip_equal(tdescShape, laneLayout, laneData)) {
72 assert((tdescDim % (laneDim * laneDataDim) == 0) &&
73 "tensor descriptor shape is not distributable");
74 tensorSize *= tdescDim;
75 }
76 // tensorSize must be adjusted for array_length.
77 tensorSize *= tdescTy.getArrayLength();
78
79 return VectorType::get({tensorSize / sgSize}, elementType);
80}
81
82FailureOr<VectorType>
83mlir::xegpu::getDistributedVectorType(VectorType originalType,
84 xegpu::LayoutAttr layout) {
85 int64_t rank = originalType.getRank();
86 // Distributed vector type is only supported for 1D, 2D and 3D vectors.
87 if (rank < 1 || rank > 3)
88 return failure();
89 ArrayRef<int64_t> shape = originalType.getShape();
90 // arrayLength is 1 for 1D and 2D vectors, and equal to the first dimension
91 // of the 3D vector.
92 int arrayLength = 1;
93 if (rank == 3) {
94 arrayLength = shape[0];
95 shape = shape.drop_front();
96 }
97 auto helperTdescTy = xegpu::TensorDescType::get(
98 shape, originalType.getElementType(), arrayLength,
99 /*boundary_check=*/true,
100 /*memory_space=*/xegpu::MemorySpace::Global, layout);
101 return xegpu::getDistributedVectorType(helperTdescTy);
102}
103
104std::string xegpu::getLayoutName(const OpOperand &operand) {
105 const StringRef prefix("layout_operand_");
106 unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
107 return llvm::formatv(Fmt: "{0}{1}", Vals: prefix, Vals&: idx).str();
108}
109
110std::string xegpu::getLayoutName(const OpResult result) {
111 const StringRef prefix = "layout_result_";
112 return llvm::formatv(Fmt: "{0}{1}", Vals: prefix, Vals: result.getResultNumber()).str();
113}
114
115xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
116 if (!value)
117 return nullptr;
118
119 if (auto tdescTy =
120 dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
121 return tdescTy.getLayoutAttr();
122
123 if (auto result = dyn_cast<OpResult>(Val: value)) {
124 Operation *defOp = result.getDefiningOp();
125 assert(defOp && "result must have a defining op");
126
127 // for LoadNdOp, the layout is stored in the tensor descriptor
128 if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
129 return getLayoutAttr(loadNd.getTensorDesc());
130
131 std::string layoutName = getLayoutName(result);
132 if (defOp->hasAttr(name: layoutName))
133 return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
134 }
135
136 if (auto arg = dyn_cast<BlockArgument>(Val: value)) {
137 auto parentOp = arg.getOwner()->getParentOp();
138 if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
139 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
140 return getLayoutAttr(tiedInit->get());
141 }
142 }
143
144 return nullptr;
145}
146
147xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) {
148 Operation *op = opr.getOwner();
149 std::string layoutName = xegpu::getLayoutName(operand: opr);
150 if (op->hasAttr(name: layoutName))
151 return op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
152 return getLayoutAttr(opr.get());
153}
154
155template <typename T, typename>
156void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) {
157 Operation *owner = operandOrResult.getOwner();
158 std::string name = xegpu::getLayoutName(operandOrResult);
159 if (layout && !owner->hasAttrOfType<LayoutAttr>(name))
160 owner->setAttr(name, layout);
161}
162
163// Explicit instantiation for OpResult
164template void
165xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result,
166 const mlir::xegpu::LayoutAttr layout);
167
168// Explicit instantiation for OpOperand
169template void
170xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand,
171 const mlir::xegpu::LayoutAttr layout);
172
173void xegpu::setLayoutAttrs(Operation *op,
174 function_ref<LayoutAttr(Value)> getLayoutImpl) {
175 op->walk(callback: [&](Operation *nestOp) {
176 for (OpOperand &opr : nestOp->getOpOperands()) {
177 auto layout = getLayoutImpl(opr.get());
178 setLayoutAttr(opr, layout);
179 }
180 for (OpResult result : nestOp->getOpResults()) {
181 auto layout = getLayoutImpl(result);
182 setLayoutAttr(result, layout);
183 }
184 });
185}
186
187SmallVector<Value>
188xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
189 Value value, ArrayRef<int64_t> shape) {
190 auto vecTy = dyn_cast<VectorType>(value.getType());
191 if (!vecTy)
192 return {value};
193
194 ArrayRef<int64_t> srcShape = vecTy.getShape();
195 if (!computeShapeRatio(shape: srcShape, subShape: shape))
196 return {value};
197
198 SmallVector<Value> result;
199 for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
200 SmallVector<int64_t> staticStrides(offsets.size(), 1);
201 result.push_back(builder.create<vector::ExtractStridedSliceOp>(
202 loc, value, offsets, shape, staticStrides));
203 }
204
205 return result;
206}
207
208Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
209 ValueRange values,
210 ArrayRef<int64_t> shape) {
211 VectorType inputTy = dyn_cast<VectorType>(values[0].getType());
212 assert(llvm::all_of(values.getTypes(),
213 [&](Type type) { return type == inputTy; }) &&
214 "values must be of the same VectorType");
215
216 Type elemTy = inputTy.getElementType();
217 ArrayRef<int64_t> tileShape = inputTy.getShape();
218
219 VectorType resultTy = VectorType::get(shape, elemTy);
220 auto zeroAttr = builder.getZeroAttr(elemTy);
221 Value result = builder.create<arith::ConstantOp>(
222 loc, resultTy, DenseElementsAttr::get(resultTy, zeroAttr));
223
224 for (auto [src, offsets] :
225 llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
226 SmallVector<int64_t> staticStrides(offsets.size(), 1);
227 result = builder.create<vector::InsertStridedSliceOp>(
228 loc, src, result, offsets, staticStrides);
229 }
230 return result;
231}
232
233void xegpu::doSCFStructuralTypeConversionWithTensorType(
234 Operation *op, TypeConverter converter) {
235 MLIRContext *context = op->getContext();
236
237 auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
238 Location loc) -> Value {
239 return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
240 .getResult(0);
241 };
242
243 { // convert VectorType to RankedTensorType for SCF Structural ops
244 TypeConverter converter;
245 converter.addConversion(callback: [](Type type) -> Type { return type; });
246 converter.addConversion(callback: [](VectorType type) -> Type {
247 return RankedTensorType::get(type.getShape(), type.getElementType());
248 });
249 converter.addSourceMaterialization(callback&: materializeCast);
250 converter.addTargetMaterialization(callback&: materializeCast);
251
252 mlir::ConversionTarget target(*context);
253 target.addLegalOp<UnrealizedConversionCastOp>();
254
255 mlir::RewritePatternSet patterns(context);
256 scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns,
257 target);
258 (void)mlir::applyPartialConversion(op, target, std::move(patterns));
259 }
260
261 { // propagate the layout attribute to RankedTensorType by checking
262 // BuiltInUnrealizedCastOps
263 // for VectorType to RankedTensorType cast.
264 op->walk(callback: [](UnrealizedConversionCastOp castOp) {
265 if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
266 return WalkResult::skip();
267
268 Value input = castOp.getInputs()[0];
269 Value result = castOp.getResults()[0];
270 auto inputTy = dyn_cast<VectorType>(input.getType());
271 auto resultTy = dyn_cast<RankedTensorType>(result.getType());
272
273 // Only look at ops casting from VectorType to RankedTensorType
274 if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
275 return WalkResult::skip();
276
277 xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
278 if (!layout)
279 return WalkResult::skip();
280
281 RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
282 result.setType(newTy);
283
284 // update the arguments if user is a LoopLike op.
285 for (OpOperand &use : result.getUses()) {
286 if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
287 BlockArgument arg = loop.getTiedLoopRegionIterArg(&use);
288 arg.setType(newTy);
289 }
290 // whileOp has two regions, the BlockArgument of the after region
291 // is not exposed by LoopLikeOpInterface
292 if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
293 unsigned idx = use.getOperandNumber();
294 BlockArgument arg = whileOp.getAfterArguments()[idx];
295 arg.setType(newTy);
296 }
297 }
298 return WalkResult::advance();
299 });
300
301 // using yieldOp as anchor to update the result type of its ParentOp
302 op->walk(callback: [](scf::YieldOp yieldOp) {
303 Operation *parentOp = yieldOp->getParentOp();
304 for (OpResult r : parentOp->getOpResults()) {
305 unsigned idx = r.getResultNumber();
306 Type resultTy = r.getType();
307 Type yieldTy = yieldOp.getResults()[idx].getType();
308 if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
309 r.setType(yieldTy);
310 }
311 });
312 }
313
314 { // perform the conversion from RankedTensorType to VectorType based on the
315 // LayoutAttr
316
317 // Handle the UnrealizedConversionCastOp introduced by the first step.
318 // For vector->RankedTensorType, it will simply forward the inputs.
319 // For RankedTensorType->vector, it will update the inputs with the
320 // one from the adaptor.
321 class UnrealizedConversionCastOpPattern
322 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
323 using OpConversionPattern<
324 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
325
326 mlir::LogicalResult
327 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
328 OneToNOpAdaptor adaptor,
329 ConversionPatternRewriter &rewriter) const override {
330 auto inputs = op.getOperands();
331 auto outputs = op.getOutputs();
332
333 if (inputs.size() != 1 || outputs.size() != 1)
334 return failure();
335
336 auto inputTy = inputs[0].getType();
337 auto outputTy = outputs[0].getType();
338
339 if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) {
340 rewriter.replaceOpWithMultiple(op, adaptor.getInputs());
341 return success();
342 }
343
344 if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
345 SmallVector<Value> values = flattenValues(adaptor.getInputs());
346 auto newOp = rewriter.create<UnrealizedConversionCastOp>(
347 op.getLoc(), outputTy, values);
348 rewriter.replaceOp(op, newOp);
349 return success();
350 }
351 return failure();
352 }
353 };
354
355 converter.addSourceMaterialization(callback&: materializeCast);
356 converter.addTargetMaterialization(callback: [&](OpBuilder &builder, TypeRange type,
357 ValueRange inputs, Location loc) {
358 return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
359 .getResults();
360 });
361
362 mlir::ConversionTarget target(*context);
363 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
364 [](UnrealizedConversionCastOp op) {
365 auto isTensorTy = [](Type type) {
366 return isa<RankedTensorType>(type);
367 };
368 return llvm::none_of(op->getOperandTypes(), isTensorTy) &&
369 llvm::none_of(op->getResultTypes(), isTensorTy);
370 });
371 mlir::RewritePatternSet patterns(context);
372 patterns.insert<UnrealizedConversionCastOpPattern>(arg&: context);
373 scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns,
374 target);
375 (void)mlir::applyPartialConversion(op, target, std::move(patterns));
376 }
377}
378

source code of mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp