1//===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to generate NVVMIR operations for higher-level
10// GPU operations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
15
16#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
17#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
18#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
19#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
20#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
21#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
22#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
23#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
24#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
25#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
26#include "mlir/Dialect/Func/IR/FuncOps.h"
27#include "mlir/Dialect/GPU/IR/GPUDialect.h"
28#include "mlir/Dialect/GPU/Transforms/Passes.h"
29#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
30#include "mlir/Dialect/Math/IR/Math.h"
31#include "mlir/Dialect/MemRef/IR/MemRef.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34
35#include "../GPUCommon/GPUOpsLowering.h"
36#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
37#include "../GPUCommon/OpToFuncCallLowering.h"
38#include <optional>
39
40namespace mlir {
41#define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS
42#include "mlir/Conversion/Passes.h.inc"
43} // namespace mlir
44
45using namespace mlir;
46
47namespace {
48
49/// Convert gpu dialect shfl mode enum to the equivalent nvvm one.
50static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) {
51 switch (mode) {
52 case gpu::ShuffleMode::XOR:
53 return NVVM::ShflKind::bfly;
54 case gpu::ShuffleMode::UP:
55 return NVVM::ShflKind::up;
56 case gpu::ShuffleMode::DOWN:
57 return NVVM::ShflKind::down;
58 case gpu::ShuffleMode::IDX:
59 return NVVM::ShflKind::idx;
60 }
61 llvm_unreachable("unknown shuffle mode");
62}
63
64static std::optional<NVVM::ReduxKind>
65convertReduxKind(gpu::AllReduceOperation mode) {
66 switch (mode) {
67 case gpu::AllReduceOperation::ADD:
68 return NVVM::ReduxKind::ADD;
69 case gpu::AllReduceOperation::MUL:
70 return std::nullopt;
71 case gpu::AllReduceOperation::MINSI:
72 return NVVM::ReduxKind::MIN;
73 case gpu::AllReduceOperation::MINUI:
74 return std::nullopt;
75 case gpu::AllReduceOperation::MINNUMF:
76 return NVVM::ReduxKind::MIN;
77 case gpu::AllReduceOperation::MAXSI:
78 return NVVM::ReduxKind::MAX;
79 case gpu::AllReduceOperation::MAXUI:
80 return std::nullopt;
81 case gpu::AllReduceOperation::MAXNUMF:
82 return NVVM::ReduxKind::MAX;
83 case gpu::AllReduceOperation::AND:
84 return NVVM::ReduxKind::AND;
85 case gpu::AllReduceOperation::OR:
86 return NVVM::ReduxKind::OR;
87 case gpu::AllReduceOperation::XOR:
88 return NVVM::ReduxKind::XOR;
89 case gpu::AllReduceOperation::MINIMUMF:
90 case gpu::AllReduceOperation::MAXIMUMF:
91 return std::nullopt;
92 }
93 return std::nullopt;
94}
95
96/// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op
97/// must be run by the entire subgroup, otherwise it is undefined behaviour.
98struct GPUSubgroupReduceOpLowering
99 : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> {
100 using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern;
101 LogicalResult
102
103 matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
104 ConversionPatternRewriter &rewriter) const override {
105 if (!op.getUniform())
106 return rewriter.notifyMatchFailure(
107 op, "cannot be lowered to redux as the op must be run "
108 "uniformly (entire subgroup).");
109 if (!op.getValue().getType().isInteger(32))
110 return rewriter.notifyMatchFailure(op, "unsupported data type");
111
112 std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp());
113 if (!mode.has_value())
114 return rewriter.notifyMatchFailure(
115 op, "unsupported reduction mode for redux");
116
117 Location loc = op->getLoc();
118 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
119 Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
120
121 auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(),
122 mode.value(), offset);
123
124 rewriter.replaceOp(op, reduxOp->getResult(0));
125 return success();
126 }
127};
128
129struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
130 using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
131
132 /// Lowers a shuffle to the corresponding NVVM op.
133 ///
134 /// Convert the `width` argument into an activeMask (a bitmask which specifies
135 /// which threads participate in the shuffle) and a maskAndClamp (specifying
136 /// the highest lane which participates in the shuffle).
137 ///
138 /// %one = llvm.constant(1 : i32) : i32
139 /// %minus_one = llvm.constant(-1 : i32) : i32
140 /// %thirty_two = llvm.constant(32 : i32) : i32
141 /// %num_lanes = llvm.sub %thirty_two, %width : i32
142 /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32
143 /// %mask_and_clamp = llvm.sub %width, %one : i32
144 /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset,
145 /// %mask_and_clamp : !llvm<"{ float, i1 }">
146 /// %shfl_value = llvm.extractvalue %shfl[0] :
147 /// !llvm<"{ float, i1 }">
148 /// %shfl_pred = llvm.extractvalue %shfl[1] :
149 /// !llvm<"{ float, i1 }">
150 LogicalResult
151 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
152 ConversionPatternRewriter &rewriter) const override {
153 Location loc = op->getLoc();
154
155 auto valueTy = adaptor.getValue().getType();
156 auto int32Type = IntegerType::get(rewriter.getContext(), 32);
157 auto predTy = IntegerType::get(rewriter.getContext(), 1);
158
159 Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1);
160 Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1);
161 Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32);
162 Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>(
163 loc, int32Type, thirtyTwo, adaptor.getWidth());
164 // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`.
165 Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne,
166 numLeadInactiveLane);
167 Value maskAndClamp;
168 if (op.getMode() == gpu::ShuffleMode::UP) {
169 // Clamp lane: `32 - activeWidth`
170 maskAndClamp = numLeadInactiveLane;
171 } else {
172 // Clamp lane: `activeWidth - 1`
173 maskAndClamp =
174 rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one);
175 }
176
177 bool predIsUsed = !op->getResult(1).use_empty();
178 UnitAttr returnValueAndIsValidAttr = nullptr;
179 Type resultTy = valueTy;
180 if (predIsUsed) {
181 returnValueAndIsValidAttr = rewriter.getUnitAttr();
182 resultTy = LLVM::LLVMStructType::getLiteral(context: rewriter.getContext(),
183 types: {valueTy, predTy});
184 }
185 Value shfl = rewriter.create<NVVM::ShflOp>(
186 loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(),
187 maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr);
188 if (predIsUsed) {
189 Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0);
190 Value isActiveSrcLane =
191 rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1);
192 rewriter.replaceOp(op, {shflValue, isActiveSrcLane});
193 } else {
194 rewriter.replaceOp(op, {shfl, nullptr});
195 }
196 return success();
197 }
198};
199
200struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
201 using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
202
203 LogicalResult
204 matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
205 ConversionPatternRewriter &rewriter) const override {
206 auto loc = op->getLoc();
207 MLIRContext *context = rewriter.getContext();
208 Value newOp = rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type());
209 // Truncate or extend the result depending on the index bitwidth specified
210 // by the LLVMTypeConverter options.
211 const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
212 if (indexBitwidth > 32) {
213 newOp = rewriter.create<LLVM::SExtOp>(
214 loc, IntegerType::get(context, indexBitwidth), newOp);
215 } else if (indexBitwidth < 32) {
216 newOp = rewriter.create<LLVM::TruncOp>(
217 loc, IntegerType::get(context, indexBitwidth), newOp);
218 }
219 rewriter.replaceOp(op, {newOp});
220 return success();
221 }
222};
223
224/// Import the GPU Ops to NVVM Patterns.
225#include "GPUToNVVM.cpp.inc"
226
227/// A pass that replaces all occurrences of GPU device operations with their
228/// corresponding NVVM equivalent.
229///
230/// This pass only handles device code and is not meant to be run on GPU host
231/// code.
232struct LowerGpuOpsToNVVMOpsPass
233 : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
234 using Base::Base;
235
236 void runOnOperation() override {
237 gpu::GPUModuleOp m = getOperation();
238
239 // Request C wrapper emission.
240 for (auto func : m.getOps<func::FuncOp>()) {
241 func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
242 UnitAttr::get(&getContext()));
243 }
244
245 // Customize the bitwidth used for the device side index computations.
246 LowerToLLVMOptions options(
247 m.getContext(),
248 DataLayout(cast<DataLayoutOpInterface>(m.getOperation())));
249 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
250 options.overrideIndexBitwidth(indexBitwidth);
251 options.useBarePtrCallConv = useBarePtrCallConv;
252
253 // Apply in-dialect lowering. In-dialect lowering will replace
254 // ops which need to be lowered further, which is not supported by a
255 // single conversion pass.
256 {
257 RewritePatternSet patterns(m.getContext());
258 populateGpuRewritePatterns(patterns);
259 if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns))))
260 return signalPassFailure();
261 }
262
263 LLVMTypeConverter converter(m.getContext(), options);
264 // NVVM uses alloca in the default address space to represent private
265 // memory allocations, so drop private annotations. NVVM uses address
266 // space 3 for shared memory. NVVM uses the default address space to
267 // represent global memory.
268 populateGpuMemorySpaceAttributeConversions(
269 converter, [](gpu::AddressSpace space) -> unsigned {
270 switch (space) {
271 case gpu::AddressSpace::Global:
272 return static_cast<unsigned>(
273 NVVM::NVVMMemorySpace::kGlobalMemorySpace);
274 case gpu::AddressSpace::Workgroup:
275 return static_cast<unsigned>(
276 NVVM::NVVMMemorySpace::kSharedMemorySpace);
277 case gpu::AddressSpace::Private:
278 return 0;
279 }
280 llvm_unreachable("unknown address space enum value");
281 return 0;
282 });
283 // Lowering for MMAMatrixType.
284 converter.addConversion([&](gpu::MMAMatrixType type) -> Type {
285 return convertMMAToLLVMType(type);
286 });
287 RewritePatternSet llvmPatterns(m.getContext());
288
289 arith::populateArithToLLVMConversionPatterns(converter, patterns&: llvmPatterns);
290 cf::populateControlFlowToLLVMConversionPatterns(converter, patterns&: llvmPatterns);
291 populateFuncToLLVMConversionPatterns(converter, patterns&: llvmPatterns);
292 populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns&: llvmPatterns);
293 populateGpuToNVVMConversionPatterns(converter, patterns&: llvmPatterns);
294 populateGpuWMMAToNVVMConversionPatterns(converter, patterns&: llvmPatterns);
295 populateVectorToLLVMConversionPatterns(converter, patterns&: llvmPatterns);
296 if (this->hasRedux)
297 populateGpuSubgroupReduceOpLoweringPattern(converter, patterns&: llvmPatterns);
298 LLVMConversionTarget target(getContext());
299 configureGpuToNVVMConversionLegality(target);
300 if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
301 signalPassFailure();
302 }
303};
304
305} // namespace
306
307void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
308 target.addIllegalOp<func::FuncOp>();
309 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
310 target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
311 target.addIllegalDialect<gpu::GPUDialect>();
312 target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
313 LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
314 LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
315 LLVM::SqrtOp>();
316
317 // TODO: Remove once we support replacing non-root ops.
318 target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
319}
320
321template <typename OpTy>
322static void populateOpPatterns(LLVMTypeConverter &converter,
323 RewritePatternSet &patterns, StringRef f32Func,
324 StringRef f64Func) {
325 patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
326 patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
327}
328
329void mlir::populateGpuSubgroupReduceOpLoweringPattern(
330 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
331 patterns.add<GPUSubgroupReduceOpLowering>(arg&: converter);
332}
333
334void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
335 RewritePatternSet &patterns) {
336 populateWithGenerated(patterns);
337 patterns.add<GPUPrintfOpToVPrintfLowering>(arg&: converter);
338 patterns.add<
339 GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
340 NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
341 GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
342 NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
343 GPUIndexIntrinsicOpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
344 NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>,
345 GPUIndexIntrinsicOpLowering<gpu::ClusterDimOp, NVVM::ClusterDimXOp,
346 NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
347 GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
348 NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
349 GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
350 NVVM::GridDimYOp, NVVM::GridDimZOp>,
351 GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(converter);
352
353 patterns.add<GPUDynamicSharedMemoryOpLowering>(
354 arg&: converter, args: NVVM::kSharedMemoryAlignmentBit);
355
356 // Explicitly drop memory space when lowering private memory
357 // attributions since NVVM models it as `alloca`s in the default
358 // memory space and does not support `alloca`s with addrspace(5).
359 patterns.add<GPUFuncOpLowering>(
360 converter, /*allocaAddrSpace=*/0,
361 /*workgroupAddrSpace=*/
362 static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace),
363 StringAttr::get(&converter.getContext(),
364 NVVM::NVVMDialect::getKernelFuncAttrName()),
365 StringAttr::get(&converter.getContext(),
366 NVVM::NVVMDialect::getMaxntidAttrName()));
367
368 populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
369 "__nv_fabs");
370 populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf",
371 "__nv_atan");
372 populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f",
373 "__nv_atan2");
374 populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf",
375 "__nv_cbrt");
376 populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf",
377 "__nv_ceil");
378 populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf", "__nv_cos");
379 populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff", "__nv_erf");
380 populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf", "__nv_exp");
381 populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f",
382 "__nv_exp2");
383 populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f",
384 "__nv_expm1");
385 populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf",
386 "__nv_floor");
387 populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
388 "__nv_fmod");
389 populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf", "__nv_log");
390 populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf",
391 "__nv_log1p");
392 populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f",
393 "__nv_log10");
394 populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f",
395 "__nv_log2");
396 populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf",
397 "__nv_pow");
398 populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf",
399 "__nv_rsqrt");
400 populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf", "__nv_sin");
401 populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf",
402 "__nv_sqrt");
403 populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf",
404 "__nv_tanh");
405 populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf", "__nv_tan");
406}
407

source code of mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp