1//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate NVVMIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
15#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
16#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
18#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
19#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
20#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
21#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
22#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
23#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
24#include "mlir/Dialect/Func/IR/FuncOps.h"
25#include "mlir/Dialect/GPU/IR/GPUDialect.h"
26#include "mlir/Dialect/GPU/Transforms/Passes.h"
27#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
28#include "mlir/Dialect/Math/IR/Math.h"
29#include "mlir/Dialect/MemRef/IR/MemRef.h"
30#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
33
34#include "../GPUCommon/GPUOpsLowering.h"
35#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
36#include "../GPUCommon/OpToFuncCallLowering.h"
37#include <optional>
38
39namespace mlir {
40#define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
41#include "mlir/Conversion/Passes.h.inc"
42} // namespace mlir
43
44using namespace mlir;
45
46namespace {
47
48/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
49static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
50 switch (mode) {
51 case gpu::ShuffleMode::XOR:
52 return NVVM::ShflKind::bfly;
53 case gpu::ShuffleMode::UP:
54 return NVVM::ShflKind::up;
55 case gpu::ShuffleMode::DOWN:
56 return NVVM::ShflKind::down;
57 case gpu::ShuffleMode::IDX:
58 return NVVM::ShflKind::idx;
59 }
60 llvm_unreachable("unknown shuffle mode");
61}
62
63static std::optional<NVVM::ReduxKind>
64convertReduxKind(gpu::AllReduceOperation mode) {
65 switch (mode) {
66 case gpu::AllReduceOperation::ADD:
67 return NVVM::ReduxKind::ADD;
68 case gpu::AllReduceOperation::MUL:
69 return std::nullopt;
70 case gpu::AllReduceOperation::MINSI:
71 return NVVM::ReduxKind::MIN;
72 case gpu::AllReduceOperation::MINUI:
73 return std::nullopt;
74 case gpu::AllReduceOperation::MINNUMF:
75 return NVVM::ReduxKind::MIN;
76 case gpu::AllReduceOperation::MAXSI:
77 return NVVM::ReduxKind::MAX;
78 case gpu::AllReduceOperation::MAXUI:
79 return std::nullopt;
80 case gpu::AllReduceOperation::MAXNUMF:
81 return NVVM::ReduxKind::MAX;
82 case gpu::AllReduceOperation::AND:
83 return NVVM::ReduxKind::AND;
84 case gpu::AllReduceOperation::OR:
85 return NVVM::ReduxKind::OR;
86 case gpu::AllReduceOperation::XOR:
87 return NVVM::ReduxKind::XOR;
88 case gpu::AllReduceOperation::MINIMUMF:
89 case gpu::AllReduceOperation::MAXIMUMF:
90 return std::nullopt;
91 }
92 return std::nullopt;
93}
94
95/// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
96/// must be run by the entire subgroup, otherwise it is undefined behaviour.
97struct GPUSubgroupReduceOpLowering
98 : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
99 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
100 LogicalResult
101
102 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
103 ConversionPatternRewriter &rewriter) const override {
104 if (op.getClusterSize())
105 return rewriter.notifyMatchFailure(
106 op, "lowering for clustered reduce not implemented");
107
108 if (!op.getUniform())
109 return rewriter.notifyMatchFailure(
110 op, "cannot be lowered to redux as the op must be run "
111 "uniformly (entire subgroup).");
112 if (!op.getValue().getType().isInteger(32))
113 return rewriter.notifyMatchFailure(op, "unsupported data type");
114
115 std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
116 if (!mode.has_value())
117 return rewriter.notifyMatchFailure(
118 op, "unsupported reduction mode for redux");
119
120 Location loc = op->getLoc();
121 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
122 Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
123
124 auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
125 mode.value(), offset);
126
127 rewriter.replaceOp(op, reduxOp->getResult(0));
128 return success();
129 }
130};
131
132struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
133 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
134
135 /// Lowers a shuffle to the corresponding NVVM op.
136 ///
137 /// Convert the `width` argument into an activeMask (a bitmask which specifies
138 /// which threads participate in the shuffle) and a maskAndClamp (specifying
139 /// the highest lane which participates in the shuffle).
140 ///
141 /// %one = llvm.constant(1 : i32) : i32
142 /// %minus_one = llvm.constant(-1 : i32) : i32
143 /// %thirty_two = llvm.constant(32 : i32) : i32
144 /// %num_lanes = llvm.sub %thirty_two, %width : i32
145 /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
146 /// %mask_and_clamp = llvm.sub %width, %one : i32
147 /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
148 /// %mask_and_clamp : !llvm<"{ float, i1 }">
149 /// %shfl_value = llvm.extractvalue %shfl[0] :
150 /// !llvm<"{ float, i1 }">
151 /// %shfl_pred = llvm.extractvalue %shfl[1] :
152 /// !llvm<"{ float, i1 }">
153 LogicalResult
154 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
155 ConversionPatternRewriter &rewriter) const override {
156 Location loc = op->getLoc();
157
158 auto valueTy = adaptor.getValue().getType();
159 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
160 auto predTy = IntegerType::get(rewriter.getContext(), 1);
161
162 Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
163 Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
164 Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32);
165 Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>(
166 loc, int32Type, thirtyTwo, adaptor.getWidth());
167 // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
168 Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne,
169 numLeadInactiveLane);
170 Value maskAndClamp;
171 if (op.getMode() == gpu::ShuffleMode::UP) {
172 // Clamp lane: `32 - activeWidth`
173 maskAndClamp = numLeadInactiveLane;
174 } else {
175 // Clamp lane: `activeWidth - 1`
176 maskAndClamp =
177 rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
178 }
179
180 bool predIsUsed = !op->getResult(1).use_empty();
181 UnitAttr returnValueAndIsValidAttr = nullptr;
182 Type resultTy = valueTy;
183 if (predIsUsed) {
184 returnValueAndIsValidAttr = rewriter.getUnitAttr();
185 resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
186 {valueTy, predTy});
187 }
188 Value shfl = rewriter.create<NVVM::ShflOp>(
189 loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
190 maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
191 if (predIsUsed) {
192 Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
193 Value isActiveSrcLane =
194 rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
195 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
196 } else {
197 rewriter.replaceOp(op, {shfl, nullptr});
198 }
199 return success();
200 }
201};
202
203struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
204 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
205
206 LogicalResult
207 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
208 ConversionPatternRewriter &rewriter) const override {
209 auto loc = op->getLoc();
210 MLIRContext *context = rewriter.getContext();
211 LLVM::ConstantRangeAttr bounds = nullptr;
212 if (std::optional<APInt> upperBound = op.getUpperBound())
213 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
214 /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
215 else
216 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
217 /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
218 Value newOp =
219 rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
220 // Truncate or extend the result depending on the index bitwidth specified
221 // by the LLVMTypeConverter options.
222 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
223 if (indexBitwidth > 32) {
224 newOp = rewriter.create<LLVM::SExtOp>(
225 loc, IntegerType::get(context, indexBitwidth), newOp);
226 } else if (indexBitwidth < 32) {
227 newOp = rewriter.create<LLVM::TruncOp>(
228 loc, IntegerType::get(context, indexBitwidth), newOp);
229 }
230 rewriter.replaceOp(op, {newOp});
231 return success();
232 }
233};
234
235/// Lowering of cf.assert into a conditional __assertfail.
236struct AssertOpToAssertfailLowering
237 : public ConvertOpToLLVMPattern<cf::AssertOp> {
238 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
239
240 LogicalResult
241 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
242 ConversionPatternRewriter &rewriter) const override {
243 MLIRContext *ctx = rewriter.getContext();
244 Location loc = assertOp.getLoc();
245 Type i8Type = typeConverter->convertType(rewriter.getIntegerType(8));
246 Type i32Type = typeConverter->convertType(rewriter.getIntegerType(32));
247 Type i64Type = typeConverter->convertType(rewriter.getIntegerType(64));
248 Type ptrType = LLVM::LLVMPointerType::get(ctx);
249 Type voidType = LLVM::LLVMVoidType::get(ctx);
250
251 // Find or create __assertfail function declaration.
252 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
253 auto assertfailType = LLVM::LLVMFunctionType::get(
254 voidType, {ptrType, ptrType, i32Type, ptrType, i64Type});
255 LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
256 moduleOp, loc, rewriter, "__assertfail", assertfailType);
257 assertfailDecl.setPassthroughAttr(
258 ArrayAttr::get(ctx, StringAttr::get(ctx, "noreturn")));
259
260 // Split blocks and insert conditional branch.
261 // ^before:
262 // ...
263 // cf.cond_br %condition, ^after, ^assert
264 // ^assert:
265 // cf.assert
266 // cf.br ^after
267 // ^after:
268 // ...
269 Block *beforeBlock = assertOp->getBlock();
270 Block *assertBlock =
271 rewriter.splitBlock(block: beforeBlock, before: assertOp->getIterator());
272 Block *afterBlock =
273 rewriter.splitBlock(block: assertBlock, before: ++assertOp->getIterator());
274 rewriter.setInsertionPointToEnd(beforeBlock);
275 rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), afterBlock,
276 assertBlock);
277 rewriter.setInsertionPointToEnd(assertBlock);
278 rewriter.create<cf::BranchOp>(loc, afterBlock);
279
280 // Continue cf.assert lowering.
281 rewriter.setInsertionPoint(assertOp);
282
283 // Populate file name, file number and function name from the location of
284 // the AssertOp.
285 StringRef fileName = "(unknown)";
286 StringRef funcName = "(unknown)";
287 int32_t fileLine = 0;
288 while (auto callSiteLoc = dyn_cast<CallSiteLoc>(loc))
289 loc = callSiteLoc.getCallee();
290 if (auto fileLineColLoc = dyn_cast<FileLineColRange>(loc)) {
291 fileName = fileLineColLoc.getFilename().strref();
292 fileLine = fileLineColLoc.getStartLine();
293 } else if (auto nameLoc = dyn_cast<NameLoc>(loc)) {
294 funcName = nameLoc.getName().strref();
295 if (auto fileLineColLoc =
296 dyn_cast<FileLineColRange>(nameLoc.getChildLoc())) {
297 fileName = fileLineColLoc.getFilename().strref();
298 fileLine = fileLineColLoc.getStartLine();
299 }
300 }
301
302 // Create constants.
303 auto getGlobal = [&](LLVM::GlobalOp global) {
304 // Get a pointer to the format string's first element.
305 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
306 loc, LLVM::LLVMPointerType::get(ctx, global.getAddrSpace()),
307 global.getSymNameAttr());
308 Value start =
309 rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
310 globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
311 return start;
312 };
313 Value assertMessage = getGlobal(getOrCreateStringConstant(
314 rewriter, loc, moduleOp, i8Type, "assert_message_", assertOp.getMsg()));
315 Value assertFile = getGlobal(getOrCreateStringConstant(
316 rewriter, loc, moduleOp, i8Type, "assert_file_", fileName));
317 Value assertFunc = getGlobal(getOrCreateStringConstant(
318 rewriter, loc, moduleOp, i8Type, "assert_func_", funcName));
319 Value assertLine =
320 rewriter.create<LLVM::ConstantOp>(loc, i32Type, fileLine);
321 Value c1 = rewriter.create<LLVM::ConstantOp>(loc, i64Type, 1);
322
323 // Insert function call to __assertfail.
324 SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
325 assertFunc, c1};
326 rewriter.replaceOpWithNewOp<LLVM::CallOp>(assertOp, assertfailDecl,
327 arguments);
328 return success();
329 }
330};
331
332/// Import the GPU Ops to NVVM Patterns.
333#include "GPUToNVVM.cpp.inc"
334
335/// A pass that replaces all occurrences of GPU device operations with their
336/// corresponding NVVM equivalent.
337///
338/// This pass only handles device code and is not meant to be run on GPU host
339/// code.
340struct LowerGpuOpsToNVVMOpsPass final
341 : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
342 using Base::Base;
343
344 void getDependentDialects(DialectRegistry &registry) const override {
345 Base::getDependentDialects(registry);
346 registerConvertToLLVMDependentDialectLoading(registry);
347 }
348
349 void runOnOperation() override {
350 gpu::GPUModuleOp m = getOperation();
351
352 // Request C wrapper emission.
353 for (auto func : m.getOps<func::FuncOp>()) {
354 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
355 UnitAttr::get(&getContext()));
356 }
357
358 // Customize the bitwidth used for the device side index computations.
359 LowerToLLVMOptions options(
360 m.getContext(),
361 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
362 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
363 options.overrideIndexBitwidth(indexBitwidth);
364 options.useBarePtrCallConv = useBarePtrCallConv;
365
366 // Apply in-dialect lowering. In-dialect lowering will replace
367 // ops which need to be lowered further, which is not supported by a
368 // single conversion pass.
369 {
370 RewritePatternSet patterns(m.getContext());
371 populateGpuRewritePatterns(patterns);
372 if (failed(applyPatternsGreedily(m, std::move(patterns))))
373 return signalPassFailure();
374 }
375
376 LLVMTypeConverter converter(m.getContext(), options);
377 configureGpuToNVVMTypeConverter(converter);
378 RewritePatternSet llvmPatterns(m.getContext());
379 LLVMConversionTarget target(getContext());
380
381 // Set higher benefit, so patterns will run before generic LLVM lowering.
382 populateGpuToNVVMConversionPatterns(converter, patterns&: llvmPatterns,
383 /*benefit=*/10);
384
385 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
386 allowedDialects.end());
387 for (Dialect *dialect : getContext().getLoadedDialects()) {
388 // Skip math patterns as nvvm needs custom math lowering.
389 if (isa<math::MathDialect>(dialect))
390 continue;
391
392 bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
393 // Empty `allowedDialectsSet` means all dialects are allowed.
394 if (!allowedDialectsSet.empty() && !allowed)
395 continue;
396
397 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
398 if (!iface) {
399 // Error out if dialect was explicily specified but doesn't implement
400 // conversion interface.
401 if (allowed) {
402 m.emitError()
403 << "dialect does not implement ConvertToLLVMPatternInterface: "
404 << dialect->getNamespace();
405 return signalPassFailure();
406 }
407 continue;
408 }
409
410 iface->populateConvertToLLVMConversionPatterns(target, converter,
411 llvmPatterns);
412 }
413
414 populateGpuWMMAToNVVMConversionPatterns(converter, patterns&: llvmPatterns);
415 if (this->hasRedux)
416 populateGpuSubgroupReduceOpLoweringPattern(converter, patterns&: llvmPatterns);
417 configureGpuToNVVMConversionLegality(target);
418 if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
419 signalPassFailure();
420 }
421};
422
423} // namespace
424
425void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
426 target.addIllegalOp<func::FuncOp>();
427 target.addIllegalOp<cf::AssertOp>();
428 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
429 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
430 target.addIllegalDialect<gpu::GPUDialect>();
431 target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
432 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FMAOp,
433 LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op,
434 LLVM::PowOp, LLVM::RoundEvenOp, LLVM::RoundOp,
435 LLVM::SinOp, LLVM::SqrtOp>();
436
437 // TODO: Remove once we support replacing non-root ops.
438 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
439}
440
441void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
442 // NVVM uses alloca in the default address space to represent private
443 // memory allocations, so drop private annotations. NVVM uses address
444 // space 3 for shared memory. NVVM uses the default address space to
445 // represent global memory.
446 populateGpuMemorySpaceAttributeConversions(
447 typeConverter&: converter, mapping: [](gpu::AddressSpace space) -> unsigned {
448 switch (space) {
449 case gpu::AddressSpace::Global:
450 return static_cast<unsigned>(
451 NVVM::NVVMMemorySpace::kGlobalMemorySpace);
452 case gpu::AddressSpace::Workgroup:
453 return static_cast<unsigned>(
454 NVVM::NVVMMemorySpace::kSharedMemorySpace);
455 case gpu::AddressSpace::Private:
456 return 0;
457 }
458 llvm_unreachable("unknown address space enum value");
459 return 0;
460 });
461 // Lowering for MMAMatrixType.
462 converter.addConversion(callback: [&](gpu::MMAMatrixType type) -> Type {
463 return convertMMAToLLVMType(type);
464 });
465}
466
467template <typename OpTy>
468static void populateOpPatterns(const LLVMTypeConverter &converter,
469 RewritePatternSet &patterns,
470 PatternBenefit benefit, StringRef f32Func,
471 StringRef f64Func, StringRef f32ApproxFunc = "",
472 StringRef f16Func = "") {
473 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
474 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
475 f32ApproxFunc, f16Func,
476 /*i32Func=*/"", benefit);
477}
478
479template <typename OpTy>
480static void populateIntOpPatterns(const LLVMTypeConverter &converter,
481 RewritePatternSet &patterns,
482 PatternBenefit benefit, StringRef i32Func) {
483 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
484 patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
485 benefit);
486}
487
488template <typename OpTy>
489static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
490 RewritePatternSet &patterns,
491 PatternBenefit benefit,
492 StringRef f32Func, StringRef f64Func) {
493 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
494 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
495 /*i32Func=*/"", benefit);
496}
497
498void mlir::populateGpuSubgroupReduceOpLoweringPattern(
499 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
500 PatternBenefit benefit) {
501 patterns.add<GPUSubgroupReduceOpLowering>(arg: converter, args&: benefit);
502}
503
504void mlir::populateLibDeviceConversionPatterns(
505 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
506 PatternBenefit benefit) {
507 populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
508 "__nv_fmod");
509 populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
510 "__nv_fmaxf", "__nv_fmax");
511 populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
512 "__nv_fminf", "__nv_fmin");
513
514 populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
515 populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
516 "__nv_fabs");
517 populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
518 "__nv_acos");
519 populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
520 "__nv_acosh");
521 populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
522 "__nv_asin");
523 populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
524 "__nv_asinh");
525 populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
526 "__nv_atan");
527 populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
528 "__nv_atan2");
529 populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
530 "__nv_atanh");
531 populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
532 "__nv_cbrt");
533 populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
534 "__nv_ceil");
535 populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
536 "__nv_copysignf", "__nv_copysign");
537 populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
538 "__nv_cos", "__nv_fast_cosf");
539 populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
540 "__nv_cosh");
541 populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
542 "__nv_erf");
543 populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
544 "__nv_erfc");
545 populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
546 "__nv_exp", "__nv_fast_expf");
547 populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
548 "__nv_exp2");
549 populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
550 "__nv_expm1");
551 populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
552 "__nv_floor");
553 populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
554 "__nv_fma");
555 // Note: libdevice uses a different name for 32-bit finite checking
556 populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit,
557 "__nv_finitef", "__nv_isfinited");
558 populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
559 "__nv_isinfd");
560 populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
561 "__nv_isnand");
562 populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
563 "__nv_log", "__nv_fast_logf");
564 populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
565 "__nv_log10", "__nv_fast_log10f");
566 populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
567 "__nv_log1p");
568 populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
569 "__nv_log2", "__nv_fast_log2f");
570 populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
571 "__nv_pow", "__nv_fast_powf");
572 populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
573 "__nv_powif", "__nv_powi");
574 populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
575 "__nv_round");
576 populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
577 "__nv_rintf", "__nv_rint");
578 populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
579 "__nv_rsqrt");
580 populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
581 "__nv_sin", "__nv_fast_sinf");
582 populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
583 "__nv_sinh");
584 populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
585 "__nv_sqrt");
586 populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
587 "__nv_tan", "__nv_fast_tanf");
588 populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
589 "__nv_tanh");
590}
591
592void mlir::populateGpuToNVVMConversionPatterns(
593 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
594 PatternBenefit benefit) {
595 using gpu::index_lowering::IndexKind;
596 using gpu::index_lowering::IntrType;
597
598 // TODO: Pass benefit to generated patterns.
599 populateWithGenerated(patterns);
600
601 patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
602 arg: converter, args&: benefit);
603 patterns.add<
604 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
605 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
606 converter, IndexKind::Block, IntrType::Id, benefit);
607 patterns.add<
608 gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
609 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
610 converter, IndexKind::Block, IntrType::Dim, benefit);
611 patterns.add<
612 gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
613 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
614 converter, IndexKind::Other, IntrType::Id, benefit);
615 patterns.add<gpu::index_lowering::OpLowering<
616 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
617 NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim,
618 benefit);
619 patterns.add<gpu::index_lowering::OpLowering<
620 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
621 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
622 converter, IndexKind::Other, IntrType::Id, benefit);
623 patterns.add<gpu::index_lowering::OpLowering<
624 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
625 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
626 converter, IndexKind::Other, IntrType::Dim, benefit);
627 patterns.add<gpu::index_lowering::OpLowering<
628 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
629 converter, IndexKind::Grid, IntrType::Id, benefit);
630 patterns.add<gpu::index_lowering::OpLowering<
631 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
632 converter, IndexKind::Grid, IntrType::Dim, benefit);
633 patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
634 arg: converter, args&: benefit);
635
636 patterns.add<GPUDynamicSharedMemoryOpLowering>(
637 arg: converter, args: NVVM::kSharedMemoryAlignmentBit, args&: benefit);
638
639 // Explicitly drop memory space when lowering private memory
640 // attributions since NVVM models it as `alloca`s in the default
641 // memory space and does not support `alloca`s with addrspace(5).
642 patterns.add<GPUFuncOpLowering>(
643 converter,
644 GPUFuncOpLoweringOptions{
645 /*allocaAddrSpace=*/0,
646 /*workgroupAddrSpace=*/
647 static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace),
648 StringAttr::get(&converter.getContext(),
649 NVVM::NVVMDialect::getKernelFuncAttrName()),
650 StringAttr::get(&converter.getContext(),
651 NVVM::NVVMDialect::getMaxntidAttrName())},
652 benefit);
653
654 populateLibDeviceConversionPatterns(converter, patterns, benefit);
655}
656
657//===----------------------------------------------------------------------===//
658// NVVMTargetAttr convert to LLVM attr interface
659//===----------------------------------------------------------------------===//
660
661namespace {
662struct NVVMTargetConvertToLLVMAttrInterface
663 : public ConvertToLLVMAttrInterface::ExternalModel<
664 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
665 /// Configure GPU to NVVM.
666 void populateConvertToLLVMConversionPatterns(
667 Attribute attr, ConversionTarget &target,
668 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
669};
670} // namespace
671
672void NVVMTargetConvertToLLVMAttrInterface::
673 populateConvertToLLVMConversionPatterns(Attribute attr,
674 ConversionTarget &target,
675 LLVMTypeConverter &typeConverter,
676 RewritePatternSet &patterns) const {
677 configureGpuToNVVMConversionLegality(target);
678 configureGpuToNVVMTypeConverter(converter&: typeConverter);
679 populateGpuToNVVMConversionPatterns(converter: typeConverter, patterns);
680}
681
682void mlir::NVVM::registerConvertGpuToNVVMInterface(DialectRegistry &registry) {
683 registry.addExtension(extensionFn: +[](MLIRContext *ctx, NVVMDialect *dialect) {
684 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(*ctx);
685 });
686}
687

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp