1 | //===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a translation of Mesh communication ops tp MPI ops. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/MeshToMPI/MeshToMPI.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
17 | #include "mlir/Dialect/DLTI/DLTI.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
19 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
20 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
21 | #include "mlir/Dialect/MPI/IR/MPI.h" |
22 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
23 | #include "mlir/Dialect/Mesh/IR/MeshDialect.h" |
24 | #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
25 | #include "mlir/Dialect/SCF/IR/SCF.h" |
26 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
27 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
28 | #include "mlir/IR/Builders.h" |
29 | #include "mlir/IR/BuiltinAttributes.h" |
30 | #include "mlir/IR/BuiltinTypes.h" |
31 | #include "mlir/IR/PatternMatch.h" |
32 | #include "mlir/IR/SymbolTable.h" |
33 | #include "mlir/Transforms/DialectConversion.h" |
34 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
35 | |
36 | #define DEBUG_TYPE "mesh-to-mpi" |
37 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
38 | |
39 | namespace mlir { |
40 | #define GEN_PASS_DEF_CONVERTMESHTOMPIPASS |
41 | #include "mlir/Conversion/Passes.h.inc" |
42 | } // namespace mlir |
43 | |
44 | using namespace mlir; |
45 | using namespace mesh; |
46 | |
47 | namespace { |
48 | /// Converts a vector of OpFoldResults (ints) into vector of Values of the |
49 | /// provided type. |
50 | static SmallVector<Value> getMixedAsValues(OpBuilder b, const Location &loc, |
51 | llvm::ArrayRef<int64_t> statics, |
52 | ValueRange dynamics, |
53 | Type type = Type()) { |
54 | SmallVector<Value> values; |
55 | auto dyn = dynamics.begin(); |
56 | Type i64 = b.getI64Type(); |
57 | if (!type) |
58 | type = i64; |
59 | assert((i64 == type || b.getIndexType() == type) && |
60 | "expected an i64 or an intex type" ); |
61 | for (auto s : statics) { |
62 | if (s == ShapedType::kDynamic) { |
63 | values.emplace_back(*(dyn++)); |
64 | } else { |
65 | TypedAttr val = type == i64 ? b.getI64IntegerAttr(s) : b.getIndexAttr(s); |
66 | values.emplace_back(b.create<arith::ConstantOp>(loc, type, val)); |
67 | } |
68 | } |
69 | return values; |
70 | } |
71 | |
72 | /// Create operations converting a linear index to a multi-dimensional index. |
73 | static SmallVector<Value> linearToMultiIndex(Location loc, OpBuilder b, |
74 | Value linearIndex, |
75 | ValueRange dimensions) { |
76 | int n = dimensions.size(); |
77 | SmallVector<Value> multiIndex(n); |
78 | |
79 | for (int i = n - 1; i >= 0; --i) { |
80 | multiIndex[i] = b.create<arith::RemSIOp>(loc, linearIndex, dimensions[i]); |
81 | if (i > 0) |
82 | linearIndex = b.create<arith::DivSIOp>(loc, linearIndex, dimensions[i]); |
83 | } |
84 | |
85 | return multiIndex; |
86 | } |
87 | |
88 | /// Create operations converting a multi-dimensional index to a linear index. |
89 | Value multiToLinearIndex(Location loc, OpBuilder b, ValueRange multiIndex, |
90 | ValueRange dimensions) { |
91 | |
92 | Value linearIndex = b.create<arith::ConstantIndexOp>(loc, 0); |
93 | Value stride = b.create<arith::ConstantIndexOp>(loc, 1); |
94 | |
95 | for (int i = multiIndex.size() - 1; i >= 0; --i) { |
96 | Value off = b.create<arith::MulIOp>(loc, multiIndex[i], stride); |
97 | linearIndex = b.create<arith::AddIOp>(loc, linearIndex, off); |
98 | stride = b.create<arith::MulIOp>(loc, stride, dimensions[i]); |
99 | } |
100 | |
101 | return linearIndex; |
102 | } |
103 | |
104 | /// Replace GetShardingOp with related/dependent ShardingOp. |
105 | struct ConvertGetShardingOp : public OpConversionPattern<GetShardingOp> { |
106 | using OpConversionPattern::OpConversionPattern; |
107 | |
108 | LogicalResult |
109 | matchAndRewrite(GetShardingOp op, OpAdaptor adaptor, |
110 | ConversionPatternRewriter &rewriter) const override { |
111 | auto shardOp = adaptor.getSource().getDefiningOp<ShardOp>(); |
112 | if (!shardOp) |
113 | return failure(); |
114 | auto shardingOp = shardOp.getSharding().getDefiningOp<ShardingOp>(); |
115 | if (!shardingOp) |
116 | return failure(); |
117 | |
118 | rewriter.replaceOp(op, shardingOp.getResult()); |
119 | return success(); |
120 | } |
121 | }; |
122 | |
123 | /// Convert a sharding op to a tuple of tensors of its components |
124 | /// (SplitAxes, HaloSizes, ShardedDimsOffsets) |
125 | /// as defined by type converter. |
126 | struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { |
127 | using OpConversionPattern::OpConversionPattern; |
128 | |
129 | LogicalResult |
130 | matchAndRewrite(ShardingOp op, OpAdaptor adaptor, |
131 | ConversionPatternRewriter &rewriter) const override { |
132 | auto splitAxes = op.getSplitAxes().getAxes(); |
133 | int64_t maxNAxes = 0; |
134 | for (auto axes : splitAxes) |
135 | maxNAxes = std::max<int64_t>(maxNAxes, axes.size()); |
136 | |
137 | // To hold the split axes, create empty 2d tensor with shape |
138 | // {splitAxes.size(), max-size-of-split-groups}. |
139 | // Set trailing elements for smaller split-groups to -1. |
140 | Location loc = op.getLoc(); |
141 | auto i16 = rewriter.getI16Type(); |
142 | auto i64 = rewriter.getI64Type(); |
143 | std::array<int64_t, 2> shape = {static_cast<int64_t>(splitAxes.size()), |
144 | maxNAxes}; |
145 | Value resSplitAxes = rewriter.create<tensor::EmptyOp>(loc, shape, i16); |
146 | auto attr = IntegerAttr::get(i16, -1); |
147 | Value fillValue = rewriter.create<arith::ConstantOp>(loc, i16, attr); |
148 | resSplitAxes = rewriter.create<linalg::FillOp>(loc, fillValue, resSplitAxes) |
149 | .getResult(0); |
150 | |
151 | // explicitly write values into tensor row by row |
152 | std::array<int64_t, 2> strides = {1, 1}; |
153 | int64_t nSplits = 0; |
154 | ValueRange empty = {}; |
155 | for (auto [i, axes] : llvm::enumerate(splitAxes)) { |
156 | int64_t size = axes.size(); |
157 | if (size > 0) |
158 | ++nSplits; |
159 | std::array<int64_t, 2> offs = {(int64_t)i, 0}; |
160 | std::array<int64_t, 2> sizes = {1, size}; |
161 | auto tensorType = RankedTensorType::get({size}, i16); |
162 | auto attrs = DenseIntElementsAttr::get(tensorType, axes.asArrayRef()); |
163 | auto vals = rewriter.create<arith::ConstantOp>(loc, tensorType, attrs); |
164 | resSplitAxes = rewriter.create<tensor::InsertSliceOp>( |
165 | loc, vals, resSplitAxes, empty, empty, empty, offs, sizes, strides); |
166 | } |
167 | |
168 | // To hold halos sizes, create 2d Tensor with shape {nSplits, 2}. |
169 | // Store the halo sizes in the tensor. |
170 | SmallVector<Value> haloSizes = |
171 | getMixedAsValues(rewriter, loc, adaptor.getStaticHaloSizes(), |
172 | adaptor.getDynamicHaloSizes()); |
173 | auto type = RankedTensorType::get({nSplits, 2}, i64); |
174 | Value resHaloSizes = |
175 | haloSizes.empty() |
176 | ? rewriter |
177 | .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0}, |
178 | i64) |
179 | .getResult() |
180 | : rewriter.create<tensor::FromElementsOp>(loc, type, haloSizes) |
181 | .getResult(); |
182 | |
183 | // To hold sharded dims offsets, create Tensor with shape {nSplits, |
184 | // maxSplitSize+1}. Store the offsets in the tensor but set trailing |
185 | // elements for smaller split-groups to -1. Computing the max size of the |
186 | // split groups needs using collectiveProcessGroupSize (which needs the |
187 | // MeshOp) |
188 | Value resOffsets; |
189 | if (adaptor.getStaticShardedDimsOffsets().empty()) { |
190 | resOffsets = rewriter.create<tensor::EmptyOp>( |
191 | loc, std::array<int64_t, 2>{0, 0}, i64); |
192 | } else { |
193 | SymbolTableCollection symbolTableCollection; |
194 | auto meshOp = getMesh(op, symbolTableCollection); |
195 | int64_t maxSplitSize = 0; |
196 | for (auto axes : splitAxes) { |
197 | int64_t splitSize = |
198 | collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); |
199 | assert(splitSize != ShapedType::kDynamic); |
200 | maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize); |
201 | } |
202 | assert(maxSplitSize); |
203 | ++maxSplitSize; // add one for the total size |
204 | |
205 | resOffsets = rewriter.create<tensor::EmptyOp>( |
206 | loc, std::array<int64_t, 2>{nSplits, maxSplitSize}, i64); |
207 | Value zero = rewriter.create<arith::ConstantOp>( |
208 | loc, i64, rewriter.getI64IntegerAttr(ShapedType::kDynamic)); |
209 | resOffsets = |
210 | rewriter.create<linalg::FillOp>(loc, zero, resOffsets).getResult(0); |
211 | SmallVector<Value> offsets = |
212 | getMixedAsValues(rewriter, loc, adaptor.getStaticShardedDimsOffsets(), |
213 | adaptor.getDynamicShardedDimsOffsets()); |
214 | int64_t curr = 0; |
215 | for (auto [i, axes] : llvm::enumerate(splitAxes)) { |
216 | int64_t splitSize = |
217 | collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); |
218 | assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); |
219 | ++splitSize; // add one for the total size |
220 | ArrayRef<Value> values(&offsets[curr], splitSize); |
221 | Value vals = rewriter.create<tensor::FromElementsOp>(loc, values); |
222 | std::array<int64_t, 2> offs = {static_cast<int64_t>(i), 0}; |
223 | std::array<int64_t, 2> sizes = {1, splitSize}; |
224 | resOffsets = rewriter.create<tensor::InsertSliceOp>( |
225 | loc, vals, resOffsets, empty, empty, empty, offs, sizes, strides); |
226 | curr += splitSize; |
227 | } |
228 | } |
229 | |
230 | // return a tuple of tensors as defined by type converter |
231 | SmallVector<Type> resTypes; |
232 | if (failed(getTypeConverter()->convertType(op.getResult().getType(), |
233 | resTypes))) |
234 | return failure(); |
235 | |
236 | resSplitAxes = |
237 | rewriter.create<tensor::CastOp>(loc, resTypes[0], resSplitAxes); |
238 | resHaloSizes = |
239 | rewriter.create<tensor::CastOp>(loc, resTypes[1], resHaloSizes); |
240 | resOffsets = rewriter.create<tensor::CastOp>(loc, resTypes[2], resOffsets); |
241 | |
242 | rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>( |
243 | op, TupleType::get(op.getContext(), resTypes), |
244 | ValueRange{resSplitAxes, resHaloSizes, resOffsets}); |
245 | |
246 | return success(); |
247 | } |
248 | }; |
249 | |
250 | struct ConvertProcessMultiIndexOp |
251 | : public OpConversionPattern<ProcessMultiIndexOp> { |
252 | using OpConversionPattern::OpConversionPattern; |
253 | |
254 | LogicalResult |
255 | matchAndRewrite(ProcessMultiIndexOp op, OpAdaptor adaptor, |
256 | ConversionPatternRewriter &rewriter) const override { |
257 | |
258 | // Currently converts its linear index to a multi-dimensional index. |
259 | |
260 | SymbolTableCollection symbolTableCollection; |
261 | Location loc = op.getLoc(); |
262 | auto meshOp = getMesh(op, symbolTableCollection); |
263 | // For now we only support static mesh shapes |
264 | if (ShapedType::isDynamicShape(meshOp.getShape())) |
265 | return failure(); |
266 | |
267 | SmallVector<Value> dims; |
268 | llvm::transform( |
269 | meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { |
270 | return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult(); |
271 | }); |
272 | Value rank = rewriter.create<ProcessLinearIndexOp>(op.getLoc(), meshOp); |
273 | auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); |
274 | |
275 | // optionally extract subset of mesh axes |
276 | auto axes = adaptor.getAxes(); |
277 | if (!axes.empty()) { |
278 | SmallVector<Value> subIndex; |
279 | for (auto axis : axes) { |
280 | subIndex.emplace_back(mIdx[axis]); |
281 | } |
282 | mIdx = std::move(subIndex); |
283 | } |
284 | |
285 | rewriter.replaceOp(op, mIdx); |
286 | return success(); |
287 | } |
288 | }; |
289 | |
290 | class ConvertProcessLinearIndexOp |
291 | : public OpConversionPattern<ProcessLinearIndexOp> { |
292 | int64_t worldRank; // rank in MPI_COMM_WORLD if available, else < 0 |
293 | |
294 | public: |
295 | using OpConversionPattern::OpConversionPattern; |
296 | |
297 | // Constructor accepting worldRank |
298 | ConvertProcessLinearIndexOp(const TypeConverter &typeConverter, |
299 | MLIRContext *context, int64_t worldRank = -1) |
300 | : OpConversionPattern(typeConverter, context), worldRank(worldRank) {} |
301 | |
302 | LogicalResult |
303 | matchAndRewrite(ProcessLinearIndexOp op, OpAdaptor adaptor, |
304 | ConversionPatternRewriter &rewriter) const override { |
305 | |
306 | Location loc = op.getLoc(); |
307 | if (worldRank >= 0) { // if rank in MPI_COMM_WORLD is known -> use it |
308 | rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, worldRank); |
309 | return success(); |
310 | } |
311 | |
312 | // Otherwise call create mpi::CommRankOp |
313 | auto ctx = op.getContext(); |
314 | Value commWorld = |
315 | rewriter.create<mpi::CommWorldOp>(loc, mpi::CommType::get(ctx)); |
316 | auto rank = |
317 | rewriter |
318 | .create<mpi::CommRankOp>( |
319 | loc, |
320 | TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, |
321 | commWorld) |
322 | .getRank(); |
323 | rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), |
324 | rank); |
325 | return success(); |
326 | } |
327 | }; |
328 | |
329 | struct ConvertNeighborsLinearIndicesOp |
330 | : public OpConversionPattern<NeighborsLinearIndicesOp> { |
331 | using OpConversionPattern::OpConversionPattern; |
332 | |
333 | LogicalResult |
334 | matchAndRewrite(NeighborsLinearIndicesOp op, OpAdaptor adaptor, |
335 | ConversionPatternRewriter &rewriter) const override { |
336 | |
337 | // Computes the neighbors indices along a split axis by simply |
338 | // adding/subtracting 1 to the current index in that dimension. |
339 | // Assigns -1 if neighbor is out of bounds. |
340 | |
341 | auto axes = adaptor.getSplitAxes(); |
342 | // For now only single axis sharding is supported |
343 | if (axes.size() != 1) |
344 | return failure(); |
345 | |
346 | Location loc = op.getLoc(); |
347 | SymbolTableCollection symbolTableCollection; |
348 | auto meshOp = getMesh(op, symbolTableCollection); |
349 | auto mIdx = adaptor.getDevice(); |
350 | auto orgIdx = mIdx[axes[0]]; |
351 | SmallVector<Value> dims; |
352 | llvm::transform( |
353 | meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { |
354 | return rewriter.create<arith::ConstantIndexOp>(loc, i).getResult(); |
355 | }); |
356 | Value dimSz = dims[axes[0]]; |
357 | Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
358 | Value minus1 = rewriter.create<arith::ConstantIndexOp>(loc, -1); |
359 | Value atBorder = rewriter.create<arith::CmpIOp>( |
360 | loc, arith::CmpIPredicate::sle, orgIdx, |
361 | rewriter.create<arith::ConstantIndexOp>(loc, 0)); |
362 | auto down = rewriter.create<scf::IfOp>( |
363 | loc, atBorder, |
364 | [&](OpBuilder &builder, Location loc) { |
365 | builder.create<scf::YieldOp>(loc, minus1); |
366 | }, |
367 | [&](OpBuilder &builder, Location loc) { |
368 | SmallVector<Value> tmp = mIdx; |
369 | tmp[axes[0]] = |
370 | rewriter.create<arith::SubIOp>(op.getLoc(), orgIdx, one) |
371 | .getResult(); |
372 | builder.create<scf::YieldOp>( |
373 | loc, multiToLinearIndex(loc, rewriter, tmp, dims)); |
374 | }); |
375 | atBorder = rewriter.create<arith::CmpIOp>( |
376 | loc, arith::CmpIPredicate::sge, orgIdx, |
377 | rewriter.create<arith::SubIOp>(loc, dimSz, one).getResult()); |
378 | auto up = rewriter.create<scf::IfOp>( |
379 | loc, atBorder, |
380 | [&](OpBuilder &builder, Location loc) { |
381 | builder.create<scf::YieldOp>(loc, minus1); |
382 | }, |
383 | [&](OpBuilder &builder, Location loc) { |
384 | SmallVector<Value> tmp = mIdx; |
385 | tmp[axes[0]] = |
386 | rewriter.create<arith::AddIOp>(op.getLoc(), orgIdx, one); |
387 | builder.create<scf::YieldOp>( |
388 | loc, multiToLinearIndex(loc, rewriter, tmp, dims)); |
389 | }); |
390 | rewriter.replaceOp(op, ValueRange{down.getResult(0), up.getResult(0)}); |
391 | return success(); |
392 | } |
393 | }; |
394 | |
395 | struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { |
396 | using OpConversionPattern::OpConversionPattern; |
397 | |
398 | LogicalResult |
399 | matchAndRewrite(ShardShapeOp op, OneToNOpAdaptor adaptor, |
400 | ConversionPatternRewriter &rewriter) const override { |
401 | auto sharding = op.getSharding().getDefiningOp<ShardingOp>(); |
402 | if (!sharding) { |
403 | return op->emitError() |
404 | << "Expected SharingOp as defining op for sharding" |
405 | << " but found " << adaptor.getSharding()[0].getDefiningOp(); |
406 | } |
407 | |
408 | // Compute the sharded shape by applying the sharding to the input shape. |
409 | // If shardedDimsOffsets is not defined in the sharding, the shard shape is |
410 | // computed by dividing the dimension size by the number of shards in that |
411 | // dimension (which is given by the size of the mesh axes provided in |
412 | // split-axes). Odd elements get distributed to trailing shards. If a |
413 | // shardedDimsOffsets is provided, the shard shape is computed by |
414 | // subtracting the offset of the current shard from the offset of the next |
415 | // shard. |
416 | |
417 | Location loc = op.getLoc(); |
418 | Type index = rewriter.getIndexType(); |
419 | |
420 | // This is a 1:N conversion because the sharding op is a 1:3 conversion. |
421 | // The operands in the adaptor are a vector<ValeRange>. For dims and device |
422 | // we have a 1:1 conversion. |
423 | // For simpler access fill a vector with the dynamic dims. |
424 | SmallVector<Value> dynDims, dynDevice; |
425 | for (auto dim : adaptor.getDimsDynamic()) { |
426 | // type conversion should be 1:1 for ints |
427 | dynDims.emplace_back(llvm::getSingleElement(dim)); |
428 | } |
429 | // same for device |
430 | for (auto device : adaptor.getDeviceDynamic()) { |
431 | dynDevice.emplace_back(llvm::getSingleElement(device)); |
432 | } |
433 | |
434 | // To keep the code simple, convert dims/device to values when they are |
435 | // attributes. Count on canonicalization to fold static values. |
436 | SmallVector<Value> shape = |
437 | getMixedAsValues(rewriter, loc, op.getDims(), dynDims, index); |
438 | SmallVector<Value> multiIdx = |
439 | getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index); |
440 | |
441 | // Get the MeshOp, the mesh shape is needed to compute the sharded shape. |
442 | SymbolTableCollection symbolTableCollection; |
443 | auto meshOp = getMesh(sharding, symbolTableCollection); |
444 | // For now we only support static mesh shapes |
445 | if (ShapedType::isDynamicShape(meshOp.getShape())) |
446 | return failure(); |
447 | |
448 | auto splitAxes = sharding.getSplitAxes().getAxes(); |
449 | // shardedDimsOffsets are optional and might be Values (not attributes). |
450 | // Also, the shardId might be dynamic which means the position in the |
451 | // shardedDimsOffsets is not statically known. Create a tensor of the |
452 | // shardedDimsOffsets and later extract the offsets for computing the |
453 | // local shard-size. |
454 | Value shardedDimsOffs; |
455 | { |
456 | SmallVector<Value> tmp = getMixedAsValues( |
457 | rewriter, loc, sharding.getStaticShardedDimsOffsets(), |
458 | sharding.getDynamicShardedDimsOffsets(), index); |
459 | if (!tmp.empty()) |
460 | shardedDimsOffs = rewriter.create<tensor::FromElementsOp>( |
461 | loc, RankedTensorType::get({(int64_t)tmp.size()}, index), tmp); |
462 | } |
463 | |
464 | // With static mesh shape the sizes of the split axes are known. |
465 | // Hence the start/pos for each split axes in shardDimsOffsets can be |
466 | // computed statically. |
467 | int64_t pos = 0; |
468 | SmallVector<Value> shardShape; |
469 | Value zero = |
470 | rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(index)); |
471 | Value one = |
472 | rewriter.create<arith::ConstantOp>(loc, rewriter.getOneAttr(index)); |
473 | |
474 | // Iterate over the dimensions of the tensor shape, get their split Axes, |
475 | // and compute the sharded shape. |
476 | for (auto [i, dim] : llvm::enumerate(shape)) { |
477 | // Trailing dimensions might not be annotated. |
478 | if (i < splitAxes.size() && !splitAxes[i].empty()) { |
479 | auto axes = splitAxes[i]; |
480 | // The current dimension might not be sharded. |
481 | // Create a value from the static position in shardDimsOffsets. |
482 | Value posVal = |
483 | rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(pos)); |
484 | // Get the index of the local shard in the mesh axis. |
485 | Value idx = multiIdx[axes[0]]; |
486 | auto numShards = |
487 | collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); |
488 | if (shardedDimsOffs) { |
489 | // If sharded dims offsets are provided, use them to compute the |
490 | // sharded shape. |
491 | if (axes.size() > 1) { |
492 | return op->emitError() << "Only single axis sharding is " |
493 | << "supported for each dimension." ; |
494 | } |
495 | idx = rewriter.create<arith::AddIOp>(loc, posVal, idx); |
496 | // Compute size = shardedDimsOffs[idx+1] - shardedDimsOffs[idx]. |
497 | Value off = |
498 | rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx); |
499 | idx = rewriter.create<arith::AddIOp>(loc, idx, one); |
500 | Value nextOff = |
501 | rewriter.create<tensor::ExtractOp>(loc, shardedDimsOffs, idx); |
502 | Value sz = rewriter.create<arith::SubIOp>(loc, nextOff, off); |
503 | shardShape.emplace_back(sz); |
504 | } else { |
505 | Value numShardsVal = rewriter.create<arith::ConstantOp>( |
506 | loc, rewriter.getIndexAttr(numShards)); |
507 | // Compute shard dim size by distributing odd elements to trailing |
508 | // shards: |
509 | // sz = dim / numShards |
510 | // + (idx >= (numShards - (dim % numShards)) ? 1 : 0) |
511 | Value sz = rewriter.create<arith::DivSIOp>(loc, dim, numShardsVal); |
512 | Value sz1 = rewriter.create<arith::RemSIOp>(loc, dim, numShardsVal); |
513 | sz1 = rewriter.create<arith::SubIOp>(loc, numShardsVal, sz1); |
514 | auto cond = rewriter.create<arith::CmpIOp>( |
515 | loc, arith::CmpIPredicate::sge, idx, sz1); |
516 | Value odd = rewriter.create<arith::SelectOp>(loc, cond, one, zero); |
517 | sz = rewriter.create<arith::AddIOp>(loc, sz, odd); |
518 | shardShape.emplace_back(sz); |
519 | } |
520 | pos += numShards + 1; // add one for the total size. |
521 | } // else no sharding if split axis is empty or no split axis |
522 | // If no size was added -> no sharding in this dimension. |
523 | if (shardShape.size() <= i) |
524 | shardShape.emplace_back(dim); |
525 | } |
526 | assert(shardShape.size() == shape.size()); |
527 | rewriter.replaceOp(op, shardShape); |
528 | return success(); |
529 | } |
530 | }; |
531 | |
532 | struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { |
533 | using OpConversionPattern::OpConversionPattern; |
534 | |
535 | LogicalResult |
536 | matchAndRewrite(UpdateHaloOp op, OpAdaptor adaptor, |
537 | ConversionPatternRewriter &rewriter) const override { |
538 | |
539 | // The input/output memref is assumed to be in C memory order. |
540 | // Halos are exchanged as 2 blocks per dimension (one for each side: down |
541 | // and up). For each haloed dimension `d`, the exchanged blocks are |
542 | // expressed as multi-dimensional subviews. The subviews include potential |
543 | // halos of higher dimensions `dh > d`, no halos for the lower dimensions |
544 | // `dl < d` and for dimension `d` the currently exchanged halo only. |
545 | // By iterating form higher to lower dimensions this also updates the halos |
546 | // in the 'corners'. |
547 | // memref.subview is used to read and write the halo data from and to the |
548 | // local data. Because subviews and halos can have mixed dynamic and static |
549 | // shapes, OpFoldResults are used whenever possible. |
550 | |
551 | auto haloSizes = getMixedValues(adaptor.getStaticHaloSizes(), |
552 | adaptor.getHaloSizes(), rewriter); |
553 | if (haloSizes.empty()) { |
554 | // no halos -> nothing to do |
555 | rewriter.replaceOp(op, adaptor.getDestination()); |
556 | return success(); |
557 | } |
558 | |
559 | SymbolTableCollection symbolTableCollection; |
560 | Location loc = op.getLoc(); |
561 | |
562 | // convert a OpFoldResult into a Value |
563 | auto toValue = [&rewriter, &loc](OpFoldResult &v) -> Value { |
564 | if (auto value = dyn_cast<Value>(v)) |
565 | return value; |
566 | return rewriter.create<arith::ConstantOp>( |
567 | loc, rewriter.getIndexAttr( |
568 | cast<IntegerAttr>(cast<Attribute>(v)).getInt())); |
569 | }; |
570 | |
571 | auto dest = adaptor.getDestination(); |
572 | auto dstShape = cast<ShapedType>(dest.getType()).getShape(); |
573 | Value array = dest; |
574 | if (isa<RankedTensorType>(array.getType())) { |
575 | // If the destination is a memref, we need to cast it to a tensor |
576 | auto tensorType = MemRefType::get( |
577 | dstShape, cast<ShapedType>(array.getType()).getElementType()); |
578 | array = |
579 | rewriter.create<bufferization::ToBufferOp>(loc, tensorType, array); |
580 | } |
581 | auto rank = cast<ShapedType>(array.getType()).getRank(); |
582 | auto opSplitAxes = adaptor.getSplitAxes().getAxes(); |
583 | auto mesh = adaptor.getMesh(); |
584 | auto meshOp = getMesh(op, symbolTableCollection); |
585 | // subviews need Index values |
586 | for (auto &sz : haloSizes) { |
587 | if (auto value = dyn_cast<Value>(sz)) |
588 | sz = |
589 | rewriter |
590 | .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value) |
591 | .getResult(); |
592 | } |
593 | |
594 | // most of the offset/size/stride data is the same for all dims |
595 | SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); |
596 | SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); |
597 | SmallVector<OpFoldResult> shape(rank), dimSizes(rank); |
598 | auto currHaloDim = -1; // halo sizes are provided for split dimensions only |
599 | // we need the actual shape to compute offsets and sizes |
600 | for (auto i = 0; i < rank; ++i) { |
601 | auto s = dstShape[i]; |
602 | if (ShapedType::isDynamic(s)) |
603 | shape[i] = rewriter.create<memref::DimOp>(loc, array, s).getResult(); |
604 | else |
605 | shape[i] = rewriter.getIndexAttr(value: s); |
606 | |
607 | if ((size_t)i < opSplitAxes.size() && !opSplitAxes[i].empty()) { |
608 | ++currHaloDim; |
609 | // the offsets for lower dim sstarts after their down halo |
610 | offsets[i] = haloSizes[currHaloDim * 2]; |
611 | |
612 | // prepare shape and offsets of highest dim's halo exchange |
613 | Value _haloSz = rewriter.create<arith::AddIOp>( |
614 | loc, toValue(haloSizes[currHaloDim * 2]), |
615 | toValue(haloSizes[currHaloDim * 2 + 1])); |
616 | // the halo shape of lower dims exlude the halos |
617 | dimSizes[i] = |
618 | rewriter.create<arith::SubIOp>(loc, toValue(shape[i]), _haloSz) |
619 | .getResult(); |
620 | } else { |
621 | dimSizes[i] = shape[i]; |
622 | } |
623 | } |
624 | |
625 | auto tagAttr = rewriter.getI32IntegerAttr(91); // we just pick something |
626 | auto tag = rewriter.create<arith::ConstantOp>(loc, tagAttr); |
627 | auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 |
628 | auto zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); |
629 | |
630 | SmallVector<Type> indexResultTypes(meshOp.getShape().size(), |
631 | rewriter.getIndexType()); |
632 | auto myMultiIndex = |
633 | rewriter.create<ProcessMultiIndexOp>(loc, indexResultTypes, mesh) |
634 | .getResult(); |
635 | // traverse all split axes from high to low dim |
636 | for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { |
637 | auto splitAxes = opSplitAxes[dim]; |
638 | if (splitAxes.empty()) |
639 | continue; |
640 | assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); |
641 | // Get the linearized ids of the neighbors (down and up) for the |
642 | // given split |
643 | auto tmp = rewriter |
644 | .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex, |
645 | splitAxes) |
646 | .getResults(); |
647 | // MPI operates on i32... |
648 | Value neighbourIDs[2] = {rewriter.create<arith::IndexCastOp>( |
649 | loc, rewriter.getI32Type(), tmp[0]), |
650 | rewriter.create<arith::IndexCastOp>( |
651 | loc, rewriter.getI32Type(), tmp[1])}; |
652 | |
653 | auto lowerRecvOffset = rewriter.getIndexAttr(0); |
654 | auto lowerSendOffset = toValue(haloSizes[currHaloDim * 2]); |
655 | auto upperRecvOffset = rewriter.create<arith::SubIOp>( |
656 | loc, toValue(shape[dim]), toValue(haloSizes[currHaloDim * 2 + 1])); |
657 | auto upperSendOffset = rewriter.create<arith::SubIOp>( |
658 | loc, upperRecvOffset, toValue(haloSizes[currHaloDim * 2])); |
659 | |
660 | Value commWorld = rewriter.create<mpi::CommWorldOp>( |
661 | loc, mpi::CommType::get(op->getContext())); |
662 | |
663 | // Make sure we send/recv in a way that does not lead to a dead-lock. |
664 | // The current approach is by far not optimal, this should be at least |
665 | // be a red-black pattern or using MPI_sendrecv. |
666 | // Also, buffers should be re-used. |
667 | // Still using temporary contiguous buffers for MPI communication... |
668 | // Still yielding a "serialized" communication pattern... |
669 | auto genSendRecv = [&](bool upperHalo) { |
670 | auto orgOffset = offsets[dim]; |
671 | dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] |
672 | : haloSizes[currHaloDim * 2]; |
673 | // Check if we need to send and/or receive |
674 | // Processes on the mesh borders have only one neighbor |
675 | auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; |
676 | auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; |
677 | auto hasFrom = rewriter.create<arith::CmpIOp>( |
678 | loc, arith::CmpIPredicate::sge, from, zero); |
679 | auto hasTo = rewriter.create<arith::CmpIOp>( |
680 | loc, arith::CmpIPredicate::sge, to, zero); |
681 | auto buffer = rewriter.create<memref::AllocOp>( |
682 | loc, dimSizes, cast<ShapedType>(array.getType()).getElementType()); |
683 | // if has neighbor: copy halo data from array to buffer and send |
684 | rewriter.create<scf::IfOp>( |
685 | loc, hasTo, [&](OpBuilder &builder, Location loc) { |
686 | offsets[dim] = upperHalo ? OpFoldResult(lowerSendOffset) |
687 | : OpFoldResult(upperSendOffset); |
688 | auto subview = builder.create<memref::SubViewOp>( |
689 | loc, array, offsets, dimSizes, strides); |
690 | builder.create<memref::CopyOp>(loc, subview, buffer); |
691 | builder.create<mpi::SendOp>(loc, TypeRange{}, buffer, tag, to, |
692 | commWorld); |
693 | builder.create<scf::YieldOp>(loc); |
694 | }); |
695 | // if has neighbor: receive halo data into buffer and copy to array |
696 | rewriter.create<scf::IfOp>( |
697 | loc, hasFrom, [&](OpBuilder &builder, Location loc) { |
698 | offsets[dim] = upperHalo ? OpFoldResult(upperRecvOffset) |
699 | : OpFoldResult(lowerRecvOffset); |
700 | builder.create<mpi::RecvOp>(loc, TypeRange{}, buffer, tag, from, |
701 | commWorld); |
702 | auto subview = builder.create<memref::SubViewOp>( |
703 | loc, array, offsets, dimSizes, strides); |
704 | builder.create<memref::CopyOp>(loc, buffer, subview); |
705 | builder.create<scf::YieldOp>(loc); |
706 | }); |
707 | rewriter.create<memref::DeallocOp>(loc, buffer); |
708 | offsets[dim] = orgOffset; |
709 | }; |
710 | |
711 | auto doSendRecv = [&](int upOrDown) { |
712 | OpFoldResult &v = haloSizes[currHaloDim * 2 + upOrDown]; |
713 | Value haloSz = dyn_cast<Value>(v); |
714 | if (!haloSz) |
715 | haloSz = rewriter.create<arith::ConstantOp>( |
716 | loc, rewriter.getI32IntegerAttr( |
717 | cast<IntegerAttr>(cast<Attribute>(v)).getInt())); |
718 | auto hasSize = rewriter.create<arith::CmpIOp>( |
719 | loc, arith::CmpIPredicate::sgt, haloSz, zero); |
720 | rewriter.create<scf::IfOp>(loc, hasSize, |
721 | [&](OpBuilder &builder, Location loc) { |
722 | genSendRecv(upOrDown > 0); |
723 | builder.create<scf::YieldOp>(loc); |
724 | }); |
725 | }; |
726 | |
727 | doSendRecv(0); |
728 | doSendRecv(1); |
729 | |
730 | // the shape for lower dims include higher dims' halos |
731 | dimSizes[dim] = shape[dim]; |
732 | // -> the offset for higher dims is always 0 |
733 | offsets[dim] = rewriter.getIndexAttr(0); |
734 | // on to next halo |
735 | --currHaloDim; |
736 | } |
737 | |
738 | if (isa<MemRefType>(op.getResult().getType())) { |
739 | rewriter.replaceOp(op, array); |
740 | } else { |
741 | assert(isa<RankedTensorType>(op.getResult().getType())); |
742 | rewriter.replaceOp(op, rewriter.create<bufferization::ToTensorOp>( |
743 | loc, op.getResult().getType(), array, |
744 | /*restrict=*/true, /*writable=*/true)); |
745 | } |
746 | return success(); |
747 | } |
748 | }; |
749 | |
750 | struct ConvertMeshToMPIPass |
751 | : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> { |
752 | using Base::Base; |
753 | |
754 | /// Run the dialect converter on the module. |
755 | void runOnOperation() override { |
756 | uint64_t worldRank = -1; |
757 | // Try to get DLTI attribute for MPI:comm_world_rank |
758 | // If found, set worldRank to the value of the attribute. |
759 | { |
760 | auto dltiAttr = |
761 | dlti::query(getOperation(), {"MPI:comm_world_rank" }, false); |
762 | if (succeeded(dltiAttr)) { |
763 | if (!isa<IntegerAttr>(dltiAttr.value())) { |
764 | getOperation()->emitError() |
765 | << "Expected an integer attribute for MPI:comm_world_rank" ; |
766 | return signalPassFailure(); |
767 | } |
768 | worldRank = cast<IntegerAttr>(dltiAttr.value()).getInt(); |
769 | } |
770 | } |
771 | |
772 | auto *ctxt = &getContext(); |
773 | RewritePatternSet patterns(ctxt); |
774 | ConversionTarget target(getContext()); |
775 | |
776 | // Define a type converter to convert mesh::ShardingType, |
777 | // mostly for use in return operations. |
778 | TypeConverter typeConverter; |
779 | typeConverter.addConversion([](Type type) { return type; }); |
780 | |
781 | // convert mesh::ShardingType to a tuple of RankedTensorTypes |
782 | typeConverter.addConversion( |
783 | [](ShardingType type, |
784 | SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { |
785 | auto i16 = IntegerType::get(type.getContext(), 16); |
786 | auto i64 = IntegerType::get(type.getContext(), 64); |
787 | std::array<int64_t, 2> shp = {ShapedType::kDynamic, |
788 | ShapedType::kDynamic}; |
789 | results.emplace_back(RankedTensorType::get(shp, i16)); |
790 | results.emplace_back(RankedTensorType::get(shp, i64)); // actually ?x2 |
791 | results.emplace_back(RankedTensorType::get(shp, i64)); |
792 | return success(); |
793 | }); |
794 | |
795 | // To 'extract' components, a UnrealizedConversionCastOp is expected |
796 | // to define the input |
797 | typeConverter.addTargetMaterialization( |
798 | [&](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, |
799 | Location loc) { |
800 | // Expecting a single input. |
801 | if (inputs.size() != 1 || !isa<TupleType>(inputs[0].getType())) |
802 | return SmallVector<Value>(); |
803 | auto castOp = inputs[0].getDefiningOp<UnrealizedConversionCastOp>(); |
804 | // Expecting an UnrealizedConversionCastOp. |
805 | if (!castOp) |
806 | return SmallVector<Value>(); |
807 | // Fill a vector with elements of the tuple/castOp. |
808 | SmallVector<Value> results; |
809 | for (auto oprnd : castOp.getInputs()) { |
810 | if (!isa<RankedTensorType>(oprnd.getType())) |
811 | return SmallVector<Value>(); |
812 | results.emplace_back(oprnd); |
813 | } |
814 | return results; |
815 | }); |
816 | |
817 | // No mesh dialect should left after conversion... |
818 | target.addIllegalDialect<mesh::MeshDialect>(); |
819 | // ...except the global MeshOp |
820 | target.addLegalOp<mesh::MeshOp>(); |
821 | // Allow all the stuff that our patterns will convert to |
822 | target.addLegalDialect<BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, |
823 | arith::ArithDialect, tensor::TensorDialect, |
824 | bufferization::BufferizationDialect, |
825 | linalg::LinalgDialect, memref::MemRefDialect>(); |
826 | // Make sure the function signature, calls etc. are legal |
827 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
828 | return typeConverter.isSignatureLegal(op.getFunctionType()); |
829 | }); |
830 | target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>( |
831 | [&](Operation *op) { return typeConverter.isLegal(op); }); |
832 | |
833 | patterns.add<ConvertUpdateHaloOp, ConvertNeighborsLinearIndicesOp, |
834 | ConvertProcessMultiIndexOp, ConvertGetShardingOp, |
835 | ConvertShardingOp, ConvertShardShapeOp>(typeConverter, ctxt); |
836 | // ConvertProcessLinearIndexOp accepts an optional worldRank |
837 | patterns.add<ConvertProcessLinearIndexOp>(typeConverter, ctxt, worldRank); |
838 | |
839 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( |
840 | patterns, typeConverter); |
841 | populateCallOpTypeConversionPattern(patterns, converter: typeConverter); |
842 | populateReturnOpTypeConversionPattern(patterns, converter: typeConverter); |
843 | |
844 | (void)applyPartialConversion(getOperation(), target, std::move(patterns)); |
845 | } |
846 | }; |
847 | |
848 | } // namespace |
849 | |