1 | //===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | // |
10 | // Copyright (C) by Argonne National Laboratory |
11 | // See COPYRIGHT in top-level directory |
12 | // of MPICH source repository. |
13 | // |
14 | |
15 | #include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" |
16 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
17 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
19 | #include "mlir/Dialect/DLTI/DLTI.h" |
20 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
21 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
22 | #include "mlir/Dialect/MPI/IR/MPI.h" |
23 | #include "mlir/Transforms/DialectConversion.h" |
24 | #include <memory> |
25 | |
26 | using namespace mlir; |
27 | |
28 | namespace { |
29 | |
30 | template <typename Op, typename... Args> |
31 | static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc, |
32 | ConversionPatternRewriter &rewriter, StringRef name, |
33 | Args &&...args) { |
34 | Op ret; |
35 | if (!(ret = moduleOp.lookupSymbol<Op>(name))) { |
36 | ConversionPatternRewriter::InsertionGuard guard(rewriter); |
37 | rewriter.setInsertionPointToStart(moduleOp.getBody()); |
38 | ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...); |
39 | } |
40 | return ret; |
41 | } |
42 | |
43 | static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, |
44 | const Location loc, |
45 | ConversionPatternRewriter &rewriter, |
46 | StringRef name, |
47 | LLVM::LLVMFunctionType type) { |
48 | return getOrDefineGlobal<LLVM::LLVMFuncOp>( |
49 | moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External); |
50 | } |
51 | |
52 | std::pair<Value, Value> getRawPtrAndSize(const Location loc, |
53 | ConversionPatternRewriter &rewriter, |
54 | Value memRef, Type elType) { |
55 | Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
56 | Value dataPtr = |
57 | rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1); |
58 | Value offset = rewriter.create<LLVM::ExtractValueOp>( |
59 | loc, rewriter.getI64Type(), memRef, 2); |
60 | Value resPtr = |
61 | rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset); |
62 | Value size; |
63 | if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) { |
64 | size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef, |
65 | ArrayRef<int64_t>{3, 0}); |
66 | size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size); |
67 | } else { |
68 | size = rewriter.create<arith::ConstantIntOp>(location: loc, args: 1, args: 32); |
69 | } |
70 | return {resPtr, size}; |
71 | } |
72 | |
73 | /// When lowering the mpi dialect to functions calls certain details |
74 | /// differ between various MPI implementations. This class will provide |
75 | /// these in a generic way, depending on the MPI implementation that got |
76 | /// selected by the DLTI attribute on the module. |
77 | class MPIImplTraits { |
78 | ModuleOp &moduleOp; |
79 | |
80 | public: |
81 | /// Instantiate a new MPIImplTraits object according to the DLTI attribute |
82 | /// on the given module. Default to MPICH if no attribute is present or |
83 | /// the value is unknown. |
84 | static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp); |
85 | |
86 | explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {} |
87 | |
88 | virtual ~MPIImplTraits() = default; |
89 | |
90 | ModuleOp &getModuleOp() { return moduleOp; } |
91 | |
92 | /// Gets or creates MPI_COMM_WORLD as a Value. |
93 | /// Different MPI implementations have different communicator types. |
94 | /// Using i64 as a portable, intermediate type. |
95 | /// Appropriate cast needs to take place before calling MPI functions. |
96 | virtual Value getCommWorld(const Location loc, |
97 | ConversionPatternRewriter &rewriter) = 0; |
98 | |
99 | /// Type converter provides i64 type for communicator type. |
100 | /// Converts to native type, which might be ptr or int or whatever. |
101 | virtual Value castComm(const Location loc, |
102 | ConversionPatternRewriter &rewriter, Value comm) = 0; |
103 | |
104 | /// Get the MPI_STATUS_IGNORE value (typically a pointer type). |
105 | virtual intptr_t getStatusIgnore() = 0; |
106 | |
107 | /// Get the MPI_IN_PLACE value (void *). |
108 | virtual void *getInPlace() = 0; |
109 | |
110 | /// Gets or creates an MPI datatype as a value which corresponds to the given |
111 | /// type. |
112 | virtual Value getDataType(const Location loc, |
113 | ConversionPatternRewriter &rewriter, Type type) = 0; |
114 | |
115 | /// Gets or creates an MPI_Op value which corresponds to the given |
116 | /// enum value. |
117 | virtual Value getMPIOp(const Location loc, |
118 | ConversionPatternRewriter &rewriter, |
119 | mpi::MPI_OpClassEnum opAttr) = 0; |
120 | }; |
121 | |
122 | //===----------------------------------------------------------------------===// |
123 | // Implementation details for MPICH ABI compatible MPI implementations |
124 | //===----------------------------------------------------------------------===// |
125 | |
126 | class MPICHImplTraits : public MPIImplTraits { |
127 | static constexpr int MPI_FLOAT = 0x4c00040a; |
128 | static constexpr int MPI_DOUBLE = 0x4c00080b; |
129 | static constexpr int MPI_INT8_T = 0x4c000137; |
130 | static constexpr int MPI_INT16_T = 0x4c000238; |
131 | static constexpr int MPI_INT32_T = 0x4c000439; |
132 | static constexpr int MPI_INT64_T = 0x4c00083a; |
133 | static constexpr int MPI_UINT8_T = 0x4c00013b; |
134 | static constexpr int MPI_UINT16_T = 0x4c00023c; |
135 | static constexpr int MPI_UINT32_T = 0x4c00043d; |
136 | static constexpr int MPI_UINT64_T = 0x4c00083e; |
137 | static constexpr int MPI_MAX = 0x58000001; |
138 | static constexpr int MPI_MIN = 0x58000002; |
139 | static constexpr int MPI_SUM = 0x58000003; |
140 | static constexpr int MPI_PROD = 0x58000004; |
141 | static constexpr int MPI_LAND = 0x58000005; |
142 | static constexpr int MPI_BAND = 0x58000006; |
143 | static constexpr int MPI_LOR = 0x58000007; |
144 | static constexpr int MPI_BOR = 0x58000008; |
145 | static constexpr int MPI_LXOR = 0x58000009; |
146 | static constexpr int MPI_BXOR = 0x5800000a; |
147 | static constexpr int MPI_MINLOC = 0x5800000b; |
148 | static constexpr int MPI_MAXLOC = 0x5800000c; |
149 | static constexpr int MPI_REPLACE = 0x5800000d; |
150 | static constexpr int MPI_NO_OP = 0x5800000e; |
151 | |
152 | public: |
153 | using MPIImplTraits::MPIImplTraits; |
154 | |
155 | ~MPICHImplTraits() override = default; |
156 | |
157 | Value getCommWorld(const Location loc, |
158 | ConversionPatternRewriter &rewriter) override { |
159 | static constexpr int MPI_COMM_WORLD = 0x44000000; |
160 | return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), |
161 | MPI_COMM_WORLD); |
162 | } |
163 | |
164 | Value castComm(const Location loc, ConversionPatternRewriter &rewriter, |
165 | Value comm) override { |
166 | return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm); |
167 | } |
168 | |
169 | intptr_t getStatusIgnore() override { return 1; } |
170 | |
171 | void *getInPlace() override { return reinterpret_cast<void *>(-1); } |
172 | |
173 | Value getDataType(const Location loc, ConversionPatternRewriter &rewriter, |
174 | Type type) override { |
175 | int32_t mtype = 0; |
176 | if (type.isF32()) |
177 | mtype = MPI_FLOAT; |
178 | else if (type.isF64()) |
179 | mtype = MPI_DOUBLE; |
180 | else if (type.isInteger(width: 64) && !type.isUnsignedInteger()) |
181 | mtype = MPI_INT64_T; |
182 | else if (type.isInteger(width: 64)) |
183 | mtype = MPI_UINT64_T; |
184 | else if (type.isInteger(width: 32) && !type.isUnsignedInteger()) |
185 | mtype = MPI_INT32_T; |
186 | else if (type.isInteger(width: 32)) |
187 | mtype = MPI_UINT32_T; |
188 | else if (type.isInteger(width: 16) && !type.isUnsignedInteger()) |
189 | mtype = MPI_INT16_T; |
190 | else if (type.isInteger(width: 16)) |
191 | mtype = MPI_UINT16_T; |
192 | else if (type.isInteger(width: 8) && !type.isUnsignedInteger()) |
193 | mtype = MPI_INT8_T; |
194 | else if (type.isInteger(width: 8)) |
195 | mtype = MPI_UINT8_T; |
196 | else |
197 | assert(false && "unsupported type" ); |
198 | return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype); |
199 | } |
200 | |
201 | Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, |
202 | mpi::MPI_OpClassEnum opAttr) override { |
203 | int32_t op = MPI_NO_OP; |
204 | switch (opAttr) { |
205 | case mpi::MPI_OpClassEnum::MPI_OP_NULL: |
206 | op = MPI_NO_OP; |
207 | break; |
208 | case mpi::MPI_OpClassEnum::MPI_MAX: |
209 | op = MPI_MAX; |
210 | break; |
211 | case mpi::MPI_OpClassEnum::MPI_MIN: |
212 | op = MPI_MIN; |
213 | break; |
214 | case mpi::MPI_OpClassEnum::MPI_SUM: |
215 | op = MPI_SUM; |
216 | break; |
217 | case mpi::MPI_OpClassEnum::MPI_PROD: |
218 | op = MPI_PROD; |
219 | break; |
220 | case mpi::MPI_OpClassEnum::MPI_LAND: |
221 | op = MPI_LAND; |
222 | break; |
223 | case mpi::MPI_OpClassEnum::MPI_BAND: |
224 | op = MPI_BAND; |
225 | break; |
226 | case mpi::MPI_OpClassEnum::MPI_LOR: |
227 | op = MPI_LOR; |
228 | break; |
229 | case mpi::MPI_OpClassEnum::MPI_BOR: |
230 | op = MPI_BOR; |
231 | break; |
232 | case mpi::MPI_OpClassEnum::MPI_LXOR: |
233 | op = MPI_LXOR; |
234 | break; |
235 | case mpi::MPI_OpClassEnum::MPI_BXOR: |
236 | op = MPI_BXOR; |
237 | break; |
238 | case mpi::MPI_OpClassEnum::MPI_MINLOC: |
239 | op = MPI_MINLOC; |
240 | break; |
241 | case mpi::MPI_OpClassEnum::MPI_MAXLOC: |
242 | op = MPI_MAXLOC; |
243 | break; |
244 | case mpi::MPI_OpClassEnum::MPI_REPLACE: |
245 | op = MPI_REPLACE; |
246 | break; |
247 | } |
248 | return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op); |
249 | } |
250 | }; |
251 | |
252 | //===----------------------------------------------------------------------===// |
253 | // Implementation details for OpenMPI |
254 | //===----------------------------------------------------------------------===// |
255 | class OMPIImplTraits : public MPIImplTraits { |
256 | LLVM::GlobalOp getOrDefineExternalStruct(const Location loc, |
257 | ConversionPatternRewriter &rewriter, |
258 | StringRef name, |
259 | LLVM::LLVMStructType type) { |
260 | |
261 | return getOrDefineGlobal<LLVM::GlobalOp>( |
262 | getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false, |
263 | LLVM::Linkage::External, name, |
264 | /*value=*/Attribute(), /*alignment=*/0, 0); |
265 | } |
266 | |
267 | public: |
268 | using MPIImplTraits::MPIImplTraits; |
269 | |
270 | ~OMPIImplTraits() override = default; |
271 | |
272 | Value getCommWorld(const Location loc, |
273 | ConversionPatternRewriter &rewriter) override { |
274 | auto context = rewriter.getContext(); |
275 | // get external opaque struct pointer type |
276 | auto commStructT = |
277 | LLVM::LLVMStructType::getOpaque("ompi_communicator_t" , context); |
278 | StringRef name = "ompi_mpi_comm_world" ; |
279 | |
280 | // make sure global op definition exists |
281 | getOrDefineExternalStruct(loc, rewriter, name, commStructT); |
282 | |
283 | // get address of symbol |
284 | auto comm = rewriter.create<LLVM::AddressOfOp>( |
285 | loc, LLVM::LLVMPointerType::get(context), |
286 | SymbolRefAttr::get(context, name)); |
287 | return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm); |
288 | } |
289 | |
290 | Value castComm(const Location loc, ConversionPatternRewriter &rewriter, |
291 | Value comm) override { |
292 | return rewriter.create<LLVM::IntToPtrOp>( |
293 | loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm); |
294 | } |
295 | |
296 | intptr_t getStatusIgnore() override { return 0; } |
297 | |
298 | void *getInPlace() override { return reinterpret_cast<void *>(1); } |
299 | |
300 | Value getDataType(const Location loc, ConversionPatternRewriter &rewriter, |
301 | Type type) override { |
302 | StringRef mtype; |
303 | if (type.isF32()) |
304 | mtype = "ompi_mpi_float" ; |
305 | else if (type.isF64()) |
306 | mtype = "ompi_mpi_double" ; |
307 | else if (type.isInteger(width: 64) && !type.isUnsignedInteger()) |
308 | mtype = "ompi_mpi_int64_t" ; |
309 | else if (type.isInteger(width: 64)) |
310 | mtype = "ompi_mpi_uint64_t" ; |
311 | else if (type.isInteger(width: 32) && !type.isUnsignedInteger()) |
312 | mtype = "ompi_mpi_int32_t" ; |
313 | else if (type.isInteger(width: 32)) |
314 | mtype = "ompi_mpi_uint32_t" ; |
315 | else if (type.isInteger(width: 16) && !type.isUnsignedInteger()) |
316 | mtype = "ompi_mpi_int16_t" ; |
317 | else if (type.isInteger(width: 16)) |
318 | mtype = "ompi_mpi_uint16_t" ; |
319 | else if (type.isInteger(width: 8) && !type.isUnsignedInteger()) |
320 | mtype = "ompi_mpi_int8_t" ; |
321 | else if (type.isInteger(width: 8)) |
322 | mtype = "ompi_mpi_uint8_t" ; |
323 | else |
324 | assert(false && "unsupported type" ); |
325 | |
326 | auto context = rewriter.getContext(); |
327 | // get external opaque struct pointer type |
328 | auto typeStructT = |
329 | LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t" , context); |
330 | // make sure global op definition exists |
331 | getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT); |
332 | // get address of symbol |
333 | return rewriter.create<LLVM::AddressOfOp>( |
334 | loc, LLVM::LLVMPointerType::get(context), |
335 | SymbolRefAttr::get(context, mtype)); |
336 | } |
337 | |
338 | Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, |
339 | mpi::MPI_OpClassEnum opAttr) override { |
340 | StringRef op; |
341 | switch (opAttr) { |
342 | case mpi::MPI_OpClassEnum::MPI_OP_NULL: |
343 | op = "ompi_mpi_no_op" ; |
344 | break; |
345 | case mpi::MPI_OpClassEnum::MPI_MAX: |
346 | op = "ompi_mpi_max" ; |
347 | break; |
348 | case mpi::MPI_OpClassEnum::MPI_MIN: |
349 | op = "ompi_mpi_min" ; |
350 | break; |
351 | case mpi::MPI_OpClassEnum::MPI_SUM: |
352 | op = "ompi_mpi_sum" ; |
353 | break; |
354 | case mpi::MPI_OpClassEnum::MPI_PROD: |
355 | op = "ompi_mpi_prod" ; |
356 | break; |
357 | case mpi::MPI_OpClassEnum::MPI_LAND: |
358 | op = "ompi_mpi_land" ; |
359 | break; |
360 | case mpi::MPI_OpClassEnum::MPI_BAND: |
361 | op = "ompi_mpi_band" ; |
362 | break; |
363 | case mpi::MPI_OpClassEnum::MPI_LOR: |
364 | op = "ompi_mpi_lor" ; |
365 | break; |
366 | case mpi::MPI_OpClassEnum::MPI_BOR: |
367 | op = "ompi_mpi_bor" ; |
368 | break; |
369 | case mpi::MPI_OpClassEnum::MPI_LXOR: |
370 | op = "ompi_mpi_lxor" ; |
371 | break; |
372 | case mpi::MPI_OpClassEnum::MPI_BXOR: |
373 | op = "ompi_mpi_bxor" ; |
374 | break; |
375 | case mpi::MPI_OpClassEnum::MPI_MINLOC: |
376 | op = "ompi_mpi_minloc" ; |
377 | break; |
378 | case mpi::MPI_OpClassEnum::MPI_MAXLOC: |
379 | op = "ompi_mpi_maxloc" ; |
380 | break; |
381 | case mpi::MPI_OpClassEnum::MPI_REPLACE: |
382 | op = "ompi_mpi_replace" ; |
383 | break; |
384 | } |
385 | auto context = rewriter.getContext(); |
386 | // get external opaque struct pointer type |
387 | auto opStructT = |
388 | LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t" , context); |
389 | // make sure global op definition exists |
390 | getOrDefineExternalStruct(loc, rewriter, op, opStructT); |
391 | // get address of symbol |
392 | return rewriter.create<LLVM::AddressOfOp>( |
393 | loc, LLVM::LLVMPointerType::get(context), |
394 | SymbolRefAttr::get(context, op)); |
395 | } |
396 | }; |
397 | |
398 | std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) { |
399 | auto attr = dlti::query(*&moduleOp, {"MPI:Implementation" }, true); |
400 | if (failed(attr)) |
401 | return std::make_unique<MPICHImplTraits>(args&: moduleOp); |
402 | auto strAttr = dyn_cast<StringAttr>(attr.value()); |
403 | if (strAttr && strAttr.getValue() == "OpenMPI" ) |
404 | return std::make_unique<OMPIImplTraits>(args&: moduleOp); |
405 | if (!strAttr || strAttr.getValue() != "MPICH" ) |
406 | moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI (" |
407 | << strAttr.getValue() << "), defaulting to MPICH" ; |
408 | return std::make_unique<MPICHImplTraits>(args&: moduleOp); |
409 | } |
410 | |
411 | //===----------------------------------------------------------------------===// |
412 | // InitOpLowering |
413 | //===----------------------------------------------------------------------===// |
414 | |
415 | struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> { |
416 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
417 | |
418 | LogicalResult |
419 | matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor, |
420 | ConversionPatternRewriter &rewriter) const override { |
421 | Location loc = op.getLoc(); |
422 | |
423 | // ptrType `!llvm.ptr` |
424 | Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
425 | |
426 | // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr` |
427 | auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType); |
428 | Value llvmnull = nullPtrOp.getRes(); |
429 | |
430 | // grab a reference to the global module op: |
431 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
432 | |
433 | // LLVM Function type representing `i32 MPI_Init(ptr, ptr)` |
434 | auto initFuncType = |
435 | LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType}); |
436 | // get or create function declaration: |
437 | LLVM::LLVMFuncOp initDecl = |
438 | getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init" , initFuncType); |
439 | |
440 | // replace init with function call |
441 | rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, |
442 | ValueRange{llvmnull, llvmnull}); |
443 | |
444 | return success(); |
445 | } |
446 | }; |
447 | |
448 | //===----------------------------------------------------------------------===// |
449 | // FinalizeOpLowering |
450 | //===----------------------------------------------------------------------===// |
451 | |
452 | struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> { |
453 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
454 | |
455 | LogicalResult |
456 | matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor, |
457 | ConversionPatternRewriter &rewriter) const override { |
458 | // get loc |
459 | Location loc = op.getLoc(); |
460 | |
461 | // grab a reference to the global module op: |
462 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
463 | |
464 | // LLVM Function type representing `i32 MPI_Finalize()` |
465 | auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {}); |
466 | // get or create function declaration: |
467 | LLVM::LLVMFuncOp initDecl = getOrDefineFunction( |
468 | moduleOp, loc, rewriter, "MPI_Finalize" , initFuncType); |
469 | |
470 | // replace init with function call |
471 | rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{}); |
472 | |
473 | return success(); |
474 | } |
475 | }; |
476 | |
477 | //===----------------------------------------------------------------------===// |
478 | // CommWorldOpLowering |
479 | //===----------------------------------------------------------------------===// |
480 | |
481 | struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> { |
482 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
483 | |
484 | LogicalResult |
485 | matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor, |
486 | ConversionPatternRewriter &rewriter) const override { |
487 | // grab a reference to the global module op: |
488 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
489 | auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp); |
490 | // get MPI_COMM_WORLD |
491 | rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter)); |
492 | |
493 | return success(); |
494 | } |
495 | }; |
496 | |
497 | //===----------------------------------------------------------------------===// |
498 | // CommSplitOpLowering |
499 | //===----------------------------------------------------------------------===// |
500 | |
501 | struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> { |
502 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
503 | |
504 | LogicalResult |
505 | matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor, |
506 | ConversionPatternRewriter &rewriter) const override { |
507 | // grab a reference to the global module op: |
508 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
509 | auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp); |
510 | Type i32 = rewriter.getI32Type(); |
511 | Type ptrType = LLVM::LLVMPointerType::get(op->getContext()); |
512 | Location loc = op.getLoc(); |
513 | |
514 | // get communicator |
515 | Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); |
516 | auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1); |
517 | auto outPtr = |
518 | rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one); |
519 | |
520 | // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm) |
521 | auto funcType = |
522 | LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType}); |
523 | // get or create function declaration: |
524 | LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter, |
525 | "MPI_Comm_split" , funcType); |
526 | |
527 | auto callOp = rewriter.create<LLVM::CallOp>( |
528 | loc, funcDecl, |
529 | ValueRange{comm, adaptor.getColor(), adaptor.getKey(), |
530 | outPtr.getRes()}); |
531 | |
532 | // load the communicator into a register |
533 | Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult()); |
534 | res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res); |
535 | |
536 | // if retval is checked, replace uses of retval with the results from the |
537 | // call op |
538 | SmallVector<Value> replacements; |
539 | if (op.getRetval()) |
540 | replacements.push_back(Elt: callOp.getResult()); |
541 | |
542 | // replace op |
543 | replacements.push_back(Elt: res); |
544 | rewriter.replaceOp(op, replacements); |
545 | |
546 | return success(); |
547 | } |
548 | }; |
549 | |
550 | //===----------------------------------------------------------------------===// |
551 | // CommRankOpLowering |
552 | //===----------------------------------------------------------------------===// |
553 | |
554 | struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> { |
555 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
556 | |
557 | LogicalResult |
558 | matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor, |
559 | ConversionPatternRewriter &rewriter) const override { |
560 | // get some helper vars |
561 | Location loc = op.getLoc(); |
562 | MLIRContext *context = rewriter.getContext(); |
563 | Type i32 = rewriter.getI32Type(); |
564 | |
565 | // ptrType `!llvm.ptr` |
566 | Type ptrType = LLVM::LLVMPointerType::get(context); |
567 | |
568 | // grab a reference to the global module op: |
569 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
570 | |
571 | auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp); |
572 | // get communicator |
573 | Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); |
574 | |
575 | // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)` |
576 | auto rankFuncType = |
577 | LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType}); |
578 | // get or create function declaration: |
579 | LLVM::LLVMFuncOp initDecl = getOrDefineFunction( |
580 | moduleOp, loc, rewriter, "MPI_Comm_rank" , rankFuncType); |
581 | |
582 | // replace with function call |
583 | auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1); |
584 | auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one); |
585 | auto callOp = rewriter.create<LLVM::CallOp>( |
586 | loc, initDecl, ValueRange{comm, rankptr.getRes()}); |
587 | |
588 | // load the rank into a register |
589 | auto loadedRank = |
590 | rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult()); |
591 | |
592 | // if retval is checked, replace uses of retval with the results from the |
593 | // call op |
594 | SmallVector<Value> replacements; |
595 | if (op.getRetval()) |
596 | replacements.push_back(Elt: callOp.getResult()); |
597 | |
598 | // replace all uses, then erase op |
599 | replacements.push_back(Elt: loadedRank.getRes()); |
600 | rewriter.replaceOp(op, replacements); |
601 | |
602 | return success(); |
603 | } |
604 | }; |
605 | |
606 | //===----------------------------------------------------------------------===// |
607 | // SendOpLowering |
608 | //===----------------------------------------------------------------------===// |
609 | |
610 | struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> { |
611 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
612 | |
613 | LogicalResult |
614 | matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor, |
615 | ConversionPatternRewriter &rewriter) const override { |
616 | // get some helper vars |
617 | Location loc = op.getLoc(); |
618 | MLIRContext *context = rewriter.getContext(); |
619 | Type i32 = rewriter.getI32Type(); |
620 | Type elemType = op.getRef().getType().getElementType(); |
621 | |
622 | // ptrType `!llvm.ptr` |
623 | Type ptrType = LLVM::LLVMPointerType::get(context); |
624 | |
625 | // grab a reference to the global module op: |
626 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
627 | |
628 | // get MPI_COMM_WORLD, dataType and pointer |
629 | auto [dataPtr, size] = |
630 | getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType); |
631 | auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp); |
632 | Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); |
633 | Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); |
634 | |
635 | // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst, |
636 | // tag, comm)` |
637 | auto funcType = LLVM::LLVMFunctionType::get( |
638 | i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()}); |
639 | // get or create function declaration: |
640 | LLVM::LLVMFuncOp funcDecl = |
641 | getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send" , funcType); |
642 | |
643 | // replace op with function call |
644 | auto funcCall = rewriter.create<LLVM::CallOp>( |
645 | loc, funcDecl, |
646 | ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(), |
647 | comm}); |
648 | if (op.getRetval()) |
649 | rewriter.replaceOp(op, funcCall.getResult()); |
650 | else |
651 | rewriter.eraseOp(op: op); |
652 | |
653 | return success(); |
654 | } |
655 | }; |
656 | |
657 | //===----------------------------------------------------------------------===// |
658 | // RecvOpLowering |
659 | //===----------------------------------------------------------------------===// |
660 | |
661 | struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> { |
662 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
663 | |
664 | LogicalResult |
665 | matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor, |
666 | ConversionPatternRewriter &rewriter) const override { |
667 | // get some helper vars |
668 | Location loc = op.getLoc(); |
669 | MLIRContext *context = rewriter.getContext(); |
670 | Type i32 = rewriter.getI32Type(); |
671 | Type i64 = rewriter.getI64Type(); |
672 | Type elemType = op.getRef().getType().getElementType(); |
673 | |
674 | // ptrType `!llvm.ptr` |
675 | Type ptrType = LLVM::LLVMPointerType::get(context); |
676 | |
677 | // grab a reference to the global module op: |
678 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
679 | |
680 | // get MPI_COMM_WORLD, dataType, status_ignore and pointer |
681 | auto [dataPtr, size] = |
682 | getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType); |
683 | auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp); |
684 | Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); |
685 | Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); |
686 | Value statusIgnore = rewriter.create<LLVM::ConstantOp>( |
687 | loc, i64, mpiTraits->getStatusIgnore()); |
688 | statusIgnore = |
689 | rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore); |
690 | |
691 | // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst, |
692 | // tag, comm)` |
693 | auto funcType = |
694 | LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32, |
695 | i32, comm.getType(), ptrType}); |
696 | // get or create function declaration: |
697 | LLVM::LLVMFuncOp funcDecl = |
698 | getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv" , funcType); |
699 | |
700 | // replace op with function call |
701 | auto funcCall = rewriter.create<LLVM::CallOp>( |
702 | loc, funcDecl, |
703 | ValueRange{dataPtr, size, dataType, adaptor.getSource(), |
704 | adaptor.getTag(), comm, statusIgnore}); |
705 | if (op.getRetval()) |
706 | rewriter.replaceOp(op, funcCall.getResult()); |
707 | else |
708 | rewriter.eraseOp(op: op); |
709 | |
710 | return success(); |
711 | } |
712 | }; |
713 | |
714 | //===----------------------------------------------------------------------===// |
715 | // AllReduceOpLowering |
716 | //===----------------------------------------------------------------------===// |
717 | |
718 | struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> { |
719 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
720 | |
721 | LogicalResult |
722 | matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor, |
723 | ConversionPatternRewriter &rewriter) const override { |
724 | Location loc = op.getLoc(); |
725 | MLIRContext *context = rewriter.getContext(); |
726 | Type i32 = rewriter.getI32Type(); |
727 | Type i64 = rewriter.getI64Type(); |
728 | Type elemType = op.getSendbuf().getType().getElementType(); |
729 | |
730 | // ptrType `!llvm.ptr` |
731 | Type ptrType = LLVM::LLVMPointerType::get(context); |
732 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
733 | auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp); |
734 | auto [sendPtr, sendSize] = |
735 | getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType); |
736 | auto [recvPtr, recvSize] = |
737 | getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType); |
738 | |
739 | // If input and output are the same, request in-place operation. |
740 | if (adaptor.getSendbuf() == adaptor.getRecvbuf()) { |
741 | sendPtr = rewriter.create<LLVM::ConstantOp>( |
742 | loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace())); |
743 | sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr); |
744 | } |
745 | |
746 | Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); |
747 | Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp()); |
748 | Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); |
749 | |
750 | // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, |
751 | // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)' |
752 | auto funcType = LLVM::LLVMFunctionType::get( |
753 | i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(), |
754 | commWorld.getType()}); |
755 | // get or create function declaration: |
756 | LLVM::LLVMFuncOp funcDecl = |
757 | getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce" , funcType); |
758 | |
759 | // replace op with function call |
760 | auto funcCall = rewriter.create<LLVM::CallOp>( |
761 | loc, funcDecl, |
762 | ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld}); |
763 | |
764 | if (op.getRetval()) |
765 | rewriter.replaceOp(op, funcCall.getResult()); |
766 | else |
767 | rewriter.eraseOp(op: op); |
768 | |
769 | return success(); |
770 | } |
771 | }; |
772 | |
773 | //===----------------------------------------------------------------------===// |
774 | // ConvertToLLVMPatternInterface implementation |
775 | //===----------------------------------------------------------------------===// |
776 | |
777 | /// Implement the interface to convert Func to LLVM. |
778 | struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
779 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
780 | /// Hook for derived dialect interface to provide conversion patterns |
781 | /// and mark dialect legal for the conversion target. |
782 | void populateConvertToLLVMConversionPatterns( |
783 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
784 | RewritePatternSet &patterns) const final { |
785 | mpi::populateMPIToLLVMConversionPatterns(converter&: typeConverter, patterns); |
786 | } |
787 | }; |
788 | } // namespace |
789 | |
790 | //===----------------------------------------------------------------------===// |
791 | // Pattern Population |
792 | //===----------------------------------------------------------------------===// |
793 | |
794 | void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, |
795 | RewritePatternSet &patterns) { |
796 | // Using i64 as a portable, intermediate type for !mpi.comm. |
797 | // It would be nicer to somehow get the right type directly, but TLDI is not |
798 | // available here. |
799 | converter.addConversion(callback: [](mpi::CommType type) { |
800 | return IntegerType::get(type.getContext(), 64); |
801 | }); |
802 | patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering, |
803 | FinalizeOpLowering, InitOpLowering, SendOpLowering, |
804 | RecvOpLowering, AllReduceOpLowering>(arg&: converter); |
805 | } |
806 | |
807 | void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) { |
808 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, mpi::MPIDialect *dialect) { |
809 | dialect->addInterfaces<FuncToLLVMDialectInterface>(); |
810 | }); |
811 | } |
812 | |