| 1 | //===- MPIToLLVM.cpp - MPI to LLVM dialect conversion ---------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | // |
| 10 | // Copyright (C) by Argonne National Laboratory |
| 11 | // See COPYRIGHT in top-level directory |
| 12 | // of MPICH source repository. |
| 13 | // |
| 14 | |
| 15 | #include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" |
| 16 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
| 17 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| 18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 19 | #include "mlir/Dialect/DLTI/DLTI.h" |
| 20 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 21 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| 22 | #include "mlir/Dialect/MPI/IR/MPI.h" |
| 23 | #include "mlir/Transforms/DialectConversion.h" |
| 24 | #include <memory> |
| 25 | |
| 26 | using namespace mlir; |
| 27 | |
| 28 | namespace { |
| 29 | |
| 30 | template <typename Op, typename... Args> |
| 31 | static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc, |
| 32 | ConversionPatternRewriter &rewriter, StringRef name, |
| 33 | Args &&...args) { |
| 34 | Op ret; |
| 35 | if (!(ret = moduleOp.lookupSymbol<Op>(name))) { |
| 36 | ConversionPatternRewriter::InsertionGuard guard(rewriter); |
| 37 | rewriter.setInsertionPointToStart(moduleOp.getBody()); |
| 38 | ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...); |
| 39 | } |
| 40 | return ret; |
| 41 | } |
| 42 | |
| 43 | static LLVM::LLVMFuncOp getOrDefineFunction(ModuleOp &moduleOp, |
| 44 | const Location loc, |
| 45 | ConversionPatternRewriter &rewriter, |
| 46 | StringRef name, |
| 47 | LLVM::LLVMFunctionType type) { |
| 48 | return getOrDefineGlobal<LLVM::LLVMFuncOp>( |
| 49 | moduleOp, loc, rewriter, name, name, type, LLVM::Linkage::External); |
| 50 | } |
| 51 | |
| 52 | std::pair<Value, Value> getRawPtrAndSize(const Location loc, |
| 53 | ConversionPatternRewriter &rewriter, |
| 54 | Value memRef, Type elType) { |
| 55 | Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
| 56 | Value dataPtr = |
| 57 | rewriter.create<LLVM::ExtractValueOp>(loc, ptrType, memRef, 1); |
| 58 | Value offset = rewriter.create<LLVM::ExtractValueOp>( |
| 59 | loc, rewriter.getI64Type(), memRef, 2); |
| 60 | Value resPtr = |
| 61 | rewriter.create<LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset); |
| 62 | Value size; |
| 63 | if (cast<LLVM::LLVMStructType>(memRef.getType()).getBody().size() > 3) { |
| 64 | size = rewriter.create<LLVM::ExtractValueOp>(loc, memRef, |
| 65 | ArrayRef<int64_t>{3, 0}); |
| 66 | size = rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), size); |
| 67 | } else { |
| 68 | size = rewriter.create<arith::ConstantIntOp>(location: loc, args: 1, args: 32); |
| 69 | } |
| 70 | return {resPtr, size}; |
| 71 | } |
| 72 | |
| 73 | /// When lowering the mpi dialect to functions calls certain details |
| 74 | /// differ between various MPI implementations. This class will provide |
| 75 | /// these in a generic way, depending on the MPI implementation that got |
| 76 | /// selected by the DLTI attribute on the module. |
| 77 | class MPIImplTraits { |
| 78 | ModuleOp &moduleOp; |
| 79 | |
| 80 | public: |
| 81 | /// Instantiate a new MPIImplTraits object according to the DLTI attribute |
| 82 | /// on the given module. Default to MPICH if no attribute is present or |
| 83 | /// the value is unknown. |
| 84 | static std::unique_ptr<MPIImplTraits> get(ModuleOp &moduleOp); |
| 85 | |
| 86 | explicit MPIImplTraits(ModuleOp &moduleOp) : moduleOp(moduleOp) {} |
| 87 | |
| 88 | virtual ~MPIImplTraits() = default; |
| 89 | |
| 90 | ModuleOp &getModuleOp() { return moduleOp; } |
| 91 | |
| 92 | /// Gets or creates MPI_COMM_WORLD as a Value. |
| 93 | /// Different MPI implementations have different communicator types. |
| 94 | /// Using i64 as a portable, intermediate type. |
| 95 | /// Appropriate cast needs to take place before calling MPI functions. |
| 96 | virtual Value getCommWorld(const Location loc, |
| 97 | ConversionPatternRewriter &rewriter) = 0; |
| 98 | |
| 99 | /// Type converter provides i64 type for communicator type. |
| 100 | /// Converts to native type, which might be ptr or int or whatever. |
| 101 | virtual Value castComm(const Location loc, |
| 102 | ConversionPatternRewriter &rewriter, Value comm) = 0; |
| 103 | |
| 104 | /// Get the MPI_STATUS_IGNORE value (typically a pointer type). |
| 105 | virtual intptr_t getStatusIgnore() = 0; |
| 106 | |
| 107 | /// Get the MPI_IN_PLACE value (void *). |
| 108 | virtual void *getInPlace() = 0; |
| 109 | |
| 110 | /// Gets or creates an MPI datatype as a value which corresponds to the given |
| 111 | /// type. |
| 112 | virtual Value getDataType(const Location loc, |
| 113 | ConversionPatternRewriter &rewriter, Type type) = 0; |
| 114 | |
| 115 | /// Gets or creates an MPI_Op value which corresponds to the given |
| 116 | /// enum value. |
| 117 | virtual Value getMPIOp(const Location loc, |
| 118 | ConversionPatternRewriter &rewriter, |
| 119 | mpi::MPI_OpClassEnum opAttr) = 0; |
| 120 | }; |
| 121 | |
| 122 | //===----------------------------------------------------------------------===// |
| 123 | // Implementation details for MPICH ABI compatible MPI implementations |
| 124 | //===----------------------------------------------------------------------===// |
| 125 | |
| 126 | class MPICHImplTraits : public MPIImplTraits { |
| 127 | static constexpr int MPI_FLOAT = 0x4c00040a; |
| 128 | static constexpr int MPI_DOUBLE = 0x4c00080b; |
| 129 | static constexpr int MPI_INT8_T = 0x4c000137; |
| 130 | static constexpr int MPI_INT16_T = 0x4c000238; |
| 131 | static constexpr int MPI_INT32_T = 0x4c000439; |
| 132 | static constexpr int MPI_INT64_T = 0x4c00083a; |
| 133 | static constexpr int MPI_UINT8_T = 0x4c00013b; |
| 134 | static constexpr int MPI_UINT16_T = 0x4c00023c; |
| 135 | static constexpr int MPI_UINT32_T = 0x4c00043d; |
| 136 | static constexpr int MPI_UINT64_T = 0x4c00083e; |
| 137 | static constexpr int MPI_MAX = 0x58000001; |
| 138 | static constexpr int MPI_MIN = 0x58000002; |
| 139 | static constexpr int MPI_SUM = 0x58000003; |
| 140 | static constexpr int MPI_PROD = 0x58000004; |
| 141 | static constexpr int MPI_LAND = 0x58000005; |
| 142 | static constexpr int MPI_BAND = 0x58000006; |
| 143 | static constexpr int MPI_LOR = 0x58000007; |
| 144 | static constexpr int MPI_BOR = 0x58000008; |
| 145 | static constexpr int MPI_LXOR = 0x58000009; |
| 146 | static constexpr int MPI_BXOR = 0x5800000a; |
| 147 | static constexpr int MPI_MINLOC = 0x5800000b; |
| 148 | static constexpr int MPI_MAXLOC = 0x5800000c; |
| 149 | static constexpr int MPI_REPLACE = 0x5800000d; |
| 150 | static constexpr int MPI_NO_OP = 0x5800000e; |
| 151 | |
| 152 | public: |
| 153 | using MPIImplTraits::MPIImplTraits; |
| 154 | |
| 155 | ~MPICHImplTraits() override = default; |
| 156 | |
| 157 | Value getCommWorld(const Location loc, |
| 158 | ConversionPatternRewriter &rewriter) override { |
| 159 | static constexpr int MPI_COMM_WORLD = 0x44000000; |
| 160 | return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), |
| 161 | MPI_COMM_WORLD); |
| 162 | } |
| 163 | |
| 164 | Value castComm(const Location loc, ConversionPatternRewriter &rewriter, |
| 165 | Value comm) override { |
| 166 | return rewriter.create<LLVM::TruncOp>(loc, rewriter.getI32Type(), comm); |
| 167 | } |
| 168 | |
| 169 | intptr_t getStatusIgnore() override { return 1; } |
| 170 | |
| 171 | void *getInPlace() override { return reinterpret_cast<void *>(-1); } |
| 172 | |
| 173 | Value getDataType(const Location loc, ConversionPatternRewriter &rewriter, |
| 174 | Type type) override { |
| 175 | int32_t mtype = 0; |
| 176 | if (type.isF32()) |
| 177 | mtype = MPI_FLOAT; |
| 178 | else if (type.isF64()) |
| 179 | mtype = MPI_DOUBLE; |
| 180 | else if (type.isInteger(width: 64) && !type.isUnsignedInteger()) |
| 181 | mtype = MPI_INT64_T; |
| 182 | else if (type.isInteger(width: 64)) |
| 183 | mtype = MPI_UINT64_T; |
| 184 | else if (type.isInteger(width: 32) && !type.isUnsignedInteger()) |
| 185 | mtype = MPI_INT32_T; |
| 186 | else if (type.isInteger(width: 32)) |
| 187 | mtype = MPI_UINT32_T; |
| 188 | else if (type.isInteger(width: 16) && !type.isUnsignedInteger()) |
| 189 | mtype = MPI_INT16_T; |
| 190 | else if (type.isInteger(width: 16)) |
| 191 | mtype = MPI_UINT16_T; |
| 192 | else if (type.isInteger(width: 8) && !type.isUnsignedInteger()) |
| 193 | mtype = MPI_INT8_T; |
| 194 | else if (type.isInteger(width: 8)) |
| 195 | mtype = MPI_UINT8_T; |
| 196 | else |
| 197 | assert(false && "unsupported type" ); |
| 198 | return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), mtype); |
| 199 | } |
| 200 | |
| 201 | Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, |
| 202 | mpi::MPI_OpClassEnum opAttr) override { |
| 203 | int32_t op = MPI_NO_OP; |
| 204 | switch (opAttr) { |
| 205 | case mpi::MPI_OpClassEnum::MPI_OP_NULL: |
| 206 | op = MPI_NO_OP; |
| 207 | break; |
| 208 | case mpi::MPI_OpClassEnum::MPI_MAX: |
| 209 | op = MPI_MAX; |
| 210 | break; |
| 211 | case mpi::MPI_OpClassEnum::MPI_MIN: |
| 212 | op = MPI_MIN; |
| 213 | break; |
| 214 | case mpi::MPI_OpClassEnum::MPI_SUM: |
| 215 | op = MPI_SUM; |
| 216 | break; |
| 217 | case mpi::MPI_OpClassEnum::MPI_PROD: |
| 218 | op = MPI_PROD; |
| 219 | break; |
| 220 | case mpi::MPI_OpClassEnum::MPI_LAND: |
| 221 | op = MPI_LAND; |
| 222 | break; |
| 223 | case mpi::MPI_OpClassEnum::MPI_BAND: |
| 224 | op = MPI_BAND; |
| 225 | break; |
| 226 | case mpi::MPI_OpClassEnum::MPI_LOR: |
| 227 | op = MPI_LOR; |
| 228 | break; |
| 229 | case mpi::MPI_OpClassEnum::MPI_BOR: |
| 230 | op = MPI_BOR; |
| 231 | break; |
| 232 | case mpi::MPI_OpClassEnum::MPI_LXOR: |
| 233 | op = MPI_LXOR; |
| 234 | break; |
| 235 | case mpi::MPI_OpClassEnum::MPI_BXOR: |
| 236 | op = MPI_BXOR; |
| 237 | break; |
| 238 | case mpi::MPI_OpClassEnum::MPI_MINLOC: |
| 239 | op = MPI_MINLOC; |
| 240 | break; |
| 241 | case mpi::MPI_OpClassEnum::MPI_MAXLOC: |
| 242 | op = MPI_MAXLOC; |
| 243 | break; |
| 244 | case mpi::MPI_OpClassEnum::MPI_REPLACE: |
| 245 | op = MPI_REPLACE; |
| 246 | break; |
| 247 | } |
| 248 | return rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), op); |
| 249 | } |
| 250 | }; |
| 251 | |
| 252 | //===----------------------------------------------------------------------===// |
| 253 | // Implementation details for OpenMPI |
| 254 | //===----------------------------------------------------------------------===// |
| 255 | class OMPIImplTraits : public MPIImplTraits { |
| 256 | LLVM::GlobalOp getOrDefineExternalStruct(const Location loc, |
| 257 | ConversionPatternRewriter &rewriter, |
| 258 | StringRef name, |
| 259 | LLVM::LLVMStructType type) { |
| 260 | |
| 261 | return getOrDefineGlobal<LLVM::GlobalOp>( |
| 262 | getModuleOp(), loc, rewriter, name, type, /*isConstant=*/false, |
| 263 | LLVM::Linkage::External, name, |
| 264 | /*value=*/Attribute(), /*alignment=*/0, 0); |
| 265 | } |
| 266 | |
| 267 | public: |
| 268 | using MPIImplTraits::MPIImplTraits; |
| 269 | |
| 270 | ~OMPIImplTraits() override = default; |
| 271 | |
| 272 | Value getCommWorld(const Location loc, |
| 273 | ConversionPatternRewriter &rewriter) override { |
| 274 | auto context = rewriter.getContext(); |
| 275 | // get external opaque struct pointer type |
| 276 | auto commStructT = |
| 277 | LLVM::LLVMStructType::getOpaque("ompi_communicator_t" , context); |
| 278 | StringRef name = "ompi_mpi_comm_world" ; |
| 279 | |
| 280 | // make sure global op definition exists |
| 281 | getOrDefineExternalStruct(loc, rewriter, name, commStructT); |
| 282 | |
| 283 | // get address of symbol |
| 284 | auto comm = rewriter.create<LLVM::AddressOfOp>( |
| 285 | loc, LLVM::LLVMPointerType::get(context), |
| 286 | SymbolRefAttr::get(context, name)); |
| 287 | return rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), comm); |
| 288 | } |
| 289 | |
| 290 | Value castComm(const Location loc, ConversionPatternRewriter &rewriter, |
| 291 | Value comm) override { |
| 292 | return rewriter.create<LLVM::IntToPtrOp>( |
| 293 | loc, LLVM::LLVMPointerType::get(rewriter.getContext()), comm); |
| 294 | } |
| 295 | |
| 296 | intptr_t getStatusIgnore() override { return 0; } |
| 297 | |
| 298 | void *getInPlace() override { return reinterpret_cast<void *>(1); } |
| 299 | |
| 300 | Value getDataType(const Location loc, ConversionPatternRewriter &rewriter, |
| 301 | Type type) override { |
| 302 | StringRef mtype; |
| 303 | if (type.isF32()) |
| 304 | mtype = "ompi_mpi_float" ; |
| 305 | else if (type.isF64()) |
| 306 | mtype = "ompi_mpi_double" ; |
| 307 | else if (type.isInteger(width: 64) && !type.isUnsignedInteger()) |
| 308 | mtype = "ompi_mpi_int64_t" ; |
| 309 | else if (type.isInteger(width: 64)) |
| 310 | mtype = "ompi_mpi_uint64_t" ; |
| 311 | else if (type.isInteger(width: 32) && !type.isUnsignedInteger()) |
| 312 | mtype = "ompi_mpi_int32_t" ; |
| 313 | else if (type.isInteger(width: 32)) |
| 314 | mtype = "ompi_mpi_uint32_t" ; |
| 315 | else if (type.isInteger(width: 16) && !type.isUnsignedInteger()) |
| 316 | mtype = "ompi_mpi_int16_t" ; |
| 317 | else if (type.isInteger(width: 16)) |
| 318 | mtype = "ompi_mpi_uint16_t" ; |
| 319 | else if (type.isInteger(width: 8) && !type.isUnsignedInteger()) |
| 320 | mtype = "ompi_mpi_int8_t" ; |
| 321 | else if (type.isInteger(width: 8)) |
| 322 | mtype = "ompi_mpi_uint8_t" ; |
| 323 | else |
| 324 | assert(false && "unsupported type" ); |
| 325 | |
| 326 | auto context = rewriter.getContext(); |
| 327 | // get external opaque struct pointer type |
| 328 | auto typeStructT = |
| 329 | LLVM::LLVMStructType::getOpaque("ompi_predefined_datatype_t" , context); |
| 330 | // make sure global op definition exists |
| 331 | getOrDefineExternalStruct(loc, rewriter, mtype, typeStructT); |
| 332 | // get address of symbol |
| 333 | return rewriter.create<LLVM::AddressOfOp>( |
| 334 | loc, LLVM::LLVMPointerType::get(context), |
| 335 | SymbolRefAttr::get(context, mtype)); |
| 336 | } |
| 337 | |
| 338 | Value getMPIOp(const Location loc, ConversionPatternRewriter &rewriter, |
| 339 | mpi::MPI_OpClassEnum opAttr) override { |
| 340 | StringRef op; |
| 341 | switch (opAttr) { |
| 342 | case mpi::MPI_OpClassEnum::MPI_OP_NULL: |
| 343 | op = "ompi_mpi_no_op" ; |
| 344 | break; |
| 345 | case mpi::MPI_OpClassEnum::MPI_MAX: |
| 346 | op = "ompi_mpi_max" ; |
| 347 | break; |
| 348 | case mpi::MPI_OpClassEnum::MPI_MIN: |
| 349 | op = "ompi_mpi_min" ; |
| 350 | break; |
| 351 | case mpi::MPI_OpClassEnum::MPI_SUM: |
| 352 | op = "ompi_mpi_sum" ; |
| 353 | break; |
| 354 | case mpi::MPI_OpClassEnum::MPI_PROD: |
| 355 | op = "ompi_mpi_prod" ; |
| 356 | break; |
| 357 | case mpi::MPI_OpClassEnum::MPI_LAND: |
| 358 | op = "ompi_mpi_land" ; |
| 359 | break; |
| 360 | case mpi::MPI_OpClassEnum::MPI_BAND: |
| 361 | op = "ompi_mpi_band" ; |
| 362 | break; |
| 363 | case mpi::MPI_OpClassEnum::MPI_LOR: |
| 364 | op = "ompi_mpi_lor" ; |
| 365 | break; |
| 366 | case mpi::MPI_OpClassEnum::MPI_BOR: |
| 367 | op = "ompi_mpi_bor" ; |
| 368 | break; |
| 369 | case mpi::MPI_OpClassEnum::MPI_LXOR: |
| 370 | op = "ompi_mpi_lxor" ; |
| 371 | break; |
| 372 | case mpi::MPI_OpClassEnum::MPI_BXOR: |
| 373 | op = "ompi_mpi_bxor" ; |
| 374 | break; |
| 375 | case mpi::MPI_OpClassEnum::MPI_MINLOC: |
| 376 | op = "ompi_mpi_minloc" ; |
| 377 | break; |
| 378 | case mpi::MPI_OpClassEnum::MPI_MAXLOC: |
| 379 | op = "ompi_mpi_maxloc" ; |
| 380 | break; |
| 381 | case mpi::MPI_OpClassEnum::MPI_REPLACE: |
| 382 | op = "ompi_mpi_replace" ; |
| 383 | break; |
| 384 | } |
| 385 | auto context = rewriter.getContext(); |
| 386 | // get external opaque struct pointer type |
| 387 | auto opStructT = |
| 388 | LLVM::LLVMStructType::getOpaque("ompi_predefined_op_t" , context); |
| 389 | // make sure global op definition exists |
| 390 | getOrDefineExternalStruct(loc, rewriter, op, opStructT); |
| 391 | // get address of symbol |
| 392 | return rewriter.create<LLVM::AddressOfOp>( |
| 393 | loc, LLVM::LLVMPointerType::get(context), |
| 394 | SymbolRefAttr::get(context, op)); |
| 395 | } |
| 396 | }; |
| 397 | |
| 398 | std::unique_ptr<MPIImplTraits> MPIImplTraits::get(ModuleOp &moduleOp) { |
| 399 | auto attr = dlti::query(*&moduleOp, {"MPI:Implementation" }, true); |
| 400 | if (failed(attr)) |
| 401 | return std::make_unique<MPICHImplTraits>(args&: moduleOp); |
| 402 | auto strAttr = dyn_cast<StringAttr>(attr.value()); |
| 403 | if (strAttr && strAttr.getValue() == "OpenMPI" ) |
| 404 | return std::make_unique<OMPIImplTraits>(args&: moduleOp); |
| 405 | if (!strAttr || strAttr.getValue() != "MPICH" ) |
| 406 | moduleOp.emitWarning() << "Unknown \"MPI:Implementation\" value in DLTI (" |
| 407 | << strAttr.getValue() << "), defaulting to MPICH" ; |
| 408 | return std::make_unique<MPICHImplTraits>(args&: moduleOp); |
| 409 | } |
| 410 | |
| 411 | //===----------------------------------------------------------------------===// |
| 412 | // InitOpLowering |
| 413 | //===----------------------------------------------------------------------===// |
| 414 | |
| 415 | struct InitOpLowering : public ConvertOpToLLVMPattern<mpi::InitOp> { |
| 416 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 417 | |
| 418 | LogicalResult |
| 419 | matchAndRewrite(mpi::InitOp op, OpAdaptor adaptor, |
| 420 | ConversionPatternRewriter &rewriter) const override { |
| 421 | Location loc = op.getLoc(); |
| 422 | |
| 423 | // ptrType `!llvm.ptr` |
| 424 | Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext()); |
| 425 | |
| 426 | // instantiate nullptr `%nullptr = llvm.mlir.zero : !llvm.ptr` |
| 427 | auto nullPtrOp = rewriter.create<LLVM::ZeroOp>(loc, ptrType); |
| 428 | Value llvmnull = nullPtrOp.getRes(); |
| 429 | |
| 430 | // grab a reference to the global module op: |
| 431 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
| 432 | |
| 433 | // LLVM Function type representing `i32 MPI_Init(ptr, ptr)` |
| 434 | auto initFuncType = |
| 435 | LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType}); |
| 436 | // get or create function declaration: |
| 437 | LLVM::LLVMFuncOp initDecl = |
| 438 | getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Init" , initFuncType); |
| 439 | |
| 440 | // replace init with function call |
| 441 | rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, |
| 442 | ValueRange{llvmnull, llvmnull}); |
| 443 | |
| 444 | return success(); |
| 445 | } |
| 446 | }; |
| 447 | |
| 448 | //===----------------------------------------------------------------------===// |
| 449 | // FinalizeOpLowering |
| 450 | //===----------------------------------------------------------------------===// |
| 451 | |
| 452 | struct FinalizeOpLowering : public ConvertOpToLLVMPattern<mpi::FinalizeOp> { |
| 453 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 454 | |
| 455 | LogicalResult |
| 456 | matchAndRewrite(mpi::FinalizeOp op, OpAdaptor adaptor, |
| 457 | ConversionPatternRewriter &rewriter) const override { |
| 458 | // get loc |
| 459 | Location loc = op.getLoc(); |
| 460 | |
| 461 | // grab a reference to the global module op: |
| 462 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
| 463 | |
| 464 | // LLVM Function type representing `i32 MPI_Finalize()` |
| 465 | auto initFuncType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {}); |
| 466 | // get or create function declaration: |
| 467 | LLVM::LLVMFuncOp initDecl = getOrDefineFunction( |
| 468 | moduleOp, loc, rewriter, "MPI_Finalize" , initFuncType); |
| 469 | |
| 470 | // replace init with function call |
| 471 | rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, initDecl, ValueRange{}); |
| 472 | |
| 473 | return success(); |
| 474 | } |
| 475 | }; |
| 476 | |
| 477 | //===----------------------------------------------------------------------===// |
| 478 | // CommWorldOpLowering |
| 479 | //===----------------------------------------------------------------------===// |
| 480 | |
| 481 | struct CommWorldOpLowering : public ConvertOpToLLVMPattern<mpi::CommWorldOp> { |
| 482 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 483 | |
| 484 | LogicalResult |
| 485 | matchAndRewrite(mpi::CommWorldOp op, OpAdaptor adaptor, |
| 486 | ConversionPatternRewriter &rewriter) const override { |
| 487 | // grab a reference to the global module op: |
| 488 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
| 489 | auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp); |
| 490 | // get MPI_COMM_WORLD |
| 491 | rewriter.replaceOp(op, mpiTraits->getCommWorld(op.getLoc(), rewriter)); |
| 492 | |
| 493 | return success(); |
| 494 | } |
| 495 | }; |
| 496 | |
| 497 | //===----------------------------------------------------------------------===// |
| 498 | // CommSplitOpLowering |
| 499 | //===----------------------------------------------------------------------===// |
| 500 | |
| 501 | struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> { |
| 502 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 503 | |
| 504 | LogicalResult |
| 505 | matchAndRewrite(mpi::CommSplitOp op, OpAdaptor adaptor, |
| 506 | ConversionPatternRewriter &rewriter) const override { |
| 507 | // grab a reference to the global module op: |
| 508 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
| 509 | auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp); |
| 510 | Type i32 = rewriter.getI32Type(); |
| 511 | Type ptrType = LLVM::LLVMPointerType::get(op->getContext()); |
| 512 | Location loc = op.getLoc(); |
| 513 | |
| 514 | // get communicator |
| 515 | Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); |
| 516 | auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1); |
| 517 | auto outPtr = |
| 518 | rewriter.create<LLVM::AllocaOp>(loc, ptrType, comm.getType(), one); |
| 519 | |
| 520 | // int MPI_Comm_split(MPI_Comm comm, int color, int key, MPI_Comm * newcomm) |
| 521 | auto funcType = |
| 522 | LLVM::LLVMFunctionType::get(i32, {comm.getType(), i32, i32, ptrType}); |
| 523 | // get or create function declaration: |
| 524 | LLVM::LLVMFuncOp funcDecl = getOrDefineFunction(moduleOp, loc, rewriter, |
| 525 | "MPI_Comm_split" , funcType); |
| 526 | |
| 527 | auto callOp = rewriter.create<LLVM::CallOp>( |
| 528 | loc, funcDecl, |
| 529 | ValueRange{comm, adaptor.getColor(), adaptor.getKey(), |
| 530 | outPtr.getRes()}); |
| 531 | |
| 532 | // load the communicator into a register |
| 533 | Value res = rewriter.create<LLVM::LoadOp>(loc, i32, outPtr.getResult()); |
| 534 | res = rewriter.create<LLVM::SExtOp>(loc, rewriter.getI64Type(), res); |
| 535 | |
| 536 | // if retval is checked, replace uses of retval with the results from the |
| 537 | // call op |
| 538 | SmallVector<Value> replacements; |
| 539 | if (op.getRetval()) |
| 540 | replacements.push_back(Elt: callOp.getResult()); |
| 541 | |
| 542 | // replace op |
| 543 | replacements.push_back(Elt: res); |
| 544 | rewriter.replaceOp(op, replacements); |
| 545 | |
| 546 | return success(); |
| 547 | } |
| 548 | }; |
| 549 | |
| 550 | //===----------------------------------------------------------------------===// |
| 551 | // CommRankOpLowering |
| 552 | //===----------------------------------------------------------------------===// |
| 553 | |
| 554 | struct CommRankOpLowering : public ConvertOpToLLVMPattern<mpi::CommRankOp> { |
| 555 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 556 | |
| 557 | LogicalResult |
| 558 | matchAndRewrite(mpi::CommRankOp op, OpAdaptor adaptor, |
| 559 | ConversionPatternRewriter &rewriter) const override { |
| 560 | // get some helper vars |
| 561 | Location loc = op.getLoc(); |
| 562 | MLIRContext *context = rewriter.getContext(); |
| 563 | Type i32 = rewriter.getI32Type(); |
| 564 | |
| 565 | // ptrType `!llvm.ptr` |
| 566 | Type ptrType = LLVM::LLVMPointerType::get(context); |
| 567 | |
| 568 | // grab a reference to the global module op: |
| 569 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
| 570 | |
| 571 | auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp); |
| 572 | // get communicator |
| 573 | Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); |
| 574 | |
| 575 | // LLVM Function type representing `i32 MPI_Comm_rank(ptr, ptr)` |
| 576 | auto rankFuncType = |
| 577 | LLVM::LLVMFunctionType::get(i32, {comm.getType(), ptrType}); |
| 578 | // get or create function declaration: |
| 579 | LLVM::LLVMFuncOp initDecl = getOrDefineFunction( |
| 580 | moduleOp, loc, rewriter, "MPI_Comm_rank" , rankFuncType); |
| 581 | |
| 582 | // replace with function call |
| 583 | auto one = rewriter.create<LLVM::ConstantOp>(loc, i32, 1); |
| 584 | auto rankptr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, i32, one); |
| 585 | auto callOp = rewriter.create<LLVM::CallOp>( |
| 586 | loc, initDecl, ValueRange{comm, rankptr.getRes()}); |
| 587 | |
| 588 | // load the rank into a register |
| 589 | auto loadedRank = |
| 590 | rewriter.create<LLVM::LoadOp>(loc, i32, rankptr.getResult()); |
| 591 | |
| 592 | // if retval is checked, replace uses of retval with the results from the |
| 593 | // call op |
| 594 | SmallVector<Value> replacements; |
| 595 | if (op.getRetval()) |
| 596 | replacements.push_back(Elt: callOp.getResult()); |
| 597 | |
| 598 | // replace all uses, then erase op |
| 599 | replacements.push_back(Elt: loadedRank.getRes()); |
| 600 | rewriter.replaceOp(op, replacements); |
| 601 | |
| 602 | return success(); |
| 603 | } |
| 604 | }; |
| 605 | |
| 606 | //===----------------------------------------------------------------------===// |
| 607 | // SendOpLowering |
| 608 | //===----------------------------------------------------------------------===// |
| 609 | |
| 610 | struct SendOpLowering : public ConvertOpToLLVMPattern<mpi::SendOp> { |
| 611 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 612 | |
| 613 | LogicalResult |
| 614 | matchAndRewrite(mpi::SendOp op, OpAdaptor adaptor, |
| 615 | ConversionPatternRewriter &rewriter) const override { |
| 616 | // get some helper vars |
| 617 | Location loc = op.getLoc(); |
| 618 | MLIRContext *context = rewriter.getContext(); |
| 619 | Type i32 = rewriter.getI32Type(); |
| 620 | Type elemType = op.getRef().getType().getElementType(); |
| 621 | |
| 622 | // ptrType `!llvm.ptr` |
| 623 | Type ptrType = LLVM::LLVMPointerType::get(context); |
| 624 | |
| 625 | // grab a reference to the global module op: |
| 626 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
| 627 | |
| 628 | // get MPI_COMM_WORLD, dataType and pointer |
| 629 | auto [dataPtr, size] = |
| 630 | getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType); |
| 631 | auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp); |
| 632 | Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); |
| 633 | Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); |
| 634 | |
| 635 | // LLVM Function type representing `i32 MPI_send(data, count, datatype, dst, |
| 636 | // tag, comm)` |
| 637 | auto funcType = LLVM::LLVMFunctionType::get( |
| 638 | i32, {ptrType, i32, dataType.getType(), i32, i32, comm.getType()}); |
| 639 | // get or create function declaration: |
| 640 | LLVM::LLVMFuncOp funcDecl = |
| 641 | getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Send" , funcType); |
| 642 | |
| 643 | // replace op with function call |
| 644 | auto funcCall = rewriter.create<LLVM::CallOp>( |
| 645 | loc, funcDecl, |
| 646 | ValueRange{dataPtr, size, dataType, adaptor.getDest(), adaptor.getTag(), |
| 647 | comm}); |
| 648 | if (op.getRetval()) |
| 649 | rewriter.replaceOp(op, funcCall.getResult()); |
| 650 | else |
| 651 | rewriter.eraseOp(op: op); |
| 652 | |
| 653 | return success(); |
| 654 | } |
| 655 | }; |
| 656 | |
| 657 | //===----------------------------------------------------------------------===// |
| 658 | // RecvOpLowering |
| 659 | //===----------------------------------------------------------------------===// |
| 660 | |
| 661 | struct RecvOpLowering : public ConvertOpToLLVMPattern<mpi::RecvOp> { |
| 662 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 663 | |
| 664 | LogicalResult |
| 665 | matchAndRewrite(mpi::RecvOp op, OpAdaptor adaptor, |
| 666 | ConversionPatternRewriter &rewriter) const override { |
| 667 | // get some helper vars |
| 668 | Location loc = op.getLoc(); |
| 669 | MLIRContext *context = rewriter.getContext(); |
| 670 | Type i32 = rewriter.getI32Type(); |
| 671 | Type i64 = rewriter.getI64Type(); |
| 672 | Type elemType = op.getRef().getType().getElementType(); |
| 673 | |
| 674 | // ptrType `!llvm.ptr` |
| 675 | Type ptrType = LLVM::LLVMPointerType::get(context); |
| 676 | |
| 677 | // grab a reference to the global module op: |
| 678 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
| 679 | |
| 680 | // get MPI_COMM_WORLD, dataType, status_ignore and pointer |
| 681 | auto [dataPtr, size] = |
| 682 | getRawPtrAndSize(loc, rewriter, adaptor.getRef(), elemType); |
| 683 | auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp); |
| 684 | Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); |
| 685 | Value comm = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); |
| 686 | Value statusIgnore = rewriter.create<LLVM::ConstantOp>( |
| 687 | loc, i64, mpiTraits->getStatusIgnore()); |
| 688 | statusIgnore = |
| 689 | rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, statusIgnore); |
| 690 | |
| 691 | // LLVM Function type representing `i32 MPI_Recv(data, count, datatype, dst, |
| 692 | // tag, comm)` |
| 693 | auto funcType = |
| 694 | LLVM::LLVMFunctionType::get(i32, {ptrType, i32, dataType.getType(), i32, |
| 695 | i32, comm.getType(), ptrType}); |
| 696 | // get or create function declaration: |
| 697 | LLVM::LLVMFuncOp funcDecl = |
| 698 | getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Recv" , funcType); |
| 699 | |
| 700 | // replace op with function call |
| 701 | auto funcCall = rewriter.create<LLVM::CallOp>( |
| 702 | loc, funcDecl, |
| 703 | ValueRange{dataPtr, size, dataType, adaptor.getSource(), |
| 704 | adaptor.getTag(), comm, statusIgnore}); |
| 705 | if (op.getRetval()) |
| 706 | rewriter.replaceOp(op, funcCall.getResult()); |
| 707 | else |
| 708 | rewriter.eraseOp(op: op); |
| 709 | |
| 710 | return success(); |
| 711 | } |
| 712 | }; |
| 713 | |
| 714 | //===----------------------------------------------------------------------===// |
| 715 | // AllReduceOpLowering |
| 716 | //===----------------------------------------------------------------------===// |
| 717 | |
| 718 | struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> { |
| 719 | using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; |
| 720 | |
| 721 | LogicalResult |
| 722 | matchAndRewrite(mpi::AllReduceOp op, OpAdaptor adaptor, |
| 723 | ConversionPatternRewriter &rewriter) const override { |
| 724 | Location loc = op.getLoc(); |
| 725 | MLIRContext *context = rewriter.getContext(); |
| 726 | Type i32 = rewriter.getI32Type(); |
| 727 | Type i64 = rewriter.getI64Type(); |
| 728 | Type elemType = op.getSendbuf().getType().getElementType(); |
| 729 | |
| 730 | // ptrType `!llvm.ptr` |
| 731 | Type ptrType = LLVM::LLVMPointerType::get(context); |
| 732 | auto moduleOp = op->getParentOfType<ModuleOp>(); |
| 733 | auto mpiTraits = MPIImplTraits::get(moduleOp&: moduleOp); |
| 734 | auto [sendPtr, sendSize] = |
| 735 | getRawPtrAndSize(loc, rewriter, adaptor.getSendbuf(), elemType); |
| 736 | auto [recvPtr, recvSize] = |
| 737 | getRawPtrAndSize(loc, rewriter, adaptor.getRecvbuf(), elemType); |
| 738 | |
| 739 | // If input and output are the same, request in-place operation. |
| 740 | if (adaptor.getSendbuf() == adaptor.getRecvbuf()) { |
| 741 | sendPtr = rewriter.create<LLVM::ConstantOp>( |
| 742 | loc, i64, reinterpret_cast<int64_t>(mpiTraits->getInPlace())); |
| 743 | sendPtr = rewriter.create<LLVM::IntToPtrOp>(loc, ptrType, sendPtr); |
| 744 | } |
| 745 | |
| 746 | Value dataType = mpiTraits->getDataType(loc, rewriter, elemType); |
| 747 | Value mpiOp = mpiTraits->getMPIOp(loc, rewriter, op.getOp()); |
| 748 | Value commWorld = mpiTraits->castComm(loc, rewriter, adaptor.getComm()); |
| 749 | |
| 750 | // 'int MPI_Allreduce(const void *sendbuf, void *recvbuf, int count, |
| 751 | // MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)' |
| 752 | auto funcType = LLVM::LLVMFunctionType::get( |
| 753 | i32, {ptrType, ptrType, i32, dataType.getType(), mpiOp.getType(), |
| 754 | commWorld.getType()}); |
| 755 | // get or create function declaration: |
| 756 | LLVM::LLVMFuncOp funcDecl = |
| 757 | getOrDefineFunction(moduleOp, loc, rewriter, "MPI_Allreduce" , funcType); |
| 758 | |
| 759 | // replace op with function call |
| 760 | auto funcCall = rewriter.create<LLVM::CallOp>( |
| 761 | loc, funcDecl, |
| 762 | ValueRange{sendPtr, recvPtr, sendSize, dataType, mpiOp, commWorld}); |
| 763 | |
| 764 | if (op.getRetval()) |
| 765 | rewriter.replaceOp(op, funcCall.getResult()); |
| 766 | else |
| 767 | rewriter.eraseOp(op: op); |
| 768 | |
| 769 | return success(); |
| 770 | } |
| 771 | }; |
| 772 | |
| 773 | //===----------------------------------------------------------------------===// |
| 774 | // ConvertToLLVMPatternInterface implementation |
| 775 | //===----------------------------------------------------------------------===// |
| 776 | |
| 777 | /// Implement the interface to convert Func to LLVM. |
| 778 | struct FuncToLLVMDialectInterface : public ConvertToLLVMPatternInterface { |
| 779 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
| 780 | /// Hook for derived dialect interface to provide conversion patterns |
| 781 | /// and mark dialect legal for the conversion target. |
| 782 | void populateConvertToLLVMConversionPatterns( |
| 783 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
| 784 | RewritePatternSet &patterns) const final { |
| 785 | mpi::populateMPIToLLVMConversionPatterns(converter&: typeConverter, patterns); |
| 786 | } |
| 787 | }; |
| 788 | } // namespace |
| 789 | |
| 790 | //===----------------------------------------------------------------------===// |
| 791 | // Pattern Population |
| 792 | //===----------------------------------------------------------------------===// |
| 793 | |
| 794 | void mpi::populateMPIToLLVMConversionPatterns(LLVMTypeConverter &converter, |
| 795 | RewritePatternSet &patterns) { |
| 796 | // Using i64 as a portable, intermediate type for !mpi.comm. |
| 797 | // It would be nicer to somehow get the right type directly, but TLDI is not |
| 798 | // available here. |
| 799 | converter.addConversion(callback: [](mpi::CommType type) { |
| 800 | return IntegerType::get(type.getContext(), 64); |
| 801 | }); |
| 802 | patterns.add<CommRankOpLowering, CommSplitOpLowering, CommWorldOpLowering, |
| 803 | FinalizeOpLowering, InitOpLowering, SendOpLowering, |
| 804 | RecvOpLowering, AllReduceOpLowering>(arg&: converter); |
| 805 | } |
| 806 | |
| 807 | void mpi::registerConvertMPIToLLVMInterface(DialectRegistry ®istry) { |
| 808 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, mpi::MPIDialect *dialect) { |
| 809 | dialect->addInterfaces<FuncToLLVMDialectInterface>(); |
| 810 | }); |
| 811 | } |
| 812 | |