1//===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9//
10// Copyright (C) by Argonne National Laboratory
11// See COPYRIGHT in top-level directory
12// of MPICH source repository.
13//
14
15#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
16#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17#include "mlir/Conversion/LLVMCommon/Pattern.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/DLTI/DLTI.h"
20#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
22#include "mlir/Dialect/MPI/IR/MPI.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include <memory>
25
26using namespace mlir;
27
28namespace {
29
30template <typename Op, typename... Args>
31static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc,
32 ConversionPatternRewriter &rewriter, StringRef name,
33 Args &&...args) {
34 Op ret;
35 if (!(ret = moduleOp.lookupSymbol<Op>(name))) {
36 ConversionPatternRewriter::InsertionGuard guard(rewriter);
37 rewriter.setInsertionPointToStart(moduleOp.getBody());
38 ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...);
39 }
40 return ret;
41}
42
43static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp,
44 const Location loc,
45 ConversionPatternRewriter &rewriter,
46 StringRef name,
47 LLVM::LLVMFunctionType type) {
48 return getOrDefineGlobal<LLVM::LLVMFuncOp>(
49 moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External);
50}
51
52std::pair<Value, Value> getRawPtrAndSize(const Location loc,
53 ConversionPatternRewriter &rewriter,
54 Value memRef, Type elType) {
55 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
56 Value dataPtr =
57 rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1);
58 Value offset = rewriter.create<LLVM::ExtractValueOp>(
59 loc, rewriter.getI64Type(), memRef, 2);
60 Value resPtr =
61 rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
62 Value size;
63 if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) {
64 size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef,
65 ArrayRef<int64_t>{3, 0});
66 size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size);
67 } else {
68 size = rewriter.create<arith::ConstantIntOp>(location: loc, args: 1, args: 32);
69 }
70 return {resPtr, size};
71}
72
73/// When lowering the mpi dialect to functions calls certain details
74/// differ between various MPI implementations. This class will provide
75/// these in a generic way, depending on the MPI implementation that got
76/// selected by the DLTI attribute on the module.
77class MPIImplTraits {
78 ModuleOp &moduleOp;
79
80public:
81 /// Instantiate a new MPIImplTraits object according to the DLTI attribute
82 /// on the given module. Default to MPICH if no attribute is present or
83 /// the value is unknown.
84 static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp);
85
86 explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {}
87
88 virtual ~MPIImplTraits() = default;
89
90 ModuleOp &getModuleOp() { return moduleOp; }
91
92 /// Gets or creates MPI_COMM_WORLD as a Value.
93 /// Different MPI implementations have different communicator types.
94 /// Using i64 as a portable, intermediate type.
95 /// Appropriate cast needs to take place before calling MPI functions.
96 virtual Value getCommWorld(const Location loc,
97 ConversionPatternRewriter &rewriter) = 0;
98
99 /// Type converter provides i64 type for communicator type.
100 /// Converts to native type, which might be ptr or int or whatever.
101 virtual Value castComm(const Location loc,
102 ConversionPatternRewriter &rewriter, Value comm) = 0;
103
104 /// Get the MPI_STATUS_IGNORE value (typically a pointer type).
105 virtual intptr_t getStatusIgnore() = 0;
106
107 /// Get the MPI_IN_PLACE value (void *).
108 virtual void *getInPlace() = 0;
109
110 /// Gets or creates an MPI datatype as a value which corresponds to the given
111 /// type.
112 virtual Value getDataType(const Location loc,
113 ConversionPatternRewriter &rewriter, Type type) = 0;
114
115 /// Gets or creates an MPI_Op value which corresponds to the given
116 /// enum value.
117 virtual Value getMPIOp(const Location loc,
118 ConversionPatternRewriter &rewriter,
119 mpi::MPI_OpClassEnum opAttr) = 0;
120};
121
122//===----------------------------------------------------------------------===//
123// Implementation details for MPICH ABI compatible MPI implementations
124//===----------------------------------------------------------------------===//
125
126class MPICHImplTraits : public MPIImplTraits {
127 static constexpr int MPI_FLOAT = 0x4c00040a;
128 static constexpr int MPI_DOUBLE = 0x4c00080b;
129 static constexpr int MPI_INT8_T = 0x4c000137;
130 static constexpr int MPI_INT16_T = 0x4c000238;
131 static constexpr int MPI_INT32_T = 0x4c000439;
132 static constexpr int MPI_INT64_T = 0x4c00083a;
133 static constexpr int MPI_UINT8_T = 0x4c00013b;
134 static constexpr int MPI_UINT16_T = 0x4c00023c;
135 static constexpr int MPI_UINT32_T = 0x4c00043d;
136 static constexpr int MPI_UINT64_T = 0x4c00083e;
137 static constexpr int MPI_MAX = 0x58000001;
138 static constexpr int MPI_MIN = 0x58000002;
139 static constexpr int MPI_SUM = 0x58000003;
140 static constexpr int MPI_PROD = 0x58000004;
141 static constexpr int MPI_LAND = 0x58000005;
142 static constexpr int MPI_BAND = 0x58000006;
143 static constexpr int MPI_LOR = 0x58000007;
144 static constexpr int MPI_BOR = 0x58000008;
145 static constexpr int MPI_LXOR = 0x58000009;
146 static constexpr int MPI_BXOR = 0x5800000a;
147 static constexpr int MPI_MINLOC = 0x5800000b;
148 static constexpr int MPI_MAXLOC = 0x5800000c;
149 static constexpr int MPI_REPLACE = 0x5800000d;
150 static constexpr int MPI_NO_OP = 0x5800000e;
151
152public:
153 using MPIImplTraits::MPIImplTraits;
154
155 ~MPICHImplTraits() override = default;
156
157 Value getCommWorld(const Location loc,
158 ConversionPatternRewriter &rewriter) override {
159 static constexpr int MPI_COMM_WORLD = 0x44000000;
160 return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(),
161 MPI_COMM_WORLD);
162 }
163
164 Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
165 Value comm) override {
166 return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm);
167 }
168
169 intptr_t getStatusIgnore() override { return 1; }
170
171 void *getInPlace() override { return reinterpret_cast<void *>(-1); }
172
173 Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
174 Type type) override {
175 int32_t mtype = 0;
176 if (type.isF32())
177 mtype = MPI_FLOAT;
178 else if (type.isF64())
179 mtype = MPI_DOUBLE;
180 else if (type.isInteger(width: 64) && !type.isUnsignedInteger())
181 mtype = MPI_INT64_T;
182 else if (type.isInteger(width: 64))
183 mtype = MPI_UINT64_T;
184 else if (type.isInteger(width: 32) && !type.isUnsignedInteger())
185 mtype = MPI_INT32_T;
186 else if (type.isInteger(width: 32))
187 mtype = MPI_UINT32_T;
188 else if (type.isInteger(width: 16) && !type.isUnsignedInteger())
189 mtype = MPI_INT16_T;
190 else if (type.isInteger(width: 16))
191 mtype = MPI_UINT16_T;
192 else if (type.isInteger(width: 8) && !type.isUnsignedInteger())
193 mtype = MPI_INT8_T;
194 else if (type.isInteger(width: 8))
195 mtype = MPI_UINT8_T;
196 else
197 assert(false && "unsupported type");
198 return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype);
199 }
200
201 Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
202 mpi::MPI_OpClassEnum opAttr) override {
203 int32_t op = MPI_NO_OP;
204 switch (opAttr) {
205 case mpi::MPI_OpClassEnum::MPI_OP_NULL:
206 op = MPI_NO_OP;
207 break;
208 case mpi::MPI_OpClassEnum::MPI_MAX:
209 op = MPI_MAX;
210 break;
211 case mpi::MPI_OpClassEnum::MPI_MIN:
212 op = MPI_MIN;
213 break;
214 case mpi::MPI_OpClassEnum::MPI_SUM:
215 op = MPI_SUM;
216 break;
217 case mpi::MPI_OpClassEnum::MPI_PROD:
218 op = MPI_PROD;
219 break;
220 case mpi::MPI_OpClassEnum::MPI_LAND:
221 op = MPI_LAND;
222 break;
223 case mpi::MPI_OpClassEnum::MPI_BAND:
224 op = MPI_BAND;
225 break;
226 case mpi::MPI_OpClassEnum::MPI_LOR:
227 op = MPI_LOR;
228 break;
229 case mpi::MPI_OpClassEnum::MPI_BOR:
230 op = MPI_BOR;
231 break;
232 case mpi::MPI_OpClassEnum::MPI_LXOR:
233 op = MPI_LXOR;
234 break;
235 case mpi::MPI_OpClassEnum::MPI_BXOR:
236 op = MPI_BXOR;
237 break;
238 case mpi::MPI_OpClassEnum::MPI_MINLOC:
239 op = MPI_MINLOC;
240 break;
241 case mpi::MPI_OpClassEnum::MPI_MAXLOC:
242 op = MPI_MAXLOC;
243 break;
244 case mpi::MPI_OpClassEnum::MPI_REPLACE:
245 op = MPI_REPLACE;
246 break;
247 }
248 return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op);
249 }
250};
251
252//===----------------------------------------------------------------------===//
253// Implementation details for OpenMPI
254//===----------------------------------------------------------------------===//
255class OMPIImplTraits : public MPIImplTraits {
256 LLVM::GlobalOp getOrDefineExternalStruct(const Location loc,
257 ConversionPatternRewriter &rewriter,
258 StringRef name,
259 LLVM::LLVMStructType type) {
260
261 return getOrDefineGlobal<LLVM::GlobalOp>(
262 getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false,
263 LLVM::Linkage::External, name,
264 /*value=*/Attribute(), /*alignment=*/0, 0);
265 }
266
267public:
268 using MPIImplTraits::MPIImplTraits;
269
270 ~OMPIImplTraits() override = default;
271
272 Value getCommWorld(const Location loc,
273 ConversionPatternRewriter &rewriter) override {
274 auto context = rewriter.getContext();
275 // get external opaque struct pointer type
276 auto commStructT =
277 LLVM::LLVMStructType::getOpaque("ompi_communicator_t", context);
278 StringRef name = "ompi_mpi_comm_world";
279
280 // make sure global op definition exists
281 getOrDefineExternalStruct(loc, rewriter, name, commStructT);
282
283 // get address of symbol
284 auto comm = rewriter.create<LLVM::AddressOfOp>(
285 loc, LLVM::LLVMPointerType::get(context),
286 SymbolRefAttr::get(context, name));
287 return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm);
288 }
289
290 Value castComm(const Location loc, ConversionPatternRewriter &rewriter,
291 Value comm) override {
292 return rewriter.create<LLVM::IntToPtrOp>(
293 loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm);
294 }
295
296 intptr_t getStatusIgnore() override { return 0; }
297
298 void *getInPlace() override { return reinterpret_cast<void *>(1); }
299
300 Value getDataType(const Location loc, ConversionPatternRewriter &rewriter,
301 Type type) override {
302 StringRef mtype;
303 if (type.isF32())
304 mtype = "ompi_mpi_float";
305 else if (type.isF64())
306 mtype = "ompi_mpi_double";
307 else if (type.isInteger(width: 64) && !type.isUnsignedInteger())
308 mtype = "ompi_mpi_int64_t";
309 else if (type.isInteger(width: 64))
310 mtype = "ompi_mpi_uint64_t";
311 else if (type.isInteger(width: 32) && !type.isUnsignedInteger())
312 mtype = "ompi_mpi_int32_t";
313 else if (type.isInteger(width: 32))
314 mtype = "ompi_mpi_uint32_t";
315 else if (type.isInteger(width: 16) && !type.isUnsignedInteger())
316 mtype = "ompi_mpi_int16_t";
317 else if (type.isInteger(width: 16))
318 mtype = "ompi_mpi_uint16_t";
319 else if (type.isInteger(width: 8) && !type.isUnsignedInteger())
320 mtype = "ompi_mpi_int8_t";
321 else if (type.isInteger(width: 8))
322 mtype = "ompi_mpi_uint8_t";
323 else
324 assert(false && "unsupported type");
325
326 auto context = rewriter.getContext();
327 // get external opaque struct pointer type
328 auto typeStructT =
329 LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t", context);
330 // make sure global op definition exists
331 getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT);
332 // get address of symbol
333 return rewriter.create<LLVM::AddressOfOp>(
334 loc, LLVM::LLVMPointerType::get(context),
335 SymbolRefAttr::get(context, mtype));
336 }
337
338 Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter,
339 mpi::MPI_OpClassEnum opAttr) override {
340 StringRef op;
341 switch (opAttr) {
342 case mpi::MPI_OpClassEnum::MPI_OP_NULL:
343 op = "ompi_mpi_no_op";
344 break;
345 case mpi::MPI_OpClassEnum::MPI_MAX:
346 op = "ompi_mpi_max";
347 break;
348 case mpi::MPI_OpClassEnum::MPI_MIN:
349 op = "ompi_mpi_min";
350 break;
351 case mpi::MPI_OpClassEnum::MPI_SUM:
352 op = "ompi_mpi_sum";
353 break;
354 case mpi::MPI_OpClassEnum::MPI_PROD:
355 op = "ompi_mpi_prod";
356 break;
357 case mpi::MPI_OpClassEnum::MPI_LAND:
358 op = "ompi_mpi_land";
359 break;
360 case mpi::MPI_OpClassEnum::MPI_BAND:
361 op = "ompi_mpi_band";
362 break;
363 case mpi::MPI_OpClassEnum::MPI_LOR:
364 op = "ompi_mpi_lor";
365 break;
366 case mpi::MPI_OpClassEnum::MPI_BOR:
367 op = "ompi_mpi_bor";
368 break;
369 case mpi::MPI_OpClassEnum::MPI_LXOR:
370 op = "ompi_mpi_lxor";
371 break;
372 case mpi::MPI_OpClassEnum::MPI_BXOR:
373 op = "ompi_mpi_bxor";
374 break;
375 case mpi::MPI_OpClassEnum::MPI_MINLOC:
376 op = "ompi_mpi_minloc";
377 break;
378 case mpi::MPI_OpClassEnum::MPI_MAXLOC:
379 op = "ompi_mpi_maxloc";
380 break;
381 case mpi::MPI_OpClassEnum::MPI_REPLACE:
382 op = "ompi_mpi_replace";
383 break;
384 }
385 auto context = rewriter.getContext();
386 // get external opaque struct pointer type
387 auto opStructT =
388 LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t", context);
389 // make sure global op definition exists
390 getOrDefineExternalStruct(loc, rewriter, op, opStructT);
391 // get address of symbol
392 return rewriter.create<LLVM::AddressOfOp>(
393 loc, LLVM::LLVMPointerType::get(context),
394 SymbolRefAttr::get(context, op));
395 }
396};
397
398std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) {
399 auto attr = dlti::query(*&moduleOp, {"MPI:Implementation"}, true);
400 if (failed(attr))
401 return std::make_unique<MPICHImplTraits>(args&: moduleOp);
402 auto strAttr = dyn_cast<StringAttr>(attr.value());
403 if (strAttr && strAttr.getValue() == "OpenMPI")
404 return std::make_unique<OMPIImplTraits>(args&: moduleOp);
405 if (!strAttr || strAttr.getValue() != "MPICH")
406 moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI ("
407 << strAttr.getValue() << "), defaulting to MPICH";
408 return std::make_unique<MPICHImplTraits>(args&: moduleOp);
409}
410
411//===----------------------------------------------------------------------===//
412// InitOpLowering
413//===----------------------------------------------------------------------===//
414
415struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> {
416 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
417
418 LogicalResult
419 matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor,
420 ConversionPatternRewriter &rewriter) const override {
421 Location loc = op.getLoc();
422
423 // ptrType `!llvm.ptr`
424 Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
425
426 // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr`
427 auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType);
428 Value llvmnull = nullPtrOp.getRes();
429
430 // grab a reference to the global module op:
431 auto moduleOp = op->getParentOfType<ModuleOp>();
432
433 // LLVM Function type representing `i32 MPI_Init(ptr, ptr)`
434 auto initFuncType =
435 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
436 // get or create function declaration:
437 LLVM::LLVMFuncOp initDecl =
438 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init", initFuncType);
439
440 // replace init with function call
441 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl,
442 ValueRange{llvmnull, llvmnull});
443
444 return success();
445 }
446};
447
448//===----------------------------------------------------------------------===//
449// FinalizeOpLowering
450//===----------------------------------------------------------------------===//
451
452struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> {
453 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
454
455 LogicalResult
456 matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor,
457 ConversionPatternRewriter &rewriter) const override {
458 // get loc
459 Location loc = op.getLoc();
460
461 // grab a reference to the global module op:
462 auto moduleOp = op->getParentOfType<ModuleOp>();
463
464 // LLVM Function type representing `i32 MPI_Finalize()`
465 auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {});
466 // get or create function declaration:
467 LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
468 moduleOp, loc, rewriter, "MPI_Finalize", initFuncType);
469
470 // replace init with function call
471 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{});
472
473 return success();
474 }
475};
476
477//===----------------------------------------------------------------------===//
478// CommWorldOpLowering
479//===----------------------------------------------------------------------===//
480
481struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> {
482 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
483
484 LogicalResult
485 matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor,
486 ConversionPatternRewriter &rewriter) const override {
487 // grab a reference to the global module op:
488 auto moduleOp = op->getParentOfType<ModuleOp>();
489 auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp);
490 // get MPI_COMM_WORLD
491 rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter));
492
493 return success();
494 }
495};
496
497//===----------------------------------------------------------------------===//
498// CommSplitOpLowering
499//===----------------------------------------------------------------------===//
500
501struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
502 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
503
504 LogicalResult
505 matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor,
506 ConversionPatternRewriter &rewriter) const override {
507 // grab a reference to the global module op:
508 auto moduleOp = op->getParentOfType<ModuleOp>();
509 auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp);
510 Type i32 = rewriter.getI32Type();
511 Type ptrType = LLVM::LLVMPointerType::get(op->getContext());
512 Location loc = op.getLoc();
513
514 // get communicator
515 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
516 auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
517 auto outPtr =
518 rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one);
519
520 // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm)
521 auto funcType =
522 LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType});
523 // get or create function declaration:
524 LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter,
525 "MPI_Comm_split", funcType);
526
527 auto callOp = rewriter.create<LLVM::CallOp>(
528 loc, funcDecl,
529 ValueRange{comm, adaptor.getColor(), adaptor.getKey(),
530 outPtr.getRes()});
531
532 // load the communicator into a register
533 Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult());
534 res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res);
535
536 // if retval is checked, replace uses of retval with the results from the
537 // call op
538 SmallVector<Value> replacements;
539 if (op.getRetval())
540 replacements.push_back(Elt: callOp.getResult());
541
542 // replace op
543 replacements.push_back(Elt: res);
544 rewriter.replaceOp(op, replacements);
545
546 return success();
547 }
548};
549
550//===----------------------------------------------------------------------===//
551// CommRankOpLowering
552//===----------------------------------------------------------------------===//
553
554struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> {
555 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
556
557 LogicalResult
558 matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter) const override {
560 // get some helper vars
561 Location loc = op.getLoc();
562 MLIRContext *context = rewriter.getContext();
563 Type i32 = rewriter.getI32Type();
564
565 // ptrType `!llvm.ptr`
566 Type ptrType = LLVM::LLVMPointerType::get(context);
567
568 // grab a reference to the global module op:
569 auto moduleOp = op->getParentOfType<ModuleOp>();
570
571 auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp);
572 // get communicator
573 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
574
575 // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)`
576 auto rankFuncType =
577 LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType});
578 // get or create function declaration:
579 LLVM::LLVMFuncOp initDecl = getOrDefineFunction(
580 moduleOp, loc, rewriter, "MPI_Comm_rank", rankFuncType);
581
582 // replace with function call
583 auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1);
584 auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one);
585 auto callOp = rewriter.create<LLVM::CallOp>(
586 loc, initDecl, ValueRange{comm, rankptr.getRes()});
587
588 // load the rank into a register
589 auto loadedRank =
590 rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult());
591
592 // if retval is checked, replace uses of retval with the results from the
593 // call op
594 SmallVector<Value> replacements;
595 if (op.getRetval())
596 replacements.push_back(Elt: callOp.getResult());
597
598 // replace all uses, then erase op
599 replacements.push_back(Elt: loadedRank.getRes());
600 rewriter.replaceOp(op, replacements);
601
602 return success();
603 }
604};
605
606//===----------------------------------------------------------------------===//
607// SendOpLowering
608//===----------------------------------------------------------------------===//
609
610struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> {
611 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
612
613 LogicalResult
614 matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor,
615 ConversionPatternRewriter &rewriter) const override {
616 // get some helper vars
617 Location loc = op.getLoc();
618 MLIRContext *context = rewriter.getContext();
619 Type i32 = rewriter.getI32Type();
620 Type elemType = op.getRef().getType().getElementType();
621
622 // ptrType `!llvm.ptr`
623 Type ptrType = LLVM::LLVMPointerType::get(context);
624
625 // grab a reference to the global module op:
626 auto moduleOp = op->getParentOfType<ModuleOp>();
627
628 // get MPI_COMM_WORLD, dataType and pointer
629 auto [dataPtr, size] =
630 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
631 auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp);
632 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
633 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
634
635 // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst,
636 // tag, comm)`
637 auto funcType = LLVM::LLVMFunctionType::get(
638 i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()});
639 // get or create function declaration:
640 LLVM::LLVMFuncOp funcDecl =
641 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send", funcType);
642
643 // replace op with function call
644 auto funcCall = rewriter.create<LLVM::CallOp>(
645 loc, funcDecl,
646 ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(),
647 comm});
648 if (op.getRetval())
649 rewriter.replaceOp(op, funcCall.getResult());
650 else
651 rewriter.eraseOp(op: op);
652
653 return success();
654 }
655};
656
657//===----------------------------------------------------------------------===//
658// RecvOpLowering
659//===----------------------------------------------------------------------===//
660
661struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> {
662 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
663
664 LogicalResult
665 matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor,
666 ConversionPatternRewriter &rewriter) const override {
667 // get some helper vars
668 Location loc = op.getLoc();
669 MLIRContext *context = rewriter.getContext();
670 Type i32 = rewriter.getI32Type();
671 Type i64 = rewriter.getI64Type();
672 Type elemType = op.getRef().getType().getElementType();
673
674 // ptrType `!llvm.ptr`
675 Type ptrType = LLVM::LLVMPointerType::get(context);
676
677 // grab a reference to the global module op:
678 auto moduleOp = op->getParentOfType<ModuleOp>();
679
680 // get MPI_COMM_WORLD, dataType, status_ignore and pointer
681 auto [dataPtr, size] =
682 getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType);
683 auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp);
684 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
685 Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
686 Value statusIgnore = rewriter.create<LLVM::ConstantOp>(
687 loc, i64, mpiTraits->getStatusIgnore());
688 statusIgnore =
689 rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore);
690
691 // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst,
692 // tag, comm)`
693 auto funcType =
694 LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32,
695 i32, comm.getType(), ptrType});
696 // get or create function declaration:
697 LLVM::LLVMFuncOp funcDecl =
698 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv", funcType);
699
700 // replace op with function call
701 auto funcCall = rewriter.create<LLVM::CallOp>(
702 loc, funcDecl,
703 ValueRange{dataPtr, size, dataType, adaptor.getSource(),
704 adaptor.getTag(), comm, statusIgnore});
705 if (op.getRetval())
706 rewriter.replaceOp(op, funcCall.getResult());
707 else
708 rewriter.eraseOp(op: op);
709
710 return success();
711 }
712};
713
714//===----------------------------------------------------------------------===//
715// AllReduceOpLowering
716//===----------------------------------------------------------------------===//
717
718struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
719 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
720
721 LogicalResult
722 matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor,
723 ConversionPatternRewriter &rewriter) const override {
724 Location loc = op.getLoc();
725 MLIRContext *context = rewriter.getContext();
726 Type i32 = rewriter.getI32Type();
727 Type i64 = rewriter.getI64Type();
728 Type elemType = op.getSendbuf().getType().getElementType();
729
730 // ptrType `!llvm.ptr`
731 Type ptrType = LLVM::LLVMPointerType::get(context);
732 auto moduleOp = op->getParentOfType<ModuleOp>();
733 auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp);
734 auto [sendPtr, sendSize] =
735 getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType);
736 auto [recvPtr, recvSize] =
737 getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType);
738
739 // If input and output are the same, request in-place operation.
740 if (adaptor.getSendbuf() == adaptor.getRecvbuf()) {
741 sendPtr = rewriter.create<LLVM::ConstantOp>(
742 loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace()));
743 sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr);
744 }
745
746 Value dataType = mpiTraits->getDataType(loc, rewriter, elemType);
747 Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp());
748 Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm());
749
750 // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count,
751 // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)'
752 auto funcType = LLVM::LLVMFunctionType::get(
753 i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(),
754 commWorld.getType()});
755 // get or create function declaration:
756 LLVM::LLVMFuncOp funcDecl =
757 getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce", funcType);
758
759 // replace op with function call
760 auto funcCall = rewriter.create<LLVM::CallOp>(
761 loc, funcDecl,
762 ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld});
763
764 if (op.getRetval())
765 rewriter.replaceOp(op, funcCall.getResult());
766 else
767 rewriter.eraseOp(op: op);
768
769 return success();
770 }
771};
772
773//===----------------------------------------------------------------------===//
774// ConvertToLLVMPatternInterface implementation
775//===----------------------------------------------------------------------===//
776
777/// Implement the interface to convert Func to LLVM.
778struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
779 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
780 /// Hook for derived dialect interface to provide conversion patterns
781 /// and mark dialect legal for the conversion target.
782 void populateConvertToLLVMConversionPatterns(
783 ConversionTarget &target, LLVMTypeConverter &typeConverter,
784 RewritePatternSet &patterns) const final {
785 mpi::populateMPIToLLVMConversionPatterns(converter&: typeConverter, patterns);
786 }
787};
788} // namespace
789
790//===----------------------------------------------------------------------===//
791// Pattern Population
792//===----------------------------------------------------------------------===//
793
794void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter,
795 RewritePatternSet &patterns) {
796 // Using i64 as a portable, intermediate type for !mpi.comm.
797 // It would be nicer to somehow get the right type directly, but TLDI is not
798 // available here.
799 converter.addConversion(callback: [](mpi::CommType type) {
800 return IntegerType::get(type.getContext(), 64);
801 });
802 patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering,
803 FinalizeOpLowering, InitOpLowering, SendOpLowering,
804 RecvOpLowering, AllReduceOpLowering>(arg&: converter);
805}
806
807void mpi::registerConvertMPIToLLVMInterface(DialectRegistry &registry) {
808 registry.addExtension(extensionFn: +[](MLIRContext *ctx, mpi::MPIDialect *dialect) {
809 dialect->addInterfaces<FuncToLLVMDialectInterface>();
810 });
811}
812

source code of mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp