1 | //===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps ------------------===// |
2 | // |
3 | // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utility methods for working with the XeGPU dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" |
14 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
15 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
16 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/Operation.h" |
19 | #include "mlir/IR/ValueRange.h" |
20 | #include "mlir/Interfaces/LoopLikeInterface.h" |
21 | #include "mlir/Transforms/DialectConversion.h" |
22 | #include "llvm/Support/Debug.h" |
23 | #include "llvm/Support/FormatVariadic.h" |
24 | #include <cstdint> |
25 | #include <numeric> |
26 | |
27 | using namespace mlir; |
28 | |
29 | /// convert ArrayRef<ValueRange> into SmallVector<Value> |
30 | static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { |
31 | SmallVector<Value> result; |
32 | for (const auto &vals : values) |
33 | llvm::append_range(C&: result, R: vals); |
34 | return result; |
35 | } |
36 | |
37 | FailureOr<VectorType> |
38 | mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) { |
39 | auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout()); |
40 | // It only works for subgroup level layout, which only has lane_layout |
41 | // and lane_data, and is to distribute a SIMD code into SIMT code. |
42 | if (!layout || !layout.isSgLayout()) |
43 | return failure(); |
44 | |
45 | SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef()); |
46 | SmallVector<int64_t> laneLayout(layout.getLaneLayout().asArrayRef()); |
47 | auto tdescShape = tdescTy.getShape(); |
48 | auto elementType = tdescTy.getElementType(); |
49 | |
50 | // compute sgSize by multiply elements of laneLayout |
51 | // e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1] |
52 | // e.g. for 1D layout, sgSize = laneLayout[0] |
53 | auto sgSize = std::accumulate(first: laneLayout.begin(), last: laneLayout.end(), init: 1, |
54 | binary_op: std::multiplies<int64_t>()); |
55 | |
56 | // Case 1: regular loads/stores |
57 | auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr(); |
58 | if (scatterAttr) { |
59 | auto chunkSize = scatterAttr.getChunkSize().getInt(); |
60 | // Verify if the first dimension of the tensor descriptor shape is |
61 | // distributable. |
62 | assert(tdescShape[0] == laneLayout[0] && |
63 | "tensor descriptor shape is not distributable" ); |
64 | return VectorType::get({chunkSize}, elementType); |
65 | } |
66 | |
67 | // Case 2: block loads/stores |
68 | // Check if the tensor descriptor shape is distributable. |
69 | int64_t tensorSize = 1; |
70 | for (auto [tdescDim, laneDim, laneDataDim] : |
71 | llvm::zip_equal(tdescShape, laneLayout, laneData)) { |
72 | assert((tdescDim % (laneDim * laneDataDim) == 0) && |
73 | "tensor descriptor shape is not distributable" ); |
74 | tensorSize *= tdescDim; |
75 | } |
76 | // tensorSize must be adjusted for array_length. |
77 | tensorSize *= tdescTy.getArrayLength(); |
78 | |
79 | return VectorType::get({tensorSize / sgSize}, elementType); |
80 | } |
81 | |
82 | FailureOr<VectorType> |
83 | mlir::xegpu::getDistributedVectorType(VectorType originalType, |
84 | xegpu::LayoutAttr layout) { |
85 | int64_t rank = originalType.getRank(); |
86 | // Distributed vector type is only supported for 1D, 2D and 3D vectors. |
87 | if (rank < 1 || rank > 3) |
88 | return failure(); |
89 | ArrayRef<int64_t> shape = originalType.getShape(); |
90 | // arrayLength is 1 for 1D and 2D vectors, and equal to the first dimension |
91 | // of the 3D vector. |
92 | int arrayLength = 1; |
93 | if (rank == 3) { |
94 | arrayLength = shape[0]; |
95 | shape = shape.drop_front(); |
96 | } |
97 | auto helperTdescTy = xegpu::TensorDescType::get( |
98 | shape, originalType.getElementType(), arrayLength, |
99 | /*boundary_check=*/true, |
100 | /*memory_space=*/xegpu::MemorySpace::Global, layout); |
101 | return xegpu::getDistributedVectorType(helperTdescTy); |
102 | } |
103 | |
104 | std::string xegpu::getLayoutName(const OpOperand &operand) { |
105 | const StringRef prefix("layout_operand_" ); |
106 | unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber(); |
107 | return llvm::formatv(Fmt: "{0}{1}" , Vals: prefix, Vals&: idx).str(); |
108 | } |
109 | |
110 | std::string xegpu::getLayoutName(const OpResult result) { |
111 | const StringRef prefix = "layout_result_" ; |
112 | return llvm::formatv(Fmt: "{0}{1}" , Vals: prefix, Vals: result.getResultNumber()).str(); |
113 | } |
114 | |
115 | xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) { |
116 | if (!value) |
117 | return nullptr; |
118 | |
119 | if (auto tdescTy = |
120 | dyn_cast_if_present<xegpu::TensorDescType>(value.getType())) |
121 | return tdescTy.getLayoutAttr(); |
122 | |
123 | if (auto result = dyn_cast<OpResult>(Val: value)) { |
124 | Operation *defOp = result.getDefiningOp(); |
125 | assert(defOp && "result must have a defining op" ); |
126 | |
127 | // for LoadNdOp, the layout is stored in the tensor descriptor |
128 | if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp)) |
129 | return getLayoutAttr(loadNd.getTensorDesc()); |
130 | |
131 | std::string layoutName = getLayoutName(result); |
132 | if (defOp->hasAttr(name: layoutName)) |
133 | return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName); |
134 | } |
135 | |
136 | if (auto arg = dyn_cast<BlockArgument>(Val: value)) { |
137 | auto parentOp = arg.getOwner()->getParentOp(); |
138 | if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) { |
139 | OpOperand *tiedInit = loop.getTiedLoopInit(arg); |
140 | return getLayoutAttr(tiedInit->get()); |
141 | } |
142 | } |
143 | |
144 | return nullptr; |
145 | } |
146 | |
147 | xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) { |
148 | Operation *op = opr.getOwner(); |
149 | std::string layoutName = xegpu::getLayoutName(operand: opr); |
150 | if (op->hasAttr(name: layoutName)) |
151 | return op->getAttrOfType<xegpu::LayoutAttr>(layoutName); |
152 | return getLayoutAttr(opr.get()); |
153 | } |
154 | |
155 | template <typename T, typename> |
156 | void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) { |
157 | Operation *owner = operandOrResult.getOwner(); |
158 | std::string name = xegpu::getLayoutName(operandOrResult); |
159 | if (layout && !owner->hasAttrOfType<LayoutAttr>(name)) |
160 | owner->setAttr(name, layout); |
161 | } |
162 | |
163 | // Explicit instantiation for OpResult |
164 | template void |
165 | xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result, |
166 | const mlir::xegpu::LayoutAttr layout); |
167 | |
168 | // Explicit instantiation for OpOperand |
169 | template void |
170 | xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand, |
171 | const mlir::xegpu::LayoutAttr layout); |
172 | |
173 | void xegpu::setLayoutAttrs(Operation *op, |
174 | function_ref<LayoutAttr(Value)> getLayoutImpl) { |
175 | op->walk(callback: [&](Operation *nestOp) { |
176 | for (OpOperand &opr : nestOp->getOpOperands()) { |
177 | auto layout = getLayoutImpl(opr.get()); |
178 | setLayoutAttr(opr, layout); |
179 | } |
180 | for (OpResult result : nestOp->getOpResults()) { |
181 | auto layout = getLayoutImpl(result); |
182 | setLayoutAttr(result, layout); |
183 | } |
184 | }); |
185 | } |
186 | |
187 | SmallVector<Value> |
188 | xegpu::(OpBuilder &builder, Location loc, |
189 | Value value, ArrayRef<int64_t> shape) { |
190 | auto vecTy = dyn_cast<VectorType>(value.getType()); |
191 | if (!vecTy) |
192 | return {value}; |
193 | |
194 | ArrayRef<int64_t> srcShape = vecTy.getShape(); |
195 | if (!computeShapeRatio(shape: srcShape, subShape: shape)) |
196 | return {value}; |
197 | |
198 | SmallVector<Value> result; |
199 | for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) { |
200 | SmallVector<int64_t> staticStrides(offsets.size(), 1); |
201 | result.push_back(builder.create<vector::ExtractStridedSliceOp>( |
202 | loc, value, offsets, shape, staticStrides)); |
203 | } |
204 | |
205 | return result; |
206 | } |
207 | |
208 | Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc, |
209 | ValueRange values, |
210 | ArrayRef<int64_t> shape) { |
211 | VectorType inputTy = dyn_cast<VectorType>(values[0].getType()); |
212 | assert(llvm::all_of(values.getTypes(), |
213 | [&](Type type) { return type == inputTy; }) && |
214 | "values must be of the same VectorType" ); |
215 | |
216 | Type elemTy = inputTy.getElementType(); |
217 | ArrayRef<int64_t> tileShape = inputTy.getShape(); |
218 | |
219 | VectorType resultTy = VectorType::get(shape, elemTy); |
220 | auto zeroAttr = builder.getZeroAttr(elemTy); |
221 | Value result = builder.create<arith::ConstantOp>( |
222 | loc, resultTy, DenseElementsAttr::get(resultTy, zeroAttr)); |
223 | |
224 | for (auto [src, offsets] : |
225 | llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) { |
226 | SmallVector<int64_t> staticStrides(offsets.size(), 1); |
227 | result = builder.create<vector::InsertStridedSliceOp>( |
228 | loc, src, result, offsets, staticStrides); |
229 | } |
230 | return result; |
231 | } |
232 | |
233 | void xegpu::doSCFStructuralTypeConversionWithTensorType( |
234 | Operation *op, TypeConverter converter) { |
235 | MLIRContext *context = op->getContext(); |
236 | |
237 | auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs, |
238 | Location loc) -> Value { |
239 | return builder.create<UnrealizedConversionCastOp>(loc, type, inputs) |
240 | .getResult(0); |
241 | }; |
242 | |
243 | { // convert VectorType to RankedTensorType for SCF Structural ops |
244 | TypeConverter converter; |
245 | converter.addConversion(callback: [](Type type) -> Type { return type; }); |
246 | converter.addConversion(callback: [](VectorType type) -> Type { |
247 | return RankedTensorType::get(type.getShape(), type.getElementType()); |
248 | }); |
249 | converter.addSourceMaterialization(callback&: materializeCast); |
250 | converter.addTargetMaterialization(callback&: materializeCast); |
251 | |
252 | mlir::ConversionTarget target(*context); |
253 | target.addLegalOp<UnrealizedConversionCastOp>(); |
254 | |
255 | mlir::RewritePatternSet patterns(context); |
256 | scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns, |
257 | target); |
258 | (void)mlir::applyPartialConversion(op, target, std::move(patterns)); |
259 | } |
260 | |
261 | { // propagate the layout attribute to RankedTensorType by checking |
262 | // BuiltInUnrealizedCastOps |
263 | // for VectorType to RankedTensorType cast. |
264 | op->walk(callback: [](UnrealizedConversionCastOp castOp) { |
265 | if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1) |
266 | return WalkResult::skip(); |
267 | |
268 | Value input = castOp.getInputs()[0]; |
269 | Value result = castOp.getResults()[0]; |
270 | auto inputTy = dyn_cast<VectorType>(input.getType()); |
271 | auto resultTy = dyn_cast<RankedTensorType>(result.getType()); |
272 | |
273 | // Only look at ops casting from VectorType to RankedTensorType |
274 | if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy)) |
275 | return WalkResult::skip(); |
276 | |
277 | xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input); |
278 | if (!layout) |
279 | return WalkResult::skip(); |
280 | |
281 | RankedTensorType newTy = resultTy.cloneWithEncoding(layout); |
282 | result.setType(newTy); |
283 | |
284 | // update the arguments if user is a LoopLike op. |
285 | for (OpOperand &use : result.getUses()) { |
286 | if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) { |
287 | BlockArgument arg = loop.getTiedLoopRegionIterArg(&use); |
288 | arg.setType(newTy); |
289 | } |
290 | // whileOp has two regions, the BlockArgument of the after region |
291 | // is not exposed by LoopLikeOpInterface |
292 | if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) { |
293 | unsigned idx = use.getOperandNumber(); |
294 | BlockArgument arg = whileOp.getAfterArguments()[idx]; |
295 | arg.setType(newTy); |
296 | } |
297 | } |
298 | return WalkResult::advance(); |
299 | }); |
300 | |
301 | // using yieldOp as anchor to update the result type of its ParentOp |
302 | op->walk(callback: [](scf::YieldOp yieldOp) { |
303 | Operation *parentOp = yieldOp->getParentOp(); |
304 | for (OpResult r : parentOp->getOpResults()) { |
305 | unsigned idx = r.getResultNumber(); |
306 | Type resultTy = r.getType(); |
307 | Type yieldTy = yieldOp.getResults()[idx].getType(); |
308 | if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy) |
309 | r.setType(yieldTy); |
310 | } |
311 | }); |
312 | } |
313 | |
314 | { // perform the conversion from RankedTensorType to VectorType based on the |
315 | // LayoutAttr |
316 | |
317 | // Handle the UnrealizedConversionCastOp introduced by the first step. |
318 | // For vector->RankedTensorType, it will simply forward the inputs. |
319 | // For RankedTensorType->vector, it will update the inputs with the |
320 | // one from the adaptor. |
321 | class UnrealizedConversionCastOpPattern |
322 | : public OpConversionPattern<mlir::UnrealizedConversionCastOp> { |
323 | using OpConversionPattern< |
324 | mlir::UnrealizedConversionCastOp>::OpConversionPattern; |
325 | |
326 | mlir::LogicalResult |
327 | matchAndRewrite(mlir::UnrealizedConversionCastOp op, |
328 | OneToNOpAdaptor adaptor, |
329 | ConversionPatternRewriter &rewriter) const override { |
330 | auto inputs = op.getOperands(); |
331 | auto outputs = op.getOutputs(); |
332 | |
333 | if (inputs.size() != 1 || outputs.size() != 1) |
334 | return failure(); |
335 | |
336 | auto inputTy = inputs[0].getType(); |
337 | auto outputTy = outputs[0].getType(); |
338 | |
339 | if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) { |
340 | rewriter.replaceOpWithMultiple(op, adaptor.getInputs()); |
341 | return success(); |
342 | } |
343 | |
344 | if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) { |
345 | SmallVector<Value> values = flattenValues(adaptor.getInputs()); |
346 | auto newOp = rewriter.create<UnrealizedConversionCastOp>( |
347 | op.getLoc(), outputTy, values); |
348 | rewriter.replaceOp(op, newOp); |
349 | return success(); |
350 | } |
351 | return failure(); |
352 | } |
353 | }; |
354 | |
355 | converter.addSourceMaterialization(callback&: materializeCast); |
356 | converter.addTargetMaterialization(callback: [&](OpBuilder &builder, TypeRange type, |
357 | ValueRange inputs, Location loc) { |
358 | return builder.create<UnrealizedConversionCastOp>(loc, type, inputs) |
359 | .getResults(); |
360 | }); |
361 | |
362 | mlir::ConversionTarget target(*context); |
363 | target.addDynamicallyLegalOp<UnrealizedConversionCastOp>( |
364 | [](UnrealizedConversionCastOp op) { |
365 | auto isTensorTy = [](Type type) { |
366 | return isa<RankedTensorType>(type); |
367 | }; |
368 | return llvm::none_of(op->getOperandTypes(), isTensorTy) && |
369 | llvm::none_of(op->getResultTypes(), isTensorTy); |
370 | }); |
371 | mlir::RewritePatternSet patterns(context); |
372 | patterns.insert<UnrealizedConversionCastOpPattern>(arg&: context); |
373 | scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter: converter, patterns, |
374 | target); |
375 | (void)mlir::applyPartialConversion(op, target, std::move(patterns)); |
376 | } |
377 | } |
378 | |