| 1 | //===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a translation of Mesh communication ops tp MPI ops. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Conversion/MeshToMPI/MeshToMPI.h" |
| 14 | |
| 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| 18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 19 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
| 20 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 21 | #include "mlir/Dialect/MPI/IR/MPI.h" |
| 22 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 23 | #include "mlir/Dialect/Mesh/IR/MeshDialect.h" |
| 24 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
| 25 | #include "mlir/Dialect/Mesh/Transforms/Simplifications.h" |
| 26 | #include "mlir/Dialect/Mesh/Transforms/Transforms.h" |
| 27 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 28 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 29 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 30 | #include "mlir/IR/Builders.h" |
| 31 | #include "mlir/IR/BuiltinAttributes.h" |
| 32 | #include "mlir/IR/BuiltinTypes.h" |
| 33 | #include "mlir/IR/PatternMatch.h" |
| 34 | #include "mlir/IR/SymbolTable.h" |
| 35 | #include "mlir/Transforms/DialectConversion.h" |
| 36 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 37 | |
| 38 | #define DEBUG_TYPE "mesh-to-mpi" |
| 39 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 40 | |
| 41 | namespace mlir { |
| 42 | #define GEN_PASS_DEF_CONVERTMESHTOMPIPASS |
| 43 | #include "mlir/Conversion/Passes.h.inc" |
| 44 | } // namespace mlir |
| 45 | |
| 46 | using namespace mlir; |
| 47 | using namespace mesh; |
| 48 | |
| 49 | namespace { |
| 50 | /// Converts a vector of OpFoldResults (ints) into vector of Values of the |
| 51 | /// provided type. |
| 52 | static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc, |
| 53 | llvm::ArrayRef<int64_t> statics, |
| 54 | ValueRange dynamics, |
| 55 | Type type = Type()) { |
| 56 | SmallVector<Value> values; |
| 57 | auto dyn = dynamics.begin(); |
| 58 | Type i64 = b.getI64Type(); |
| 59 | if (!type) |
| 60 | type = i64; |
| 61 | assert((i64 == type || b.getIndexType() == type) && |
| 62 | "expected an i64 or an intex type" ); |
| 63 | for (auto s : statics) { |
| 64 | if (s == ShapedType::kDynamic) { |
| 65 | values.emplace_back(Args: *(dyn++)); |
| 66 | } else { |
| 67 | TypedAttr val = type == i64 ? b.getI64IntegerAttr(value: s) : b.getIndexAttr(value: s); |
| 68 | values.emplace_back(Args: b.create<arith::ConstantOp>(location: loc, args&: type, args&: val)); |
| 69 | } |
| 70 | } |
| 71 | return values; |
| 72 | } |
| 73 | |
| 74 | /// Create operations converting a linear index to a multi-dimensional index. |
| 75 | static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b, |
| 76 | Value linearIndex, |
| 77 | ValueRange dimensions) { |
| 78 | int n = dimensions.size(); |
| 79 | SmallVector<Value> multiIndex(n); |
| 80 | |
| 81 | for (int i = n - 1; i >= 0; --i) { |
| 82 | multiIndex[i] = b.create<arith::RemSIOp>(location: loc, args&: linearIndex, args: dimensions[i]); |
| 83 | if (i > 0) |
| 84 | linearIndex = b.create<arith::DivSIOp>(location: loc, args&: linearIndex, args: dimensions[i]); |
| 85 | } |
| 86 | |
| 87 | return multiIndex; |
| 88 | } |
| 89 | |
| 90 | /// Create operations converting a multi-dimensional index to a linear index. |
| 91 | Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, |
| 92 | ValueRange dimensions) { |
| 93 | |
| 94 | Value linearIndex = b.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 95 | Value stride = b.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 96 | |
| 97 | for (int i = multiIndex.size() - 1; i >= 0; --i) { |
| 98 | Value off = b.create<arith::MulIOp>(location: loc, args: multiIndex[i], args&: stride); |
| 99 | linearIndex = b.create<arith::AddIOp>(location: loc, args&: linearIndex, args&: off); |
| 100 | stride = b.create<arith::MulIOp>(location: loc, args&: stride, args: dimensions[i]); |
| 101 | } |
| 102 | |
| 103 | return linearIndex; |
| 104 | } |
| 105 | |
| 106 | /// Replace GetShardingOp with related/dependent ShardingOp. |
| 107 | struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> { |
| 108 | using OpConversionPattern::OpConversionPattern; |
| 109 | |
| 110 | LogicalResult |
| 111 | matchAndRewrite(GetShardingOp op, OpAdaptor adaptor, |
| 112 | ConversionPatternRewriter &rewriter) const override { |
| 113 | auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>(); |
| 114 | if (!shardOp) |
| 115 | return failure(); |
| 116 | auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>(); |
| 117 | if (!shardingOp) |
| 118 | return failure(); |
| 119 | |
| 120 | rewriter.replaceOp(op, newValues: shardingOp.getResult()); |
| 121 | return success(); |
| 122 | } |
| 123 | }; |
| 124 | |
| 125 | /// Convert a sharding op to a tuple of tensors of its components |
| 126 | /// (SplitAxes, HaloSizes, ShardedDimsOffsets) |
| 127 | /// as defined by type converter. |
| 128 | struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { |
| 129 | using OpConversionPattern::OpConversionPattern; |
| 130 | |
| 131 | LogicalResult |
| 132 | matchAndRewrite(ShardingOp op, OpAdaptor adaptor, |
| 133 | ConversionPatternRewriter &rewriter) const override { |
| 134 | auto splitAxes = op.getSplitAxes().getAxes(); |
| 135 | int64_t maxNAxes = 0; |
| 136 | for (auto axes : splitAxes) |
| 137 | maxNAxes = std::max<int64_t>(a: maxNAxes, b: axes.size()); |
| 138 | |
| 139 | // To hold the split axes, create empty 2d tensor with shape |
| 140 | // {splitAxes.size(), max-size-of-split-groups}. |
| 141 | // Set trailing elements for smaller split-groups to -1. |
| 142 | Location loc = op.getLoc(); |
| 143 | auto i16 = rewriter.getI16Type(); |
| 144 | auto i64 = rewriter.getI64Type(); |
| 145 | std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()), |
| 146 | maxNAxes}; |
| 147 | Value resSplitAxes = rewriter.create<tensor::EmptyOp>(location: loc, args&: shape, args&: i16); |
| 148 | auto attr = IntegerAttr::get(type: i16, value: -1); |
| 149 | Value fillValue = rewriter.create<arith::ConstantOp>(location: loc, args&: i16, args&: attr); |
| 150 | resSplitAxes = rewriter.create<linalg::FillOp>(location: loc, args&: fillValue, args&: resSplitAxes) |
| 151 | .getResult(i: 0); |
| 152 | |
| 153 | // explicitly write values into tensor row by row |
| 154 | std::array<int64_t, 2> strides = {1, 1}; |
| 155 | int64_t nSplits = 0; |
| 156 | ValueRange empty = {}; |
| 157 | for (auto [i, axes] : llvm::enumerate(First&: splitAxes)) { |
| 158 | int64_t size = axes.size(); |
| 159 | if (size > 0) |
| 160 | ++nSplits; |
| 161 | std::array<int64_t, 2> offs = {(int64_t)i, 0}; |
| 162 | std::array<int64_t, 2> sizes = {1, size}; |
| 163 | auto tensorType = RankedTensorType::get(shape: {size}, elementType: i16); |
| 164 | auto attrs = DenseIntElementsAttr::get(type: tensorType, arg: axes.asArrayRef()); |
| 165 | auto vals = rewriter.create<arith::ConstantOp>(location: loc, args&: tensorType, args&: attrs); |
| 166 | resSplitAxes = rewriter.create<tensor::InsertSliceOp>( |
| 167 | location: loc, args&: vals, args&: resSplitAxes, args&: empty, args&: empty, args&: empty, args&: offs, args&: sizes, args&: strides); |
| 168 | } |
| 169 | |
| 170 | // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}. |
| 171 | // Store the halo sizes in the tensor. |
| 172 | SmallVector<Value> haloSizes = |
| 173 | getMixedAsValues(b: rewriter, loc, statics: adaptor.getStaticHaloSizes(), |
| 174 | dynamics: adaptor.getDynamicHaloSizes()); |
| 175 | auto type = RankedTensorType::get(shape: {nSplits, 2}, elementType: i64); |
| 176 | Value resHaloSizes = |
| 177 | haloSizes.empty() |
| 178 | ? rewriter |
| 179 | .create<tensor::EmptyOp>(location: loc, args: std::array<int64_t, 2>{0, 0}, |
| 180 | args&: i64) |
| 181 | .getResult() |
| 182 | : rewriter.create<tensor::FromElementsOp>(location: loc, args&: type, args&: haloSizes) |
| 183 | .getResult(); |
| 184 | |
| 185 | // To hold sharded dims offsets, create Tensor with shape {nSplits, |
| 186 | // maxSplitSize+1}. Store the offsets in the tensor but set trailing |
| 187 | // elements for smaller split-groups to -1. Computing the max size of the |
| 188 | // split groups needs using collectiveProcessGroupSize (which needs the |
| 189 | // MeshOp) |
| 190 | Value resOffsets; |
| 191 | if (adaptor.getStaticShardedDimsOffsets().empty()) { |
| 192 | resOffsets = rewriter.create<tensor::EmptyOp>( |
| 193 | location: loc, args: std::array<int64_t, 2>{0, 0}, args&: i64); |
| 194 | } else { |
| 195 | SymbolTableCollection symbolTableCollection; |
| 196 | auto meshOp = getMesh(op, symbolTableCollection); |
| 197 | int64_t maxSplitSize = 0; |
| 198 | for (auto axes : splitAxes) { |
| 199 | int64_t splitSize = |
| 200 | collectiveProcessGroupSize(meshAxes: axes.asArrayRef(), meshShape: meshOp.getShape()); |
| 201 | assert(splitSize != ShapedType::kDynamic); |
| 202 | maxSplitSize = std::max<int64_t>(a: maxSplitSize, b: splitSize); |
| 203 | } |
| 204 | assert(maxSplitSize); |
| 205 | ++maxSplitSize; // add one for the total size |
| 206 | |
| 207 | resOffsets = rewriter.create<tensor::EmptyOp>( |
| 208 | location: loc, args: std::array<int64_t, 2>{nSplits, maxSplitSize}, args&: i64); |
| 209 | Value zero = rewriter.create<arith::ConstantOp>( |
| 210 | location: loc, args&: i64, args: rewriter.getI64IntegerAttr(value: ShapedType::kDynamic)); |
| 211 | resOffsets = |
| 212 | rewriter.create<linalg::FillOp>(location: loc, args&: zero, args&: resOffsets).getResult(i: 0); |
| 213 | SmallVector<Value> offsets = |
| 214 | getMixedAsValues(b: rewriter, loc, statics: adaptor.getStaticShardedDimsOffsets(), |
| 215 | dynamics: adaptor.getDynamicShardedDimsOffsets()); |
| 216 | int64_t curr = 0; |
| 217 | for (auto [i, axes] : llvm::enumerate(First&: splitAxes)) { |
| 218 | int64_t splitSize = |
| 219 | collectiveProcessGroupSize(meshAxes: axes.asArrayRef(), meshShape: meshOp.getShape()); |
| 220 | assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); |
| 221 | ++splitSize; // add one for the total size |
| 222 | ArrayRef<Value> values(&offsets[curr], splitSize); |
| 223 | Value vals = rewriter.create<tensor::FromElementsOp>(location: loc, args&: values); |
| 224 | std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0}; |
| 225 | std::array<int64_t, 2> sizes = {1, splitSize}; |
| 226 | resOffsets = rewriter.create<tensor::InsertSliceOp>( |
| 227 | location: loc, args&: vals, args&: resOffsets, args&: empty, args&: empty, args&: empty, args&: offs, args&: sizes, args&: strides); |
| 228 | curr += splitSize; |
| 229 | } |
| 230 | } |
| 231 | |
| 232 | // return a tuple of tensors as defined by type converter |
| 233 | SmallVector<Type> resTypes; |
| 234 | if (failed(Result: getTypeConverter()->convertType(t: op.getResult().getType(), |
| 235 | results&: resTypes))) |
| 236 | return failure(); |
| 237 | |
| 238 | resSplitAxes = |
| 239 | rewriter.create<tensor::CastOp>(location: loc, args&: resTypes[0], args&: resSplitAxes); |
| 240 | resHaloSizes = |
| 241 | rewriter.create<tensor::CastOp>(location: loc, args&: resTypes[1], args&: resHaloSizes); |
| 242 | resOffsets = rewriter.create<tensor::CastOp>(location: loc, args&: resTypes[2], args&: resOffsets); |
| 243 | |
| 244 | rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>( |
| 245 | op, args: TupleType::get(context: op.getContext(), elementTypes: resTypes), |
| 246 | args: ValueRange{resSplitAxes, resHaloSizes, resOffsets}); |
| 247 | |
| 248 | return success(); |
| 249 | } |
| 250 | }; |
| 251 | |
| 252 | struct ConvertProcessMultiIndexOp |
| 253 | : public OpConversionPattern<ProcessMultiIndexOp> { |
| 254 | using OpConversionPattern::OpConversionPattern; |
| 255 | |
| 256 | LogicalResult |
| 257 | matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor, |
| 258 | ConversionPatternRewriter &rewriter) const override { |
| 259 | |
| 260 | // Currently converts its linear index to a multi-dimensional index. |
| 261 | |
| 262 | SymbolTableCollection symbolTableCollection; |
| 263 | Location loc = op.getLoc(); |
| 264 | auto meshOp = getMesh(op, symbolTableCollection); |
| 265 | // For now we only support static mesh shapes |
| 266 | if (ShapedType::isDynamicShape(dSizes: meshOp.getShape())) |
| 267 | return failure(); |
| 268 | |
| 269 | SmallVector<Value> dims; |
| 270 | llvm::transform( |
| 271 | Range: meshOp.getShape(), d_first: std::back_inserter(x&: dims), F: [&](int64_t i) { |
| 272 | return rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i).getResult(); |
| 273 | }); |
| 274 | Value rank = rewriter.create<ProcessLinearIndexOp>(location: op.getLoc(), args&: meshOp); |
| 275 | auto mIdx = linearToMultiIndex(loc, b: rewriter, linearIndex: rank, dimensions: dims); |
| 276 | |
| 277 | // optionally extract subset of mesh axes |
| 278 | auto axes = adaptor.getAxes(); |
| 279 | if (!axes.empty()) { |
| 280 | SmallVector<Value> subIndex; |
| 281 | for (auto axis : axes) { |
| 282 | subIndex.emplace_back(Args&: mIdx[axis]); |
| 283 | } |
| 284 | mIdx = std::move(subIndex); |
| 285 | } |
| 286 | |
| 287 | rewriter.replaceOp(op, newValues: mIdx); |
| 288 | return success(); |
| 289 | } |
| 290 | }; |
| 291 | |
| 292 | class ConvertProcessLinearIndexOp |
| 293 | : public OpConversionPattern<ProcessLinearIndexOp> { |
| 294 | |
| 295 | public: |
| 296 | using OpConversionPattern::OpConversionPattern; |
| 297 | |
| 298 | LogicalResult |
| 299 | matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor, |
| 300 | ConversionPatternRewriter &rewriter) const override { |
| 301 | // Create mpi::CommRankOp |
| 302 | Location loc = op.getLoc(); |
| 303 | auto ctx = op.getContext(); |
| 304 | Value commWorld = |
| 305 | rewriter.create<mpi::CommWorldOp>(location: loc, args: mpi::CommType::get(ctx)); |
| 306 | auto rank = |
| 307 | rewriter |
| 308 | .create<mpi::CommRankOp>( |
| 309 | location: loc, |
| 310 | args: TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, |
| 311 | args&: commWorld) |
| 312 | .getRank(); |
| 313 | rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, args: rewriter.getIndexType(), |
| 314 | args&: rank); |
| 315 | return success(); |
| 316 | } |
| 317 | }; |
| 318 | |
| 319 | struct ConvertNeighborsLinearIndicesOp |
| 320 | : public OpConversionPattern<NeighborsLinearIndicesOp> { |
| 321 | using OpConversionPattern::OpConversionPattern; |
| 322 | |
| 323 | LogicalResult |
| 324 | matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor, |
| 325 | ConversionPatternRewriter &rewriter) const override { |
| 326 | |
| 327 | // Computes the neighbors indices along a split axis by simply |
| 328 | // adding/subtracting 1 to the current index in that dimension. |
| 329 | // Assigns -1 if neighbor is out of bounds. |
| 330 | |
| 331 | auto axes = adaptor.getSplitAxes(); |
| 332 | // For now only single axis sharding is supported |
| 333 | if (axes.size() != 1) |
| 334 | return failure(); |
| 335 | |
| 336 | Location loc = op.getLoc(); |
| 337 | SymbolTableCollection symbolTableCollection; |
| 338 | auto meshOp = getMesh(op, symbolTableCollection); |
| 339 | auto mIdx = adaptor.getDevice(); |
| 340 | auto orgIdx = mIdx[axes[0]]; |
| 341 | SmallVector<Value> dims; |
| 342 | llvm::transform( |
| 343 | Range: meshOp.getShape(), d_first: std::back_inserter(x&: dims), F: [&](int64_t i) { |
| 344 | return rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i).getResult(); |
| 345 | }); |
| 346 | Value dimSz = dims[axes[0]]; |
| 347 | Value one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 348 | Value minus1 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: -1); |
| 349 | Value atBorder = rewriter.create<arith::CmpIOp>( |
| 350 | location: loc, args: arith::CmpIPredicate::sle, args&: orgIdx, |
| 351 | args: rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0)); |
| 352 | auto down = rewriter.create<scf::IfOp>( |
| 353 | location: loc, args&: atBorder, |
| 354 | args: [&](OpBuilder &builder, Location loc) { |
| 355 | builder.create<scf::YieldOp>(location: loc, args&: minus1); |
| 356 | }, |
| 357 | args: [&](OpBuilder &builder, Location loc) { |
| 358 | SmallVector<Value> tmp = mIdx; |
| 359 | tmp[axes[0]] = |
| 360 | rewriter.create<arith::SubIOp>(location: op.getLoc(), args&: orgIdx, args&: one) |
| 361 | .getResult(); |
| 362 | builder.create<scf::YieldOp>( |
| 363 | location: loc, args: multiToLinearIndex(loc, b: rewriter, multiIndex: tmp, dimensions: dims)); |
| 364 | }); |
| 365 | atBorder = rewriter.create<arith::CmpIOp>( |
| 366 | location: loc, args: arith::CmpIPredicate::sge, args&: orgIdx, |
| 367 | args: rewriter.create<arith::SubIOp>(location: loc, args&: dimSz, args&: one).getResult()); |
| 368 | auto up = rewriter.create<scf::IfOp>( |
| 369 | location: loc, args&: atBorder, |
| 370 | args: [&](OpBuilder &builder, Location loc) { |
| 371 | builder.create<scf::YieldOp>(location: loc, args&: minus1); |
| 372 | }, |
| 373 | args: [&](OpBuilder &builder, Location loc) { |
| 374 | SmallVector<Value> tmp = mIdx; |
| 375 | tmp[axes[0]] = |
| 376 | rewriter.create<arith::AddIOp>(location: op.getLoc(), args&: orgIdx, args&: one); |
| 377 | builder.create<scf::YieldOp>( |
| 378 | location: loc, args: multiToLinearIndex(loc, b: rewriter, multiIndex: tmp, dimensions: dims)); |
| 379 | }); |
| 380 | rewriter.replaceOp(op, newValues: ValueRange{down.getResult(i: 0), up.getResult(i: 0)}); |
| 381 | return success(); |
| 382 | } |
| 383 | }; |
| 384 | |
| 385 | struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { |
| 386 | using OpConversionPattern::OpConversionPattern; |
| 387 | |
| 388 | LogicalResult |
| 389 | matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor, |
| 390 | ConversionPatternRewriter &rewriter) const override { |
| 391 | auto sharding = op.getSharding().getDefiningOp<ShardingOp>(); |
| 392 | if (!sharding) { |
| 393 | return op->emitError() |
| 394 | << "Expected SharingOp as defining op for sharding" |
| 395 | << " but found " << adaptor.getSharding()[0].getDefiningOp(); |
| 396 | } |
| 397 | |
| 398 | // Compute the sharded shape by applying the sharding to the input shape. |
| 399 | // If shardedDimsOffsets is not defined in the sharding, the shard shape is |
| 400 | // computed by dividing the dimension size by the number of shards in that |
| 401 | // dimension (which is given by the size of the mesh axes provided in |
| 402 | // split-axes). Odd elements get distributed to trailing shards. If a |
| 403 | // shardedDimsOffsets is provided, the shard shape is computed by |
| 404 | // subtracting the offset of the current shard from the offset of the next |
| 405 | // shard. |
| 406 | |
| 407 | Location loc = op.getLoc(); |
| 408 | Type index = rewriter.getIndexType(); |
| 409 | |
| 410 | // This is a 1:N conversion because the sharding op is a 1:3 conversion. |
| 411 | // The operands in the adaptor are a vector<ValeRange>. For dims and device |
| 412 | // we have a 1:1 conversion. |
| 413 | // For simpler access fill a vector with the dynamic dims. |
| 414 | SmallVector<Value> dynDims, dynDevice; |
| 415 | for (auto dim : adaptor.getDimsDynamic()) { |
| 416 | // type conversion should be 1:1 for ints |
| 417 | dynDims.emplace_back(Args: llvm::getSingleElement(C&: dim)); |
| 418 | } |
| 419 | // same for device |
| 420 | for (auto device : adaptor.getDeviceDynamic()) { |
| 421 | dynDevice.emplace_back(Args: llvm::getSingleElement(C&: device)); |
| 422 | } |
| 423 | |
| 424 | // To keep the code simple, convert dims/device to values when they are |
| 425 | // attributes. Count on canonicalization to fold static values. |
| 426 | SmallVector<Value> shape = |
| 427 | getMixedAsValues(b: rewriter, loc, statics: op.getDims(), dynamics: dynDims, type: index); |
| 428 | SmallVector<Value> multiIdx = |
| 429 | getMixedAsValues(b: rewriter, loc, statics: adaptor.getDevice(), dynamics: dynDevice, type: index); |
| 430 | |
| 431 | // Get the MeshOp, the mesh shape is needed to compute the sharded shape. |
| 432 | SymbolTableCollection symbolTableCollection; |
| 433 | auto meshOp = getMesh(op: sharding, symbolTableCollection); |
| 434 | // For now we only support static mesh shapes |
| 435 | if (ShapedType::isDynamicShape(dSizes: meshOp.getShape())) |
| 436 | return failure(); |
| 437 | |
| 438 | auto splitAxes = sharding.getSplitAxes().getAxes(); |
| 439 | // shardedDimsOffsets are optional and might be Values (not attributes). |
| 440 | // Also, the shardId might be dynamic which means the position in the |
| 441 | // shardedDimsOffsets is not statically known. Create a tensor of the |
| 442 | // shardedDimsOffsets and later extract the offsets for computing the |
| 443 | // local shard-size. |
| 444 | Value shardedDimsOffs; |
| 445 | { |
| 446 | SmallVector<Value> tmp = getMixedAsValues( |
| 447 | b: rewriter, loc, statics: sharding.getStaticShardedDimsOffsets(), |
| 448 | dynamics: sharding.getDynamicShardedDimsOffsets(), type: index); |
| 449 | if (!tmp.empty()) |
| 450 | shardedDimsOffs = rewriter.create<tensor::FromElementsOp>( |
| 451 | location: loc, args: RankedTensorType::get(shape: {(int64_t)tmp.size()}, elementType: index), args&: tmp); |
| 452 | } |
| 453 | |
| 454 | // With static mesh shape the sizes of the split axes are known. |
| 455 | // Hence the start/pos for each split axes in shardDimsOffsets can be |
| 456 | // computed statically. |
| 457 | int64_t pos = 0; |
| 458 | SmallVector<Value> shardShape; |
| 459 | Value zero = |
| 460 | rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getZeroAttr(type: index)); |
| 461 | Value one = |
| 462 | rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getOneAttr(type: index)); |
| 463 | |
| 464 | // Iterate over the dimensions of the tensor shape, get their split Axes, |
| 465 | // and compute the sharded shape. |
| 466 | for (auto [i, dim] : llvm::enumerate(First&: shape)) { |
| 467 | // Trailing dimensions might not be annotated. |
| 468 | if (i < splitAxes.size() && !splitAxes[i].empty()) { |
| 469 | auto axes = splitAxes[i]; |
| 470 | // The current dimension might not be sharded. |
| 471 | // Create a value from the static position in shardDimsOffsets. |
| 472 | Value posVal = |
| 473 | rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getIndexAttr(value: pos)); |
| 474 | // Get the index of the local shard in the mesh axis. |
| 475 | Value idx = multiIdx[axes[0]]; |
| 476 | auto numShards = |
| 477 | collectiveProcessGroupSize(meshAxes: axes.asArrayRef(), meshShape: meshOp.getShape()); |
| 478 | if (shardedDimsOffs) { |
| 479 | // If sharded dims offsets are provided, use them to compute the |
| 480 | // sharded shape. |
| 481 | if (axes.size() > 1) { |
| 482 | return op->emitError() << "Only single axis sharding is " |
| 483 | << "supported for each dimension." ; |
| 484 | } |
| 485 | idx = rewriter.create<arith::AddIOp>(location: loc, args&: posVal, args&: idx); |
| 486 | // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx]. |
| 487 | Value off = |
| 488 | rewriter.create<tensor::ExtractOp>(location: loc, args&: shardedDimsOffs, args&: idx); |
| 489 | idx = rewriter.create<arith::AddIOp>(location: loc, args&: idx, args&: one); |
| 490 | Value nextOff = |
| 491 | rewriter.create<tensor::ExtractOp>(location: loc, args&: shardedDimsOffs, args&: idx); |
| 492 | Value sz = rewriter.create<arith::SubIOp>(location: loc, args&: nextOff, args&: off); |
| 493 | shardShape.emplace_back(Args&: sz); |
| 494 | } else { |
| 495 | Value numShardsVal = rewriter.create<arith::ConstantOp>( |
| 496 | location: loc, args: rewriter.getIndexAttr(value: numShards)); |
| 497 | // Compute shard dim size by distributing odd elements to trailing |
| 498 | // shards: |
| 499 | // sz = dim / numShards |
| 500 | // + (idx >= (numShards - (dim % numShards)) ? 1 : 0) |
| 501 | Value sz = rewriter.create<arith::DivSIOp>(location: loc, args&: dim, args&: numShardsVal); |
| 502 | Value sz1 = rewriter.create<arith::RemSIOp>(location: loc, args&: dim, args&: numShardsVal); |
| 503 | sz1 = rewriter.create<arith::SubIOp>(location: loc, args&: numShardsVal, args&: sz1); |
| 504 | auto cond = rewriter.create<arith::CmpIOp>( |
| 505 | location: loc, args: arith::CmpIPredicate::sge, args&: idx, args&: sz1); |
| 506 | Value odd = rewriter.create<arith::SelectOp>(location: loc, args&: cond, args&: one, args&: zero); |
| 507 | sz = rewriter.create<arith::AddIOp>(location: loc, args&: sz, args&: odd); |
| 508 | shardShape.emplace_back(Args&: sz); |
| 509 | } |
| 510 | pos += numShards + 1; // add one for the total size. |
| 511 | } // else no sharding if split axis is empty or no split axis |
| 512 | // If no size was added -> no sharding in this dimension. |
| 513 | if (shardShape.size() <= i) |
| 514 | shardShape.emplace_back(Args&: dim); |
| 515 | } |
| 516 | assert(shardShape.size() == shape.size()); |
| 517 | rewriter.replaceOp(op, newValues: shardShape); |
| 518 | return success(); |
| 519 | } |
| 520 | }; |
| 521 | |
| 522 | static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) { |
| 523 | auto ctx = kind.getContext(); |
| 524 | auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) { |
| 525 | return mpi::MPI_ReductionOpEnumAttr::get(context: ctx, val: redOp); |
| 526 | }; |
| 527 | |
| 528 | switch (kind.getValue()) { |
| 529 | case ReductionKind::Sum: |
| 530 | return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM); |
| 531 | case ReductionKind::Product: |
| 532 | return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD); |
| 533 | case ReductionKind::Min: |
| 534 | return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN); |
| 535 | case ReductionKind::Max: |
| 536 | return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX); |
| 537 | case ReductionKind::BitwiseAnd: |
| 538 | return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND); |
| 539 | case ReductionKind::BitwiseOr: |
| 540 | return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR); |
| 541 | case ReductionKind::BitwiseXor: |
| 542 | return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR); |
| 543 | default: |
| 544 | llvm_unreachable("Unknown/unsupported reduction kind" ); |
| 545 | } |
| 546 | } |
| 547 | |
| 548 | struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { |
| 549 | using OpConversionPattern::OpConversionPattern; |
| 550 | |
| 551 | LogicalResult |
| 552 | matchAndRewrite(AllReduceOp op, OpAdaptor adaptor, |
| 553 | ConversionPatternRewriter &rewriter) const override { |
| 554 | SymbolTableCollection symbolTableCollection; |
| 555 | auto mesh = adaptor.getMesh(); |
| 556 | mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection); |
| 557 | if (!meshOp) |
| 558 | return op->emitError() << "No mesh found for AllReduceOp" ; |
| 559 | if (ShapedType::isDynamicShape(dSizes: meshOp.getShape())) |
| 560 | return op->emitError() |
| 561 | << "Dynamic mesh shape not supported in AllReduceOp" ; |
| 562 | |
| 563 | ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); |
| 564 | Value input = adaptor.getInput(); |
| 565 | auto inputShape = cast<ShapedType>(Val: input.getType()).getShape(); |
| 566 | |
| 567 | // If the source is a memref, cast it to a tensor. |
| 568 | if (isa<RankedTensorType>(Val: input.getType())) { |
| 569 | auto memrefType = MemRefType::get( |
| 570 | shape: inputShape, elementType: cast<ShapedType>(Val: input.getType()).getElementType()); |
| 571 | input = iBuilder.create<bufferization::ToBufferOp>(args&: memrefType, args&: input); |
| 572 | } |
| 573 | MemRefType inType = cast<MemRefType>(Val: input.getType()); |
| 574 | |
| 575 | // Get the actual shape to allocate the buffer. |
| 576 | SmallVector<OpFoldResult> shape(inType.getRank()); |
| 577 | for (auto i = 0; i < inType.getRank(); ++i) { |
| 578 | auto s = inputShape[i]; |
| 579 | if (ShapedType::isDynamic(dValue: s)) |
| 580 | shape[i] = iBuilder.create<memref::DimOp>(args&: input, args&: s).getResult(); |
| 581 | else |
| 582 | shape[i] = iBuilder.getIndexAttr(value: s); |
| 583 | } |
| 584 | |
| 585 | // Allocate buffer and copy input to buffer. |
| 586 | Value buffer = iBuilder.create<memref::AllocOp>( |
| 587 | args&: shape, args: cast<ShapedType>(Val: op.getType()).getElementType()); |
| 588 | iBuilder.create<linalg::CopyOp>(args&: input, args&: buffer); |
| 589 | |
| 590 | // Get an MPI_Comm_split for the AllReduce operation. |
| 591 | // The color is the linear index of the process in the mesh along the |
| 592 | // non-reduced axes. The key is the linear index of the process in the mesh |
| 593 | // along the reduced axes. |
| 594 | SmallVector<Type> indexResultTypes(meshOp.getShape().size(), |
| 595 | iBuilder.getIndexType()); |
| 596 | SmallVector<Value> myMultiIndex = |
| 597 | iBuilder.create<ProcessMultiIndexOp>(args&: indexResultTypes, args&: mesh) |
| 598 | .getResult(); |
| 599 | Value zero = iBuilder.create<arith::ConstantIndexOp>(args: 0); |
| 600 | SmallVector<Value> multiKey(myMultiIndex.size(), zero); |
| 601 | |
| 602 | auto redAxes = adaptor.getMeshAxes(); |
| 603 | for (auto axis : redAxes) { |
| 604 | multiKey[axis] = myMultiIndex[axis]; |
| 605 | myMultiIndex[axis] = zero; |
| 606 | } |
| 607 | |
| 608 | Value color = |
| 609 | createProcessLinearIndex(mesh, processInGroupMultiIndex: myMultiIndex, meshAxes: redAxes, builder&: iBuilder); |
| 610 | color = iBuilder.create<arith::IndexCastOp>(args: iBuilder.getI32Type(), args&: color); |
| 611 | Value key = createProcessLinearIndex(mesh, processInGroupMultiIndex: multiKey, meshAxes: redAxes, builder&: iBuilder); |
| 612 | key = iBuilder.create<arith::IndexCastOp>(args: iBuilder.getI32Type(), args&: key); |
| 613 | |
| 614 | // Finally split the communicator |
| 615 | auto commType = mpi::CommType::get(ctx: op->getContext()); |
| 616 | Value commWorld = iBuilder.create<mpi::CommWorldOp>(args&: commType); |
| 617 | auto comm = |
| 618 | iBuilder.create<mpi::CommSplitOp>(args&: commType, args&: commWorld, args&: color, args&: key) |
| 619 | .getNewcomm(); |
| 620 | |
| 621 | Value buffer1d = buffer; |
| 622 | // Collapse shape to 1d if needed |
| 623 | if (inType.getRank() > 1) { |
| 624 | ReassociationIndices reassociation(inType.getRank()); |
| 625 | std::iota(first: reassociation.begin(), last: reassociation.end(), value: 0); |
| 626 | buffer1d = iBuilder.create<memref::CollapseShapeOp>( |
| 627 | args&: buffer, args: ArrayRef<ReassociationIndices>(reassociation)); |
| 628 | } |
| 629 | |
| 630 | // Create the MPI AllReduce operation. |
| 631 | iBuilder.create<mpi::AllReduceOp>( |
| 632 | args: TypeRange(), args&: buffer1d, args&: buffer1d, |
| 633 | args: getMPIReductionOp(kind: adaptor.getReductionAttr()), args&: comm); |
| 634 | |
| 635 | // If the destination is a memref, cast it to a tensor |
| 636 | if (isa<RankedTensorType>(Val: op.getType())) |
| 637 | buffer = iBuilder.create<bufferization::ToTensorOp>(args: op.getType(), args&: buffer, |
| 638 | args: true); |
| 639 | |
| 640 | rewriter.replaceOp(op, newValues: buffer); |
| 641 | return success(); |
| 642 | } |
| 643 | }; |
| 644 | |
| 645 | struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { |
| 646 | using OpConversionPattern::OpConversionPattern; |
| 647 | |
| 648 | LogicalResult |
| 649 | matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor, |
| 650 | ConversionPatternRewriter &rewriter) const override { |
| 651 | |
| 652 | // The input/output memref is assumed to be in C memory order. |
| 653 | // Halos are exchanged as 2 blocks per dimension (one for each side: down |
| 654 | // and up). For each haloed dimension `d`, the exchanged blocks are |
| 655 | // expressed as multi-dimensional subviews. The subviews include potential |
| 656 | // halos of higher dimensions `dh > d`, no halos for the lower dimensions |
| 657 | // `dl < d` and for dimension `d` the currently exchanged halo only. |
| 658 | // By iterating form higher to lower dimensions this also updates the halos |
| 659 | // in the 'corners'. |
| 660 | // memref.subview is used to read and write the halo data from and to the |
| 661 | // local data. Because subviews and halos can have mixed dynamic and static |
| 662 | // shapes, OpFoldResults are used whenever possible. |
| 663 | |
| 664 | auto haloSizes = getMixedValues(staticValues: adaptor.getStaticHaloSizes(), |
| 665 | dynamicValues: adaptor.getHaloSizes(), b&: rewriter); |
| 666 | if (haloSizes.empty()) { |
| 667 | // no halos -> nothing to do |
| 668 | rewriter.replaceOp(op, newValues: adaptor.getDestination()); |
| 669 | return success(); |
| 670 | } |
| 671 | |
| 672 | SymbolTableCollection symbolTableCollection; |
| 673 | Location loc = op.getLoc(); |
| 674 | |
| 675 | // convert a OpFoldResult into a Value |
| 676 | auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value { |
| 677 | if (auto value = dyn_cast<Value>(Val&: v)) |
| 678 | return value; |
| 679 | return rewriter.create<arith::ConstantOp>( |
| 680 | location: loc, args: rewriter.getIndexAttr( |
| 681 | value: cast<IntegerAttr>(Val: cast<Attribute>(Val&: v)).getInt())); |
| 682 | }; |
| 683 | |
| 684 | auto dest = adaptor.getDestination(); |
| 685 | auto dstShape = cast<ShapedType>(Val: dest.getType()).getShape(); |
| 686 | Value array = dest; |
| 687 | if (isa<RankedTensorType>(Val: array.getType())) { |
| 688 | // If the destination is a memref, we need to cast it to a tensor |
| 689 | auto mmemrefType = MemRefType::get( |
| 690 | shape: dstShape, elementType: cast<ShapedType>(Val: array.getType()).getElementType()); |
| 691 | array = |
| 692 | rewriter.create<bufferization::ToBufferOp>(location: loc, args&: mmemrefType, args&: array); |
| 693 | } |
| 694 | auto rank = cast<ShapedType>(Val: array.getType()).getRank(); |
| 695 | auto opSplitAxes = adaptor.getSplitAxes().getAxes(); |
| 696 | auto mesh = adaptor.getMesh(); |
| 697 | auto meshOp = getMesh(op, symbolTableCollection); |
| 698 | // subviews need Index values |
| 699 | for (auto &sz : haloSizes) { |
| 700 | if (auto value = dyn_cast<Value>(Val&: sz)) |
| 701 | sz = |
| 702 | rewriter |
| 703 | .create<arith::IndexCastOp>(location: loc, args: rewriter.getIndexType(), args&: value) |
| 704 | .getResult(); |
| 705 | } |
| 706 | |
| 707 | // most of the offset/size/stride data is the same for all dims |
| 708 | SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(value: 0)); |
| 709 | SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(value: 1)); |
| 710 | SmallVector<OpFoldResult> shape(rank), dimSizes(rank); |
| 711 | auto currHaloDim = -1; // halo sizes are provided for split dimensions only |
| 712 | // we need the actual shape to compute offsets and sizes |
| 713 | for (auto i = 0; i < rank; ++i) { |
| 714 | auto s = dstShape[i]; |
| 715 | if (ShapedType::isDynamic(dValue: s)) |
| 716 | shape[i] = rewriter.create<memref::DimOp>(location: loc, args&: array, args&: s).getResult(); |
| 717 | else |
| 718 | shape[i] = rewriter.getIndexAttr(value: s); |
| 719 | |
| 720 | if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) { |
| 721 | ++currHaloDim; |
| 722 | // the offsets for lower dim sstarts after their down halo |
| 723 | offsets[i] = haloSizes[currHaloDim * 2]; |
| 724 | |
| 725 | // prepare shape and offsets of highest dim's halo exchange |
| 726 | Value _haloSz = rewriter.create<arith::AddIOp>( |
| 727 | location: loc, args: toValue(haloSizes[currHaloDim * 2]), |
| 728 | args: toValue(haloSizes[currHaloDim * 2 + 1])); |
| 729 | // the halo shape of lower dims exlude the halos |
| 730 | dimSizes[i] = |
| 731 | rewriter.create<arith::SubIOp>(location: loc, args: toValue(shape[i]), args&: _haloSz) |
| 732 | .getResult(); |
| 733 | } else { |
| 734 | dimSizes[i] = shape[i]; |
| 735 | } |
| 736 | } |
| 737 | |
| 738 | auto tagAttr = rewriter.getI32IntegerAttr(value: 91); // we just pick something |
| 739 | auto tag = rewriter.create<arith::ConstantOp>(location: loc, args&: tagAttr); |
| 740 | auto zeroAttr = rewriter.getI32IntegerAttr(value: 0); // for detecting v<0 |
| 741 | auto zero = rewriter.create<arith::ConstantOp>(location: loc, args&: zeroAttr); |
| 742 | |
| 743 | SmallVector<Type> indexResultTypes(meshOp.getShape().size(), |
| 744 | rewriter.getIndexType()); |
| 745 | auto myMultiIndex = |
| 746 | rewriter.create<ProcessMultiIndexOp>(location: loc, args&: indexResultTypes, args&: mesh) |
| 747 | .getResult(); |
| 748 | // traverse all split axes from high to low dim |
| 749 | for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { |
| 750 | auto splitAxes = opSplitAxes[dim]; |
| 751 | if (splitAxes.empty()) |
| 752 | continue; |
| 753 | assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); |
| 754 | // Get the linearized ids of the neighbors (down and up) for the |
| 755 | // given split |
| 756 | auto tmp = rewriter |
| 757 | .create<NeighborsLinearIndicesOp>(location: loc, args&: mesh, args&: myMultiIndex, |
| 758 | args&: splitAxes) |
| 759 | .getResults(); |
| 760 | // MPI operates on i32... |
| 761 | Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>( |
| 762 | location: loc, args: rewriter.getI32Type(), args: tmp[0]), |
| 763 | rewriter.create<arith::IndexCastOp>( |
| 764 | location: loc, args: rewriter.getI32Type(), args: tmp[1])}; |
| 765 | |
| 766 | auto lowerRecvOffset = rewriter.getIndexAttr(value: 0); |
| 767 | auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]); |
| 768 | auto upperRecvOffset = rewriter.create<arith::SubIOp>( |
| 769 | location: loc, args: toValue(shape[dim]), args: toValue(haloSizes[currHaloDim * 2 + 1])); |
| 770 | auto upperSendOffset = rewriter.create<arith::SubIOp>( |
| 771 | location: loc, args&: upperRecvOffset, args: toValue(haloSizes[currHaloDim * 2])); |
| 772 | |
| 773 | Value commWorld = rewriter.create<mpi::CommWorldOp>( |
| 774 | location: loc, args: mpi::CommType::get(ctx: op->getContext())); |
| 775 | |
| 776 | // Make sure we send/recv in a way that does not lead to a dead-lock. |
| 777 | // The current approach is by far not optimal, this should be at least |
| 778 | // be a red-black pattern or using MPI_sendrecv. |
| 779 | // Also, buffers should be re-used. |
| 780 | // Still using temporary contiguous buffers for MPI communication... |
| 781 | // Still yielding a "serialized" communication pattern... |
| 782 | auto genSendRecv = [&](bool upperHalo) { |
| 783 | auto orgOffset = offsets[dim]; |
| 784 | dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] |
| 785 | : haloSizes[currHaloDim * 2]; |
| 786 | // Check if we need to send and/or receive |
| 787 | // Processes on the mesh borders have only one neighbor |
| 788 | auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; |
| 789 | auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; |
| 790 | auto hasFrom = rewriter.create<arith::CmpIOp>( |
| 791 | location: loc, args: arith::CmpIPredicate::sge, args&: from, args&: zero); |
| 792 | auto hasTo = rewriter.create<arith::CmpIOp>( |
| 793 | location: loc, args: arith::CmpIPredicate::sge, args&: to, args&: zero); |
| 794 | auto buffer = rewriter.create<memref::AllocOp>( |
| 795 | location: loc, args&: dimSizes, args: cast<ShapedType>(Val: array.getType()).getElementType()); |
| 796 | // if has neighbor: copy halo data from array to buffer and send |
| 797 | rewriter.create<scf::IfOp>( |
| 798 | location: loc, args&: hasTo, args: [&](OpBuilder &builder, Location loc) { |
| 799 | offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset) |
| 800 | : OpFoldResult(upperSendOffset); |
| 801 | auto subview = builder.create<memref::SubViewOp>( |
| 802 | location: loc, args&: array, args&: offsets, args&: dimSizes, args&: strides); |
| 803 | builder.create<memref::CopyOp>(location: loc, args&: subview, args&: buffer); |
| 804 | builder.create<mpi::SendOp>(location: loc, args: TypeRange{}, args&: buffer, args&: tag, args&: to, |
| 805 | args&: commWorld); |
| 806 | builder.create<scf::YieldOp>(location: loc); |
| 807 | }); |
| 808 | // if has neighbor: receive halo data into buffer and copy to array |
| 809 | rewriter.create<scf::IfOp>( |
| 810 | location: loc, args&: hasFrom, args: [&](OpBuilder &builder, Location loc) { |
| 811 | offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset) |
| 812 | : OpFoldResult(lowerRecvOffset); |
| 813 | builder.create<mpi::RecvOp>(location: loc, args: TypeRange{}, args&: buffer, args&: tag, args&: from, |
| 814 | args&: commWorld); |
| 815 | auto subview = builder.create<memref::SubViewOp>( |
| 816 | location: loc, args&: array, args&: offsets, args&: dimSizes, args&: strides); |
| 817 | builder.create<memref::CopyOp>(location: loc, args&: buffer, args&: subview); |
| 818 | builder.create<scf::YieldOp>(location: loc); |
| 819 | }); |
| 820 | rewriter.create<memref::DeallocOp>(location: loc, args&: buffer); |
| 821 | offsets[dim] = orgOffset; |
| 822 | }; |
| 823 | |
| 824 | auto doSendRecv = [&](int upOrDown) { |
| 825 | OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown]; |
| 826 | Value haloSz = dyn_cast<Value>(Val&: v); |
| 827 | if (!haloSz) |
| 828 | haloSz = rewriter.create<arith::ConstantOp>( |
| 829 | location: loc, args: rewriter.getI32IntegerAttr( |
| 830 | value: cast<IntegerAttr>(Val: cast<Attribute>(Val&: v)).getInt())); |
| 831 | auto hasSize = rewriter.create<arith::CmpIOp>( |
| 832 | location: loc, args: arith::CmpIPredicate::sgt, args&: haloSz, args&: zero); |
| 833 | rewriter.create<scf::IfOp>(location: loc, args&: hasSize, |
| 834 | args: [&](OpBuilder &builder, Location loc) { |
| 835 | genSendRecv(upOrDown > 0); |
| 836 | builder.create<scf::YieldOp>(location: loc); |
| 837 | }); |
| 838 | }; |
| 839 | |
| 840 | doSendRecv(0); |
| 841 | doSendRecv(1); |
| 842 | |
| 843 | // the shape for lower dims include higher dims' halos |
| 844 | dimSizes[dim] = shape[dim]; |
| 845 | // -> the offset for higher dims is always 0 |
| 846 | offsets[dim] = rewriter.getIndexAttr(value: 0); |
| 847 | // on to next halo |
| 848 | --currHaloDim; |
| 849 | } |
| 850 | |
| 851 | if (isa<MemRefType>(Val: op.getResult().getType())) { |
| 852 | rewriter.replaceOp(op, newValues: array); |
| 853 | } else { |
| 854 | assert(isa<RankedTensorType>(op.getResult().getType())); |
| 855 | rewriter.replaceOp(op, newOp: rewriter.create<bufferization::ToTensorOp>( |
| 856 | location: loc, args: op.getResult().getType(), args&: array, |
| 857 | /*restrict=*/args: true, /*writable=*/args: true)); |
| 858 | } |
| 859 | return success(); |
| 860 | } |
| 861 | }; |
| 862 | |
| 863 | struct ConvertMeshToMPIPass |
| 864 | : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> { |
| 865 | using Base::Base; |
| 866 | |
| 867 | /// Run the dialect converter on the module. |
| 868 | void runOnOperation() override { |
| 869 | auto *ctxt = &getContext(); |
| 870 | RewritePatternSet patterns(ctxt); |
| 871 | ConversionTarget target(getContext()); |
| 872 | |
| 873 | // Define a type converter to convert mesh::ShardingType, |
| 874 | // mostly for use in return operations. |
| 875 | TypeConverter typeConverter; |
| 876 | typeConverter.addConversion(callback: [](Type type) { return type; }); |
| 877 | |
| 878 | // convert mesh::ShardingType to a tuple of RankedTensorTypes |
| 879 | typeConverter.addConversion( |
| 880 | callback: [](ShardingType type, |
| 881 | SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { |
| 882 | auto i16 = IntegerType::get(context: type.getContext(), width: 16); |
| 883 | auto i64 = IntegerType::get(context: type.getContext(), width: 64); |
| 884 | std::array<int64_t, 2> shp = {ShapedType::kDynamic, |
| 885 | ShapedType::kDynamic}; |
| 886 | results.emplace_back(Args: RankedTensorType::get(shape: shp, elementType: i16)); |
| 887 | results.emplace_back(Args: RankedTensorType::get(shape: shp, elementType: i64)); // actually ?x2 |
| 888 | results.emplace_back(Args: RankedTensorType::get(shape: shp, elementType: i64)); |
| 889 | return success(); |
| 890 | }); |
| 891 | |
| 892 | // To 'extract' components, a UnrealizedConversionCastOp is expected |
| 893 | // to define the input |
| 894 | typeConverter.addTargetMaterialization( |
| 895 | callback: [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, |
| 896 | Location loc) { |
| 897 | // Expecting a single input. |
| 898 | if (inputs.size() != 1 || !isa<TupleType>(Val: inputs[0].getType())) |
| 899 | return SmallVector<Value>(); |
| 900 | auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>(); |
| 901 | // Expecting an UnrealizedConversionCastOp. |
| 902 | if (!castOp) |
| 903 | return SmallVector<Value>(); |
| 904 | // Fill a vector with elements of the tuple/castOp. |
| 905 | SmallVector<Value> results; |
| 906 | for (auto oprnd : castOp.getInputs()) { |
| 907 | if (!isa<RankedTensorType>(Val: oprnd.getType())) |
| 908 | return SmallVector<Value>(); |
| 909 | results.emplace_back(Args&: oprnd); |
| 910 | } |
| 911 | return results; |
| 912 | }); |
| 913 | |
| 914 | // No mesh dialect should left after conversion... |
| 915 | target.addIllegalDialect<mesh::MeshDialect>(); |
| 916 | // ...except the global MeshOp. MeshShapeOp which will get folded later. |
| 917 | target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>(); |
| 918 | // Allow all the stuff that our patterns will convert to |
| 919 | target.addLegalDialect< |
| 920 | BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect, |
| 921 | tensor::TensorDialect, bufferization::BufferizationDialect, |
| 922 | linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>(); |
| 923 | // Make sure the function signature, calls etc. are legal |
| 924 | target.addDynamicallyLegalOp<func::FuncOp>(callback: [&](func::FuncOp op) { |
| 925 | return typeConverter.isSignatureLegal(ty: op.getFunctionType()); |
| 926 | }); |
| 927 | target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>( |
| 928 | callback: [&](Operation *op) { return typeConverter.isLegal(op); }); |
| 929 | |
| 930 | patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp, |
| 931 | ConvertProcessMultiIndexOp, ConvertGetShardingOp, |
| 932 | ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp, |
| 933 | ConvertProcessLinearIndexOp>(arg&: typeConverter, args&: ctxt); |
| 934 | |
| 935 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( |
| 936 | patterns, converter: typeConverter); |
| 937 | populateCallOpTypeConversionPattern(patterns, converter: typeConverter); |
| 938 | populateReturnOpTypeConversionPattern(patterns, converter: typeConverter); |
| 939 | |
| 940 | (void)applyPartialConversion(op: getOperation(), target, patterns: std::move(patterns)); |
| 941 | |
| 942 | // Folding patterns cannot be mixed with conversion patterns -> extra pass. |
| 943 | patterns.clear(); |
| 944 | SymbolTableCollection symbolTableCollection; |
| 945 | mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection); |
| 946 | (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns)); |
| 947 | } |
| 948 | }; |
| 949 | |
| 950 | } // namespace |
| 951 | |