1 | //===- LowerGpuOpsToNVVMOps.cpp - MLIR GPU to NVVM lowering passes --------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to generate NVVMIR operations for higher-level |
10 | // GPU operations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" |
15 | |
16 | #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" |
17 | #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
18 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
19 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" |
20 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
21 | #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" |
22 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
23 | #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
24 | #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" |
25 | #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" |
26 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
27 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" |
28 | #include "mlir/Dialect/GPU/Transforms/Passes.h" |
29 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
30 | #include "mlir/Dialect/Math/IR/Math.h" |
31 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
32 | #include "mlir/Transforms/DialectConversion.h" |
33 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
34 | |
35 | #include "../GPUCommon/GPUOpsLowering.h" |
36 | #include "../GPUCommon/IndexIntrinsicsOpLowering.h" |
37 | #include "../GPUCommon/OpToFuncCallLowering.h" |
38 | #include <optional> |
39 | |
40 | namespace mlir { |
41 | #define GEN_PASS_DEF_CONVERTGPUOPSTONVVMOPS |
42 | #include "mlir/Conversion/Passes.h.inc" |
43 | } // namespace mlir |
44 | |
45 | using namespace mlir; |
46 | |
47 | namespace { |
48 | |
49 | /// Convert gpu dialect shfl mode enum to the equivalent nvvm one. |
50 | static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) { |
51 | switch (mode) { |
52 | case gpu::ShuffleMode::XOR: |
53 | return NVVM::ShflKind::bfly; |
54 | case gpu::ShuffleMode::UP: |
55 | return NVVM::ShflKind::up; |
56 | case gpu::ShuffleMode::DOWN: |
57 | return NVVM::ShflKind::down; |
58 | case gpu::ShuffleMode::IDX: |
59 | return NVVM::ShflKind::idx; |
60 | } |
61 | llvm_unreachable("unknown shuffle mode" ); |
62 | } |
63 | |
64 | static std::optional<NVVM::ReduxKind> |
65 | convertReduxKind(gpu::AllReduceOperation mode) { |
66 | switch (mode) { |
67 | case gpu::AllReduceOperation::ADD: |
68 | return NVVM::ReduxKind::ADD; |
69 | case gpu::AllReduceOperation::MUL: |
70 | return std::nullopt; |
71 | case gpu::AllReduceOperation::MINSI: |
72 | return NVVM::ReduxKind::MIN; |
73 | case gpu::AllReduceOperation::MINUI: |
74 | return std::nullopt; |
75 | case gpu::AllReduceOperation::MINNUMF: |
76 | return NVVM::ReduxKind::MIN; |
77 | case gpu::AllReduceOperation::MAXSI: |
78 | return NVVM::ReduxKind::MAX; |
79 | case gpu::AllReduceOperation::MAXUI: |
80 | return std::nullopt; |
81 | case gpu::AllReduceOperation::MAXNUMF: |
82 | return NVVM::ReduxKind::MAX; |
83 | case gpu::AllReduceOperation::AND: |
84 | return NVVM::ReduxKind::AND; |
85 | case gpu::AllReduceOperation::OR: |
86 | return NVVM::ReduxKind::OR; |
87 | case gpu::AllReduceOperation::XOR: |
88 | return NVVM::ReduxKind::XOR; |
89 | case gpu::AllReduceOperation::MINIMUMF: |
90 | case gpu::AllReduceOperation::MAXIMUMF: |
91 | return std::nullopt; |
92 | } |
93 | return std::nullopt; |
94 | } |
95 | |
96 | /// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op |
97 | /// must be run by the entire subgroup, otherwise it is undefined behaviour. |
98 | struct GPUSubgroupReduceOpLowering |
99 | : public ConvertOpToLLVMPattern<gpu::SubgroupReduceOp> { |
100 | using ConvertOpToLLVMPattern<gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern; |
101 | LogicalResult |
102 | |
103 | matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, |
104 | ConversionPatternRewriter &rewriter) const override { |
105 | if (!op.getUniform()) |
106 | return rewriter.notifyMatchFailure( |
107 | op, "cannot be lowered to redux as the op must be run " |
108 | "uniformly (entire subgroup)." ); |
109 | if (!op.getValue().getType().isInteger(32)) |
110 | return rewriter.notifyMatchFailure(op, "unsupported data type" ); |
111 | |
112 | std::optional<NVVM::ReduxKind> mode = convertReduxKind(op.getOp()); |
113 | if (!mode.has_value()) |
114 | return rewriter.notifyMatchFailure( |
115 | op, "unsupported reduction mode for redux" ); |
116 | |
117 | Location loc = op->getLoc(); |
118 | auto int32Type = IntegerType::get(rewriter.getContext(), 32); |
119 | Value offset = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1); |
120 | |
121 | auto reduxOp = rewriter.create<NVVM::ReduxOp>(loc, int32Type, op.getValue(), |
122 | mode.value(), offset); |
123 | |
124 | rewriter.replaceOp(op, reduxOp->getResult(0)); |
125 | return success(); |
126 | } |
127 | }; |
128 | |
129 | struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { |
130 | using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern; |
131 | |
132 | /// Lowers a shuffle to the corresponding NVVM op. |
133 | /// |
134 | /// Convert the `width` argument into an activeMask (a bitmask which specifies |
135 | /// which threads participate in the shuffle) and a maskAndClamp (specifying |
136 | /// the highest lane which participates in the shuffle). |
137 | /// |
138 | /// %one = llvm.constant(1 : i32) : i32 |
139 | /// %minus_one = llvm.constant(-1 : i32) : i32 |
140 | /// %thirty_two = llvm.constant(32 : i32) : i32 |
141 | /// %num_lanes = llvm.sub %thirty_two, %width : i32 |
142 | /// %active_mask = llvm.lshr %minus_one, %num_lanes : i32 |
143 | /// %mask_and_clamp = llvm.sub %width, %one : i32 |
144 | /// %shfl = nvvm.shfl.sync.bfly %active_mask, %value, %offset, |
145 | /// %mask_and_clamp : !llvm<"{ float, i1 }"> |
146 | /// %shfl_value = llvm.extractvalue %shfl[0] : |
147 | /// !llvm<"{ float, i1 }"> |
148 | /// %shfl_pred = llvm.extractvalue %shfl[1] : |
149 | /// !llvm<"{ float, i1 }"> |
150 | LogicalResult |
151 | matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor, |
152 | ConversionPatternRewriter &rewriter) const override { |
153 | Location loc = op->getLoc(); |
154 | |
155 | auto valueTy = adaptor.getValue().getType(); |
156 | auto int32Type = IntegerType::get(rewriter.getContext(), 32); |
157 | auto predTy = IntegerType::get(rewriter.getContext(), 1); |
158 | |
159 | Value one = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 1); |
160 | Value minusOne = rewriter.create<LLVM::ConstantOp>(loc, int32Type, -1); |
161 | Value thirtyTwo = rewriter.create<LLVM::ConstantOp>(loc, int32Type, 32); |
162 | Value numLeadInactiveLane = rewriter.create<LLVM::SubOp>( |
163 | loc, int32Type, thirtyTwo, adaptor.getWidth()); |
164 | // Bit mask of active lanes: `(-1) >> (32 - activeWidth)`. |
165 | Value activeMask = rewriter.create<LLVM::LShrOp>(loc, int32Type, minusOne, |
166 | numLeadInactiveLane); |
167 | Value maskAndClamp; |
168 | if (op.getMode() == gpu::ShuffleMode::UP) { |
169 | // Clamp lane: `32 - activeWidth` |
170 | maskAndClamp = numLeadInactiveLane; |
171 | } else { |
172 | // Clamp lane: `activeWidth - 1` |
173 | maskAndClamp = |
174 | rewriter.create<LLVM::SubOp>(loc, int32Type, adaptor.getWidth(), one); |
175 | } |
176 | |
177 | bool predIsUsed = !op->getResult(1).use_empty(); |
178 | UnitAttr returnValueAndIsValidAttr = nullptr; |
179 | Type resultTy = valueTy; |
180 | if (predIsUsed) { |
181 | returnValueAndIsValidAttr = rewriter.getUnitAttr(); |
182 | resultTy = LLVM::LLVMStructType::getLiteral(context: rewriter.getContext(), |
183 | types: {valueTy, predTy}); |
184 | } |
185 | Value shfl = rewriter.create<NVVM::ShflOp>( |
186 | loc, resultTy, activeMask, adaptor.getValue(), adaptor.getOffset(), |
187 | maskAndClamp, convertShflKind(op.getMode()), returnValueAndIsValidAttr); |
188 | if (predIsUsed) { |
189 | Value shflValue = rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 0); |
190 | Value isActiveSrcLane = |
191 | rewriter.create<LLVM::ExtractValueOp>(loc, shfl, 1); |
192 | rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); |
193 | } else { |
194 | rewriter.replaceOp(op, {shfl, nullptr}); |
195 | } |
196 | return success(); |
197 | } |
198 | }; |
199 | |
200 | struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> { |
201 | using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern; |
202 | |
203 | LogicalResult |
204 | matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, |
205 | ConversionPatternRewriter &rewriter) const override { |
206 | auto loc = op->getLoc(); |
207 | MLIRContext *context = rewriter.getContext(); |
208 | Value newOp = rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type()); |
209 | // Truncate or extend the result depending on the index bitwidth specified |
210 | // by the LLVMTypeConverter options. |
211 | const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); |
212 | if (indexBitwidth > 32) { |
213 | newOp = rewriter.create<LLVM::SExtOp>( |
214 | loc, IntegerType::get(context, indexBitwidth), newOp); |
215 | } else if (indexBitwidth < 32) { |
216 | newOp = rewriter.create<LLVM::TruncOp>( |
217 | loc, IntegerType::get(context, indexBitwidth), newOp); |
218 | } |
219 | rewriter.replaceOp(op, {newOp}); |
220 | return success(); |
221 | } |
222 | }; |
223 | |
224 | /// Import the GPU Ops to NVVM Patterns. |
225 | #include "GPUToNVVM.cpp.inc" |
226 | |
227 | /// A pass that replaces all occurrences of GPU device operations with their |
228 | /// corresponding NVVM equivalent. |
229 | /// |
230 | /// This pass only handles device code and is not meant to be run on GPU host |
231 | /// code. |
232 | struct LowerGpuOpsToNVVMOpsPass |
233 | : public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> { |
234 | using Base::Base; |
235 | |
236 | void runOnOperation() override { |
237 | gpu::GPUModuleOp m = getOperation(); |
238 | |
239 | // Request C wrapper emission. |
240 | for (auto func : m.getOps<func::FuncOp>()) { |
241 | func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), |
242 | UnitAttr::get(&getContext())); |
243 | } |
244 | |
245 | // Customize the bitwidth used for the device side index computations. |
246 | LowerToLLVMOptions options( |
247 | m.getContext(), |
248 | DataLayout(cast<DataLayoutOpInterface>(m.getOperation()))); |
249 | if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) |
250 | options.overrideIndexBitwidth(indexBitwidth); |
251 | options.useBarePtrCallConv = useBarePtrCallConv; |
252 | |
253 | // Apply in-dialect lowering. In-dialect lowering will replace |
254 | // ops which need to be lowered further, which is not supported by a |
255 | // single conversion pass. |
256 | { |
257 | RewritePatternSet patterns(m.getContext()); |
258 | populateGpuRewritePatterns(patterns); |
259 | if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) |
260 | return signalPassFailure(); |
261 | } |
262 | |
263 | LLVMTypeConverter converter(m.getContext(), options); |
264 | // NVVM uses alloca in the default address space to represent private |
265 | // memory allocations, so drop private annotations. NVVM uses address |
266 | // space 3 for shared memory. NVVM uses the default address space to |
267 | // represent global memory. |
268 | populateGpuMemorySpaceAttributeConversions( |
269 | converter, [](gpu::AddressSpace space) -> unsigned { |
270 | switch (space) { |
271 | case gpu::AddressSpace::Global: |
272 | return static_cast<unsigned>( |
273 | NVVM::NVVMMemorySpace::kGlobalMemorySpace); |
274 | case gpu::AddressSpace::Workgroup: |
275 | return static_cast<unsigned>( |
276 | NVVM::NVVMMemorySpace::kSharedMemorySpace); |
277 | case gpu::AddressSpace::Private: |
278 | return 0; |
279 | } |
280 | llvm_unreachable("unknown address space enum value" ); |
281 | return 0; |
282 | }); |
283 | // Lowering for MMAMatrixType. |
284 | converter.addConversion([&](gpu::MMAMatrixType type) -> Type { |
285 | return convertMMAToLLVMType(type); |
286 | }); |
287 | RewritePatternSet llvmPatterns(m.getContext()); |
288 | |
289 | arith::populateArithToLLVMConversionPatterns(converter, patterns&: llvmPatterns); |
290 | cf::populateControlFlowToLLVMConversionPatterns(converter, patterns&: llvmPatterns); |
291 | populateFuncToLLVMConversionPatterns(converter, patterns&: llvmPatterns); |
292 | populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns&: llvmPatterns); |
293 | populateGpuToNVVMConversionPatterns(converter, patterns&: llvmPatterns); |
294 | populateGpuWMMAToNVVMConversionPatterns(converter, patterns&: llvmPatterns); |
295 | populateVectorToLLVMConversionPatterns(converter, patterns&: llvmPatterns); |
296 | if (this->hasRedux) |
297 | populateGpuSubgroupReduceOpLoweringPattern(converter, patterns&: llvmPatterns); |
298 | LLVMConversionTarget target(getContext()); |
299 | configureGpuToNVVMConversionLegality(target); |
300 | if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) |
301 | signalPassFailure(); |
302 | } |
303 | }; |
304 | |
305 | } // namespace |
306 | |
307 | void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) { |
308 | target.addIllegalOp<func::FuncOp>(); |
309 | target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); |
310 | target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); |
311 | target.addIllegalDialect<gpu::GPUDialect>(); |
312 | target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp, |
313 | LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, |
314 | LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp, |
315 | LLVM::SqrtOp>(); |
316 | |
317 | // TODO: Remove once we support replacing non-root ops. |
318 | target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>(); |
319 | } |
320 | |
321 | template <typename OpTy> |
322 | static void populateOpPatterns(LLVMTypeConverter &converter, |
323 | RewritePatternSet &patterns, StringRef f32Func, |
324 | StringRef f64Func) { |
325 | patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter); |
326 | patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func); |
327 | } |
328 | |
329 | void mlir::populateGpuSubgroupReduceOpLoweringPattern( |
330 | LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
331 | patterns.add<GPUSubgroupReduceOpLowering>(arg&: converter); |
332 | } |
333 | |
334 | void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, |
335 | RewritePatternSet &patterns) { |
336 | populateWithGenerated(patterns); |
337 | patterns.add<GPUPrintfOpToVPrintfLowering>(arg&: converter); |
338 | patterns.add< |
339 | GPUIndexIntrinsicOpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp, |
340 | NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>, |
341 | GPUIndexIntrinsicOpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp, |
342 | NVVM::BlockDimYOp, NVVM::BlockDimZOp>, |
343 | GPUIndexIntrinsicOpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp, |
344 | NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>, |
345 | GPUIndexIntrinsicOpLowering<gpu::ClusterDimOp, NVVM::ClusterDimXOp, |
346 | NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>, |
347 | GPUIndexIntrinsicOpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp, |
348 | NVVM::BlockIdYOp, NVVM::BlockIdZOp>, |
349 | GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp, |
350 | NVVM::GridDimYOp, NVVM::GridDimZOp>, |
351 | GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(converter); |
352 | |
353 | patterns.add<GPUDynamicSharedMemoryOpLowering>( |
354 | arg&: converter, args: NVVM::kSharedMemoryAlignmentBit); |
355 | |
356 | // Explicitly drop memory space when lowering private memory |
357 | // attributions since NVVM models it as `alloca`s in the default |
358 | // memory space and does not support `alloca`s with addrspace(5). |
359 | patterns.add<GPUFuncOpLowering>( |
360 | converter, /*allocaAddrSpace=*/0, |
361 | /*workgroupAddrSpace=*/ |
362 | static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace), |
363 | StringAttr::get(&converter.getContext(), |
364 | NVVM::NVVMDialect::getKernelFuncAttrName()), |
365 | StringAttr::get(&converter.getContext(), |
366 | NVVM::NVVMDialect::getMaxntidAttrName())); |
367 | |
368 | populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf" , |
369 | "__nv_fabs" ); |
370 | populateOpPatterns<math::AtanOp>(converter, patterns, "__nv_atanf" , |
371 | "__nv_atan" ); |
372 | populateOpPatterns<math::Atan2Op>(converter, patterns, "__nv_atan2f" , |
373 | "__nv_atan2" ); |
374 | populateOpPatterns<math::CbrtOp>(converter, patterns, "__nv_cbrtf" , |
375 | "__nv_cbrt" ); |
376 | populateOpPatterns<math::CeilOp>(converter, patterns, "__nv_ceilf" , |
377 | "__nv_ceil" ); |
378 | populateOpPatterns<math::CosOp>(converter, patterns, "__nv_cosf" , "__nv_cos" ); |
379 | populateOpPatterns<math::ErfOp>(converter, patterns, "__nv_erff" , "__nv_erf" ); |
380 | populateOpPatterns<math::ExpOp>(converter, patterns, "__nv_expf" , "__nv_exp" ); |
381 | populateOpPatterns<math::Exp2Op>(converter, patterns, "__nv_exp2f" , |
382 | "__nv_exp2" ); |
383 | populateOpPatterns<math::ExpM1Op>(converter, patterns, "__nv_expm1f" , |
384 | "__nv_expm1" ); |
385 | populateOpPatterns<math::FloorOp>(converter, patterns, "__nv_floorf" , |
386 | "__nv_floor" ); |
387 | populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf" , |
388 | "__nv_fmod" ); |
389 | populateOpPatterns<math::LogOp>(converter, patterns, "__nv_logf" , "__nv_log" ); |
390 | populateOpPatterns<math::Log1pOp>(converter, patterns, "__nv_log1pf" , |
391 | "__nv_log1p" ); |
392 | populateOpPatterns<math::Log10Op>(converter, patterns, "__nv_log10f" , |
393 | "__nv_log10" ); |
394 | populateOpPatterns<math::Log2Op>(converter, patterns, "__nv_log2f" , |
395 | "__nv_log2" ); |
396 | populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf" , |
397 | "__nv_pow" ); |
398 | populateOpPatterns<math::RsqrtOp>(converter, patterns, "__nv_rsqrtf" , |
399 | "__nv_rsqrt" ); |
400 | populateOpPatterns<math::SinOp>(converter, patterns, "__nv_sinf" , "__nv_sin" ); |
401 | populateOpPatterns<math::SqrtOp>(converter, patterns, "__nv_sqrtf" , |
402 | "__nv_sqrt" ); |
403 | populateOpPatterns<math::TanhOp>(converter, patterns, "__nv_tanhf" , |
404 | "__nv_tanh" ); |
405 | populateOpPatterns<math::TanOp>(converter, patterns, "__nv_tanf" , "__nv_tan" ); |
406 | } |
407 | |