1//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation of Mesh communication ops tp MPI ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
17#include "mlir/Dialect/DLTI/DLTI.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
20#include "mlir/Dialect/Linalg/IR/Linalg.h"
21#include "mlir/Dialect/MPI/IR/MPI.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
24#include "mlir/Dialect/Mesh/IR/MeshOps.h"
25#include "mlir/Dialect/SCF/IR/SCF.h"
26#include "mlir/Dialect/Tensor/IR/Tensor.h"
27#include "mlir/Dialect/Utils/StaticValueUtils.h"
28#include "mlir/IR/Builders.h"
29#include "mlir/IR/BuiltinAttributes.h"
30#include "mlir/IR/BuiltinTypes.h"
31#include "mlir/IR/PatternMatch.h"
32#include "mlir/IR/SymbolTable.h"
33#include "mlir/Transforms/DialectConversion.h"
34#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35
36#define DEBUG_TYPE "mesh-to-mpi"
37#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
38
39namespace mlir {
40#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
41#include "mlir/Conversion/Passes.h.inc"
42} // namespace mlir
43
44using namespace mlir;
45using namespace mesh;
46
47namespace {
48/// Converts a vector of OpFoldResults (ints) into vector of Values of the
49/// provided type.
50static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
51 llvm::ArrayRef<int64_t> statics,
52 ValueRange dynamics,
53 Type type = Type()) {
54 SmallVector<Value> values;
55 auto dyn = dynamics.begin();
56 Type i64 = b.getI64Type();
57 if (!type)
58 type = i64;
59 assert((i64 == type || b.getIndexType() == type) &&
60 "expected an i64 or an intex type");
61 for (auto s : statics) {
62 if (s == ShapedType::kDynamic) {
63 values.emplace_back(*(dyn++));
64 } else {
65 TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s);
66 values.emplace_back(b.create<arith::ConstantOp>(loc, type, val));
67 }
68 }
69 return values;
70}
71
72/// Create operations converting a linear index to a multi-dimensional index.
73static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
74 Value linearIndex,
75 ValueRange dimensions) {
76 int n = dimensions.size();
77 SmallVector<Value> multiIndex(n);
78
79 for (int i = n - 1; i >= 0; --i) {
80 multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]);
81 if (i > 0)
82 linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]);
83 }
84
85 return multiIndex;
86}
87
88/// Create operations converting a multi-dimensional index to a linear index.
89Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
90 ValueRange dimensions) {
91
92 Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0);
93 Value stride = b.create<arith::ConstantIndexOp>(loc, 1);
94
95 for (int i = multiIndex.size() - 1; i >= 0; --i) {
96 Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride);
97 linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off);
98 stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]);
99 }
100
101 return linearIndex;
102}
103
104/// Replace GetShardingOp with related/dependent ShardingOp.
105struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
106 using OpConversionPattern::OpConversionPattern;
107
108 LogicalResult
109 matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
110 ConversionPatternRewriter &rewriter) const override {
111 auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
112 if (!shardOp)
113 return failure();
114 auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
115 if (!shardingOp)
116 return failure();
117
118 rewriter.replaceOp(op, shardingOp.getResult());
119 return success();
120 }
121};
122
123/// Convert a sharding op to a tuple of tensors of its components
124/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
125/// as defined by type converter.
126struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
127 using OpConversionPattern::OpConversionPattern;
128
129 LogicalResult
130 matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
131 ConversionPatternRewriter &rewriter) const override {
132 auto splitAxes = op.getSplitAxes().getAxes();
133 int64_t maxNAxes = 0;
134 for (auto axes : splitAxes)
135 maxNAxes = std::max<int64_t>(maxNAxes, axes.size());
136
137 // To hold the split axes, create empty 2d tensor with shape
138 // {splitAxes.size(), max-size-of-split-groups}.
139 // Set trailing elements for smaller split-groups to -1.
140 Location loc = op.getLoc();
141 auto i16 = rewriter.getI16Type();
142 auto i64 = rewriter.getI64Type();
143 std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
144 maxNAxes};
145 Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16);
146 auto attr = IntegerAttr::get(i16, -1);
147 Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr);
148 resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes)
149 .getResult(0);
150
151 // explicitly write values into tensor row by row
152 std::array<int64_t, 2> strides = {1, 1};
153 int64_t nSplits = 0;
154 ValueRange empty = {};
155 for (auto [i, axes] : llvm::enumerate(splitAxes)) {
156 int64_t size = axes.size();
157 if (size > 0)
158 ++nSplits;
159 std::array<int64_t, 2> offs = {(int64_t)i, 0};
160 std::array<int64_t, 2> sizes = {1, size};
161 auto tensorType = RankedTensorType::get({size}, i16);
162 auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef());
163 auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs);
164 resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
165 loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides);
166 }
167
168 // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
169 // Store the halo sizes in the tensor.
170 SmallVector<Value> haloSizes =
171 getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(),
172 adaptor.getDynamicHaloSizes());
173 auto type = RankedTensorType::get({nSplits, 2}, i64);
174 Value resHaloSizes =
175 haloSizes.empty()
176 ? rewriter
177 .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0},
178 i64)
179 .getResult()
180 : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes)
181 .getResult();
182
183 // To hold sharded dims offsets, create Tensor with shape {nSplits,
184 // maxSplitSize+1}. Store the offsets in the tensor but set trailing
185 // elements for smaller split-groups to -1. Computing the max size of the
186 // split groups needs using collectiveProcessGroupSize (which needs the
187 // MeshOp)
188 Value resOffsets;
189 if (adaptor.getStaticShardedDimsOffsets().empty()) {
190 resOffsets = rewriter.create<tensor::EmptyOp>(
191 loc, std::array<int64_t, 2>{0, 0}, i64);
192 } else {
193 SymbolTableCollection symbolTableCollection;
194 auto meshOp = getMesh(op, symbolTableCollection);
195 int64_t maxSplitSize = 0;
196 for (auto axes : splitAxes) {
197 int64_t splitSize =
198 collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
199 assert(splitSize != ShapedType::kDynamic);
200 maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize);
201 }
202 assert(maxSplitSize);
203 ++maxSplitSize; // add one for the total size
204
205 resOffsets = rewriter.create<tensor::EmptyOp>(
206 loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64);
207 Value zero = rewriter.create<arith::ConstantOp>(
208 loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic));
209 resOffsets =
210 rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0);
211 SmallVector<Value> offsets =
212 getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(),
213 adaptor.getDynamicShardedDimsOffsets());
214 int64_t curr = 0;
215 for (auto [i, axes] : llvm::enumerate(splitAxes)) {
216 int64_t splitSize =
217 collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
218 assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
219 ++splitSize; // add one for the total size
220 ArrayRef<Value> values(&offsets[curr], splitSize);
221 Value vals = rewriter.create<tensor::FromElementsOp>(loc, values);
222 std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
223 std::array<int64_t, 2> sizes = {1, splitSize};
224 resOffsets = rewriter.create<tensor::InsertSliceOp>(
225 loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides);
226 curr += splitSize;
227 }
228 }
229
230 // return a tuple of tensors as defined by type converter
231 SmallVector<Type> resTypes;
232 if (failed(getTypeConverter()->convertType(op.getResult().getType(),
233 resTypes)))
234 return failure();
235
236 resSplitAxes =
237 rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes);
238 resHaloSizes =
239 rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes);
240 resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets);
241
242 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
243 op, TupleType::get(op.getContext(), resTypes),
244 ValueRange{resSplitAxes, resHaloSizes, resOffsets});
245
246 return success();
247 }
248};
249
250struct ConvertProcessMultiIndexOp
251 : public OpConversionPattern<ProcessMultiIndexOp> {
252 using OpConversionPattern::OpConversionPattern;
253
254 LogicalResult
255 matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
256 ConversionPatternRewriter &rewriter) const override {
257
258 // Currently converts its linear index to a multi-dimensional index.
259
260 SymbolTableCollection symbolTableCollection;
261 Location loc = op.getLoc();
262 auto meshOp = getMesh(op, symbolTableCollection);
263 // For now we only support static mesh shapes
264 if (ShapedType::isDynamicShape(meshOp.getShape()))
265 return failure();
266
267 SmallVector<Value> dims;
268 llvm::transform(
269 meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
270 return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
271 });
272 Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp);
273 auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims);
274
275 // optionally extract subset of mesh axes
276 auto axes = adaptor.getAxes();
277 if (!axes.empty()) {
278 SmallVector<Value> subIndex;
279 for (auto axis : axes) {
280 subIndex.emplace_back(mIdx[axis]);
281 }
282 mIdx = std::move(subIndex);
283 }
284
285 rewriter.replaceOp(op, mIdx);
286 return success();
287 }
288};
289
290class ConvertProcessLinearIndexOp
291 : public OpConversionPattern<ProcessLinearIndexOp> {
292 int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0
293
294public:
295 using OpConversionPattern::OpConversionPattern;
296
297 // Constructor accepting worldRank
298 ConvertProcessLinearIndexOp(const TypeConverter &typeConverter,
299 MLIRContext *context, int64_t worldRank = -1)
300 : OpConversionPattern(typeConverter, context), worldRank(worldRank) {}
301
302 LogicalResult
303 matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
304 ConversionPatternRewriter &rewriter) const override {
305
306 Location loc = op.getLoc();
307 if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it
308 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank);
309 return success();
310 }
311
312 // Otherwise call create mpi::CommRankOp
313 auto ctx = op.getContext();
314 Value commWorld =
315 rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx));
316 auto rank =
317 rewriter
318 .create<mpi::CommRankOp>(
319 loc,
320 TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
321 commWorld)
322 .getRank();
323 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(),
324 rank);
325 return success();
326 }
327};
328
329struct ConvertNeighborsLinearIndicesOp
330 : public OpConversionPattern<NeighborsLinearIndicesOp> {
331 using OpConversionPattern::OpConversionPattern;
332
333 LogicalResult
334 matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const override {
336
337 // Computes the neighbors indices along a split axis by simply
338 // adding/subtracting 1 to the current index in that dimension.
339 // Assigns -1 if neighbor is out of bounds.
340
341 auto axes = adaptor.getSplitAxes();
342 // For now only single axis sharding is supported
343 if (axes.size() != 1)
344 return failure();
345
346 Location loc = op.getLoc();
347 SymbolTableCollection symbolTableCollection;
348 auto meshOp = getMesh(op, symbolTableCollection);
349 auto mIdx = adaptor.getDevice();
350 auto orgIdx = mIdx[axes[0]];
351 SmallVector<Value> dims;
352 llvm::transform(
353 meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) {
354 return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult();
355 });
356 Value dimSz = dims[axes[0]];
357 Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
358 Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1);
359 Value atBorder = rewriter.create<arith::CmpIOp>(
360 loc, arith::CmpIPredicate::sle, orgIdx,
361 rewriter.create<arith::ConstantIndexOp>(loc, 0));
362 auto down = rewriter.create<scf::IfOp>(
363 loc, atBorder,
364 [&](OpBuilder &builder, Location loc) {
365 builder.create<scf::YieldOp>(loc, minus1);
366 },
367 [&](OpBuilder &builder, Location loc) {
368 SmallVector<Value> tmp = mIdx;
369 tmp[axes[0]] =
370 rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one)
371 .getResult();
372 builder.create<scf::YieldOp>(
373 loc, multiToLinearIndex(loc, rewriter, tmp, dims));
374 });
375 atBorder = rewriter.create<arith::CmpIOp>(
376 loc, arith::CmpIPredicate::sge, orgIdx,
377 rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult());
378 auto up = rewriter.create<scf::IfOp>(
379 loc, atBorder,
380 [&](OpBuilder &builder, Location loc) {
381 builder.create<scf::YieldOp>(loc, minus1);
382 },
383 [&](OpBuilder &builder, Location loc) {
384 SmallVector<Value> tmp = mIdx;
385 tmp[axes[0]] =
386 rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one);
387 builder.create<scf::YieldOp>(
388 loc, multiToLinearIndex(loc, rewriter, tmp, dims));
389 });
390 rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)});
391 return success();
392 }
393};
394
395struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
396 using OpConversionPattern::OpConversionPattern;
397
398 LogicalResult
399 matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
400 ConversionPatternRewriter &rewriter) const override {
401 auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
402 if (!sharding) {
403 return op->emitError()
404 << "Expected SharingOp as defining op for sharding"
405 << " but found " << adaptor.getSharding()[0].getDefiningOp();
406 }
407
408 // Compute the sharded shape by applying the sharding to the input shape.
409 // If shardedDimsOffsets is not defined in the sharding, the shard shape is
410 // computed by dividing the dimension size by the number of shards in that
411 // dimension (which is given by the size of the mesh axes provided in
412 // split-axes). Odd elements get distributed to trailing shards. If a
413 // shardedDimsOffsets is provided, the shard shape is computed by
414 // subtracting the offset of the current shard from the offset of the next
415 // shard.
416
417 Location loc = op.getLoc();
418 Type index = rewriter.getIndexType();
419
420 // This is a 1:N conversion because the sharding op is a 1:3 conversion.
421 // The operands in the adaptor are a vector<ValeRange>. For dims and device
422 // we have a 1:1 conversion.
423 // For simpler access fill a vector with the dynamic dims.
424 SmallVector<Value> dynDims, dynDevice;
425 for (auto dim : adaptor.getDimsDynamic()) {
426 // type conversion should be 1:1 for ints
427 dynDims.emplace_back(llvm::getSingleElement(dim));
428 }
429 // same for device
430 for (auto device : adaptor.getDeviceDynamic()) {
431 dynDevice.emplace_back(llvm::getSingleElement(device));
432 }
433
434 // To keep the code simple, convert dims/device to values when they are
435 // attributes. Count on canonicalization to fold static values.
436 SmallVector<Value> shape =
437 getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index);
438 SmallVector<Value> multiIdx =
439 getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index);
440
441 // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
442 SymbolTableCollection symbolTableCollection;
443 auto meshOp = getMesh(sharding, symbolTableCollection);
444 // For now we only support static mesh shapes
445 if (ShapedType::isDynamicShape(meshOp.getShape()))
446 return failure();
447
448 auto splitAxes = sharding.getSplitAxes().getAxes();
449 // shardedDimsOffsets are optional and might be Values (not attributes).
450 // Also, the shardId might be dynamic which means the position in the
451 // shardedDimsOffsets is not statically known. Create a tensor of the
452 // shardedDimsOffsets and later extract the offsets for computing the
453 // local shard-size.
454 Value shardedDimsOffs;
455 {
456 SmallVector<Value> tmp = getMixedAsValues(
457 rewriter, loc, sharding.getStaticShardedDimsOffsets(),
458 sharding.getDynamicShardedDimsOffsets(), index);
459 if (!tmp.empty())
460 shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
461 loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp);
462 }
463
464 // With static mesh shape the sizes of the split axes are known.
465 // Hence the start/pos for each split axes in shardDimsOffsets can be
466 // computed statically.
467 int64_t pos = 0;
468 SmallVector<Value> shardShape;
469 Value zero =
470 rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index));
471 Value one =
472 rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index));
473
474 // Iterate over the dimensions of the tensor shape, get their split Axes,
475 // and compute the sharded shape.
476 for (auto [i, dim] : llvm::enumerate(shape)) {
477 // Trailing dimensions might not be annotated.
478 if (i < splitAxes.size() && !splitAxes[i].empty()) {
479 auto axes = splitAxes[i];
480 // The current dimension might not be sharded.
481 // Create a value from the static position in shardDimsOffsets.
482 Value posVal =
483 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos));
484 // Get the index of the local shard in the mesh axis.
485 Value idx = multiIdx[axes[0]];
486 auto numShards =
487 collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape());
488 if (shardedDimsOffs) {
489 // If sharded dims offsets are provided, use them to compute the
490 // sharded shape.
491 if (axes.size() > 1) {
492 return op->emitError() << "Only single axis sharding is "
493 << "supported for each dimension.";
494 }
495 idx = rewriter.create<arith::AddIOp>(loc, posVal, idx);
496 // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
497 Value off =
498 rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
499 idx = rewriter.create<arith::AddIOp>(loc, idx, one);
500 Value nextOff =
501 rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx);
502 Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off);
503 shardShape.emplace_back(sz);
504 } else {
505 Value numShardsVal = rewriter.create<arith::ConstantOp>(
506 loc, rewriter.getIndexAttr(numShards));
507 // Compute shard dim size by distributing odd elements to trailing
508 // shards:
509 // sz = dim / numShards
510 // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
511 Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal);
512 Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal);
513 sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1);
514 auto cond = rewriter.create<arith::CmpIOp>(
515 loc, arith::CmpIPredicate::sge, idx, sz1);
516 Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero);
517 sz = rewriter.create<arith::AddIOp>(loc, sz, odd);
518 shardShape.emplace_back(sz);
519 }
520 pos += numShards + 1; // add one for the total size.
521 } // else no sharding if split axis is empty or no split axis
522 // If no size was added -> no sharding in this dimension.
523 if (shardShape.size() <= i)
524 shardShape.emplace_back(dim);
525 }
526 assert(shardShape.size() == shape.size());
527 rewriter.replaceOp(op, shardShape);
528 return success();
529 }
530};
531
532struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
533 using OpConversionPattern::OpConversionPattern;
534
535 LogicalResult
536 matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
537 ConversionPatternRewriter &rewriter) const override {
538
539 // The input/output memref is assumed to be in C memory order.
540 // Halos are exchanged as 2 blocks per dimension (one for each side: down
541 // and up). For each haloed dimension `d`, the exchanged blocks are
542 // expressed as multi-dimensional subviews. The subviews include potential
543 // halos of higher dimensions `dh > d`, no halos for the lower dimensions
544 // `dl < d` and for dimension `d` the currently exchanged halo only.
545 // By iterating form higher to lower dimensions this also updates the halos
546 // in the 'corners'.
547 // memref.subview is used to read and write the halo data from and to the
548 // local data. Because subviews and halos can have mixed dynamic and static
549 // shapes, OpFoldResults are used whenever possible.
550
551 auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(),
552 adaptor.getHaloSizes(), rewriter);
553 if (haloSizes.empty()) {
554 // no halos -> nothing to do
555 rewriter.replaceOp(op, adaptor.getDestination());
556 return success();
557 }
558
559 SymbolTableCollection symbolTableCollection;
560 Location loc = op.getLoc();
561
562 // convert a OpFoldResult into a Value
563 auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
564 if (auto value = dyn_cast<Value>(v))
565 return value;
566 return rewriter.create<arith::ConstantOp>(
567 loc, rewriter.getIndexAttr(
568 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
569 };
570
571 auto dest = adaptor.getDestination();
572 auto dstShape = cast<ShapedType>(dest.getType()).getShape();
573 Value array = dest;
574 if (isa<RankedTensorType>(array.getType())) {
575 // If the destination is a memref, we need to cast it to a tensor
576 auto tensorType = MemRefType::get(
577 dstShape, cast<ShapedType>(array.getType()).getElementType());
578 array =
579 rewriter.create<bufferization::ToBufferOp>(loc, tensorType, array);
580 }
581 auto rank = cast<ShapedType>(array.getType()).getRank();
582 auto opSplitAxes = adaptor.getSplitAxes().getAxes();
583 auto mesh = adaptor.getMesh();
584 auto meshOp = getMesh(op, symbolTableCollection);
585 // subviews need Index values
586 for (auto &sz : haloSizes) {
587 if (auto value = dyn_cast<Value>(sz))
588 sz =
589 rewriter
590 .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value)
591 .getResult();
592 }
593
594 // most of the offset/size/stride data is the same for all dims
595 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
596 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
597 SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
598 auto currHaloDim = -1; // halo sizes are provided for split dimensions only
599 // we need the actual shape to compute offsets and sizes
600 for (auto i = 0; i < rank; ++i) {
601 auto s = dstShape[i];
602 if (ShapedType::isDynamic(s))
603 shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult();
604 else
605 shape[i] = rewriter.getIndexAttr(value: s);
606
607 if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
608 ++currHaloDim;
609 // the offsets for lower dim sstarts after their down halo
610 offsets[i] = haloSizes[currHaloDim * 2];
611
612 // prepare shape and offsets of highest dim's halo exchange
613 Value _haloSz = rewriter.create<arith::AddIOp>(
614 loc, toValue(haloSizes[currHaloDim * 2]),
615 toValue(haloSizes[currHaloDim * 2 + 1]));
616 // the halo shape of lower dims exlude the halos
617 dimSizes[i] =
618 rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz)
619 .getResult();
620 } else {
621 dimSizes[i] = shape[i];
622 }
623 }
624
625 auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something
626 auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr);
627 auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0
628 auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
629
630 SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
631 rewriter.getIndexType());
632 auto myMultiIndex =
633 rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh)
634 .getResult();
635 // traverse all split axes from high to low dim
636 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
637 auto splitAxes = opSplitAxes[dim];
638 if (splitAxes.empty())
639 continue;
640 assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
641 // Get the linearized ids of the neighbors (down and up) for the
642 // given split
643 auto tmp = rewriter
644 .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex,
645 splitAxes)
646 .getResults();
647 // MPI operates on i32...
648 Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
649 loc, rewriter.getI32Type(), tmp[0]),
650 rewriter.create<arith::IndexCastOp>(
651 loc, rewriter.getI32Type(), tmp[1])};
652
653 auto lowerRecvOffset = rewriter.getIndexAttr(0);
654 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
655 auto upperRecvOffset = rewriter.create<arith::SubIOp>(
656 loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1]));
657 auto upperSendOffset = rewriter.create<arith::SubIOp>(
658 loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2]));
659
660 Value commWorld = rewriter.create<mpi::CommWorldOp>(
661 loc, mpi::CommType::get(op->getContext()));
662
663 // Make sure we send/recv in a way that does not lead to a dead-lock.
664 // The current approach is by far not optimal, this should be at least
665 // be a red-black pattern or using MPI_sendrecv.
666 // Also, buffers should be re-used.
667 // Still using temporary contiguous buffers for MPI communication...
668 // Still yielding a "serialized" communication pattern...
669 auto genSendRecv = [&](bool upperHalo) {
670 auto orgOffset = offsets[dim];
671 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
672 : haloSizes[currHaloDim * 2];
673 // Check if we need to send and/or receive
674 // Processes on the mesh borders have only one neighbor
675 auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
676 auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
677 auto hasFrom = rewriter.create<arith::CmpIOp>(
678 loc, arith::CmpIPredicate::sge, from, zero);
679 auto hasTo = rewriter.create<arith::CmpIOp>(
680 loc, arith::CmpIPredicate::sge, to, zero);
681 auto buffer = rewriter.create<memref::AllocOp>(
682 loc, dimSizes, cast<ShapedType>(array.getType()).getElementType());
683 // if has neighbor: copy halo data from array to buffer and send
684 rewriter.create<scf::IfOp>(
685 loc, hasTo, [&](OpBuilder &builder, Location loc) {
686 offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
687 : OpFoldResult(upperSendOffset);
688 auto subview = builder.create<memref::SubViewOp>(
689 loc, array, offsets, dimSizes, strides);
690 builder.create<memref::CopyOp>(loc, subview, buffer);
691 builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to,
692 commWorld);
693 builder.create<scf::YieldOp>(loc);
694 });
695 // if has neighbor: receive halo data into buffer and copy to array
696 rewriter.create<scf::IfOp>(
697 loc, hasFrom, [&](OpBuilder &builder, Location loc) {
698 offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
699 : OpFoldResult(lowerRecvOffset);
700 builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from,
701 commWorld);
702 auto subview = builder.create<memref::SubViewOp>(
703 loc, array, offsets, dimSizes, strides);
704 builder.create<memref::CopyOp>(loc, buffer, subview);
705 builder.create<scf::YieldOp>(loc);
706 });
707 rewriter.create<memref::DeallocOp>(loc, buffer);
708 offsets[dim] = orgOffset;
709 };
710
711 auto doSendRecv = [&](int upOrDown) {
712 OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
713 Value haloSz = dyn_cast<Value>(v);
714 if (!haloSz)
715 haloSz = rewriter.create<arith::ConstantOp>(
716 loc, rewriter.getI32IntegerAttr(
717 cast<IntegerAttr>(cast<Attribute>(v)).getInt()));
718 auto hasSize = rewriter.create<arith::CmpIOp>(
719 loc, arith::CmpIPredicate::sgt, haloSz, zero);
720 rewriter.create<scf::IfOp>(loc, hasSize,
721 [&](OpBuilder &builder, Location loc) {
722 genSendRecv(upOrDown > 0);
723 builder.create<scf::YieldOp>(loc);
724 });
725 };
726
727 doSendRecv(0);
728 doSendRecv(1);
729
730 // the shape for lower dims include higher dims' halos
731 dimSizes[dim] = shape[dim];
732 // -> the offset for higher dims is always 0
733 offsets[dim] = rewriter.getIndexAttr(0);
734 // on to next halo
735 --currHaloDim;
736 }
737
738 if (isa<MemRefType>(op.getResult().getType())) {
739 rewriter.replaceOp(op, array);
740 } else {
741 assert(isa<RankedTensorType>(op.getResult().getType()));
742 rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>(
743 loc, op.getResult().getType(), array,
744 /*restrict=*/true, /*writable=*/true));
745 }
746 return success();
747 }
748};
749
750struct ConvertMeshToMPIPass
751 : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
752 using Base::Base;
753
754 /// Run the dialect converter on the module.
755 void runOnOperation() override {
756 uint64_t worldRank = -1;
757 // Try to get DLTI attribute for MPI:comm_world_rank
758 // If found, set worldRank to the value of the attribute.
759 {
760 auto dltiAttr =
761 dlti::query(getOperation(), {"MPI:comm_world_rank"}, false);
762 if (succeeded(dltiAttr)) {
763 if (!isa<IntegerAttr>(dltiAttr.value())) {
764 getOperation()->emitError()
765 << "Expected an integer attribute for MPI:comm_world_rank";
766 return signalPassFailure();
767 }
768 worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt();
769 }
770 }
771
772 auto *ctxt = &getContext();
773 RewritePatternSet patterns(ctxt);
774 ConversionTarget target(getContext());
775
776 // Define a type converter to convert mesh::ShardingType,
777 // mostly for use in return operations.
778 TypeConverter typeConverter;
779 typeConverter.addConversion([](Type type) { return type; });
780
781 // convert mesh::ShardingType to a tuple of RankedTensorTypes
782 typeConverter.addConversion(
783 [](ShardingType type,
784 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
785 auto i16 = IntegerType::get(type.getContext(), 16);
786 auto i64 = IntegerType::get(type.getContext(), 64);
787 std::array<int64_t, 2> shp = {ShapedType::kDynamic,
788 ShapedType::kDynamic};
789 results.emplace_back(RankedTensorType::get(shp, i16));
790 results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2
791 results.emplace_back(RankedTensorType::get(shp, i64));
792 return success();
793 });
794
795 // To 'extract' components, a UnrealizedConversionCastOp is expected
796 // to define the input
797 typeConverter.addTargetMaterialization(
798 [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
799 Location loc) {
800 // Expecting a single input.
801 if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType()))
802 return SmallVector<Value>();
803 auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
804 // Expecting an UnrealizedConversionCastOp.
805 if (!castOp)
806 return SmallVector<Value>();
807 // Fill a vector with elements of the tuple/castOp.
808 SmallVector<Value> results;
809 for (auto oprnd : castOp.getInputs()) {
810 if (!isa<RankedTensorType>(oprnd.getType()))
811 return SmallVector<Value>();
812 results.emplace_back(oprnd);
813 }
814 return results;
815 });
816
817 // No mesh dialect should left after conversion...
818 target.addIllegalDialect<mesh::MeshDialect>();
819 // ...except the global MeshOp
820 target.addLegalOp<mesh::MeshOp>();
821 // Allow all the stuff that our patterns will convert to
822 target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect,
823 arith::ArithDialect, tensor::TensorDialect,
824 bufferization::BufferizationDialect,
825 linalg::LinalgDialect, memref::MemRefDialect>();
826 // Make sure the function signature, calls etc. are legal
827 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
828 return typeConverter.isSignatureLegal(op.getFunctionType());
829 });
830 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
831 [&](Operation *op) { return typeConverter.isLegal(op); });
832
833 patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
834 ConvertProcessMultiIndexOp, ConvertGetShardingOp,
835 ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt);
836 // ConvertProcessLinearIndexOp accepts an optional worldRank
837 patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank);
838
839 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
840 patterns, typeConverter);
841 populateCallOpTypeConversionPattern(patterns, converter: typeConverter);
842 populateReturnOpTypeConversionPattern(patterns, converter: typeConverter);
843
844 (void)applyPartialConversion(getOperation(), target, std::move(patterns));
845 }
846};
847
848} // namespace
849

source code of mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp