1//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation of Mesh communication ops tp MPI ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/MeshToMPI/MeshToMPI.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
20#include "mlir/Dialect/Linalg/IR/Linalg.h"
21#include "mlir/Dialect/MPI/IR/MPI.h"
22#include "mlir/Dialect/MemRef/IR/MemRef.h"
23#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
24#include "mlir/Dialect/Mesh/IR/MeshOps.h"
25#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
26#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
27#include "mlir/Dialect/SCF/IR/SCF.h"
28#include "mlir/Dialect/Tensor/IR/Tensor.h"
29#include "mlir/Dialect/Utils/StaticValueUtils.h"
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/BuiltinAttributes.h"
32#include "mlir/IR/BuiltinTypes.h"
33#include "mlir/IR/PatternMatch.h"
34#include "mlir/IR/SymbolTable.h"
35#include "mlir/Transforms/DialectConversion.h"
36#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
37
38#define DEBUG_TYPE "mesh-to-mpi"
39#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
40
41namespace mlir {
42#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS
43#include "mlir/Conversion/Passes.h.inc"
44} // namespace mlir
45
46using namespace mlir;
47using namespace mesh;
48
49namespace {
50/// Converts a vector of OpFoldResults (ints) into vector of Values of the
51/// provided type.
52static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc,
53 llvm::ArrayRef<int64_t> statics,
54 ValueRange dynamics,
55 Type type = Type()) {
56 SmallVector<Value> values;
57 auto dyn = dynamics.begin();
58 Type i64 = b.getI64Type();
59 if (!type)
60 type = i64;
61 assert((i64 == type || b.getIndexType() == type) &&
62 "expected an i64 or an intex type");
63 for (auto s : statics) {
64 if (s == ShapedType::kDynamic) {
65 values.emplace_back(Args: *(dyn++));
66 } else {
67 TypedAttr val = type == i64 ? b.getI64IntegerAttr(value: s) : b.getIndexAttr(value: s);
68 values.emplace_back(Args: b.create<arith::ConstantOp>(location: loc, args&: type, args&: val));
69 }
70 }
71 return values;
72}
73
74/// Create operations converting a linear index to a multi-dimensional index.
75static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b,
76 Value linearIndex,
77 ValueRange dimensions) {
78 int n = dimensions.size();
79 SmallVector<Value> multiIndex(n);
80
81 for (int i = n - 1; i >= 0; --i) {
82 multiIndex[i] = b.create<arith::RemSIOp>(location: loc, args&: linearIndex, args: dimensions[i]);
83 if (i > 0)
84 linearIndex = b.create<arith::DivSIOp>(location: loc, args&: linearIndex, args: dimensions[i]);
85 }
86
87 return multiIndex;
88}
89
90/// Create operations converting a multi-dimensional index to a linear index.
91Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex,
92 ValueRange dimensions) {
93
94 Value linearIndex = b.create<arith::ConstantIndexOp>(location: loc, args: 0);
95 Value stride = b.create<arith::ConstantIndexOp>(location: loc, args: 1);
96
97 for (int i = multiIndex.size() - 1; i >= 0; --i) {
98 Value off = b.create<arith::MulIOp>(location: loc, args: multiIndex[i], args&: stride);
99 linearIndex = b.create<arith::AddIOp>(location: loc, args&: linearIndex, args&: off);
100 stride = b.create<arith::MulIOp>(location: loc, args&: stride, args: dimensions[i]);
101 }
102
103 return linearIndex;
104}
105
106/// Replace GetShardingOp with related/dependent ShardingOp.
107struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> {
108 using OpConversionPattern::OpConversionPattern;
109
110 LogicalResult
111 matchAndRewrite(GetShardingOp op, OpAdaptor adaptor,
112 ConversionPatternRewriter &rewriter) const override {
113 auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>();
114 if (!shardOp)
115 return failure();
116 auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>();
117 if (!shardingOp)
118 return failure();
119
120 rewriter.replaceOp(op, newValues: shardingOp.getResult());
121 return success();
122 }
123};
124
125/// Convert a sharding op to a tuple of tensors of its components
126/// (SplitAxes, HaloSizes, ShardedDimsOffsets)
127/// as defined by type converter.
128struct ConvertShardingOp : public OpConversionPattern<ShardingOp> {
129 using OpConversionPattern::OpConversionPattern;
130
131 LogicalResult
132 matchAndRewrite(ShardingOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter) const override {
134 auto splitAxes = op.getSplitAxes().getAxes();
135 int64_t maxNAxes = 0;
136 for (auto axes : splitAxes)
137 maxNAxes = std::max<int64_t>(a: maxNAxes, b: axes.size());
138
139 // To hold the split axes, create empty 2d tensor with shape
140 // {splitAxes.size(), max-size-of-split-groups}.
141 // Set trailing elements for smaller split-groups to -1.
142 Location loc = op.getLoc();
143 auto i16 = rewriter.getI16Type();
144 auto i64 = rewriter.getI64Type();
145 std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()),
146 maxNAxes};
147 Value resSplitAxes = rewriter.create<tensor::EmptyOp>(location: loc, args&: shape, args&: i16);
148 auto attr = IntegerAttr::get(type: i16, value: -1);
149 Value fillValue = rewriter.create<arith::ConstantOp>(location: loc, args&: i16, args&: attr);
150 resSplitAxes = rewriter.create<linalg::FillOp>(location: loc, args&: fillValue, args&: resSplitAxes)
151 .getResult(i: 0);
152
153 // explicitly write values into tensor row by row
154 std::array<int64_t, 2> strides = {1, 1};
155 int64_t nSplits = 0;
156 ValueRange empty = {};
157 for (auto [i, axes] : llvm::enumerate(First&: splitAxes)) {
158 int64_t size = axes.size();
159 if (size > 0)
160 ++nSplits;
161 std::array<int64_t, 2> offs = {(int64_t)i, 0};
162 std::array<int64_t, 2> sizes = {1, size};
163 auto tensorType = RankedTensorType::get(shape: {size}, elementType: i16);
164 auto attrs = DenseIntElementsAttr::get(type: tensorType, arg: axes.asArrayRef());
165 auto vals = rewriter.create<arith::ConstantOp>(location: loc, args&: tensorType, args&: attrs);
166 resSplitAxes = rewriter.create<tensor::InsertSliceOp>(
167 location: loc, args&: vals, args&: resSplitAxes, args&: empty, args&: empty, args&: empty, args&: offs, args&: sizes, args&: strides);
168 }
169
170 // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}.
171 // Store the halo sizes in the tensor.
172 SmallVector<Value> haloSizes =
173 getMixedAsValues(b: rewriter, loc, statics: adaptor.getStaticHaloSizes(),
174 dynamics: adaptor.getDynamicHaloSizes());
175 auto type = RankedTensorType::get(shape: {nSplits, 2}, elementType: i64);
176 Value resHaloSizes =
177 haloSizes.empty()
178 ? rewriter
179 .create<tensor::EmptyOp>(location: loc, args: std::array<int64_t, 2>{0, 0},
180 args&: i64)
181 .getResult()
182 : rewriter.create<tensor::FromElementsOp>(location: loc, args&: type, args&: haloSizes)
183 .getResult();
184
185 // To hold sharded dims offsets, create Tensor with shape {nSplits,
186 // maxSplitSize+1}. Store the offsets in the tensor but set trailing
187 // elements for smaller split-groups to -1. Computing the max size of the
188 // split groups needs using collectiveProcessGroupSize (which needs the
189 // MeshOp)
190 Value resOffsets;
191 if (adaptor.getStaticShardedDimsOffsets().empty()) {
192 resOffsets = rewriter.create<tensor::EmptyOp>(
193 location: loc, args: std::array<int64_t, 2>{0, 0}, args&: i64);
194 } else {
195 SymbolTableCollection symbolTableCollection;
196 auto meshOp = getMesh(op, symbolTableCollection);
197 int64_t maxSplitSize = 0;
198 for (auto axes : splitAxes) {
199 int64_t splitSize =
200 collectiveProcessGroupSize(meshAxes: axes.asArrayRef(), meshShape: meshOp.getShape());
201 assert(splitSize != ShapedType::kDynamic);
202 maxSplitSize = std::max<int64_t>(a: maxSplitSize, b: splitSize);
203 }
204 assert(maxSplitSize);
205 ++maxSplitSize; // add one for the total size
206
207 resOffsets = rewriter.create<tensor::EmptyOp>(
208 location: loc, args: std::array<int64_t, 2>{nSplits, maxSplitSize}, args&: i64);
209 Value zero = rewriter.create<arith::ConstantOp>(
210 location: loc, args&: i64, args: rewriter.getI64IntegerAttr(value: ShapedType::kDynamic));
211 resOffsets =
212 rewriter.create<linalg::FillOp>(location: loc, args&: zero, args&: resOffsets).getResult(i: 0);
213 SmallVector<Value> offsets =
214 getMixedAsValues(b: rewriter, loc, statics: adaptor.getStaticShardedDimsOffsets(),
215 dynamics: adaptor.getDynamicShardedDimsOffsets());
216 int64_t curr = 0;
217 for (auto [i, axes] : llvm::enumerate(First&: splitAxes)) {
218 int64_t splitSize =
219 collectiveProcessGroupSize(meshAxes: axes.asArrayRef(), meshShape: meshOp.getShape());
220 assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize);
221 ++splitSize; // add one for the total size
222 ArrayRef<Value> values(&offsets[curr], splitSize);
223 Value vals = rewriter.create<tensor::FromElementsOp>(location: loc, args&: values);
224 std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0};
225 std::array<int64_t, 2> sizes = {1, splitSize};
226 resOffsets = rewriter.create<tensor::InsertSliceOp>(
227 location: loc, args&: vals, args&: resOffsets, args&: empty, args&: empty, args&: empty, args&: offs, args&: sizes, args&: strides);
228 curr += splitSize;
229 }
230 }
231
232 // return a tuple of tensors as defined by type converter
233 SmallVector<Type> resTypes;
234 if (failed(Result: getTypeConverter()->convertType(t: op.getResult().getType(),
235 results&: resTypes)))
236 return failure();
237
238 resSplitAxes =
239 rewriter.create<tensor::CastOp>(location: loc, args&: resTypes[0], args&: resSplitAxes);
240 resHaloSizes =
241 rewriter.create<tensor::CastOp>(location: loc, args&: resTypes[1], args&: resHaloSizes);
242 resOffsets = rewriter.create<tensor::CastOp>(location: loc, args&: resTypes[2], args&: resOffsets);
243
244 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
245 op, args: TupleType::get(context: op.getContext(), elementTypes: resTypes),
246 args: ValueRange{resSplitAxes, resHaloSizes, resOffsets});
247
248 return success();
249 }
250};
251
252struct ConvertProcessMultiIndexOp
253 : public OpConversionPattern<ProcessMultiIndexOp> {
254 using OpConversionPattern::OpConversionPattern;
255
256 LogicalResult
257 matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor,
258 ConversionPatternRewriter &rewriter) const override {
259
260 // Currently converts its linear index to a multi-dimensional index.
261
262 SymbolTableCollection symbolTableCollection;
263 Location loc = op.getLoc();
264 auto meshOp = getMesh(op, symbolTableCollection);
265 // For now we only support static mesh shapes
266 if (ShapedType::isDynamicShape(dSizes: meshOp.getShape()))
267 return failure();
268
269 SmallVector<Value> dims;
270 llvm::transform(
271 Range: meshOp.getShape(), d_first: std::back_inserter(x&: dims), F: [&](int64_t i) {
272 return rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i).getResult();
273 });
274 Value rank = rewriter.create<ProcessLinearIndexOp>(location: op.getLoc(), args&: meshOp);
275 auto mIdx = linearToMultiIndex(loc, b: rewriter, linearIndex: rank, dimensions: dims);
276
277 // optionally extract subset of mesh axes
278 auto axes = adaptor.getAxes();
279 if (!axes.empty()) {
280 SmallVector<Value> subIndex;
281 for (auto axis : axes) {
282 subIndex.emplace_back(Args&: mIdx[axis]);
283 }
284 mIdx = std::move(subIndex);
285 }
286
287 rewriter.replaceOp(op, newValues: mIdx);
288 return success();
289 }
290};
291
292class ConvertProcessLinearIndexOp
293 : public OpConversionPattern<ProcessLinearIndexOp> {
294
295public:
296 using OpConversionPattern::OpConversionPattern;
297
298 LogicalResult
299 matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor,
300 ConversionPatternRewriter &rewriter) const override {
301 // Create mpi::CommRankOp
302 Location loc = op.getLoc();
303 auto ctx = op.getContext();
304 Value commWorld =
305 rewriter.create<mpi::CommWorldOp>(location: loc, args: mpi::CommType::get(ctx));
306 auto rank =
307 rewriter
308 .create<mpi::CommRankOp>(
309 location: loc,
310 args: TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()},
311 args&: commWorld)
312 .getRank();
313 rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, args: rewriter.getIndexType(),
314 args&: rank);
315 return success();
316 }
317};
318
319struct ConvertNeighborsLinearIndicesOp
320 : public OpConversionPattern<NeighborsLinearIndicesOp> {
321 using OpConversionPattern::OpConversionPattern;
322
323 LogicalResult
324 matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor,
325 ConversionPatternRewriter &rewriter) const override {
326
327 // Computes the neighbors indices along a split axis by simply
328 // adding/subtracting 1 to the current index in that dimension.
329 // Assigns -1 if neighbor is out of bounds.
330
331 auto axes = adaptor.getSplitAxes();
332 // For now only single axis sharding is supported
333 if (axes.size() != 1)
334 return failure();
335
336 Location loc = op.getLoc();
337 SymbolTableCollection symbolTableCollection;
338 auto meshOp = getMesh(op, symbolTableCollection);
339 auto mIdx = adaptor.getDevice();
340 auto orgIdx = mIdx[axes[0]];
341 SmallVector<Value> dims;
342 llvm::transform(
343 Range: meshOp.getShape(), d_first: std::back_inserter(x&: dims), F: [&](int64_t i) {
344 return rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i).getResult();
345 });
346 Value dimSz = dims[axes[0]];
347 Value one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
348 Value minus1 = rewriter.create<arith::ConstantIndexOp>(location: loc, args: -1);
349 Value atBorder = rewriter.create<arith::CmpIOp>(
350 location: loc, args: arith::CmpIPredicate::sle, args&: orgIdx,
351 args: rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0));
352 auto down = rewriter.create<scf::IfOp>(
353 location: loc, args&: atBorder,
354 args: [&](OpBuilder &builder, Location loc) {
355 builder.create<scf::YieldOp>(location: loc, args&: minus1);
356 },
357 args: [&](OpBuilder &builder, Location loc) {
358 SmallVector<Value> tmp = mIdx;
359 tmp[axes[0]] =
360 rewriter.create<arith::SubIOp>(location: op.getLoc(), args&: orgIdx, args&: one)
361 .getResult();
362 builder.create<scf::YieldOp>(
363 location: loc, args: multiToLinearIndex(loc, b: rewriter, multiIndex: tmp, dimensions: dims));
364 });
365 atBorder = rewriter.create<arith::CmpIOp>(
366 location: loc, args: arith::CmpIPredicate::sge, args&: orgIdx,
367 args: rewriter.create<arith::SubIOp>(location: loc, args&: dimSz, args&: one).getResult());
368 auto up = rewriter.create<scf::IfOp>(
369 location: loc, args&: atBorder,
370 args: [&](OpBuilder &builder, Location loc) {
371 builder.create<scf::YieldOp>(location: loc, args&: minus1);
372 },
373 args: [&](OpBuilder &builder, Location loc) {
374 SmallVector<Value> tmp = mIdx;
375 tmp[axes[0]] =
376 rewriter.create<arith::AddIOp>(location: op.getLoc(), args&: orgIdx, args&: one);
377 builder.create<scf::YieldOp>(
378 location: loc, args: multiToLinearIndex(loc, b: rewriter, multiIndex: tmp, dimensions: dims));
379 });
380 rewriter.replaceOp(op, newValues: ValueRange{down.getResult(i: 0), up.getResult(i: 0)});
381 return success();
382 }
383};
384
385struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
386 using OpConversionPattern::OpConversionPattern;
387
388 LogicalResult
389 matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor,
390 ConversionPatternRewriter &rewriter) const override {
391 auto sharding = op.getSharding().getDefiningOp<ShardingOp>();
392 if (!sharding) {
393 return op->emitError()
394 << "Expected SharingOp as defining op for sharding"
395 << " but found " << adaptor.getSharding()[0].getDefiningOp();
396 }
397
398 // Compute the sharded shape by applying the sharding to the input shape.
399 // If shardedDimsOffsets is not defined in the sharding, the shard shape is
400 // computed by dividing the dimension size by the number of shards in that
401 // dimension (which is given by the size of the mesh axes provided in
402 // split-axes). Odd elements get distributed to trailing shards. If a
403 // shardedDimsOffsets is provided, the shard shape is computed by
404 // subtracting the offset of the current shard from the offset of the next
405 // shard.
406
407 Location loc = op.getLoc();
408 Type index = rewriter.getIndexType();
409
410 // This is a 1:N conversion because the sharding op is a 1:3 conversion.
411 // The operands in the adaptor are a vector<ValeRange>. For dims and device
412 // we have a 1:1 conversion.
413 // For simpler access fill a vector with the dynamic dims.
414 SmallVector<Value> dynDims, dynDevice;
415 for (auto dim : adaptor.getDimsDynamic()) {
416 // type conversion should be 1:1 for ints
417 dynDims.emplace_back(Args: llvm::getSingleElement(C&: dim));
418 }
419 // same for device
420 for (auto device : adaptor.getDeviceDynamic()) {
421 dynDevice.emplace_back(Args: llvm::getSingleElement(C&: device));
422 }
423
424 // To keep the code simple, convert dims/device to values when they are
425 // attributes. Count on canonicalization to fold static values.
426 SmallVector<Value> shape =
427 getMixedAsValues(b: rewriter, loc, statics: op.getDims(), dynamics: dynDims, type: index);
428 SmallVector<Value> multiIdx =
429 getMixedAsValues(b: rewriter, loc, statics: adaptor.getDevice(), dynamics: dynDevice, type: index);
430
431 // Get the MeshOp, the mesh shape is needed to compute the sharded shape.
432 SymbolTableCollection symbolTableCollection;
433 auto meshOp = getMesh(op: sharding, symbolTableCollection);
434 // For now we only support static mesh shapes
435 if (ShapedType::isDynamicShape(dSizes: meshOp.getShape()))
436 return failure();
437
438 auto splitAxes = sharding.getSplitAxes().getAxes();
439 // shardedDimsOffsets are optional and might be Values (not attributes).
440 // Also, the shardId might be dynamic which means the position in the
441 // shardedDimsOffsets is not statically known. Create a tensor of the
442 // shardedDimsOffsets and later extract the offsets for computing the
443 // local shard-size.
444 Value shardedDimsOffs;
445 {
446 SmallVector<Value> tmp = getMixedAsValues(
447 b: rewriter, loc, statics: sharding.getStaticShardedDimsOffsets(),
448 dynamics: sharding.getDynamicShardedDimsOffsets(), type: index);
449 if (!tmp.empty())
450 shardedDimsOffs = rewriter.create<tensor::FromElementsOp>(
451 location: loc, args: RankedTensorType::get(shape: {(int64_t)tmp.size()}, elementType: index), args&: tmp);
452 }
453
454 // With static mesh shape the sizes of the split axes are known.
455 // Hence the start/pos for each split axes in shardDimsOffsets can be
456 // computed statically.
457 int64_t pos = 0;
458 SmallVector<Value> shardShape;
459 Value zero =
460 rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getZeroAttr(type: index));
461 Value one =
462 rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getOneAttr(type: index));
463
464 // Iterate over the dimensions of the tensor shape, get their split Axes,
465 // and compute the sharded shape.
466 for (auto [i, dim] : llvm::enumerate(First&: shape)) {
467 // Trailing dimensions might not be annotated.
468 if (i < splitAxes.size() && !splitAxes[i].empty()) {
469 auto axes = splitAxes[i];
470 // The current dimension might not be sharded.
471 // Create a value from the static position in shardDimsOffsets.
472 Value posVal =
473 rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getIndexAttr(value: pos));
474 // Get the index of the local shard in the mesh axis.
475 Value idx = multiIdx[axes[0]];
476 auto numShards =
477 collectiveProcessGroupSize(meshAxes: axes.asArrayRef(), meshShape: meshOp.getShape());
478 if (shardedDimsOffs) {
479 // If sharded dims offsets are provided, use them to compute the
480 // sharded shape.
481 if (axes.size() > 1) {
482 return op->emitError() << "Only single axis sharding is "
483 << "supported for each dimension.";
484 }
485 idx = rewriter.create<arith::AddIOp>(location: loc, args&: posVal, args&: idx);
486 // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx].
487 Value off =
488 rewriter.create<tensor::ExtractOp>(location: loc, args&: shardedDimsOffs, args&: idx);
489 idx = rewriter.create<arith::AddIOp>(location: loc, args&: idx, args&: one);
490 Value nextOff =
491 rewriter.create<tensor::ExtractOp>(location: loc, args&: shardedDimsOffs, args&: idx);
492 Value sz = rewriter.create<arith::SubIOp>(location: loc, args&: nextOff, args&: off);
493 shardShape.emplace_back(Args&: sz);
494 } else {
495 Value numShardsVal = rewriter.create<arith::ConstantOp>(
496 location: loc, args: rewriter.getIndexAttr(value: numShards));
497 // Compute shard dim size by distributing odd elements to trailing
498 // shards:
499 // sz = dim / numShards
500 // + (idx >= (numShards - (dim % numShards)) ? 1 : 0)
501 Value sz = rewriter.create<arith::DivSIOp>(location: loc, args&: dim, args&: numShardsVal);
502 Value sz1 = rewriter.create<arith::RemSIOp>(location: loc, args&: dim, args&: numShardsVal);
503 sz1 = rewriter.create<arith::SubIOp>(location: loc, args&: numShardsVal, args&: sz1);
504 auto cond = rewriter.create<arith::CmpIOp>(
505 location: loc, args: arith::CmpIPredicate::sge, args&: idx, args&: sz1);
506 Value odd = rewriter.create<arith::SelectOp>(location: loc, args&: cond, args&: one, args&: zero);
507 sz = rewriter.create<arith::AddIOp>(location: loc, args&: sz, args&: odd);
508 shardShape.emplace_back(Args&: sz);
509 }
510 pos += numShards + 1; // add one for the total size.
511 } // else no sharding if split axis is empty or no split axis
512 // If no size was added -> no sharding in this dimension.
513 if (shardShape.size() <= i)
514 shardShape.emplace_back(Args&: dim);
515 }
516 assert(shardShape.size() == shape.size());
517 rewriter.replaceOp(op, newValues: shardShape);
518 return success();
519 }
520};
521
522static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) {
523 auto ctx = kind.getContext();
524 auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) {
525 return mpi::MPI_ReductionOpEnumAttr::get(context: ctx, val: redOp);
526 };
527
528 switch (kind.getValue()) {
529 case ReductionKind::Sum:
530 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_SUM);
531 case ReductionKind::Product:
532 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_PROD);
533 case ReductionKind::Min:
534 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MIN);
535 case ReductionKind::Max:
536 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_MAX);
537 case ReductionKind::BitwiseAnd:
538 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BAND);
539 case ReductionKind::BitwiseOr:
540 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BOR);
541 case ReductionKind::BitwiseXor:
542 return getReductionOp(mpi::MPI_ReductionOpEnum::MPI_BXOR);
543 default:
544 llvm_unreachable("Unknown/unsupported reduction kind");
545 }
546}
547
548struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> {
549 using OpConversionPattern::OpConversionPattern;
550
551 LogicalResult
552 matchAndRewrite(AllReduceOp op, OpAdaptor adaptor,
553 ConversionPatternRewriter &rewriter) const override {
554 SymbolTableCollection symbolTableCollection;
555 auto mesh = adaptor.getMesh();
556 mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection);
557 if (!meshOp)
558 return op->emitError() << "No mesh found for AllReduceOp";
559 if (ShapedType::isDynamicShape(dSizes: meshOp.getShape()))
560 return op->emitError()
561 << "Dynamic mesh shape not supported in AllReduceOp";
562
563 ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter);
564 Value input = adaptor.getInput();
565 auto inputShape = cast<ShapedType>(Val: input.getType()).getShape();
566
567 // If the source is a memref, cast it to a tensor.
568 if (isa<RankedTensorType>(Val: input.getType())) {
569 auto memrefType = MemRefType::get(
570 shape: inputShape, elementType: cast<ShapedType>(Val: input.getType()).getElementType());
571 input = iBuilder.create<bufferization::ToBufferOp>(args&: memrefType, args&: input);
572 }
573 MemRefType inType = cast<MemRefType>(Val: input.getType());
574
575 // Get the actual shape to allocate the buffer.
576 SmallVector<OpFoldResult> shape(inType.getRank());
577 for (auto i = 0; i < inType.getRank(); ++i) {
578 auto s = inputShape[i];
579 if (ShapedType::isDynamic(dValue: s))
580 shape[i] = iBuilder.create<memref::DimOp>(args&: input, args&: s).getResult();
581 else
582 shape[i] = iBuilder.getIndexAttr(value: s);
583 }
584
585 // Allocate buffer and copy input to buffer.
586 Value buffer = iBuilder.create<memref::AllocOp>(
587 args&: shape, args: cast<ShapedType>(Val: op.getType()).getElementType());
588 iBuilder.create<linalg::CopyOp>(args&: input, args&: buffer);
589
590 // Get an MPI_Comm_split for the AllReduce operation.
591 // The color is the linear index of the process in the mesh along the
592 // non-reduced axes. The key is the linear index of the process in the mesh
593 // along the reduced axes.
594 SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
595 iBuilder.getIndexType());
596 SmallVector<Value> myMultiIndex =
597 iBuilder.create<ProcessMultiIndexOp>(args&: indexResultTypes, args&: mesh)
598 .getResult();
599 Value zero = iBuilder.create<arith::ConstantIndexOp>(args: 0);
600 SmallVector<Value> multiKey(myMultiIndex.size(), zero);
601
602 auto redAxes = adaptor.getMeshAxes();
603 for (auto axis : redAxes) {
604 multiKey[axis] = myMultiIndex[axis];
605 myMultiIndex[axis] = zero;
606 }
607
608 Value color =
609 createProcessLinearIndex(mesh, processInGroupMultiIndex: myMultiIndex, meshAxes: redAxes, builder&: iBuilder);
610 color = iBuilder.create<arith::IndexCastOp>(args: iBuilder.getI32Type(), args&: color);
611 Value key = createProcessLinearIndex(mesh, processInGroupMultiIndex: multiKey, meshAxes: redAxes, builder&: iBuilder);
612 key = iBuilder.create<arith::IndexCastOp>(args: iBuilder.getI32Type(), args&: key);
613
614 // Finally split the communicator
615 auto commType = mpi::CommType::get(ctx: op->getContext());
616 Value commWorld = iBuilder.create<mpi::CommWorldOp>(args&: commType);
617 auto comm =
618 iBuilder.create<mpi::CommSplitOp>(args&: commType, args&: commWorld, args&: color, args&: key)
619 .getNewcomm();
620
621 Value buffer1d = buffer;
622 // Collapse shape to 1d if needed
623 if (inType.getRank() > 1) {
624 ReassociationIndices reassociation(inType.getRank());
625 std::iota(first: reassociation.begin(), last: reassociation.end(), value: 0);
626 buffer1d = iBuilder.create<memref::CollapseShapeOp>(
627 args&: buffer, args: ArrayRef<ReassociationIndices>(reassociation));
628 }
629
630 // Create the MPI AllReduce operation.
631 iBuilder.create<mpi::AllReduceOp>(
632 args: TypeRange(), args&: buffer1d, args&: buffer1d,
633 args: getMPIReductionOp(kind: adaptor.getReductionAttr()), args&: comm);
634
635 // If the destination is a memref, cast it to a tensor
636 if (isa<RankedTensorType>(Val: op.getType()))
637 buffer = iBuilder.create<bufferization::ToTensorOp>(args: op.getType(), args&: buffer,
638 args: true);
639
640 rewriter.replaceOp(op, newValues: buffer);
641 return success();
642 }
643};
644
645struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> {
646 using OpConversionPattern::OpConversionPattern;
647
648 LogicalResult
649 matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor,
650 ConversionPatternRewriter &rewriter) const override {
651
652 // The input/output memref is assumed to be in C memory order.
653 // Halos are exchanged as 2 blocks per dimension (one for each side: down
654 // and up). For each haloed dimension `d`, the exchanged blocks are
655 // expressed as multi-dimensional subviews. The subviews include potential
656 // halos of higher dimensions `dh > d`, no halos for the lower dimensions
657 // `dl < d` and for dimension `d` the currently exchanged halo only.
658 // By iterating form higher to lower dimensions this also updates the halos
659 // in the 'corners'.
660 // memref.subview is used to read and write the halo data from and to the
661 // local data. Because subviews and halos can have mixed dynamic and static
662 // shapes, OpFoldResults are used whenever possible.
663
664 auto haloSizes = getMixedValues(staticValues: adaptor.getStaticHaloSizes(),
665 dynamicValues: adaptor.getHaloSizes(), b&: rewriter);
666 if (haloSizes.empty()) {
667 // no halos -> nothing to do
668 rewriter.replaceOp(op, newValues: adaptor.getDestination());
669 return success();
670 }
671
672 SymbolTableCollection symbolTableCollection;
673 Location loc = op.getLoc();
674
675 // convert a OpFoldResult into a Value
676 auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value {
677 if (auto value = dyn_cast<Value>(Val&: v))
678 return value;
679 return rewriter.create<arith::ConstantOp>(
680 location: loc, args: rewriter.getIndexAttr(
681 value: cast<IntegerAttr>(Val: cast<Attribute>(Val&: v)).getInt()));
682 };
683
684 auto dest = adaptor.getDestination();
685 auto dstShape = cast<ShapedType>(Val: dest.getType()).getShape();
686 Value array = dest;
687 if (isa<RankedTensorType>(Val: array.getType())) {
688 // If the destination is a memref, we need to cast it to a tensor
689 auto mmemrefType = MemRefType::get(
690 shape: dstShape, elementType: cast<ShapedType>(Val: array.getType()).getElementType());
691 array =
692 rewriter.create<bufferization::ToBufferOp>(location: loc, args&: mmemrefType, args&: array);
693 }
694 auto rank = cast<ShapedType>(Val: array.getType()).getRank();
695 auto opSplitAxes = adaptor.getSplitAxes().getAxes();
696 auto mesh = adaptor.getMesh();
697 auto meshOp = getMesh(op, symbolTableCollection);
698 // subviews need Index values
699 for (auto &sz : haloSizes) {
700 if (auto value = dyn_cast<Value>(Val&: sz))
701 sz =
702 rewriter
703 .create<arith::IndexCastOp>(location: loc, args: rewriter.getIndexType(), args&: value)
704 .getResult();
705 }
706
707 // most of the offset/size/stride data is the same for all dims
708 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(value: 0));
709 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(value: 1));
710 SmallVector<OpFoldResult> shape(rank), dimSizes(rank);
711 auto currHaloDim = -1; // halo sizes are provided for split dimensions only
712 // we need the actual shape to compute offsets and sizes
713 for (auto i = 0; i < rank; ++i) {
714 auto s = dstShape[i];
715 if (ShapedType::isDynamic(dValue: s))
716 shape[i] = rewriter.create<memref::DimOp>(location: loc, args&: array, args&: s).getResult();
717 else
718 shape[i] = rewriter.getIndexAttr(value: s);
719
720 if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) {
721 ++currHaloDim;
722 // the offsets for lower dim sstarts after their down halo
723 offsets[i] = haloSizes[currHaloDim * 2];
724
725 // prepare shape and offsets of highest dim's halo exchange
726 Value _haloSz = rewriter.create<arith::AddIOp>(
727 location: loc, args: toValue(haloSizes[currHaloDim * 2]),
728 args: toValue(haloSizes[currHaloDim * 2 + 1]));
729 // the halo shape of lower dims exlude the halos
730 dimSizes[i] =
731 rewriter.create<arith::SubIOp>(location: loc, args: toValue(shape[i]), args&: _haloSz)
732 .getResult();
733 } else {
734 dimSizes[i] = shape[i];
735 }
736 }
737
738 auto tagAttr = rewriter.getI32IntegerAttr(value: 91); // we just pick something
739 auto tag = rewriter.create<arith::ConstantOp>(location: loc, args&: tagAttr);
740 auto zeroAttr = rewriter.getI32IntegerAttr(value: 0); // for detecting v<0
741 auto zero = rewriter.create<arith::ConstantOp>(location: loc, args&: zeroAttr);
742
743 SmallVector<Type> indexResultTypes(meshOp.getShape().size(),
744 rewriter.getIndexType());
745 auto myMultiIndex =
746 rewriter.create<ProcessMultiIndexOp>(location: loc, args&: indexResultTypes, args&: mesh)
747 .getResult();
748 // traverse all split axes from high to low dim
749 for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) {
750 auto splitAxes = opSplitAxes[dim];
751 if (splitAxes.empty())
752 continue;
753 assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2);
754 // Get the linearized ids of the neighbors (down and up) for the
755 // given split
756 auto tmp = rewriter
757 .create<NeighborsLinearIndicesOp>(location: loc, args&: mesh, args&: myMultiIndex,
758 args&: splitAxes)
759 .getResults();
760 // MPI operates on i32...
761 Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>(
762 location: loc, args: rewriter.getI32Type(), args: tmp[0]),
763 rewriter.create<arith::IndexCastOp>(
764 location: loc, args: rewriter.getI32Type(), args: tmp[1])};
765
766 auto lowerRecvOffset = rewriter.getIndexAttr(value: 0);
767 auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]);
768 auto upperRecvOffset = rewriter.create<arith::SubIOp>(
769 location: loc, args: toValue(shape[dim]), args: toValue(haloSizes[currHaloDim * 2 + 1]));
770 auto upperSendOffset = rewriter.create<arith::SubIOp>(
771 location: loc, args&: upperRecvOffset, args: toValue(haloSizes[currHaloDim * 2]));
772
773 Value commWorld = rewriter.create<mpi::CommWorldOp>(
774 location: loc, args: mpi::CommType::get(ctx: op->getContext()));
775
776 // Make sure we send/recv in a way that does not lead to a dead-lock.
777 // The current approach is by far not optimal, this should be at least
778 // be a red-black pattern or using MPI_sendrecv.
779 // Also, buffers should be re-used.
780 // Still using temporary contiguous buffers for MPI communication...
781 // Still yielding a "serialized" communication pattern...
782 auto genSendRecv = [&](bool upperHalo) {
783 auto orgOffset = offsets[dim];
784 dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1]
785 : haloSizes[currHaloDim * 2];
786 // Check if we need to send and/or receive
787 // Processes on the mesh borders have only one neighbor
788 auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1];
789 auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0];
790 auto hasFrom = rewriter.create<arith::CmpIOp>(
791 location: loc, args: arith::CmpIPredicate::sge, args&: from, args&: zero);
792 auto hasTo = rewriter.create<arith::CmpIOp>(
793 location: loc, args: arith::CmpIPredicate::sge, args&: to, args&: zero);
794 auto buffer = rewriter.create<memref::AllocOp>(
795 location: loc, args&: dimSizes, args: cast<ShapedType>(Val: array.getType()).getElementType());
796 // if has neighbor: copy halo data from array to buffer and send
797 rewriter.create<scf::IfOp>(
798 location: loc, args&: hasTo, args: [&](OpBuilder &builder, Location loc) {
799 offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset)
800 : OpFoldResult(upperSendOffset);
801 auto subview = builder.create<memref::SubViewOp>(
802 location: loc, args&: array, args&: offsets, args&: dimSizes, args&: strides);
803 builder.create<memref::CopyOp>(location: loc, args&: subview, args&: buffer);
804 builder.create<mpi::SendOp>(location: loc, args: TypeRange{}, args&: buffer, args&: tag, args&: to,
805 args&: commWorld);
806 builder.create<scf::YieldOp>(location: loc);
807 });
808 // if has neighbor: receive halo data into buffer and copy to array
809 rewriter.create<scf::IfOp>(
810 location: loc, args&: hasFrom, args: [&](OpBuilder &builder, Location loc) {
811 offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset)
812 : OpFoldResult(lowerRecvOffset);
813 builder.create<mpi::RecvOp>(location: loc, args: TypeRange{}, args&: buffer, args&: tag, args&: from,
814 args&: commWorld);
815 auto subview = builder.create<memref::SubViewOp>(
816 location: loc, args&: array, args&: offsets, args&: dimSizes, args&: strides);
817 builder.create<memref::CopyOp>(location: loc, args&: buffer, args&: subview);
818 builder.create<scf::YieldOp>(location: loc);
819 });
820 rewriter.create<memref::DeallocOp>(location: loc, args&: buffer);
821 offsets[dim] = orgOffset;
822 };
823
824 auto doSendRecv = [&](int upOrDown) {
825 OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown];
826 Value haloSz = dyn_cast<Value>(Val&: v);
827 if (!haloSz)
828 haloSz = rewriter.create<arith::ConstantOp>(
829 location: loc, args: rewriter.getI32IntegerAttr(
830 value: cast<IntegerAttr>(Val: cast<Attribute>(Val&: v)).getInt()));
831 auto hasSize = rewriter.create<arith::CmpIOp>(
832 location: loc, args: arith::CmpIPredicate::sgt, args&: haloSz, args&: zero);
833 rewriter.create<scf::IfOp>(location: loc, args&: hasSize,
834 args: [&](OpBuilder &builder, Location loc) {
835 genSendRecv(upOrDown > 0);
836 builder.create<scf::YieldOp>(location: loc);
837 });
838 };
839
840 doSendRecv(0);
841 doSendRecv(1);
842
843 // the shape for lower dims include higher dims' halos
844 dimSizes[dim] = shape[dim];
845 // -> the offset for higher dims is always 0
846 offsets[dim] = rewriter.getIndexAttr(value: 0);
847 // on to next halo
848 --currHaloDim;
849 }
850
851 if (isa<MemRefType>(Val: op.getResult().getType())) {
852 rewriter.replaceOp(op, newValues: array);
853 } else {
854 assert(isa<RankedTensorType>(op.getResult().getType()));
855 rewriter.replaceOp(op, newOp: rewriter.create<bufferization::ToTensorOp>(
856 location: loc, args: op.getResult().getType(), args&: array,
857 /*restrict=*/args: true, /*writable=*/args: true));
858 }
859 return success();
860 }
861};
862
863struct ConvertMeshToMPIPass
864 : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> {
865 using Base::Base;
866
867 /// Run the dialect converter on the module.
868 void runOnOperation() override {
869 auto *ctxt = &getContext();
870 RewritePatternSet patterns(ctxt);
871 ConversionTarget target(getContext());
872
873 // Define a type converter to convert mesh::ShardingType,
874 // mostly for use in return operations.
875 TypeConverter typeConverter;
876 typeConverter.addConversion(callback: [](Type type) { return type; });
877
878 // convert mesh::ShardingType to a tuple of RankedTensorTypes
879 typeConverter.addConversion(
880 callback: [](ShardingType type,
881 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
882 auto i16 = IntegerType::get(context: type.getContext(), width: 16);
883 auto i64 = IntegerType::get(context: type.getContext(), width: 64);
884 std::array<int64_t, 2> shp = {ShapedType::kDynamic,
885 ShapedType::kDynamic};
886 results.emplace_back(Args: RankedTensorType::get(shape: shp, elementType: i16));
887 results.emplace_back(Args: RankedTensorType::get(shape: shp, elementType: i64)); // actually ?x2
888 results.emplace_back(Args: RankedTensorType::get(shape: shp, elementType: i64));
889 return success();
890 });
891
892 // To 'extract' components, a UnrealizedConversionCastOp is expected
893 // to define the input
894 typeConverter.addTargetMaterialization(
895 callback: [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
896 Location loc) {
897 // Expecting a single input.
898 if (inputs.size() != 1 || !isa<TupleType>(Val: inputs[0].getType()))
899 return SmallVector<Value>();
900 auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>();
901 // Expecting an UnrealizedConversionCastOp.
902 if (!castOp)
903 return SmallVector<Value>();
904 // Fill a vector with elements of the tuple/castOp.
905 SmallVector<Value> results;
906 for (auto oprnd : castOp.getInputs()) {
907 if (!isa<RankedTensorType>(Val: oprnd.getType()))
908 return SmallVector<Value>();
909 results.emplace_back(Args&: oprnd);
910 }
911 return results;
912 });
913
914 // No mesh dialect should left after conversion...
915 target.addIllegalDialect<mesh::MeshDialect>();
916 // ...except the global MeshOp. MeshShapeOp which will get folded later.
917 target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>();
918 // Allow all the stuff that our patterns will convert to
919 target.addLegalDialect<
920 BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect,
921 tensor::TensorDialect, bufferization::BufferizationDialect,
922 linalg::LinalgDialect, memref::MemRefDialect, affine::AffineDialect>();
923 // Make sure the function signature, calls etc. are legal
924 target.addDynamicallyLegalOp<func::FuncOp>(callback: [&](func::FuncOp op) {
925 return typeConverter.isSignatureLegal(ty: op.getFunctionType());
926 });
927 target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
928 callback: [&](Operation *op) { return typeConverter.isLegal(op); });
929
930 patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp,
931 ConvertProcessMultiIndexOp, ConvertGetShardingOp,
932 ConvertShardingOp, ConvertShardShapeOp, ConvertAllReduceOp,
933 ConvertProcessLinearIndexOp>(arg&: typeConverter, args&: ctxt);
934
935 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
936 patterns, converter: typeConverter);
937 populateCallOpTypeConversionPattern(patterns, converter: typeConverter);
938 populateReturnOpTypeConversionPattern(patterns, converter: typeConverter);
939
940 (void)applyPartialConversion(op: getOperation(), target, patterns: std::move(patterns));
941
942 // Folding patterns cannot be mixed with conversion patterns -> extra pass.
943 patterns.clear();
944 SymbolTableCollection symbolTableCollection;
945 mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection);
946 (void)applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns));
947 }
948};
949
950} // namespace
951

source code of mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp