1//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
10
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
12#include "mlir/Conversion/LLVMCommon/Pattern.h"
13#include "mlir/Dialect/GPU/IR/GPUDialect.h"
14#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Support/LLVM.h"
19#include "llvm/Support/FormatVariadic.h"
20
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/Types.h"
23
24#include "llvm/ADT/TypeSwitch.h"
25
26namespace mlir {
27#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
28#include "mlir/Conversion/Passes.h.inc"
29} // namespace mlir
30
31using namespace mlir;
32using namespace xevm;
33
34namespace {
35
36struct LLVMFuncAttributeOptions {
37 bool isConvergent = false;
38 bool isNoUnwind = false;
39 bool isWillReturn = false;
40 LLVM::MemoryEffectsAttr memEffectsAttr{};
41};
42static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
43 .isConvergent: false, .isNoUnwind: true, .isWillReturn: false, .memEffectsAttr: {}};
44static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
45 .isConvergent: false, .isNoUnwind: true, .isWillReturn: true, .memEffectsAttr: {}};
46static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
47 .isConvergent: true, .isNoUnwind: true, .isWillReturn: true, .memEffectsAttr: {}};
48
49std::string getTypeMangling(Type ty, bool isUnsigned = false) {
50 return TypeSwitch<Type, std::string>(ty)
51 .Case(caseFn: [isUnsigned](VectorType ty) -> std::string {
52 return "Dv" + std::to_string(val: ty.getNumElements()) + "_" +
53 getTypeMangling(ty: ty.getElementType(), isUnsigned);
54 })
55 .Case(caseFn: [](Float16Type) -> std::string { return "Dh"; })
56 .Case(caseFn: [](Float32Type) -> std::string { return "f"; })
57 .Case(caseFn: [](Float64Type) -> std::string { return "d"; })
58 .Case(caseFn: [isUnsigned](IntegerType ty) -> std::string {
59 switch (ty.getWidth()) {
60 case 8:
61 return isUnsigned ? "h" : "c";
62 case 16:
63 return isUnsigned ? "t" : "s";
64 case 32:
65 return isUnsigned ? "j" : "i";
66 case 64:
67 return isUnsigned ? "m" : "l";
68 default:
69 llvm_unreachable("unhandled integer type");
70 }
71 })
72 .Default(defaultFn: [](Type) -> std::string {
73 llvm_unreachable("unhandled type for mangling");
74 });
75}
76
77std::string mangle(StringRef baseName, ArrayRef<Type> types,
78 ArrayRef<bool> isUnsigned = {}) {
79 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
80 "Signedness info doesn't match");
81 std::string s;
82 llvm::raw_string_ostream os(s);
83 llvm::SmallDenseMap<Type, unsigned> substitutions;
84 os << "_Z" << baseName.size() << baseName;
85 for (auto [idx, type] : llvm::enumerate(First&: types)) {
86 auto it = substitutions.find(Val: type);
87 if (it != substitutions.end()) {
88 os << "S";
89 // First substitution is `S_`, second is `S0_`, and so on.
90 if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
91 os << firstIdx - 1;
92 os << "_";
93 } else {
94 if (!type.isIntOrFloat())
95 substitutions[type] = substitutions.size();
96 os << getTypeMangling(ty: type, isUnsigned: isUnsigned.empty() ? false : isUnsigned[idx]);
97 }
98 }
99 return os.str();
100}
101
102template <bool isLoad, typename OpType>
103int32_t getL1CacheControl(OpType op) {
104 int32_t control = 0;
105 if constexpr (isLoad) {
106 switch (*op.getCacheControl()) {
107 case LoadCacheControl::L1UC_L2UC_L3UC:
108 case LoadCacheControl::L1UC_L2UC_L3C:
109 case LoadCacheControl::L1UC_L2C_L3UC:
110 case LoadCacheControl::L1UC_L2C_L3C:
111 control = 1;
112 break;
113 case LoadCacheControl::L1C_L2UC_L3UC:
114 case LoadCacheControl::L1C_L2UC_L3C:
115 case LoadCacheControl::L1C_L2C_L3UC:
116 case LoadCacheControl::L1C_L2C_L3C:
117 control = 2;
118 break;
119 case LoadCacheControl::L1S_L2UC_L3UC:
120 case LoadCacheControl::L1S_L2UC_L3C:
121 case LoadCacheControl::L1S_L2C_L3UC:
122 case LoadCacheControl::L1S_L2C_L3C:
123 control = 3;
124 break;
125 case LoadCacheControl::INVALIDATE_READ:
126 control = 4;
127 break;
128 }
129 } else {
130 switch (*op.getCacheControl()) {
131 case StoreCacheControl::L1UC_L2UC_L3UC:
132 case StoreCacheControl::L1UC_L2UC_L3WB:
133 case StoreCacheControl::L1UC_L2WB_L3UC:
134 case StoreCacheControl::L1UC_L2WB_L3WB:
135 control = 1;
136 break;
137 case StoreCacheControl::L1WT_L2UC_L3UC:
138 case StoreCacheControl::L1WT_L2UC_L3WB:
139 case StoreCacheControl::L1WT_L2WB_L3UC:
140 case StoreCacheControl::L1WT_L2WB_L3WB:
141 control = 2;
142 break;
143 case StoreCacheControl::L1S_L2UC_L3UC:
144 case StoreCacheControl::L1S_L2UC_L3WB:
145 case StoreCacheControl::L1S_L2WB_L3UC:
146 case StoreCacheControl::L1S_L2WB_L3WB:
147 control = 3;
148 break;
149 case StoreCacheControl::L1WB_L2UC_L3UC:
150 case StoreCacheControl::L1WB_L2WB_L3UC:
151 case StoreCacheControl::L1WB_L2UC_L3WB:
152 control = 4;
153 break;
154 }
155 }
156 return control;
157}
158
159template <bool isLoad, typename OpType>
160int32_t getL3CacheControl(OpType op) {
161 int32_t control = 0;
162 if constexpr (isLoad) {
163 switch (*op.getCacheControl()) {
164 case LoadCacheControl::L1UC_L2UC_L3UC:
165 case LoadCacheControl::L1UC_L2C_L3UC:
166 case LoadCacheControl::L1C_L2UC_L3UC:
167 case LoadCacheControl::L1C_L2C_L3UC:
168 case LoadCacheControl::L1S_L2UC_L3UC:
169 case LoadCacheControl::L1S_L2C_L3UC:
170 control = 1;
171 break;
172 case LoadCacheControl::L1UC_L2UC_L3C:
173 case LoadCacheControl::L1UC_L2C_L3C:
174 case LoadCacheControl::L1C_L2UC_L3C:
175 case LoadCacheControl::L1C_L2C_L3C:
176 case LoadCacheControl::L1S_L2UC_L3C:
177 case LoadCacheControl::L1S_L2C_L3C:
178 control = 2;
179 break;
180 case LoadCacheControl::INVALIDATE_READ:
181 control = 4;
182 break;
183 }
184 } else {
185 switch (*op.getCacheControl()) {
186 case StoreCacheControl::L1UC_L2UC_L3UC:
187 case StoreCacheControl::L1UC_L2WB_L3UC:
188 case StoreCacheControl::L1WT_L2UC_L3UC:
189 case StoreCacheControl::L1WT_L2WB_L3UC:
190 case StoreCacheControl::L1S_L2UC_L3UC:
191 case StoreCacheControl::L1S_L2WB_L3UC:
192 case StoreCacheControl::L1WB_L2UC_L3UC:
193 case StoreCacheControl::L1WB_L2WB_L3UC:
194 control = 1;
195 break;
196 case StoreCacheControl::L1UC_L2UC_L3WB:
197 case StoreCacheControl::L1UC_L2WB_L3WB:
198 case StoreCacheControl::L1WT_L2UC_L3WB:
199 case StoreCacheControl::L1WT_L2WB_L3WB:
200 case StoreCacheControl::L1S_L2UC_L3WB:
201 case StoreCacheControl::L1S_L2WB_L3WB:
202 case StoreCacheControl::L1WB_L2UC_L3WB:
203 control = 2;
204 break;
205 }
206 }
207 return control;
208}
209
210template <bool isLoad, typename OpType>
211static std::optional<ArrayAttr>
212getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
213 if (!op.getCacheControl())
214 return {};
215 constexpr int32_t decorationCacheControlArity{4};
216 constexpr int32_t loadCacheControlKey{6442};
217 constexpr int32_t storeCacheControlKey{6443};
218 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
219 SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
220 controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
221 SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
222 controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
223 auto arrayAttrL1 = rewriter.getI32ArrayAttr(values: decorationsL1);
224 auto arrayAttrL3 = rewriter.getI32ArrayAttr(values: decorationsL3);
225
226 SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
227 return rewriter.getArrayAttr(value: combinedAttrs);
228}
229
230static LLVM::CallOp createDeviceFunctionCall(
231 ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
232 ArrayRef<Type> argTypes, ArrayRef<Value> args,
233 mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
234 LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
235 auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
236 assert(moduleOp && "Expecting module");
237 Location loc = op->getLoc();
238
239 auto funcOpRes =
240 LLVM::lookupOrCreateFn(b&: rewriter, moduleOp, name: funcName, paramTypes: argTypes, resultType: retType);
241 assert(!failed(funcOpRes));
242 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
243 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
244 funcOp.setConvergent(funcAttributeOptions.isConvergent);
245 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
246 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
247
248 if (funcAttributeOptions.memEffectsAttr)
249 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
250
251 for (auto [idx, attrName] : paramAttrs)
252 funcOp.setArgAttr(index: idx, name: attrName, value: rewriter.getUnitAttr());
253
254 auto callOp = rewriter.create<LLVM::CallOp>(location: loc, args&: funcOp, args);
255 callOp->setAttrs(funcOp->getAttrs());
256
257 return callOp;
258}
259
260class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
261 using OpConversionPattern::OpConversionPattern;
262 LogicalResult
263 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
264 ConversionPatternRewriter &rewriter) const override {
265 if (!op.getC()) {
266 return rewriter.notifyMatchFailure(arg&: op, msg: "OCL requires C operand");
267 }
268 auto precisionA = op.getTypes().getA();
269 auto precisionB = op.getTypes().getB();
270 auto precisionC = op.getTypes().getC();
271 auto precisionD = op.getTypes().getD();
272 if (precisionC != precisionD) {
273 return rewriter.notifyMatchFailure(arg&: op, msg: "type of C and D need to match");
274 }
275 if (precisionC != xevm::ElemType::S32 &&
276 precisionC != xevm::ElemType::F32 &&
277 precisionC != xevm::ElemType::F16 &&
278 precisionC != xevm::ElemType::BF16) {
279 return rewriter.notifyMatchFailure(
280 arg&: op, msg: "type of C and D must be S32, F32, F16 or BF16");
281 }
282 if (precisionA == xevm::ElemType::S32 ||
283 precisionA == xevm::ElemType::F32) {
284 return rewriter.notifyMatchFailure(arg&: op, msg: "type of A cannot be S32 or F32");
285 }
286 if (precisionB == xevm::ElemType::S32 ||
287 precisionB == xevm::ElemType::F32) {
288 return rewriter.notifyMatchFailure(arg&: op, msg: "type of B cannot be S32 or F32");
289 }
290 constexpr uint32_t bitWidthPackedA{16};
291 constexpr uint32_t bitWidthPackedB{32};
292 auto loc = op.getLoc();
293
294 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
295 VectorType origTy = cast<VectorType>(Val: val.getType());
296 const uint32_t vecBitSize =
297 origTy.getNumElements() *
298 origTy.getElementType().getIntOrFloatBitWidth();
299 VectorType newTy = VectorType::get(
300 shape: vecBitSize / packedType.getIntOrFloatBitWidth(), elementType: packedType);
301 if (origTy != newTy)
302 val = rewriter.create<LLVM::BitcastOp>(location: loc, args&: newTy, args&: val);
303 return val;
304 };
305
306 Value a = op.getA();
307 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
308 ? cast<Type>(Val: rewriter.getF32Type())
309 : rewriter.getIntegerType(width: bitWidthPackedA);
310 a = castIfNeeded(a, packedAType);
311
312 Value b = op.getB();
313 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
314 ? cast<Type>(Val: rewriter.getF32Type())
315 : rewriter.getIntegerType(width: bitWidthPackedB);
316 b = castIfNeeded(b, packedBType);
317
318 Value c = op.getC();
319 VectorType cOrigTy = cast<VectorType>(Val: c.getType());
320 VectorType resOrigTy = cast<VectorType>(Val: op->getResultTypes()[0]);
321 assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
322 // OCL builtins encode bfloat16 as int16
323 VectorType cTy =
324 cOrigTy.getElementType().isBF16()
325 ? VectorType::get(shape: cOrigTy.getShape(), elementType: rewriter.getIntegerType(width: 16))
326 : cOrigTy;
327 VectorType resTy = cTy;
328 if (cOrigTy != cTy)
329 c = rewriter.create<LLVM::BitcastOp>(location: loc, args&: cTy, args&: c);
330
331 constexpr int32_t systolicDepth{8};
332 std::string fnName =
333 llvm::formatv(Fmt: "intel_sub_group_{0}_{1}_matrix_mad_k{2}",
334 Vals: stringifyElemType(op.getTypes().getA()).str(),
335 Vals: stringifyElemType(op.getTypes().getB()).str(),
336 Vals: systolicDepth *
337 getNumOperandsPerDword(pTy: op.getTypes().getA()))
338 .str();
339 SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
340 fnName = mangle(baseName: fnName, types: argTypes);
341 SmallVector<Value> args{a, b, c};
342
343 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
344 /*other=*/args: LLVM::ModRefInfo::NoModRef,
345 /*argMem=*/args: LLVM::ModRefInfo::NoModRef,
346 /*inaccessibleMem=*/args: LLVM::ModRefInfo::NoModRef);
347 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
348 funcAttrs.memEffectsAttr = memAttr;
349 Value result =
350 createDeviceFunctionCall(rewriter, funcName: fnName, retType: resTy, argTypes, args, paramAttrs: {},
351 funcAttributeOptions: funcAttrs, op: op.getOperation())
352 ->getResult(idx: 0);
353
354 if (resOrigTy != resTy)
355 result = rewriter.create<LLVM::BitcastOp>(location: loc, args&: resOrigTy, args&: result);
356
357 rewriter.replaceOp(op, newValues: result);
358 return success();
359 }
360
361private:
362 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
363 switch (pTy) {
364 case xevm::ElemType::TF32:
365 return 1;
366 case xevm::ElemType::BF16:
367 case xevm::ElemType::F16:
368 return 2;
369 case xevm::ElemType::U8:
370 case xevm::ElemType::S8:
371 return 4;
372 default:
373 llvm_unreachable("unsupported xevm::ElemType");
374 }
375 }
376};
377
378class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
379 using OpConversionPattern::OpConversionPattern;
380 LogicalResult
381 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
382 ConversionPatternRewriter &rewriter) const override {
383 auto loc = op.getLoc();
384 const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
385 Value one =
386 rewriter.create<LLVM::ConstantOp>(location: loc, args: rewriter.getI64Type(), args: 1);
387 SmallVector<Value> args{op.getPtr(), one};
388 SmallVector<Type> argTypes;
389 for (auto arg : args)
390 argTypes.push_back(Elt: arg.getType());
391 auto funcAttr = noUnwindAttrs;
392 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
393 /*other=*/args: LLVM::ModRefInfo::NoModRef,
394 /*argMem=*/args: LLVM::ModRefInfo::Ref,
395 /*inaccessibleMem=*/args: LLVM::ModRefInfo::NoModRef);
396 funcAttr.memEffectsAttr = memAttr;
397
398 LLVM::CallOp call = createDeviceFunctionCall(
399 rewriter, funcName: fnName, retType: LLVM::LLVMVoidType::get(ctx: rewriter.getContext()),
400 argTypes, args, paramAttrs: {}, funcAttributeOptions: funcAttr, op: op.getOperation());
401 if (std::optional<ArrayAttr> optCacheControls =
402 getCacheControlMetadata<true>(rewriter, op))
403 call->setAttr(name: XeVMDialect::getCacheControlsAttrName(), value: *optCacheControls);
404 rewriter.eraseOp(op);
405 return success();
406 }
407};
408
409class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
410 using OpConversionPattern::OpConversionPattern;
411 LogicalResult
412 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
413 ConversionPatternRewriter &rewriter) const override {
414 auto loc = op.getLoc();
415 const std::string fnName{"atomic_work_item_fence"};
416 int memScope, addrSpace;
417 switch (op.getAddrspace()) {
418 case xevm::AddrSpace::SHARED:
419 addrSpace = 1; // CLK_LOCAL_MEM_FENCE
420 break;
421 case xevm::AddrSpace::GLOBAL:
422 addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
423 break;
424 default:
425 // GENERIC is not supported in OpenCL
426 return rewriter.notifyMatchFailure(
427 arg&: op, msg: "Fence only supports global and shared address spaces.");
428 }
429 switch (op.getScope()) {
430 case xevm::MemScope::WORKGROUP:
431 memScope = 1;
432 break;
433 case xevm::MemScope::DEVICE:
434 memScope = 2;
435 break;
436 default:
437 // CLUSTER and SYSTEM are not supported in OpenCL
438 return rewriter.notifyMatchFailure(
439 arg&: op, msg: "Fence only supports workgroup and device memory scopes.");
440 }
441 Type i32Type = rewriter.getI32Type();
442 Value acqRel = rewriter.create<LLVM::ConstantOp>(location: loc, args&: i32Type, args: 4);
443 Value memScopeConst =
444 rewriter.create<LLVM::ConstantOp>(location: loc, args&: i32Type, args&: memScope);
445 Value addrSpaceConst =
446 rewriter.create<LLVM::ConstantOp>(location: loc, args&: i32Type, args&: addrSpace);
447 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
448 SmallVector<Type> argTypes{3, i32Type};
449 createDeviceFunctionCall(rewriter, funcName: mangle(baseName: fnName, types: argTypes),
450 retType: LLVM::LLVMVoidType::get(ctx: rewriter.getContext()),
451 argTypes, args, paramAttrs: {}, funcAttributeOptions: noUnwindAttrs,
452 op: op.getOperation());
453 rewriter.eraseOp(op);
454 return success();
455 }
456};
457template <typename OpType>
458class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
459 using OpConversionPattern<OpType>::OpConversionPattern;
460 LogicalResult
461 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
462 ConversionPatternRewriter &rewriter) const override {
463 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
464 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
465
466 auto loc = op.getLoc();
467 VectorType vecType;
468 bool packReg = false;
469 bool transpose = false;
470 if constexpr (isLoad) {
471 vecType = op.getRes().getType();
472 packReg = op.getPackRegister();
473 transpose = op.getTranspose();
474 } else if constexpr (!isPrefetch) {
475 vecType = op.getStoredVal().getType();
476 }
477
478 auto i32Type = rewriter.getI32Type();
479 Value byteCoord =
480 rewriter.create<LLVM::UndefOp>(loc, VectorType::get(shape: 2, elementType: i32Type));
481 Value zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 0);
482 Value one = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 1);
483 byteCoord = rewriter.create<LLVM::InsertElementOp>(
484 loc, VectorType::get(shape: 2, elementType: i32Type), byteCoord, op.getX(), zero);
485 byteCoord = rewriter.create<LLVM::InsertElementOp>(
486 loc, VectorType::get(shape: 2, elementType: i32Type), byteCoord, op.getY(), one);
487 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
488 op.getBasePitch(), byteCoord};
489 SmallVector<Type> retTypes;
490 Value spvLoadDstPtr;
491 std::string funcName{"intel_sub_group_2d_block_"};
492 std::string bitWidthId;
493 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
494 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
495 if constexpr (isPrefetch) { // Prefetch
496 funcName += "prefetch";
497 paramAttrs = {std::make_pair(x: 0, y: LLVM::LLVMDialect::getNonNullAttrName())};
498 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
499 /*other=*/args: LLVM::ModRefInfo::NoModRef,
500 /*argMem=*/args: LLVM::ModRefInfo::Ref,
501 /*inaccessibleMem=*/args: LLVM::ModRefInfo::NoModRef);
502 funcAttr = noUnwindAttrs;
503 funcAttr.memEffectsAttr = memAttr;
504 } else {
505 auto vecElemType = vecType.getElementType();
506 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
507 Value numElems = rewriter.create<LLVM::ConstantOp>(
508 loc, i32Type, vecType.getNumElements());
509 auto dstOrSrcPtr = rewriter.create<LLVM::AllocaOp>(
510 loc, LLVM::LLVMPointerType::get(context: rewriter.getContext()), vecElemType,
511 numElems);
512 args.push_back(Elt: dstOrSrcPtr);
513 if constexpr (isLoad) { // Load
514 funcName += "read";
515 bitWidthId = getTypeMangling(ty: vecElemType, /*isUnsigned=*/true);
516 if (packReg)
517 funcName += "_transform";
518 else if (transpose)
519 funcName += "_transpose";
520 spvLoadDstPtr = dstOrSrcPtr;
521 retTypes.push_back(Elt: vecType);
522 paramAttrs = {
523 std::make_pair(x: 0, y: LLVM::LLVMDialect::getNonNullAttrName()),
524 std::make_pair(x: 0, y: LLVM::LLVMDialect::getReadonlyAttrName()),
525 std::make_pair(x: 5, y: LLVM::LLVMDialect::getNonNullAttrName()),
526 std::make_pair(x: 5, y: LLVM::LLVMDialect::getWriteOnlyAttrName()),
527 };
528 } else { // Store
529 funcName += "write";
530 bitWidthId = (vecElemBitWidth == 32)
531 ? "j"
532 : ((vecElemBitWidth == 16) ? "t" : "h");
533 rewriter.create<LLVM::StoreOp>(loc, op.getStoredVal(), dstOrSrcPtr);
534 paramAttrs = {
535 std::make_pair(x: 0, y: LLVM::LLVMDialect::getNonNullAttrName()),
536 std::make_pair(x: 0, y: LLVM::LLVMDialect::getWriteOnlyAttrName()),
537 std::make_pair(x: 5, y: LLVM::LLVMDialect::getNonNullAttrName()),
538 std::make_pair(x: 5, y: LLVM::LLVMDialect::getReadonlyAttrName()),
539 };
540 }
541 }
542
543 funcName =
544 llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
545 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
546 .str();
547 std::string prefetchCode("");
548 if (!isPrefetch)
549 prefetchCode += "P";
550 funcName = llvm::formatv(Fmt: "_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", Vals: funcName.size(),
551 Vals&: funcName, Vals&: prefetchCode, Vals&: bitWidthId)
552 .str();
553 SmallVector<Type> argTypes;
554 for (auto arg : args) {
555 argTypes.push_back(Elt: arg.getType());
556 }
557 LLVM::CallOp call = createDeviceFunctionCall(
558 rewriter, funcName, LLVM::LLVMVoidType::get(ctx: rewriter.getContext()),
559 argTypes, args, paramAttrs, funcAttr, op.getOperation());
560 if (std::optional<ArrayAttr> optCacheControls =
561 getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
562 call->setAttr(name: XeVMDialect::getCacheControlsAttrName(), value: *optCacheControls);
563 }
564 if constexpr (isLoad)
565 rewriter.replaceOp(
566 op, rewriter.create<LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
567 else
568 rewriter.eraseOp(op);
569 return success();
570 }
571};
572
573//===----------------------------------------------------------------------===//
574// Pass Definition
575//===----------------------------------------------------------------------===//
576
577struct ConvertXeVMToLLVMPass
578 : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
579 using Base::Base;
580
581 void getDependentDialects(DialectRegistry &registry) const override {
582 registry.insert<LLVM::LLVMDialect, XeVMDialect>();
583 }
584
585 void runOnOperation() override {
586 ConversionTarget target(getContext());
587 target.addLegalDialect<LLVM::LLVMDialect>();
588 target.addIllegalDialect<XeVMDialect>();
589 RewritePatternSet patterns(&getContext());
590 populateXeVMToLLVMConversionPatterns(patterns);
591 if (failed(Result: applyPartialConversion(op: getOperation(), target,
592 patterns: std::move(patterns))))
593 signalPassFailure();
594 }
595};
596} // namespace
597
598//===----------------------------------------------------------------------===//
599// ConvertToLLVMPatternInterface implementation
600//===----------------------------------------------------------------------===//
601
602namespace {
603/// Implement the interface to convert XeVM to LLVM.
604struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
605 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
606 void loadDependentDialects(MLIRContext *context) const final {
607 context->loadDialect<LLVM::LLVMDialect>();
608 }
609
610 /// Hook for derived dialect interface to provide conversion patterns
611 /// and mark dialect legal for the conversion target.
612 void populateConvertToLLVMConversionPatterns(
613 ConversionTarget &target, LLVMTypeConverter &typeConverter,
614 RewritePatternSet &patterns) const final {
615 populateXeVMToLLVMConversionPatterns(patterns);
616 }
617};
618} // namespace
619
620//===----------------------------------------------------------------------===//
621// Pattern Population
622//===----------------------------------------------------------------------===//
623
624void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
625 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
626 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
627 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
628 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>(
629 arg: patterns.getContext());
630}
631
632void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {
633 registry.addExtension(extensionFn: +[](MLIRContext *ctx, XeVMDialect *dialect) {
634 dialect->addInterfaces<XeVMToLLVMDialectInterface>();
635 });
636}
637

source code of mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp