1 | //===- GPUToLLVMSPV.cpp - Convert GPU operations to LLVM dialect ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h" |
10 | |
11 | #include "../GPUCommon/GPUOpsLowering.h" |
12 | #include "mlir/Conversion/GPUCommon/AttrToSPIRVConverter.h" |
13 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
14 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
15 | #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" |
16 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
17 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
18 | #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h" |
19 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
20 | #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
21 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
22 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
23 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" |
24 | #include "mlir/IR/BuiltinTypes.h" |
25 | #include "mlir/IR/Matchers.h" |
26 | #include "mlir/IR/PatternMatch.h" |
27 | #include "mlir/IR/SymbolTable.h" |
28 | #include "mlir/Pass/Pass.h" |
29 | #include "mlir/Support/LLVM.h" |
30 | #include "mlir/Transforms/DialectConversion.h" |
31 | |
32 | #include "llvm/ADT/TypeSwitch.h" |
33 | #include "llvm/Support/FormatVariadic.h" |
34 | |
35 | #define DEBUG_TYPE "gpu-to-llvm-spv" |
36 | |
37 | using namespace mlir; |
38 | |
39 | namespace mlir { |
40 | #define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS |
41 | #include "mlir/Conversion/Passes.h.inc" |
42 | } // namespace mlir |
43 | |
44 | //===----------------------------------------------------------------------===// |
45 | // Helper Functions |
46 | //===----------------------------------------------------------------------===// |
47 | |
48 | static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, |
49 | StringRef name, |
50 | ArrayRef<Type> paramTypes, |
51 | Type resultType, bool isMemNone, |
52 | bool isConvergent) { |
53 | auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>( |
54 | SymbolTable::lookupSymbolIn(symbolTable, name)); |
55 | if (!func) { |
56 | OpBuilder b(symbolTable->getRegion(index: 0)); |
57 | func = b.create<LLVM::LLVMFuncOp>( |
58 | symbolTable->getLoc(), name, |
59 | LLVM::LLVMFunctionType::get(resultType, paramTypes)); |
60 | func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); |
61 | func.setNoUnwind(true); |
62 | func.setWillReturn(true); |
63 | |
64 | if (isMemNone) { |
65 | // no externally observable effects |
66 | constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef; |
67 | auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>( |
68 | /*other=*/noModRef, |
69 | /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); |
70 | func.setMemoryEffectsAttr(memAttr); |
71 | } |
72 | |
73 | func.setConvergent(isConvergent); |
74 | } |
75 | return func; |
76 | } |
77 | |
78 | static LLVM::CallOp createSPIRVBuiltinCall(Location loc, |
79 | ConversionPatternRewriter &rewriter, |
80 | LLVM::LLVMFuncOp func, |
81 | ValueRange args) { |
82 | auto call = rewriter.create<LLVM::CallOp>(loc, func, args); |
83 | call.setCConv(func.getCConv()); |
84 | call.setConvergentAttr(func.getConvergentAttr()); |
85 | call.setNoUnwindAttr(func.getNoUnwindAttr()); |
86 | call.setWillReturnAttr(func.getWillReturnAttr()); |
87 | call.setMemoryEffectsAttr(func.getMemoryEffectsAttr()); |
88 | return call; |
89 | } |
90 | |
91 | namespace { |
92 | //===----------------------------------------------------------------------===// |
93 | // Barriers |
94 | //===----------------------------------------------------------------------===// |
95 | |
96 | /// Replace `gpu.barrier` with an `llvm.call` to `barrier` with |
97 | /// `CLK_LOCAL_MEM_FENCE` argument, indicating work-group memory scope: |
98 | /// ``` |
99 | /// // gpu.barrier |
100 | /// %c1 = llvm.mlir.constant(1: i32) : i32 |
101 | /// llvm.call spir_funccc @_Z7barrierj(%c1) : (i32) -> () |
102 | /// ``` |
103 | struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> { |
104 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
105 | |
106 | LogicalResult |
107 | matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor, |
108 | ConversionPatternRewriter &rewriter) const final { |
109 | constexpr StringLiteral funcName = "_Z7barrierj" ; |
110 | |
111 | Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); |
112 | assert(moduleOp && "Expecting module" ); |
113 | Type flagTy = rewriter.getI32Type(); |
114 | Type voidTy = rewriter.getType<LLVM::LLVMVoidType>(); |
115 | LLVM::LLVMFuncOp func = |
116 | lookupOrCreateSPIRVFn(moduleOp, funcName, flagTy, voidTy, |
117 | /*isMemNone=*/false, /*isConvergent=*/true); |
118 | |
119 | // Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`. |
120 | // See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`. |
121 | constexpr int64_t localMemFenceFlag = 1; |
122 | Location loc = op->getLoc(); |
123 | Value flag = |
124 | rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag); |
125 | rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag)); |
126 | return success(); |
127 | } |
128 | }; |
129 | |
130 | //===----------------------------------------------------------------------===// |
131 | // SPIR-V Builtins |
132 | //===----------------------------------------------------------------------===// |
133 | |
134 | /// Replace `gpu.*` with an `llvm.call` to the corresponding SPIR-V builtin with |
135 | /// a constant argument for the `dimension` attribute. Return type will depend |
136 | /// on index width option: |
137 | /// ``` |
138 | /// // %thread_id_y = gpu.thread_id y |
139 | /// %c1 = llvm.mlir.constant(1: i32) : i32 |
140 | /// %0 = llvm.call spir_funccc @_Z12get_local_idj(%c1) : (i32) -> i64 |
141 | /// ``` |
142 | struct LaunchConfigConversion : ConvertToLLVMPattern { |
143 | LaunchConfigConversion(StringRef funcName, StringRef rootOpName, |
144 | MLIRContext *context, |
145 | const LLVMTypeConverter &typeConverter, |
146 | PatternBenefit benefit) |
147 | : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit), |
148 | funcName(funcName) {} |
149 | |
150 | virtual gpu::Dimension getDimension(Operation *op) const = 0; |
151 | |
152 | LogicalResult |
153 | matchAndRewrite(Operation *op, ArrayRef<Value> operands, |
154 | ConversionPatternRewriter &rewriter) const final { |
155 | Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); |
156 | assert(moduleOp && "Expecting module" ); |
157 | Type dimTy = rewriter.getI32Type(); |
158 | Type indexTy = getTypeConverter()->getIndexType(); |
159 | LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(moduleOp, funcName, dimTy, |
160 | indexTy, /*isMemNone=*/true, |
161 | /*isConvergent=*/false); |
162 | |
163 | Location loc = op->getLoc(); |
164 | gpu::Dimension dim = getDimension(op); |
165 | Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy, |
166 | static_cast<int64_t>(dim)); |
167 | rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal)); |
168 | return success(); |
169 | } |
170 | |
171 | StringRef funcName; |
172 | }; |
173 | |
174 | template <typename SourceOp> |
175 | struct LaunchConfigOpConversion final : LaunchConfigConversion { |
176 | static StringRef getFuncName(); |
177 | |
178 | explicit LaunchConfigOpConversion(const LLVMTypeConverter &typeConverter, |
179 | PatternBenefit benefit = 1) |
180 | : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(), |
181 | &typeConverter.getContext(), typeConverter, |
182 | benefit) {} |
183 | |
184 | gpu::Dimension getDimension(Operation *op) const final { |
185 | return cast<SourceOp>(op).getDimension(); |
186 | } |
187 | }; |
188 | |
189 | template <> |
190 | StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() { |
191 | return "_Z12get_group_idj" ; |
192 | } |
193 | |
194 | template <> |
195 | StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() { |
196 | return "_Z14get_num_groupsj" ; |
197 | } |
198 | |
199 | template <> |
200 | StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() { |
201 | return "_Z14get_local_sizej" ; |
202 | } |
203 | |
204 | template <> |
205 | StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() { |
206 | return "_Z12get_local_idj" ; |
207 | } |
208 | |
209 | template <> |
210 | StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() { |
211 | return "_Z13get_global_idj" ; |
212 | } |
213 | |
214 | //===----------------------------------------------------------------------===// |
215 | // Shuffles |
216 | //===----------------------------------------------------------------------===// |
217 | |
218 | /// Replace `gpu.shuffle` with an `llvm.call` to the corresponding SPIR-V |
219 | /// builtin for `shuffleResult`, keeping `value` and `offset` arguments, and a |
220 | /// `true` constant for the `valid` result type. Conversion will only take place |
221 | /// if `width` is constant and equal to the `subgroup` pass option: |
222 | /// ``` |
223 | /// // %0 = gpu.shuffle idx %value, %offset, %width : f64 |
224 | /// %0 = llvm.call spir_funccc @_Z17sub_group_shuffledj(%value, %offset) |
225 | /// : (f64, i32) -> f64 |
226 | /// ``` |
227 | struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> { |
228 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
229 | |
230 | static StringRef getBaseName(gpu::ShuffleMode mode) { |
231 | switch (mode) { |
232 | case gpu::ShuffleMode::IDX: |
233 | return "sub_group_shuffle" ; |
234 | case gpu::ShuffleMode::XOR: |
235 | return "sub_group_shuffle_xor" ; |
236 | case gpu::ShuffleMode::UP: |
237 | return "sub_group_shuffle_up" ; |
238 | case gpu::ShuffleMode::DOWN: |
239 | return "sub_group_shuffle_down" ; |
240 | } |
241 | llvm_unreachable("Unhandled shuffle mode" ); |
242 | } |
243 | |
244 | static std::optional<StringRef> getTypeMangling(Type type) { |
245 | return TypeSwitch<Type, std::optional<StringRef>>(type) |
246 | .Case<Float16Type>([](auto) { return "Dhj" ; }) |
247 | .Case<Float32Type>([](auto) { return "fj" ; }) |
248 | .Case<Float64Type>([](auto) { return "dj" ; }) |
249 | .Case<IntegerType>([](auto intTy) -> std::optional<StringRef> { |
250 | switch (intTy.getWidth()) { |
251 | case 8: |
252 | return "cj" ; |
253 | case 16: |
254 | return "sj" ; |
255 | case 32: |
256 | return "ij" ; |
257 | case 64: |
258 | return "lj" ; |
259 | } |
260 | return std::nullopt; |
261 | }) |
262 | .Default([](auto) { return std::nullopt; }); |
263 | } |
264 | |
265 | static std::optional<std::string> getFuncName(gpu::ShuffleMode mode, |
266 | Type type) { |
267 | StringRef baseName = getBaseName(mode); |
268 | std::optional<StringRef> typeMangling = getTypeMangling(type); |
269 | if (!typeMangling) |
270 | return std::nullopt; |
271 | return llvm::formatv(Fmt: "_Z{}{}{}" , Vals: baseName.size(), Vals&: baseName, |
272 | Vals&: typeMangling.value()); |
273 | } |
274 | |
275 | /// Get the subgroup size from the target or return a default. |
276 | static std::optional<int> getSubgroupSize(Operation *op) { |
277 | auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>(); |
278 | if (!parentFunc) |
279 | return std::nullopt; |
280 | return parentFunc.getIntelReqdSubGroupSize(); |
281 | } |
282 | |
283 | static bool hasValidWidth(gpu::ShuffleOp op) { |
284 | llvm::APInt val; |
285 | Value width = op.getWidth(); |
286 | return matchPattern(width, m_ConstantInt(&val)) && |
287 | val == getSubgroupSize(op); |
288 | } |
289 | |
290 | static Value bitcastOrExtBeforeShuffle(Value oldVal, Location loc, |
291 | ConversionPatternRewriter &rewriter) { |
292 | return TypeSwitch<Type, Value>(oldVal.getType()) |
293 | .Case(caseFn: [&](BFloat16Type) { |
294 | return rewriter.create<LLVM::BitcastOp>(loc, rewriter.getI16Type(), |
295 | oldVal); |
296 | }) |
297 | .Case(caseFn: [&](IntegerType intTy) -> Value { |
298 | if (intTy.getWidth() == 1) |
299 | return rewriter.create<LLVM::ZExtOp>(loc, rewriter.getI8Type(), |
300 | oldVal); |
301 | return oldVal; |
302 | }) |
303 | .Default(defaultResult: oldVal); |
304 | } |
305 | |
306 | static Value bitcastOrTruncAfterShuffle(Value oldVal, Type newTy, |
307 | Location loc, |
308 | ConversionPatternRewriter &rewriter) { |
309 | return TypeSwitch<Type, Value>(newTy) |
310 | .Case(caseFn: [&](BFloat16Type) { |
311 | return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal); |
312 | }) |
313 | .Case(caseFn: [&](IntegerType intTy) -> Value { |
314 | if (intTy.getWidth() == 1) |
315 | return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal); |
316 | return oldVal; |
317 | }) |
318 | .Default(defaultResult: oldVal); |
319 | } |
320 | |
321 | LogicalResult |
322 | matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, |
323 | ConversionPatternRewriter &rewriter) const final { |
324 | if (!hasValidWidth(op)) |
325 | return rewriter.notifyMatchFailure( |
326 | op, "shuffle width and subgroup size mismatch" ); |
327 | |
328 | Location loc = op->getLoc(); |
329 | Value inValue = |
330 | bitcastOrExtBeforeShuffle(oldVal: adaptor.getValue(), loc, rewriter); |
331 | std::optional<std::string> funcName = |
332 | getFuncName(op.getMode(), inValue.getType()); |
333 | if (!funcName) |
334 | return rewriter.notifyMatchFailure(op, "unsupported value type" ); |
335 | |
336 | Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>(); |
337 | assert(moduleOp && "Expecting module" ); |
338 | Type valueType = inValue.getType(); |
339 | Type offsetType = adaptor.getOffset().getType(); |
340 | Type resultType = valueType; |
341 | LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn( |
342 | moduleOp, funcName.value(), {valueType, offsetType}, resultType, |
343 | /*isMemNone=*/false, /*isConvergent=*/true); |
344 | |
345 | std::array<Value, 2> args{inValue, adaptor.getOffset()}; |
346 | Value result = |
347 | createSPIRVBuiltinCall(loc, rewriter, func, args).getResult(); |
348 | Value resultOrConversion = |
349 | bitcastOrTruncAfterShuffle(oldVal: result, newTy: op.getType(0), loc, rewriter); |
350 | |
351 | Value trueVal = |
352 | rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true); |
353 | rewriter.replaceOp(op, {resultOrConversion, trueVal}); |
354 | return success(); |
355 | } |
356 | }; |
357 | |
358 | class MemorySpaceToOpenCLMemorySpaceConverter final : public TypeConverter { |
359 | public: |
360 | MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) { |
361 | addConversion(callback: [](Type t) { return t; }); |
362 | addConversion(callback: [ctx](BaseMemRefType memRefType) -> std::optional<Type> { |
363 | // Attach global addr space attribute to memrefs with no addr space attr |
364 | Attribute memSpaceAttr = memRefType.getMemorySpace(); |
365 | if (memSpaceAttr) |
366 | return std::nullopt; |
367 | |
368 | unsigned globalAddrspace = storageClassToAddressSpace( |
369 | spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup); |
370 | Attribute addrSpaceAttr = |
371 | IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace); |
372 | if (auto rankedType = dyn_cast<MemRefType>(memRefType)) { |
373 | return MemRefType::get(memRefType.getShape(), |
374 | memRefType.getElementType(), |
375 | rankedType.getLayout(), addrSpaceAttr); |
376 | } |
377 | return UnrankedMemRefType::get(memRefType.getElementType(), |
378 | addrSpaceAttr); |
379 | }); |
380 | addConversion(callback: [this](FunctionType type) { |
381 | auto inputs = llvm::map_to_vector( |
382 | type.getInputs(), [this](Type ty) { return convertType(t: ty); }); |
383 | auto results = llvm::map_to_vector( |
384 | type.getResults(), [this](Type ty) { return convertType(t: ty); }); |
385 | return FunctionType::get(type.getContext(), inputs, results); |
386 | }); |
387 | } |
388 | }; |
389 | |
390 | //===----------------------------------------------------------------------===// |
391 | // Subgroup query ops. |
392 | //===----------------------------------------------------------------------===// |
393 | |
394 | template <typename SubgroupOp> |
395 | struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> { |
396 | using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern; |
397 | using ConvertToLLVMPattern::getTypeConverter; |
398 | |
399 | LogicalResult |
400 | matchAndRewrite(SubgroupOp op, typename SubgroupOp::Adaptor adaptor, |
401 | ConversionPatternRewriter &rewriter) const final { |
402 | constexpr StringRef funcName = [] { |
403 | if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) { |
404 | return "_Z16get_sub_group_id" ; |
405 | } else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) { |
406 | return "_Z22get_sub_group_local_id" ; |
407 | } else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) { |
408 | return "_Z18get_num_sub_groups" ; |
409 | } else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) { |
410 | return "_Z18get_sub_group_size" ; |
411 | } |
412 | }(); |
413 | |
414 | Operation *moduleOp = |
415 | op->template getParentWithTrait<OpTrait::SymbolTable>(); |
416 | Type resultTy = rewriter.getI32Type(); |
417 | LLVM::LLVMFuncOp func = |
418 | lookupOrCreateSPIRVFn(moduleOp, funcName, {}, resultTy, |
419 | /*isMemNone=*/false, /*isConvergent=*/false); |
420 | |
421 | Location loc = op->getLoc(); |
422 | Value result = createSPIRVBuiltinCall(loc, rewriter, func, {}).getResult(); |
423 | |
424 | Type indexTy = getTypeConverter()->getIndexType(); |
425 | if (resultTy != indexTy) { |
426 | if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) { |
427 | return failure(); |
428 | } |
429 | result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result); |
430 | } |
431 | |
432 | rewriter.replaceOp(op, result); |
433 | return success(); |
434 | } |
435 | }; |
436 | |
437 | //===----------------------------------------------------------------------===// |
438 | // GPU To LLVM-SPV Pass. |
439 | //===----------------------------------------------------------------------===// |
440 | |
441 | struct GPUToLLVMSPVConversionPass final |
442 | : impl::ConvertGpuOpsToLLVMSPVOpsBase<GPUToLLVMSPVConversionPass> { |
443 | using Base::Base; |
444 | |
445 | void runOnOperation() final { |
446 | MLIRContext *context = &getContext(); |
447 | RewritePatternSet patterns(context); |
448 | |
449 | LowerToLLVMOptions options(context); |
450 | options.overrideIndexBitwidth(bitwidth: this->use64bitIndex ? 64 : 32); |
451 | LLVMTypeConverter converter(context, options); |
452 | LLVMConversionTarget target(*context); |
453 | |
454 | // Force OpenCL address spaces when they are not present |
455 | { |
456 | MemorySpaceToOpenCLMemorySpaceConverter converter(context); |
457 | AttrTypeReplacer replacer; |
458 | replacer.addReplacement(callback: [&converter](BaseMemRefType origType) |
459 | -> std::optional<BaseMemRefType> { |
460 | return converter.convertType<BaseMemRefType>(t: origType); |
461 | }); |
462 | |
463 | replacer.recursivelyReplaceElementsIn(op: getOperation(), |
464 | /*replaceAttrs=*/true, |
465 | /*replaceLocs=*/false, |
466 | /*replaceTypes=*/true); |
467 | } |
468 | |
469 | target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp, |
470 | gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp, |
471 | gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp, |
472 | gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp, |
473 | gpu::ThreadIdOp>(); |
474 | |
475 | populateGpuToLLVMSPVConversionPatterns(converter, patterns); |
476 | populateGpuMemorySpaceAttributeConversions(typeConverter&: converter); |
477 | |
478 | if (failed(applyPartialConversion(getOperation(), target, |
479 | std::move(patterns)))) |
480 | signalPassFailure(); |
481 | } |
482 | }; |
483 | } // namespace |
484 | |
485 | //===----------------------------------------------------------------------===// |
486 | // GPU To LLVM-SPV Patterns. |
487 | //===----------------------------------------------------------------------===// |
488 | |
489 | namespace mlir { |
490 | namespace { |
491 | static unsigned |
492 | gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) { |
493 | constexpr spirv::ClientAPI clientAPI = spirv::ClientAPI::OpenCL; |
494 | return storageClassToAddressSpace(clientAPI, |
495 | addressSpaceToStorageClass(addressSpace)); |
496 | } |
497 | } // namespace |
498 | |
499 | void populateGpuToLLVMSPVConversionPatterns( |
500 | const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { |
501 | patterns.add<GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion, |
502 | GPUSubgroupOpConversion<gpu::LaneIdOp>, |
503 | GPUSubgroupOpConversion<gpu::NumSubgroupsOp>, |
504 | GPUSubgroupOpConversion<gpu::SubgroupIdOp>, |
505 | GPUSubgroupOpConversion<gpu::SubgroupSizeOp>, |
506 | LaunchConfigOpConversion<gpu::BlockDimOp>, |
507 | LaunchConfigOpConversion<gpu::BlockIdOp>, |
508 | LaunchConfigOpConversion<gpu::GlobalIdOp>, |
509 | LaunchConfigOpConversion<gpu::GridDimOp>, |
510 | LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter); |
511 | MLIRContext *context = &typeConverter.getContext(); |
512 | unsigned privateAddressSpace = |
513 | gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private); |
514 | unsigned localAddressSpace = |
515 | gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup); |
516 | OperationName llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context); |
517 | StringAttr kernelBlockSizeAttributeName = |
518 | LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName); |
519 | patterns.add<GPUFuncOpLowering>( |
520 | typeConverter, |
521 | GPUFuncOpLoweringOptions{ |
522 | privateAddressSpace, localAddressSpace, |
523 | /*kernelAttributeName=*/{}, kernelBlockSizeAttributeName, |
524 | LLVM::CConv::SPIR_KERNEL, LLVM::CConv::SPIR_FUNC, |
525 | /*encodeWorkgroupAttributionsAsArguments=*/true}); |
526 | } |
527 | |
528 | void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter) { |
529 | populateGpuMemorySpaceAttributeConversions(typeConverter, |
530 | mapping: gpuAddressSpaceToOCLAddressSpace); |
531 | } |
532 | } // namespace mlir |
533 | |