1//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate NVVMIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
15#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
16#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
18#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
19#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
20#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
21#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
22#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/GPU/IR/GPUDialect.h"
25#include "mlir/Dialect/GPU/Transforms/Passes.h"
26#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
27#include "mlir/Dialect/Math/IR/Math.h"
28#include "mlir/Dialect/MemRef/IR/MemRef.h"
29#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
30#include "mlir/Transforms/DialectConversion.h"
31#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32
33#include "../GPUCommon/GPUOpsLowering.h"
34#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
35#include "../GPUCommon/OpToFuncCallLowering.h"
36#include <optional>
37
38namespace mlir {
39#define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
40#include "mlir/Conversion/Passes.h.inc"
41} // namespace mlir
42
43using namespace mlir;
44
45namespace {
46
47/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
48static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
49 switch (mode) {
50 case gpu::ShuffleMode::XOR:
51 return NVVM::ShflKind::bfly;
52 case gpu::ShuffleMode::UP:
53 return NVVM::ShflKind::up;
54 case gpu::ShuffleMode::DOWN:
55 return NVVM::ShflKind::down;
56 case gpu::ShuffleMode::IDX:
57 return NVVM::ShflKind::idx;
58 }
59 llvm_unreachable("unknown shuffle mode");
60}
61
62static std::optional<NVVM::ReduxKind>
63convertReduxKind(gpu::AllReduceOperation mode) {
64 switch (mode) {
65 case gpu::AllReduceOperation::ADD:
66 return NVVM::ReduxKind::ADD;
67 case gpu::AllReduceOperation::MUL:
68 return std::nullopt;
69 case gpu::AllReduceOperation::MINSI:
70 return NVVM::ReduxKind::MIN;
71 case gpu::AllReduceOperation::MINUI:
72 return std::nullopt;
73 case gpu::AllReduceOperation::MINNUMF:
74 return NVVM::ReduxKind::MIN;
75 case gpu::AllReduceOperation::MAXSI:
76 return NVVM::ReduxKind::MAX;
77 case gpu::AllReduceOperation::MAXUI:
78 return std::nullopt;
79 case gpu::AllReduceOperation::MAXNUMF:
80 return NVVM::ReduxKind::MAX;
81 case gpu::AllReduceOperation::AND:
82 return NVVM::ReduxKind::AND;
83 case gpu::AllReduceOperation::OR:
84 return NVVM::ReduxKind::OR;
85 case gpu::AllReduceOperation::XOR:
86 return NVVM::ReduxKind::XOR;
87 case gpu::AllReduceOperation::MINIMUMF:
88 case gpu::AllReduceOperation::MAXIMUMF:
89 return std::nullopt;
90 }
91 return std::nullopt;
92}
93
94/// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
95/// must be run by the entire subgroup, otherwise it is undefined behaviour.
96struct GPUSubgroupReduceOpLowering
97 : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
98 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
99 LogicalResult
100
101 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
102 ConversionPatternRewriter &rewriter) const override {
103 if (op.getClusterSize())
104 return rewriter.notifyMatchFailure(
105 arg&: op, msg: "lowering for clustered reduce not implemented");
106
107 if (!op.getUniform())
108 return rewriter.notifyMatchFailure(
109 arg&: op, msg: "cannot be lowered to redux as the op must be run "
110 "uniformly (entire subgroup).");
111 if (!op.getValue().getType().isInteger(width: 32))
112 return rewriter.notifyMatchFailure(arg&: op, msg: "unsupported data type");
113
114 std::optional<NVVM::ReduxKind> mode = convertReduxKind(mode: op.getOp());
115 if (!mode.has_value())
116 return rewriter.notifyMatchFailure(
117 arg&: op, msg: "unsupported reduction mode for redux");
118
119 Location loc = op->getLoc();
120 auto int32Type = IntegerType::get(context: rewriter.getContext(), width: 32);
121 Value offset = rewriter.create<LLVM::ConstantOp>(location: loc, args&: int32Type, args: -1);
122
123 auto reduxOp = rewriter.create<NVVM::ReduxOp>(location: loc, args&: int32Type, args: op.getValue(),
124 args&: mode.value(), args&: offset);
125
126 rewriter.replaceOp(op, newValues: reduxOp->getResult(idx: 0));
127 return success();
128 }
129};
130
131struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
132 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
133
134 /// Lowers a shuffle to the corresponding NVVM op.
135 ///
136 /// Convert the `width` argument into an activeMask (a bitmask which specifies
137 /// which threads participate in the shuffle) and a maskAndClamp (specifying
138 /// the highest lane which participates in the shuffle).
139 ///
140 /// %one = llvm.constant(1 : i32) : i32
141 /// %minus_one = llvm.constant(-1 : i32) : i32
142 /// %thirty_two = llvm.constant(32 : i32) : i32
143 /// %num_lanes = llvm.sub %thirty_two, %width : i32
144 /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
145 /// %mask_and_clamp = llvm.sub %width, %one : i32
146 /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
147 /// %mask_and_clamp : !llvm<"{ float, i1 }">
148 /// %shfl_value = llvm.extractvalue %shfl[0] :
149 /// !llvm<"{ float, i1 }">
150 /// %shfl_pred = llvm.extractvalue %shfl[1] :
151 /// !llvm<"{ float, i1 }">
152 LogicalResult
153 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
154 ConversionPatternRewriter &rewriter) const override {
155 Location loc = op->getLoc();
156
157 auto valueTy = adaptor.getValue().getType();
158 auto int32Type = IntegerType::get(context: rewriter.getContext(), width: 32);
159 auto predTy = IntegerType::get(context: rewriter.getContext(), width: 1);
160
161 Value one = rewriter.create<LLVM::ConstantOp>(location: loc, args&: int32Type, args: 1);
162 Value minusOne = rewriter.create<LLVM::ConstantOp>(location: loc, args&: int32Type, args: -1);
163 Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(location: loc, args&: int32Type, args: 32);
164 Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>(
165 location: loc, args&: int32Type, args&: thirtyTwo, args: adaptor.getWidth());
166 // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
167 Value activeMask = rewriter.create<LLVM::LShrOp>(location: loc, args&: int32Type, args&: minusOne,
168 args&: numLeadInactiveLane);
169 Value maskAndClamp;
170 if (op.getMode() == gpu::ShuffleMode::UP) {
171 // Clamp lane: `32 - activeWidth`
172 maskAndClamp = numLeadInactiveLane;
173 } else {
174 // Clamp lane: `activeWidth - 1`
175 maskAndClamp =
176 rewriter.create<LLVM::SubOp>(location: loc, args&: int32Type, args: adaptor.getWidth(), args&: one);
177 }
178
179 bool predIsUsed = !op->getResult(idx: 1).use_empty();
180 UnitAttr returnValueAndIsValidAttr = nullptr;
181 Type resultTy = valueTy;
182 if (predIsUsed) {
183 returnValueAndIsValidAttr = rewriter.getUnitAttr();
184 resultTy = LLVM::LLVMStructType::getLiteral(context: rewriter.getContext(),
185 types: {valueTy, predTy});
186 }
187 Value shfl = rewriter.create<NVVM::ShflOp>(
188 location: loc, args&: resultTy, args&: activeMask, args: adaptor.getValue(), args: adaptor.getOffset(),
189 args&: maskAndClamp, args: convertShflKind(mode: op.getMode()), args&: returnValueAndIsValidAttr);
190 if (predIsUsed) {
191 Value shflValue = rewriter.create<LLVM::ExtractValueOp>(location: loc, args&: shfl, args: 0);
192 Value isActiveSrcLane =
193 rewriter.create<LLVM::ExtractValueOp>(location: loc, args&: shfl, args: 1);
194 rewriter.replaceOp(op, newValues: {shflValue, isActiveSrcLane});
195 } else {
196 rewriter.replaceOp(op, newValues: {shfl, nullptr});
197 }
198 return success();
199 }
200};
201
202struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
203 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
204
205 LogicalResult
206 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
207 ConversionPatternRewriter &rewriter) const override {
208 auto loc = op->getLoc();
209 MLIRContext *context = rewriter.getContext();
210 LLVM::ConstantRangeAttr bounds = nullptr;
211 if (std::optional<APInt> upperBound = op.getUpperBound())
212 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
213 /*bitWidth=*/args: 32, /*lower=*/args: 0, args: upperBound->getZExtValue());
214 else
215 bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
216 /*bitWidth=*/args: 32, /*lower=*/args: 0, /*upper=*/args: kWarpSize);
217 Value newOp =
218 rewriter.create<NVVM::LaneIdOp>(location: loc, args: rewriter.getI32Type(), args&: bounds);
219 // Truncate or extend the result depending on the index bitwidth specified
220 // by the LLVMTypeConverter options.
221 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
222 if (indexBitwidth > 32) {
223 newOp = rewriter.create<LLVM::SExtOp>(
224 location: loc, args: IntegerType::get(context, width: indexBitwidth), args&: newOp);
225 } else if (indexBitwidth < 32) {
226 newOp = rewriter.create<LLVM::TruncOp>(
227 location: loc, args: IntegerType::get(context, width: indexBitwidth), args&: newOp);
228 }
229 rewriter.replaceOp(op, newValues: {newOp});
230 return success();
231 }
232};
233
234/// Lowering of cf.assert into a conditional __assertfail.
235struct AssertOpToAssertfailLowering
236 : public ConvertOpToLLVMPattern<cf::AssertOp> {
237 using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
238
239 LogicalResult
240 matchAndRewrite(cf::AssertOp assertOp, cf::AssertOpAdaptor adaptor,
241 ConversionPatternRewriter &rewriter) const override {
242 MLIRContext *ctx = rewriter.getContext();
243 Location loc = assertOp.getLoc();
244 Type i8Type = typeConverter->convertType(t: rewriter.getIntegerType(width: 8));
245 Type i32Type = typeConverter->convertType(t: rewriter.getIntegerType(width: 32));
246 Type i64Type = typeConverter->convertType(t: rewriter.getIntegerType(width: 64));
247 Type ptrType = LLVM::LLVMPointerType::get(context: ctx);
248 Type voidType = LLVM::LLVMVoidType::get(ctx);
249
250 // Find or create __assertfail function declaration.
251 auto moduleOp = assertOp->getParentOfType<gpu::GPUModuleOp>();
252 auto assertfailType = LLVM::LLVMFunctionType::get(
253 result: voidType, arguments: {ptrType, ptrType, i32Type, ptrType, i64Type});
254 LLVM::LLVMFuncOp assertfailDecl = getOrDefineFunction(
255 moduleOp, loc, b&: rewriter, name: "__assertfail", type: assertfailType);
256 assertfailDecl.setPassthroughAttr(
257 ArrayAttr::get(context: ctx, value: StringAttr::get(context: ctx, bytes: "noreturn")));
258
259 // Split blocks and insert conditional branch.
260 // ^before:
261 // ...
262 // cf.cond_br %condition, ^after, ^assert
263 // ^assert:
264 // cf.assert
265 // cf.br ^after
266 // ^after:
267 // ...
268 Block *beforeBlock = assertOp->getBlock();
269 Block *assertBlock =
270 rewriter.splitBlock(block: beforeBlock, before: assertOp->getIterator());
271 Block *afterBlock =
272 rewriter.splitBlock(block: assertBlock, before: ++assertOp->getIterator());
273 rewriter.setInsertionPointToEnd(beforeBlock);
274 rewriter.create<cf::CondBranchOp>(location: loc, args: adaptor.getArg(), args&: afterBlock,
275 args&: assertBlock);
276 rewriter.setInsertionPointToEnd(assertBlock);
277 rewriter.create<cf::BranchOp>(location: loc, args&: afterBlock);
278
279 // Continue cf.assert lowering.
280 rewriter.setInsertionPoint(assertOp);
281
282 // Populate file name, file number and function name from the location of
283 // the AssertOp.
284 StringRef fileName = "(unknown)";
285 StringRef funcName = "(unknown)";
286 int32_t fileLine = 0;
287 while (auto callSiteLoc = dyn_cast<CallSiteLoc>(Val&: loc))
288 loc = callSiteLoc.getCallee();
289 if (auto fileLineColLoc = dyn_cast<FileLineColRange>(Val&: loc)) {
290 fileName = fileLineColLoc.getFilename().strref();
291 fileLine = fileLineColLoc.getStartLine();
292 } else if (auto nameLoc = dyn_cast<NameLoc>(Val&: loc)) {
293 funcName = nameLoc.getName().strref();
294 if (auto fileLineColLoc =
295 dyn_cast<FileLineColRange>(Val: nameLoc.getChildLoc())) {
296 fileName = fileLineColLoc.getFilename().strref();
297 fileLine = fileLineColLoc.getStartLine();
298 }
299 }
300
301 // Create constants.
302 auto getGlobal = [&](LLVM::GlobalOp global) {
303 // Get a pointer to the format string's first element.
304 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
305 location: loc, args: LLVM::LLVMPointerType::get(context: ctx, addressSpace: global.getAddrSpace()),
306 args: global.getSymNameAttr());
307 Value start =
308 rewriter.create<LLVM::GEPOp>(location: loc, args&: ptrType, args: global.getGlobalType(),
309 args&: globalPtr, args: ArrayRef<LLVM::GEPArg>{0, 0});
310 return start;
311 };
312 Value assertMessage = getGlobal(getOrCreateStringConstant(
313 b&: rewriter, loc, moduleOp, llvmI8: i8Type, namePrefix: "assert_message_", str: assertOp.getMsg()));
314 Value assertFile = getGlobal(getOrCreateStringConstant(
315 b&: rewriter, loc, moduleOp, llvmI8: i8Type, namePrefix: "assert_file_", str: fileName));
316 Value assertFunc = getGlobal(getOrCreateStringConstant(
317 b&: rewriter, loc, moduleOp, llvmI8: i8Type, namePrefix: "assert_func_", str: funcName));
318 Value assertLine =
319 rewriter.create<LLVM::ConstantOp>(location: loc, args&: i32Type, args&: fileLine);
320 Value c1 = rewriter.create<LLVM::ConstantOp>(location: loc, args&: i64Type, args: 1);
321
322 // Insert function call to __assertfail.
323 SmallVector<Value> arguments{assertMessage, assertFile, assertLine,
324 assertFunc, c1};
325 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op: assertOp, args&: assertfailDecl,
326 args&: arguments);
327 return success();
328 }
329};
330
331/// Import the GPU Ops to NVVM Patterns.
332#include "GPUToNVVM.cpp.inc"
333
334/// A pass that replaces all occurrences of GPU device operations with their
335/// corresponding NVVM equivalent.
336///
337/// This pass only handles device code and is not meant to be run on GPU host
338/// code.
339struct LowerGpuOpsToNVVMOpsPass final
340 : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
341 using Base::Base;
342
343 void getDependentDialects(DialectRegistry &registry) const override {
344 Base::getDependentDialects(registry);
345 registerConvertToLLVMDependentDialectLoading(registry);
346 }
347
348 void runOnOperation() override {
349 gpu::GPUModuleOp m = getOperation();
350
351 // Request C wrapper emission.
352 for (auto func : m.getOps<func::FuncOp>()) {
353 func->setAttr(name: LLVM::LLVMDialect::getEmitCWrapperAttrName(),
354 value: UnitAttr::get(context: &getContext()));
355 }
356
357 // Customize the bitwidth used for the device side index computations.
358 LowerToLLVMOptions options(
359 m.getContext(),
360 DataLayout(cast<DataLayoutOpInterface>(Val: m.getOperation())));
361 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
362 options.overrideIndexBitwidth(bitwidth: indexBitwidth);
363 options.useBarePtrCallConv = useBarePtrCallConv;
364
365 // Apply in-dialect lowering. In-dialect lowering will replace
366 // ops which need to be lowered further, which is not supported by a
367 // single conversion pass.
368 {
369 RewritePatternSet patterns(m.getContext());
370 populateGpuRewritePatterns(patterns);
371 if (failed(Result: applyPatternsGreedily(op: m, patterns: std::move(patterns))))
372 return signalPassFailure();
373 }
374
375 LLVMTypeConverter converter(m.getContext(), options);
376 configureGpuToNVVMTypeConverter(converter);
377 RewritePatternSet llvmPatterns(m.getContext());
378 LLVMConversionTarget target(getContext());
379
380 // Set higher benefit, so patterns will run before generic LLVM lowering.
381 populateGpuToNVVMConversionPatterns(converter, patterns&: llvmPatterns,
382 /*benefit=*/10);
383
384 llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
385 allowedDialects.end());
386 for (Dialect *dialect : getContext().getLoadedDialects()) {
387 // Skip math patterns as nvvm needs custom math lowering.
388 if (isa<math::MathDialect>(Val: dialect))
389 continue;
390
391 bool allowed = allowedDialectsSet.contains(V: dialect->getNamespace());
392 // Empty `allowedDialectsSet` means all dialects are allowed.
393 if (!allowedDialectsSet.empty() && !allowed)
394 continue;
395
396 auto iface = dyn_cast<ConvertToLLVMPatternInterface>(Val: dialect);
397 if (!iface) {
398 // Error out if dialect was explicily specified but doesn't implement
399 // conversion interface.
400 if (allowed) {
401 m.emitError()
402 << "dialect does not implement ConvertToLLVMPatternInterface: "
403 << dialect->getNamespace();
404 return signalPassFailure();
405 }
406 continue;
407 }
408
409 iface->populateConvertToLLVMConversionPatterns(target, typeConverter&: converter,
410 patterns&: llvmPatterns);
411 }
412
413 populateGpuWMMAToNVVMConversionPatterns(converter, patterns&: llvmPatterns);
414 if (this->hasRedux)
415 populateGpuSubgroupReduceOpLoweringPattern(converter, patterns&: llvmPatterns);
416 configureGpuToNVVMConversionLegality(target);
417 if (failed(Result: applyPartialConversion(op: m, target, patterns: std::move(llvmPatterns))))
418 signalPassFailure();
419 }
420};
421
422} // namespace
423
424void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
425 target.addIllegalOp<func::FuncOp>();
426 target.addIllegalOp<cf::AssertOp>();
427 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
428 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
429 target.addIllegalDialect<gpu::GPUDialect>();
430 target.addIllegalOp<LLVM::CopySignOp, LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op,
431 LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
432 LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
433 LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
434 LLVM::SqrtOp>();
435
436 // TODO: Remove once we support replacing non-root ops.
437 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
438}
439
440void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
441 // NVVM uses alloca in the default address space to represent private
442 // memory allocations, so drop private annotations. NVVM uses address
443 // space 3 for shared memory. NVVM uses the default address space to
444 // represent global memory.
445 populateGpuMemorySpaceAttributeConversions(
446 typeConverter&: converter, mapping: [](gpu::AddressSpace space) -> unsigned {
447 switch (space) {
448 case gpu::AddressSpace::Global:
449 return static_cast<unsigned>(
450 NVVM::NVVMMemorySpace::kGlobalMemorySpace);
451 case gpu::AddressSpace::Workgroup:
452 return static_cast<unsigned>(
453 NVVM::NVVMMemorySpace::kSharedMemorySpace);
454 case gpu::AddressSpace::Private:
455 return 0;
456 }
457 llvm_unreachable("unknown address space enum value");
458 return 0;
459 });
460 // Lowering for MMAMatrixType.
461 converter.addConversion(callback: [&](gpu::MMAMatrixType type) -> Type {
462 return convertMMAToLLVMType(type);
463 });
464}
465
466template <typename OpTy>
467static void populateOpPatterns(const LLVMTypeConverter &converter,
468 RewritePatternSet &patterns,
469 PatternBenefit benefit, StringRef f32Func,
470 StringRef f64Func, StringRef f32ApproxFunc = "",
471 StringRef f16Func = "") {
472 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
473 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
474 f32ApproxFunc, f16Func,
475 /*i32Func=*/"", benefit);
476}
477
478template <typename OpTy>
479static void populateIntOpPatterns(const LLVMTypeConverter &converter,
480 RewritePatternSet &patterns,
481 PatternBenefit benefit, StringRef i32Func) {
482 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
483 patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
484 benefit);
485}
486
487template <typename OpTy>
488static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
489 RewritePatternSet &patterns,
490 PatternBenefit benefit,
491 StringRef f32Func, StringRef f64Func) {
492 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
493 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
494 /*i32Func=*/"", benefit);
495}
496
497void mlir::populateGpuSubgroupReduceOpLoweringPattern(
498 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
499 PatternBenefit benefit) {
500 patterns.add<GPUSubgroupReduceOpLowering>(arg: converter, args&: benefit);
501}
502
503void mlir::populateLibDeviceConversionPatterns(
504 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
505 PatternBenefit benefit) {
506 populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, f32Func: "__nv_fmodf",
507 f64Func: "__nv_fmod");
508 populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
509 f32Func: "__nv_fmaxf", f64Func: "__nv_fmax");
510 populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
511 f32Func: "__nv_fminf", f64Func: "__nv_fmin");
512
513 populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, i32Func: "__nv_abs");
514 populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, f32Func: "__nv_fabsf",
515 f64Func: "__nv_fabs");
516 populateOpPatterns<math::AcosOp>(converter, patterns, benefit, f32Func: "__nv_acosf",
517 f64Func: "__nv_acos");
518 populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, f32Func: "__nv_acoshf",
519 f64Func: "__nv_acosh");
520 populateOpPatterns<math::AsinOp>(converter, patterns, benefit, f32Func: "__nv_asinf",
521 f64Func: "__nv_asin");
522 populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, f32Func: "__nv_asinhf",
523 f64Func: "__nv_asinh");
524 populateOpPatterns<math::AtanOp>(converter, patterns, benefit, f32Func: "__nv_atanf",
525 f64Func: "__nv_atan");
526 populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, f32Func: "__nv_atan2f",
527 f64Func: "__nv_atan2");
528 populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, f32Func: "__nv_atanhf",
529 f64Func: "__nv_atanh");
530 populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, f32Func: "__nv_cbrtf",
531 f64Func: "__nv_cbrt");
532 populateOpPatterns<math::CeilOp>(converter, patterns, benefit, f32Func: "__nv_ceilf",
533 f64Func: "__nv_ceil");
534 populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
535 f32Func: "__nv_copysignf", f64Func: "__nv_copysign");
536 populateOpPatterns<math::CosOp>(converter, patterns, benefit, f32Func: "__nv_cosf",
537 f64Func: "__nv_cos", f32ApproxFunc: "__nv_fast_cosf");
538 populateOpPatterns<math::CoshOp>(converter, patterns, benefit, f32Func: "__nv_coshf",
539 f64Func: "__nv_cosh");
540 populateOpPatterns<math::ErfOp>(converter, patterns, benefit, f32Func: "__nv_erff",
541 f64Func: "__nv_erf");
542 populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, f32Func: "__nv_erfcf",
543 f64Func: "__nv_erfc");
544 populateOpPatterns<math::ExpOp>(converter, patterns, benefit, f32Func: "__nv_expf",
545 f64Func: "__nv_exp", f32ApproxFunc: "__nv_fast_expf");
546 populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, f32Func: "__nv_exp2f",
547 f64Func: "__nv_exp2");
548 populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, f32Func: "__nv_expm1f",
549 f64Func: "__nv_expm1");
550 populateOpPatterns<math::FloorOp>(converter, patterns, benefit, f32Func: "__nv_floorf",
551 f64Func: "__nv_floor");
552 populateOpPatterns<math::FmaOp>(converter, patterns, benefit, f32Func: "__nv_fmaf",
553 f64Func: "__nv_fma");
554 // Note: libdevice uses a different name for 32-bit finite checking
555 populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit,
556 f32Func: "__nv_finitef", f64Func: "__nv_isfinited");
557 populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, f32Func: "__nv_isinff",
558 f64Func: "__nv_isinfd");
559 populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, f32Func: "__nv_isnanf",
560 f64Func: "__nv_isnand");
561 populateOpPatterns<math::LogOp>(converter, patterns, benefit, f32Func: "__nv_logf",
562 f64Func: "__nv_log", f32ApproxFunc: "__nv_fast_logf");
563 populateOpPatterns<math::Log10Op>(converter, patterns, benefit, f32Func: "__nv_log10f",
564 f64Func: "__nv_log10", f32ApproxFunc: "__nv_fast_log10f");
565 populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, f32Func: "__nv_log1pf",
566 f64Func: "__nv_log1p");
567 populateOpPatterns<math::Log2Op>(converter, patterns, benefit, f32Func: "__nv_log2f",
568 f64Func: "__nv_log2", f32ApproxFunc: "__nv_fast_log2f");
569 populateOpPatterns<math::PowFOp>(converter, patterns, benefit, f32Func: "__nv_powf",
570 f64Func: "__nv_pow", f32ApproxFunc: "__nv_fast_powf");
571 populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
572 f32Func: "__nv_powif", f64Func: "__nv_powi");
573 populateOpPatterns<math::RoundOp>(converter, patterns, benefit, f32Func: "__nv_roundf",
574 f64Func: "__nv_round");
575 populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
576 f32Func: "__nv_rintf", f64Func: "__nv_rint");
577 populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, f32Func: "__nv_rsqrtf",
578 f64Func: "__nv_rsqrt");
579 populateOpPatterns<math::SinOp>(converter, patterns, benefit, f32Func: "__nv_sinf",
580 f64Func: "__nv_sin", f32ApproxFunc: "__nv_fast_sinf");
581 populateOpPatterns<math::SinhOp>(converter, patterns, benefit, f32Func: "__nv_sinhf",
582 f64Func: "__nv_sinh");
583 populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, f32Func: "__nv_sqrtf",
584 f64Func: "__nv_sqrt");
585 populateOpPatterns<math::TanOp>(converter, patterns, benefit, f32Func: "__nv_tanf",
586 f64Func: "__nv_tan", f32ApproxFunc: "__nv_fast_tanf");
587 populateOpPatterns<math::TanhOp>(converter, patterns, benefit, f32Func: "__nv_tanhf",
588 f64Func: "__nv_tanh");
589}
590
591void mlir::populateGpuToNVVMConversionPatterns(
592 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
593 PatternBenefit benefit) {
594 using gpu::index_lowering::IndexKind;
595 using gpu::index_lowering::IntrType;
596
597 // TODO: Pass benefit to generated patterns.
598 populateWithGenerated(patterns);
599
600 patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
601 arg: converter, args&: benefit);
602 patterns.add<
603 gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
604 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
605 arg: converter, args: IndexKind::Block, args: IntrType::Id, args&: benefit);
606 patterns.add<
607 gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
608 NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
609 arg: converter, args: IndexKind::Block, args: IntrType::Dim, args&: benefit);
610 patterns.add<
611 gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
612 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
613 arg: converter, args: IndexKind::Other, args: IntrType::Id, args&: benefit);
614 patterns.add<gpu::index_lowering::OpLowering<
615 gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
616 NVVM::ClusterDimZOp>>(arg: converter, args: IndexKind::Other, args: IntrType::Dim,
617 args&: benefit);
618 patterns.add<gpu::index_lowering::OpLowering<
619 gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
620 NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
621 arg: converter, args: IndexKind::Other, args: IntrType::Id, args&: benefit);
622 patterns.add<gpu::index_lowering::OpLowering<
623 gpu::ClusterDimBlocksOp, NVVM::ClusterDimBlocksXOp,
624 NVVM::ClusterDimBlocksYOp, NVVM::ClusterDimBlocksZOp>>(
625 arg: converter, args: IndexKind::Other, args: IntrType::Dim, args&: benefit);
626 patterns.add<gpu::index_lowering::OpLowering<
627 gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
628 arg: converter, args: IndexKind::Grid, args: IntrType::Id, args&: benefit);
629 patterns.add<gpu::index_lowering::OpLowering<
630 gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
631 arg: converter, args: IndexKind::Grid, args: IntrType::Dim, args&: benefit);
632 patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
633 arg: converter, args&: benefit);
634
635 patterns.add<GPUDynamicSharedMemoryOpLowering>(
636 arg: converter, args: NVVM::kSharedMemoryAlignmentBit, args&: benefit);
637
638 // Explicitly drop memory space when lowering private memory
639 // attributions since NVVM models it as `alloca`s in the default
640 // memory space and does not support `alloca`s with addrspace(5).
641 patterns.add<GPUFuncOpLowering>(
642 arg: converter,
643 args: GPUFuncOpLoweringOptions{
644 /*allocaAddrSpace=*/0,
645 /*workgroupAddrSpace=*/
646 static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace),
647 .kernelAttributeName: StringAttr::get(context: &converter.getContext(),
648 bytes: NVVM::NVVMDialect::getKernelFuncAttrName()),
649 .kernelBlockSizeAttributeName: StringAttr::get(context: &converter.getContext(),
650 bytes: NVVM::NVVMDialect::getMaxntidAttrName())},
651 args&: benefit);
652
653 populateLibDeviceConversionPatterns(converter, patterns, benefit);
654}
655
656//===----------------------------------------------------------------------===//
657// NVVMTargetAttr convert to LLVM attr interface
658//===----------------------------------------------------------------------===//
659
660namespace {
661struct NVVMTargetConvertToLLVMAttrInterface
662 : public ConvertToLLVMAttrInterface::ExternalModel<
663 NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> {
664 /// Configure GPU to NVVM.
665 void populateConvertToLLVMConversionPatterns(
666 Attribute attr, ConversionTarget &target,
667 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const;
668};
669} // namespace
670
671void NVVMTargetConvertToLLVMAttrInterface::
672 populateConvertToLLVMConversionPatterns(Attribute attr,
673 ConversionTarget &target,
674 LLVMTypeConverter &typeConverter,
675 RewritePatternSet &patterns) const {
676 configureGpuToNVVMConversionLegality(target);
677 configureGpuToNVVMTypeConverter(converter&: typeConverter);
678 populateGpuToNVVMConversionPatterns(converter: typeConverter, patterns);
679}
680
681void mlir::NVVM::registerConvertGpuToNVVMInterface(DialectRegistry &registry) {
682 registry.addExtension(extensionFn: +[](MLIRContext *ctx, NVVMDialect *dialect) {
683 NVVMTargetAttr::attachInterface<NVVMTargetConvertToLLVMAttrInterface>(context&: *ctx);
684 });
685}
686

source code of mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp