1//===- Transforms.cpp ---------------------------------------------- C++ --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
10#include "TransformsDetail.h"
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/Affine/Utils.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Arith/Utils/Utils.h"
15#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
16#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
18#include "mlir/Dialect/Mesh/IR/MeshOps.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/DialectRegistry.h"
23#include "mlir/IR/ImplicitLocOpBuilder.h"
24#include "mlir/IR/OpDefinition.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/Value.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/SmallVector.h"
29#include <iterator>
30#include <numeric>
31
32namespace mlir::mesh {
33
34namespace {
35
36/// Lower `mesh.process_multi_index` into expression using
37/// `mesh.process_linear_index` and `mesh.mesh_shape`.
38struct ProcessMultiIndexOpLowering
39 : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
40 using OpRewritePatternWithSymbolTableCollection::
41 OpRewritePatternWithSymbolTableCollection;
42
43 LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
44 PatternRewriter &rewriter) const override {
45 MeshOp mesh = getMesh(op, symbolTableCollection);
46 if (!mesh) {
47 return failure();
48 }
49
50 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
51 builder.setInsertionPointAfter(op.getOperation());
52 Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
53 ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
54 SmallVector<Value> completeMultiIndex =
55 builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
56 .getMultiIndex();
57 SmallVector<Value> multiIndex;
58 ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
59 SmallVector<MeshAxis> opAxesIota;
60 if (opMeshAxes.empty()) {
61 opAxesIota.resize(mesh.getRank());
62 std::iota(first: opAxesIota.begin(), last: opAxesIota.end(), value: 0);
63 opMeshAxes = opAxesIota;
64 }
65 llvm::transform(Range&: opMeshAxes, d_first: std::back_inserter(x&: multiIndex),
66 F: [&completeMultiIndex](MeshAxis meshAxis) {
67 return completeMultiIndex[meshAxis];
68 });
69 rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
70 return success();
71 }
72};
73
74struct AllSliceOpLowering
75 : OpRewritePatternWithSymbolTableCollection<AllSliceOp> {
76 using OpRewritePatternWithSymbolTableCollection::
77 OpRewritePatternWithSymbolTableCollection;
78
79 LogicalResult matchAndRewrite(AllSliceOp op,
80 PatternRewriter &rewriter) const override {
81 // 1. Compute the process linear index inside the process group from its
82 // multi-index.
83 //
84 // 2. Extract a slice from the input tensor.
85 // All axes except the slicing axis are not interesting and take the full
86 // axis.
87 // The slice axis is split into equisized parts with count
88 // the number of processes in the collective process group induced by
89 // the mesh axes.
90 // The part for each process is determined by the corresponding
91 // linear-index in the process group.
92 //
93 // There are no collectives that require communication.
94 // Each process operates on its local tensor.
95
96 MeshOp mesh = getMesh(op, symbolTableCollection);
97 if (!mesh) {
98 return failure();
99 }
100
101 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
102 builder.setInsertionPointAfter(op.getOperation());
103
104 Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
105
106 Operation::result_range processInGroupMultiIndex =
107 builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
108 .getResults();
109
110 Operation::result_range processGroupShape =
111 builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
112 .getResult();
113 Value processGroupSize =
114 createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
115
116 int64_t sliceAxis = op.getSliceAxis().getSExtValue();
117 Value operandSliceAxisSize =
118 builder.create<tensor::DimOp>(op.getOperand(), sliceAxis);
119 Value operandSliceAxisSizeModProcessGroupSize =
120 builder.create<arith::RemUIOp>(operandSliceAxisSize, processGroupSize);
121 Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
122 arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize,
123 zero);
124 builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
125 "Slicing a tensor with axis size that is "
126 "not exactly divisible by the "
127 "mesh process group size is not supported.");
128 Value resultSliceAxisSize =
129 builder.create<arith::DivUIOp>(operandSliceAxisSize, processGroupSize);
130 OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
131 multiIndex: llvm::to_vector_of<OpFoldResult>(Range&: processInGroupMultiIndex),
132 basis: llvm::to_vector_of<OpFoldResult>(Range&: processGroupShape), builder);
133
134 // insert tensor.extract_slice
135 RankedTensorType operandType =
136 cast<RankedTensorType>(op.getOperand().getType());
137 SmallVector<OpFoldResult> sizes;
138 for (int64_t i = 0; i < operandType.getRank(); ++i) {
139 if (i == sliceAxis) {
140 sizes.emplace_back(Args&: resultSliceAxisSize);
141 } else {
142 Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
143 sizes.emplace_back(Args&: dimSize);
144 }
145 }
146 SmallVector<OpFoldResult> offsets(
147 operandType.getRank(), getAsIndexOpFoldResult(ctx: builder.getContext(), val: 0));
148 offsets[sliceAxis] =
149 ArithBuilder(builder, builder.getLoc())
150 .mul(lhs: getValueOrCreateConstantIndexOp(b&: builder, loc: builder.getLoc(),
151 ofr: processInGroupLinearIndex),
152 rhs: resultSliceAxisSize);
153 SmallVector<OpFoldResult> strides(
154 operandType.getRank(), getAsIndexOpFoldResult(ctx: builder.getContext(), val: 1));
155 Value slice = builder.create<tensor::ExtractSliceOp>(
156 op.getOperand(), offsets, sizes, strides);
157 Value newResult =
158 builder.create<tensor::CastOp>(op.getResult().getType(), slice);
159 rewriter.replaceAllUsesWith(op.getResult(), newResult);
160
161 return success();
162 }
163};
164
165} // namespace
166
167void populateProcessMultiIndexOpLoweringPatterns(
168 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
169 patterns.add<ProcessMultiIndexOpLowering>(arg&: symbolTableCollection,
170 args: patterns.getContext());
171}
172
173void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry) {
174 registry.insert<affine::AffineDialect, mesh::MeshDialect>();
175}
176
177void populateAllSliceOpLoweringPatterns(
178 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
179 patterns.add<AllSliceOpLowering>(arg&: symbolTableCollection,
180 args: patterns.getContext());
181}
182
183void registerAllSliceOpLoweringDialects(DialectRegistry &registry) {
184 registry.insert<affine::AffineDialect, arith::ArithDialect,
185 cf::ControlFlowDialect, mesh::MeshDialect,
186 tensor::TensorDialect>();
187}
188
189void populateAllOpLoweringPatterns(
190 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
191 populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection);
192 populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
193}
194
195void registerAllOpLoweringDialects(DialectRegistry &registry) {
196 registerProcessMultiIndexOpLoweringDialects(registry);
197 registerAllSliceOpLoweringDialects(registry);
198}
199
200TypedValue<IndexType>
201createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
202 ImplicitLocOpBuilder &builder) {
203 Operation::result_range meshShape =
204 builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
205 return cast<TypedValue<IndexType>>(Val: arith::createProduct(
206 builder, builder.getLoc(), llvm::to_vector_of<Value>(Range&: meshShape),
207 builder.getIndexType()));
208}
209
210TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
211 ArrayRef<MeshAxis> meshAxes,
212 ImplicitLocOpBuilder &builder) {
213 ResultRange processInGroupMultiIndex =
214 builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
215 Operation::result_range processGroupShape =
216 builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
217 OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
218 multiIndex: llvm::to_vector_of<OpFoldResult>(Range&: processInGroupMultiIndex),
219 basis: llvm::to_vector_of<OpFoldResult>(Range&: processGroupShape), builder);
220 return cast<TypedValue<IndexType>>(Val: processInGroupLinearIndex.get<Value>());
221}
222
223} // namespace mlir::mesh
224

source code of mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp