| 1 | //===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a translation of Mesh communication ops tp MPI ops. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Conversion/MeshToMPI/MeshToMPI.h" |
| 14 | |
| 15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 17 | #include "mlir/Dialect/DLTI/DLTI.h" |
| 18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 19 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
| 20 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 21 | #include "mlir/Dialect/MPI/IR/MPI.h" |
| 22 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 23 | #include "mlir/Dialect/Mesh/IR/MeshDialect.h" |
| 24 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
| 25 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 26 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 27 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 28 | #include "mlir/IR/Builders.h" |
| 29 | #include "mlir/IR/BuiltinAttributes.h" |
| 30 | #include "mlir/IR/BuiltinTypes.h" |
| 31 | #include "mlir/IR/PatternMatch.h" |
| 32 | #include "mlir/IR/SymbolTable.h" |
| 33 | #include "mlir/Transforms/DialectConversion.h" |
| 34 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 35 | |
| 36 | #define DEBUG_TYPE "mesh-to-mpi" |
| 37 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 38 | |
| 39 | namespace mlir { |
| 40 | #define GEN_PASS_DEF_CONVERTMESHTOMPIPASS |
| 41 | #include "mlir/Conversion/Passes.h.inc" |
| 42 | } // namespace mlir |
| 43 | |
| 44 | using namespace mlir; |
| 45 | using namespace mesh; |
| 46 | |
| 47 | namespace { |
| 48 | /// Converts a vector of OpFoldResults (ints) into vector of Values of the |
| 49 | /// provided type. |
| 50 | static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc, |
| 51 | llvm::ArrayRef<int64_t> statics, |
| 52 | ValueRange dynamics, |
| 53 | Type type = Type()) { |
| 54 | SmallVector<Value> values; |
| 55 | auto dyn = dynamics.begin(); |
| 56 | Type i64 = b.getI64Type(); |
| 57 | if (!type) |
| 58 | type = i64; |
| 59 | assert((i64 == type || b.getIndexType() == type) && |
| 60 | "expected an i64 or an intex type" ); |
| 61 | for (auto s : statics) { |
| 62 | if (s == ShapedType::kDynamic) { |
| 63 | values.emplace_back(*(dyn++)); |
| 64 | } else { |
| 65 | TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s); |
| 66 | values.emplace_back(b.create<arith::ConstantOp>(loc, type, val)); |
| 67 | } |
| 68 | } |
| 69 | return values; |
| 70 | } |
| 71 | |
| 72 | /// Create operations converting a linear index to a multi-dimensional index. |
| 73 | static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b, |
| 74 | Value linearIndex, |
| 75 | ValueRange dimensions) { |
| 76 | int n = dimensions.size(); |
| 77 | SmallVector<Value> multiIndex(n); |
| 78 | |
| 79 | for (int i = n - 1; i >= 0; --i) { |
| 80 | multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]); |
| 81 | if (i > 0) |
| 82 | linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]); |
| 83 | } |
| 84 | |
| 85 | return multiIndex; |
| 86 | } |
| 87 | |
| 88 | /// Create operations converting a multi-dimensional index to a linear index. |
| 89 | Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, |
| 90 | ValueRange dimensions) { |
| 91 | |
| 92 | Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0); |
| 93 | Value stride = b.create<arith::ConstantIndexOp>(loc, 1); |
| 94 | |
| 95 | for (int i = multiIndex.size() - 1; i >= 0; --i) { |
| 96 | Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride); |
| 97 | linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off); |
| 98 | stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]); |
| 99 | } |
| 100 | |
| 101 | return linearIndex; |
| 102 | } |
| 103 | |
| 104 | /// Replace GetShardingOp with related/dependent ShardingOp. |
| 105 | struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> { |
| 106 | using OpConversionPattern::OpConversionPattern; |
| 107 | |
| 108 | LogicalResult |
| 109 | matchAndRewrite(GetShardingOp op, OpAdaptor adaptor, |
| 110 | ConversionPatternRewriter &rewriter) const override { |
| 111 | auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>(); |
| 112 | if (!shardOp) |
| 113 | return failure(); |
| 114 | auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>(); |
| 115 | if (!shardingOp) |
| 116 | return failure(); |
| 117 | |
| 118 | rewriter.replaceOp(op, shardingOp.getResult()); |
| 119 | return success(); |
| 120 | } |
| 121 | }; |
| 122 | |
| 123 | /// Convert a sharding op to a tuple of tensors of its components |
| 124 | /// (SplitAxes, HaloSizes, ShardedDimsOffsets) |
| 125 | /// as defined by type converter. |
| 126 | struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { |
| 127 | using OpConversionPattern::OpConversionPattern; |
| 128 | |
| 129 | LogicalResult |
| 130 | matchAndRewrite(ShardingOp op, OpAdaptor adaptor, |
| 131 | ConversionPatternRewriter &rewriter) const override { |
| 132 | auto splitAxes = op.getSplitAxes().getAxes(); |
| 133 | int64_t maxNAxes = 0; |
| 134 | for (auto axes : splitAxes) |
| 135 | maxNAxes = std::max<int64_t>(maxNAxes, axes.size()); |
| 136 | |
| 137 | // To hold the split axes, create empty 2d tensor with shape |
| 138 | // {splitAxes.size(), max-size-of-split-groups}. |
| 139 | // Set trailing elements for smaller split-groups to -1. |
| 140 | Location loc = op.getLoc(); |
| 141 | auto i16 = rewriter.getI16Type(); |
| 142 | auto i64 = rewriter.getI64Type(); |
| 143 | std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()), |
| 144 | maxNAxes}; |
| 145 | Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16); |
| 146 | auto attr = IntegerAttr::get(i16, -1); |
| 147 | Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr); |
| 148 | resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes) |
| 149 | .getResult(0); |
| 150 | |
| 151 | // explicitly write values into tensor row by row |
| 152 | std::array<int64_t, 2> strides = {1, 1}; |
| 153 | int64_t nSplits = 0; |
| 154 | ValueRange empty = {}; |
| 155 | for (auto [i, axes] : llvm::enumerate(splitAxes)) { |
| 156 | int64_t size = axes.size(); |
| 157 | if (size > 0) |
| 158 | ++nSplits; |
| 159 | std::array<int64_t, 2> offs = {(int64_t)i, 0}; |
| 160 | std::array<int64_t, 2> sizes = {1, size}; |
| 161 | auto tensorType = RankedTensorType::get({size}, i16); |
| 162 | auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef()); |
| 163 | auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs); |
| 164 | resSplitAxes = rewriter.create<tensor::InsertSliceOp>( |
| 165 | loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides); |
| 166 | } |
| 167 | |
| 168 | // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}. |
| 169 | // Store the halo sizes in the tensor. |
| 170 | SmallVector<Value> haloSizes = |
| 171 | getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(), |
| 172 | adaptor.getDynamicHaloSizes()); |
| 173 | auto type = RankedTensorType::get({nSplits, 2}, i64); |
| 174 | Value resHaloSizes = |
| 175 | haloSizes.empty() |
| 176 | ? rewriter |
| 177 | .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0}, |
| 178 | i64) |
| 179 | .getResult() |
| 180 | : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes) |
| 181 | .getResult(); |
| 182 | |
| 183 | // To hold sharded dims offsets, create Tensor with shape {nSplits, |
| 184 | // maxSplitSize+1}. Store the offsets in the tensor but set trailing |
| 185 | // elements for smaller split-groups to -1. Computing the max size of the |
| 186 | // split groups needs using collectiveProcessGroupSize (which needs the |
| 187 | // MeshOp) |
| 188 | Value resOffsets; |
| 189 | if (adaptor.getStaticShardedDimsOffsets().empty()) { |
| 190 | resOffsets = rewriter.create<tensor::EmptyOp>( |
| 191 | loc, std::array<int64_t, 2>{0, 0}, i64); |
| 192 | } else { |
| 193 | SymbolTableCollection symbolTableCollection; |
| 194 | auto meshOp = getMesh(op, symbolTableCollection); |
| 195 | int64_t maxSplitSize = 0; |
| 196 | for (auto axes : splitAxes) { |
| 197 | int64_t splitSize = |
| 198 | collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); |
| 199 | assert(splitSize != ShapedType::kDynamic); |
| 200 | maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize); |
| 201 | } |
| 202 | assert(maxSplitSize); |
| 203 | ++maxSplitSize; // add one for the total size |
| 204 | |
| 205 | resOffsets = rewriter.create<tensor::EmptyOp>( |
| 206 | loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64); |
| 207 | Value zero = rewriter.create<arith::ConstantOp>( |
| 208 | loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic)); |
| 209 | resOffsets = |
| 210 | rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0); |
| 211 | SmallVector<Value> offsets = |
| 212 | getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(), |
| 213 | adaptor.getDynamicShardedDimsOffsets()); |
| 214 | int64_t curr = 0; |
| 215 | for (auto [i, axes] : llvm::enumerate(splitAxes)) { |
| 216 | int64_t splitSize = |
| 217 | collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); |
| 218 | assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); |
| 219 | ++splitSize; // add one for the total size |
| 220 | ArrayRef<Value> values(&offsets[curr], splitSize); |
| 221 | Value vals = rewriter.create<tensor::FromElementsOp>(loc, values); |
| 222 | std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0}; |
| 223 | std::array<int64_t, 2> sizes = {1, splitSize}; |
| 224 | resOffsets = rewriter.create<tensor::InsertSliceOp>( |
| 225 | loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides); |
| 226 | curr += splitSize; |
| 227 | } |
| 228 | } |
| 229 | |
| 230 | // return a tuple of tensors as defined by type converter |
| 231 | SmallVector<Type> resTypes; |
| 232 | if (failed(getTypeConverter()->convertType(op.getResult().getType(), |
| 233 | resTypes))) |
| 234 | return failure(); |
| 235 | |
| 236 | resSplitAxes = |
| 237 | rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes); |
| 238 | resHaloSizes = |
| 239 | rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes); |
| 240 | resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets); |
| 241 | |
| 242 | rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>( |
| 243 | op, TupleType::get(op.getContext(), resTypes), |
| 244 | ValueRange{resSplitAxes, resHaloSizes, resOffsets}); |
| 245 | |
| 246 | return success(); |
| 247 | } |
| 248 | }; |
| 249 | |
| 250 | struct ConvertProcessMultiIndexOp |
| 251 | : public OpConversionPattern<ProcessMultiIndexOp> { |
| 252 | using OpConversionPattern::OpConversionPattern; |
| 253 | |
| 254 | LogicalResult |
| 255 | matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor, |
| 256 | ConversionPatternRewriter &rewriter) const override { |
| 257 | |
| 258 | // Currently converts its linear index to a multi-dimensional index. |
| 259 | |
| 260 | SymbolTableCollection symbolTableCollection; |
| 261 | Location loc = op.getLoc(); |
| 262 | auto meshOp = getMesh(op, symbolTableCollection); |
| 263 | // For now we only support static mesh shapes |
| 264 | if (ShapedType::isDynamicShape(meshOp.getShape())) |
| 265 | return failure(); |
| 266 | |
| 267 | SmallVector<Value> dims; |
| 268 | llvm::transform( |
| 269 | meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { |
| 270 | return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult(); |
| 271 | }); |
| 272 | Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp); |
| 273 | auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); |
| 274 | |
| 275 | // optionally extract subset of mesh axes |
| 276 | auto axes = adaptor.getAxes(); |
| 277 | if (!axes.empty()) { |
| 278 | SmallVector<Value> subIndex; |
| 279 | for (auto axis : axes) { |
| 280 | subIndex.emplace_back(mIdx[axis]); |
| 281 | } |
| 282 | mIdx = std::move(subIndex); |
| 283 | } |
| 284 | |
| 285 | rewriter.replaceOp(op, mIdx); |
| 286 | return success(); |
| 287 | } |
| 288 | }; |
| 289 | |
| 290 | class ConvertProcessLinearIndexOp |
| 291 | : public OpConversionPattern<ProcessLinearIndexOp> { |
| 292 | int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0 |
| 293 | |
| 294 | public: |
| 295 | using OpConversionPattern::OpConversionPattern; |
| 296 | |
| 297 | // Constructor accepting worldRank |
| 298 | ConvertProcessLinearIndexOp(const TypeConverter &typeConverter, |
| 299 | MLIRContext *context, int64_t worldRank = -1) |
| 300 | : OpConversionPattern(typeConverter, context), worldRank(worldRank) {} |
| 301 | |
| 302 | LogicalResult |
| 303 | matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor, |
| 304 | ConversionPatternRewriter &rewriter) const override { |
| 305 | |
| 306 | Location loc = op.getLoc(); |
| 307 | if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it |
| 308 | rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank); |
| 309 | return success(); |
| 310 | } |
| 311 | |
| 312 | // Otherwise call create mpi::CommRankOp |
| 313 | auto ctx = op.getContext(); |
| 314 | Value commWorld = |
| 315 | rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx)); |
| 316 | auto rank = |
| 317 | rewriter |
| 318 | .create<mpi::CommRankOp>( |
| 319 | loc, |
| 320 | TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, |
| 321 | commWorld) |
| 322 | .getRank(); |
| 323 | rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), |
| 324 | rank); |
| 325 | return success(); |
| 326 | } |
| 327 | }; |
| 328 | |
| 329 | struct ConvertNeighborsLinearIndicesOp |
| 330 | : public OpConversionPattern<NeighborsLinearIndicesOp> { |
| 331 | using OpConversionPattern::OpConversionPattern; |
| 332 | |
| 333 | LogicalResult |
| 334 | matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor, |
| 335 | ConversionPatternRewriter &rewriter) const override { |
| 336 | |
| 337 | // Computes the neighbors indices along a split axis by simply |
| 338 | // adding/subtracting 1 to the current index in that dimension. |
| 339 | // Assigns -1 if neighbor is out of bounds. |
| 340 | |
| 341 | auto axes = adaptor.getSplitAxes(); |
| 342 | // For now only single axis sharding is supported |
| 343 | if (axes.size() != 1) |
| 344 | return failure(); |
| 345 | |
| 346 | Location loc = op.getLoc(); |
| 347 | SymbolTableCollection symbolTableCollection; |
| 348 | auto meshOp = getMesh(op, symbolTableCollection); |
| 349 | auto mIdx = adaptor.getDevice(); |
| 350 | auto orgIdx = mIdx[axes[0]]; |
| 351 | SmallVector<Value> dims; |
| 352 | llvm::transform( |
| 353 | meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { |
| 354 | return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult(); |
| 355 | }); |
| 356 | Value dimSz = dims[axes[0]]; |
| 357 | Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
| 358 | Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1); |
| 359 | Value atBorder = rewriter.create<arith::CmpIOp>( |
| 360 | loc, arith::CmpIPredicate::sle, orgIdx, |
| 361 | rewriter.create<arith::ConstantIndexOp>(loc, 0)); |
| 362 | auto down = rewriter.create<scf::IfOp>( |
| 363 | loc, atBorder, |
| 364 | [&](OpBuilder &builder, Location loc) { |
| 365 | builder.create<scf::YieldOp>(loc, minus1); |
| 366 | }, |
| 367 | [&](OpBuilder &builder, Location loc) { |
| 368 | SmallVector<Value> tmp = mIdx; |
| 369 | tmp[axes[0]] = |
| 370 | rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one) |
| 371 | .getResult(); |
| 372 | builder.create<scf::YieldOp>( |
| 373 | loc, multiToLinearIndex(loc, rewriter, tmp, dims)); |
| 374 | }); |
| 375 | atBorder = rewriter.create<arith::CmpIOp>( |
| 376 | loc, arith::CmpIPredicate::sge, orgIdx, |
| 377 | rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult()); |
| 378 | auto up = rewriter.create<scf::IfOp>( |
| 379 | loc, atBorder, |
| 380 | [&](OpBuilder &builder, Location loc) { |
| 381 | builder.create<scf::YieldOp>(loc, minus1); |
| 382 | }, |
| 383 | [&](OpBuilder &builder, Location loc) { |
| 384 | SmallVector<Value> tmp = mIdx; |
| 385 | tmp[axes[0]] = |
| 386 | rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one); |
| 387 | builder.create<scf::YieldOp>( |
| 388 | loc, multiToLinearIndex(loc, rewriter, tmp, dims)); |
| 389 | }); |
| 390 | rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); |
| 391 | return success(); |
| 392 | } |
| 393 | }; |
| 394 | |
| 395 | struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { |
| 396 | using OpConversionPattern::OpConversionPattern; |
| 397 | |
| 398 | LogicalResult |
| 399 | matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor, |
| 400 | ConversionPatternRewriter &rewriter) const override { |
| 401 | auto sharding = op.getSharding().getDefiningOp<ShardingOp>(); |
| 402 | if (!sharding) { |
| 403 | return op->emitError() |
| 404 | << "Expected SharingOp as defining op for sharding" |
| 405 | << " but found " << adaptor.getSharding()[0].getDefiningOp(); |
| 406 | } |
| 407 | |
| 408 | // Compute the sharded shape by applying the sharding to the input shape. |
| 409 | // If shardedDimsOffsets is not defined in the sharding, the shard shape is |
| 410 | // computed by dividing the dimension size by the number of shards in that |
| 411 | // dimension (which is given by the size of the mesh axes provided in |
| 412 | // split-axes). Odd elements get distributed to trailing shards. If a |
| 413 | // shardedDimsOffsets is provided, the shard shape is computed by |
| 414 | // subtracting the offset of the current shard from the offset of the next |
| 415 | // shard. |
| 416 | |
| 417 | Location loc = op.getLoc(); |
| 418 | Type index = rewriter.getIndexType(); |
| 419 | |
| 420 | // This is a 1:N conversion because the sharding op is a 1:3 conversion. |
| 421 | // The operands in the adaptor are a vector<ValeRange>. For dims and device |
| 422 | // we have a 1:1 conversion. |
| 423 | // For simpler access fill a vector with the dynamic dims. |
| 424 | SmallVector<Value> dynDims, dynDevice; |
| 425 | for (auto dim : adaptor.getDimsDynamic()) { |
| 426 | // type conversion should be 1:1 for ints |
| 427 | dynDims.emplace_back(llvm::getSingleElement(dim)); |
| 428 | } |
| 429 | // same for device |
| 430 | for (auto device : adaptor.getDeviceDynamic()) { |
| 431 | dynDevice.emplace_back(llvm::getSingleElement(device)); |
| 432 | } |
| 433 | |
| 434 | // To keep the code simple, convert dims/device to values when they are |
| 435 | // attributes. Count on canonicalization to fold static values. |
| 436 | SmallVector<Value> shape = |
| 437 | getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index); |
| 438 | SmallVector<Value> multiIdx = |
| 439 | getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index); |
| 440 | |
| 441 | // Get the MeshOp, the mesh shape is needed to compute the sharded shape. |
| 442 | SymbolTableCollection symbolTableCollection; |
| 443 | auto meshOp = getMesh(sharding, symbolTableCollection); |
| 444 | // For now we only support static mesh shapes |
| 445 | if (ShapedType::isDynamicShape(meshOp.getShape())) |
| 446 | return failure(); |
| 447 | |
| 448 | auto splitAxes = sharding.getSplitAxes().getAxes(); |
| 449 | // shardedDimsOffsets are optional and might be Values (not attributes). |
| 450 | // Also, the shardId might be dynamic which means the position in the |
| 451 | // shardedDimsOffsets is not statically known. Create a tensor of the |
| 452 | // shardedDimsOffsets and later extract the offsets for computing the |
| 453 | // local shard-size. |
| 454 | Value shardedDimsOffs; |
| 455 | { |
| 456 | SmallVector<Value> tmp = getMixedAsValues( |
| 457 | rewriter, loc, sharding.getStaticShardedDimsOffsets(), |
| 458 | sharding.getDynamicShardedDimsOffsets(), index); |
| 459 | if (!tmp.empty()) |
| 460 | shardedDimsOffs = rewriter.create<tensor::FromElementsOp>( |
| 461 | loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp); |
| 462 | } |
| 463 | |
| 464 | // With static mesh shape the sizes of the split axes are known. |
| 465 | // Hence the start/pos for each split axes in shardDimsOffsets can be |
| 466 | // computed statically. |
| 467 | int64_t pos = 0; |
| 468 | SmallVector<Value> shardShape; |
| 469 | Value zero = |
| 470 | rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index)); |
| 471 | Value one = |
| 472 | rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index)); |
| 473 | |
| 474 | // Iterate over the dimensions of the tensor shape, get their split Axes, |
| 475 | // and compute the sharded shape. |
| 476 | for (auto [i, dim] : llvm::enumerate(shape)) { |
| 477 | // Trailing dimensions might not be annotated. |
| 478 | if (i < splitAxes.size() && !splitAxes[i].empty()) { |
| 479 | auto axes = splitAxes[i]; |
| 480 | // The current dimension might not be sharded. |
| 481 | // Create a value from the static position in shardDimsOffsets. |
| 482 | Value posVal = |
| 483 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos)); |
| 484 | // Get the index of the local shard in the mesh axis. |
| 485 | Value idx = multiIdx[axes[0]]; |
| 486 | auto numShards = |
| 487 | collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); |
| 488 | if (shardedDimsOffs) { |
| 489 | // If sharded dims offsets are provided, use them to compute the |
| 490 | // sharded shape. |
| 491 | if (axes.size() > 1) { |
| 492 | return op->emitError() << "Only single axis sharding is " |
| 493 | << "supported for each dimension." ; |
| 494 | } |
| 495 | idx = rewriter.create<arith::AddIOp>(loc, posVal, idx); |
| 496 | // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx]. |
| 497 | Value off = |
| 498 | rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx); |
| 499 | idx = rewriter.create<arith::AddIOp>(loc, idx, one); |
| 500 | Value nextOff = |
| 501 | rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx); |
| 502 | Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off); |
| 503 | shardShape.emplace_back(sz); |
| 504 | } else { |
| 505 | Value numShardsVal = rewriter.create<arith::ConstantOp>( |
| 506 | loc, rewriter.getIndexAttr(numShards)); |
| 507 | // Compute shard dim size by distributing odd elements to trailing |
| 508 | // shards: |
| 509 | // sz = dim / numShards |
| 510 | // + (idx >= (numShards - (dim % numShards)) ? 1 : 0) |
| 511 | Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal); |
| 512 | Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal); |
| 513 | sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1); |
| 514 | auto cond = rewriter.create<arith::CmpIOp>( |
| 515 | loc, arith::CmpIPredicate::sge, idx, sz1); |
| 516 | Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero); |
| 517 | sz = rewriter.create<arith::AddIOp>(loc, sz, odd); |
| 518 | shardShape.emplace_back(sz); |
| 519 | } |
| 520 | pos += numShards + 1; // add one for the total size. |
| 521 | } // else no sharding if split axis is empty or no split axis |
| 522 | // If no size was added -> no sharding in this dimension. |
| 523 | if (shardShape.size() <= i) |
| 524 | shardShape.emplace_back(dim); |
| 525 | } |
| 526 | assert(shardShape.size() == shape.size()); |
| 527 | rewriter.replaceOp(op, shardShape); |
| 528 | return success(); |
| 529 | } |
| 530 | }; |
| 531 | |
| 532 | struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { |
| 533 | using OpConversionPattern::OpConversionPattern; |
| 534 | |
| 535 | LogicalResult |
| 536 | matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor, |
| 537 | ConversionPatternRewriter &rewriter) const override { |
| 538 | |
| 539 | // The input/output memref is assumed to be in C memory order. |
| 540 | // Halos are exchanged as 2 blocks per dimension (one for each side: down |
| 541 | // and up). For each haloed dimension `d`, the exchanged blocks are |
| 542 | // expressed as multi-dimensional subviews. The subviews include potential |
| 543 | // halos of higher dimensions `dh > d`, no halos for the lower dimensions |
| 544 | // `dl < d` and for dimension `d` the currently exchanged halo only. |
| 545 | // By iterating form higher to lower dimensions this also updates the halos |
| 546 | // in the 'corners'. |
| 547 | // memref.subview is used to read and write the halo data from and to the |
| 548 | // local data. Because subviews and halos can have mixed dynamic and static |
| 549 | // shapes, OpFoldResults are used whenever possible. |
| 550 | |
| 551 | auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(), |
| 552 | adaptor.getHaloSizes(), rewriter); |
| 553 | if (haloSizes.empty()) { |
| 554 | // no halos -> nothing to do |
| 555 | rewriter.replaceOp(op, adaptor.getDestination()); |
| 556 | return success(); |
| 557 | } |
| 558 | |
| 559 | SymbolTableCollection symbolTableCollection; |
| 560 | Location loc = op.getLoc(); |
| 561 | |
| 562 | // convert a OpFoldResult into a Value |
| 563 | auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value { |
| 564 | if (auto value = dyn_cast<Value>(v)) |
| 565 | return value; |
| 566 | return rewriter.create<arith::ConstantOp>( |
| 567 | loc, rewriter.getIndexAttr( |
| 568 | cast<IntegerAttr>(cast<Attribute>(v)).getInt())); |
| 569 | }; |
| 570 | |
| 571 | auto dest = adaptor.getDestination(); |
| 572 | auto dstShape = cast<ShapedType>(dest.getType()).getShape(); |
| 573 | Value array = dest; |
| 574 | if (isa<RankedTensorType>(array.getType())) { |
| 575 | // If the destination is a memref, we need to cast it to a tensor |
| 576 | auto tensorType = MemRefType::get( |
| 577 | dstShape, cast<ShapedType>(array.getType()).getElementType()); |
| 578 | array = |
| 579 | rewriter.create<bufferization::ToBufferOp>(loc, tensorType, array); |
| 580 | } |
| 581 | auto rank = cast<ShapedType>(array.getType()).getRank(); |
| 582 | auto opSplitAxes = adaptor.getSplitAxes().getAxes(); |
| 583 | auto mesh = adaptor.getMesh(); |
| 584 | auto meshOp = getMesh(op, symbolTableCollection); |
| 585 | // subviews need Index values |
| 586 | for (auto &sz : haloSizes) { |
| 587 | if (auto value = dyn_cast<Value>(sz)) |
| 588 | sz = |
| 589 | rewriter |
| 590 | .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value) |
| 591 | .getResult(); |
| 592 | } |
| 593 | |
| 594 | // most of the offset/size/stride data is the same for all dims |
| 595 | SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); |
| 596 | SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); |
| 597 | SmallVector<OpFoldResult> shape(rank), dimSizes(rank); |
| 598 | auto currHaloDim = -1; // halo sizes are provided for split dimensions only |
| 599 | // we need the actual shape to compute offsets and sizes |
| 600 | for (auto i = 0; i < rank; ++i) { |
| 601 | auto s = dstShape[i]; |
| 602 | if (ShapedType::isDynamic(s)) |
| 603 | shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult(); |
| 604 | else |
| 605 | shape[i] = rewriter.getIndexAttr(value: s); |
| 606 | |
| 607 | if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) { |
| 608 | ++currHaloDim; |
| 609 | // the offsets for lower dim sstarts after their down halo |
| 610 | offsets[i] = haloSizes[currHaloDim * 2]; |
| 611 | |
| 612 | // prepare shape and offsets of highest dim's halo exchange |
| 613 | Value _haloSz = rewriter.create<arith::AddIOp>( |
| 614 | loc, toValue(haloSizes[currHaloDim * 2]), |
| 615 | toValue(haloSizes[currHaloDim * 2 + 1])); |
| 616 | // the halo shape of lower dims exlude the halos |
| 617 | dimSizes[i] = |
| 618 | rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz) |
| 619 | .getResult(); |
| 620 | } else { |
| 621 | dimSizes[i] = shape[i]; |
| 622 | } |
| 623 | } |
| 624 | |
| 625 | auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something |
| 626 | auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr); |
| 627 | auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 |
| 628 | auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); |
| 629 | |
| 630 | SmallVector<Type> indexResultTypes(meshOp.getShape().size(), |
| 631 | rewriter.getIndexType()); |
| 632 | auto myMultiIndex = |
| 633 | rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh) |
| 634 | .getResult(); |
| 635 | // traverse all split axes from high to low dim |
| 636 | for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { |
| 637 | auto splitAxes = opSplitAxes[dim]; |
| 638 | if (splitAxes.empty()) |
| 639 | continue; |
| 640 | assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); |
| 641 | // Get the linearized ids of the neighbors (down and up) for the |
| 642 | // given split |
| 643 | auto tmp = rewriter |
| 644 | .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex, |
| 645 | splitAxes) |
| 646 | .getResults(); |
| 647 | // MPI operates on i32... |
| 648 | Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>( |
| 649 | loc, rewriter.getI32Type(), tmp[0]), |
| 650 | rewriter.create<arith::IndexCastOp>( |
| 651 | loc, rewriter.getI32Type(), tmp[1])}; |
| 652 | |
| 653 | auto lowerRecvOffset = rewriter.getIndexAttr(0); |
| 654 | auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]); |
| 655 | auto upperRecvOffset = rewriter.create<arith::SubIOp>( |
| 656 | loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1])); |
| 657 | auto upperSendOffset = rewriter.create<arith::SubIOp>( |
| 658 | loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); |
| 659 | |
| 660 | Value commWorld = rewriter.create<mpi::CommWorldOp>( |
| 661 | loc, mpi::CommType::get(op->getContext())); |
| 662 | |
| 663 | // Make sure we send/recv in a way that does not lead to a dead-lock. |
| 664 | // The current approach is by far not optimal, this should be at least |
| 665 | // be a red-black pattern or using MPI_sendrecv. |
| 666 | // Also, buffers should be re-used. |
| 667 | // Still using temporary contiguous buffers for MPI communication... |
| 668 | // Still yielding a "serialized" communication pattern... |
| 669 | auto genSendRecv = [&](bool upperHalo) { |
| 670 | auto orgOffset = offsets[dim]; |
| 671 | dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] |
| 672 | : haloSizes[currHaloDim * 2]; |
| 673 | // Check if we need to send and/or receive |
| 674 | // Processes on the mesh borders have only one neighbor |
| 675 | auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; |
| 676 | auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; |
| 677 | auto hasFrom = rewriter.create<arith::CmpIOp>( |
| 678 | loc, arith::CmpIPredicate::sge, from, zero); |
| 679 | auto hasTo = rewriter.create<arith::CmpIOp>( |
| 680 | loc, arith::CmpIPredicate::sge, to, zero); |
| 681 | auto buffer = rewriter.create<memref::AllocOp>( |
| 682 | loc, dimSizes, cast<ShapedType>(array.getType()).getElementType()); |
| 683 | // if has neighbor: copy halo data from array to buffer and send |
| 684 | rewriter.create<scf::IfOp>( |
| 685 | loc, hasTo, [&](OpBuilder &builder, Location loc) { |
| 686 | offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset) |
| 687 | : OpFoldResult(upperSendOffset); |
| 688 | auto subview = builder.create<memref::SubViewOp>( |
| 689 | loc, array, offsets, dimSizes, strides); |
| 690 | builder.create<memref::CopyOp>(loc, subview, buffer); |
| 691 | builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to, |
| 692 | commWorld); |
| 693 | builder.create<scf::YieldOp>(loc); |
| 694 | }); |
| 695 | // if has neighbor: receive halo data into buffer and copy to array |
| 696 | rewriter.create<scf::IfOp>( |
| 697 | loc, hasFrom, [&](OpBuilder &builder, Location loc) { |
| 698 | offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset) |
| 699 | : OpFoldResult(lowerRecvOffset); |
| 700 | builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from, |
| 701 | commWorld); |
| 702 | auto subview = builder.create<memref::SubViewOp>( |
| 703 | loc, array, offsets, dimSizes, strides); |
| 704 | builder.create<memref::CopyOp>(loc, buffer, subview); |
| 705 | builder.create<scf::YieldOp>(loc); |
| 706 | }); |
| 707 | rewriter.create<memref::DeallocOp>(loc, buffer); |
| 708 | offsets[dim] = orgOffset; |
| 709 | }; |
| 710 | |
| 711 | auto doSendRecv = [&](int upOrDown) { |
| 712 | OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown]; |
| 713 | Value haloSz = dyn_cast<Value>(v); |
| 714 | if (!haloSz) |
| 715 | haloSz = rewriter.create<arith::ConstantOp>( |
| 716 | loc, rewriter.getI32IntegerAttr( |
| 717 | cast<IntegerAttr>(cast<Attribute>(v)).getInt())); |
| 718 | auto hasSize = rewriter.create<arith::CmpIOp>( |
| 719 | loc, arith::CmpIPredicate::sgt, haloSz, zero); |
| 720 | rewriter.create<scf::IfOp>(loc, hasSize, |
| 721 | [&](OpBuilder &builder, Location loc) { |
| 722 | genSendRecv(upOrDown > 0); |
| 723 | builder.create<scf::YieldOp>(loc); |
| 724 | }); |
| 725 | }; |
| 726 | |
| 727 | doSendRecv(0); |
| 728 | doSendRecv(1); |
| 729 | |
| 730 | // the shape for lower dims include higher dims' halos |
| 731 | dimSizes[dim] = shape[dim]; |
| 732 | // -> the offset for higher dims is always 0 |
| 733 | offsets[dim] = rewriter.getIndexAttr(0); |
| 734 | // on to next halo |
| 735 | --currHaloDim; |
| 736 | } |
| 737 | |
| 738 | if (isa<MemRefType>(op.getResult().getType())) { |
| 739 | rewriter.replaceOp(op, array); |
| 740 | } else { |
| 741 | assert(isa<RankedTensorType>(op.getResult().getType())); |
| 742 | rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>( |
| 743 | loc, op.getResult().getType(), array, |
| 744 | /*restrict=*/true, /*writable=*/true)); |
| 745 | } |
| 746 | return success(); |
| 747 | } |
| 748 | }; |
| 749 | |
| 750 | struct ConvertMeshToMPIPass |
| 751 | : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> { |
| 752 | using Base::Base; |
| 753 | |
| 754 | /// Run the dialect converter on the module. |
| 755 | void runOnOperation() override { |
| 756 | uint64_t worldRank = -1; |
| 757 | // Try to get DLTI attribute for MPI:comm_world_rank |
| 758 | // If found, set worldRank to the value of the attribute. |
| 759 | { |
| 760 | auto dltiAttr = |
| 761 | dlti::query(getOperation(), {"MPI:comm_world_rank" }, false); |
| 762 | if (succeeded(dltiAttr)) { |
| 763 | if (!isa<IntegerAttr>(dltiAttr.value())) { |
| 764 | getOperation()->emitError() |
| 765 | << "Expected an integer attribute for MPI:comm_world_rank" ; |
| 766 | return signalPassFailure(); |
| 767 | } |
| 768 | worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt(); |
| 769 | } |
| 770 | } |
| 771 | |
| 772 | auto *ctxt = &getContext(); |
| 773 | RewritePatternSet patterns(ctxt); |
| 774 | ConversionTarget target(getContext()); |
| 775 | |
| 776 | // Define a type converter to convert mesh::ShardingType, |
| 777 | // mostly for use in return operations. |
| 778 | TypeConverter typeConverter; |
| 779 | typeConverter.addConversion([](Type type) { return type; }); |
| 780 | |
| 781 | // convert mesh::ShardingType to a tuple of RankedTensorTypes |
| 782 | typeConverter.addConversion( |
| 783 | [](ShardingType type, |
| 784 | SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { |
| 785 | auto i16 = IntegerType::get(type.getContext(), 16); |
| 786 | auto i64 = IntegerType::get(type.getContext(), 64); |
| 787 | std::array<int64_t, 2> shp = {ShapedType::kDynamic, |
| 788 | ShapedType::kDynamic}; |
| 789 | results.emplace_back(RankedTensorType::get(shp, i16)); |
| 790 | results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2 |
| 791 | results.emplace_back(RankedTensorType::get(shp, i64)); |
| 792 | return success(); |
| 793 | }); |
| 794 | |
| 795 | // To 'extract' components, a UnrealizedConversionCastOp is expected |
| 796 | // to define the input |
| 797 | typeConverter.addTargetMaterialization( |
| 798 | [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, |
| 799 | Location loc) { |
| 800 | // Expecting a single input. |
| 801 | if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType())) |
| 802 | return SmallVector<Value>(); |
| 803 | auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>(); |
| 804 | // Expecting an UnrealizedConversionCastOp. |
| 805 | if (!castOp) |
| 806 | return SmallVector<Value>(); |
| 807 | // Fill a vector with elements of the tuple/castOp. |
| 808 | SmallVector<Value> results; |
| 809 | for (auto oprnd : castOp.getInputs()) { |
| 810 | if (!isa<RankedTensorType>(oprnd.getType())) |
| 811 | return SmallVector<Value>(); |
| 812 | results.emplace_back(oprnd); |
| 813 | } |
| 814 | return results; |
| 815 | }); |
| 816 | |
| 817 | // No mesh dialect should left after conversion... |
| 818 | target.addIllegalDialect<mesh::MeshDialect>(); |
| 819 | // ...except the global MeshOp |
| 820 | target.addLegalOp<mesh::MeshOp>(); |
| 821 | // Allow all the stuff that our patterns will convert to |
| 822 | target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, |
| 823 | arith::ArithDialect, tensor::TensorDialect, |
| 824 | bufferization::BufferizationDialect, |
| 825 | linalg::LinalgDialect, memref::MemRefDialect>(); |
| 826 | // Make sure the function signature, calls etc. are legal |
| 827 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
| 828 | return typeConverter.isSignatureLegal(op.getFunctionType()); |
| 829 | }); |
| 830 | target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>( |
| 831 | [&](Operation *op) { return typeConverter.isLegal(op); }); |
| 832 | |
| 833 | patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp, |
| 834 | ConvertProcessMultiIndexOp, ConvertGetShardingOp, |
| 835 | ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt); |
| 836 | // ConvertProcessLinearIndexOp accepts an optional worldRank |
| 837 | patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank); |
| 838 | |
| 839 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( |
| 840 | patterns, typeConverter); |
| 841 | populateCallOpTypeConversionPattern(patterns, converter: typeConverter); |
| 842 | populateReturnOpTypeConversionPattern(patterns, converter: typeConverter); |
| 843 | |
| 844 | (void)applyPartialConversion(getOperation(), target, std::move(patterns)); |
| 845 | } |
| 846 | }; |
| 847 | |
| 848 | } // namespace |
| 849 | |