1 | //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
10 | |
11 | #include "mlir/IR/AffineMap.h" |
12 | #include "mlir/IR/Builders.h" |
13 | |
14 | #include <numeric> |
15 | #include <optional> |
16 | |
17 | using namespace mlir; |
18 | |
19 | std::optional<SmallVector<ReassociationIndices>> |
20 | mlir::getReassociationIndicesForReshape(ShapedType sourceType, |
21 | ShapedType targetType) { |
22 | if (sourceType.getRank() > targetType.getRank()) |
23 | return getReassociationIndicesForCollapse(sourceType.getShape(), |
24 | targetType.getShape()); |
25 | if (sourceType.getRank() < targetType.getRank()) |
26 | return getReassociationIndicesForCollapse(targetType.getShape(), |
27 | sourceType.getShape()); |
28 | return std::nullopt; |
29 | } |
30 | |
31 | std::optional<SmallVector<ReassociationIndices>> |
32 | mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape, |
33 | ArrayRef<int64_t> targetShape) { |
34 | if (sourceShape.size() <= targetShape.size()) |
35 | return std::nullopt; |
36 | unsigned sourceDim = 0; |
37 | SmallVector<ReassociationIndices> reassociationMap; |
38 | reassociationMap.reserve(N: targetShape.size()); |
39 | |
40 | ReassociationIndices currIndices; |
41 | int64_t prodOfCollapsedDims = 1; |
42 | while (sourceDim < sourceShape.size()) { |
43 | unsigned targetDim = reassociationMap.size(); |
44 | // If we have mapped all the target dimensions stop and handle the remaining |
45 | // tail of size-1 dimensions explictly. |
46 | if (targetDim == targetShape.size()) |
47 | break; |
48 | |
49 | int64_t currTargetShape = targetShape[targetDim]; |
50 | while (sourceDim < sourceShape.size() && |
51 | sourceShape[sourceDim] != ShapedType::kDynamic && |
52 | prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) { |
53 | prodOfCollapsedDims *= sourceShape[sourceDim]; |
54 | currIndices.push_back(Elt: sourceDim++); |
55 | } |
56 | |
57 | // If the current expanded dimension is dynamic, then the collapsed |
58 | // dimensions should also be dynamic and product of all previous unprocessed |
59 | // dimensions of the expanded shape should be 1. |
60 | if (sourceShape[sourceDim] == ShapedType::kDynamic && |
61 | (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1)) |
62 | return std::nullopt; |
63 | |
64 | // If the collapsed dim is dynamic, the current expanded dim should also |
65 | // be dynamic. |
66 | if (currTargetShape == ShapedType::kDynamic && |
67 | sourceShape[sourceDim] != ShapedType::kDynamic) |
68 | return std::nullopt; |
69 | |
70 | // For static shapes, if the product of dimensions of the expanded shape |
71 | // should match the collapsed dimension shape. |
72 | if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) |
73 | return std::nullopt; |
74 | |
75 | currIndices.push_back(Elt: sourceDim++); |
76 | reassociationMap.emplace_back(Args: ReassociationIndices{}); |
77 | std::swap(LHS&: reassociationMap.back(), RHS&: currIndices); |
78 | prodOfCollapsedDims = 1; |
79 | } |
80 | // All the dimensions in the target must have been processed. |
81 | if (reassociationMap.size() != targetShape.size()) |
82 | return std::nullopt; |
83 | // Process any remaining entries in the source shape. They all need to be |
84 | // 1 or dynamic. |
85 | for (; sourceDim < sourceShape.size(); sourceDim++) { |
86 | if (sourceShape[sourceDim] != ShapedType::kDynamic && |
87 | sourceShape[sourceDim] != 1) |
88 | return std::nullopt; |
89 | // The map is empty when the target type is a scalar. |
90 | if (!reassociationMap.empty()) |
91 | reassociationMap.back().push_back(Elt: sourceDim); |
92 | } |
93 | return reassociationMap; |
94 | } |
95 | |
96 | std::optional<SmallVector<ReassociationIndices>> |
97 | mlir::composeReassociationIndices( |
98 | ArrayRef<ReassociationIndices> producerReassociations, |
99 | ArrayRef<ReassociationIndices> consumerReassociations, |
100 | MLIRContext *context) { |
101 | SmallVector<ReassociationIndices> composedIndices; |
102 | // Make the producer the larger sized vector. If they are of same size, the |
103 | // resulting reshape is not a supported reshape op. |
104 | if (producerReassociations.size() == consumerReassociations.size()) |
105 | return std::nullopt; |
106 | if (producerReassociations.size() < consumerReassociations.size()) |
107 | std::swap(a&: producerReassociations, b&: consumerReassociations); |
108 | |
109 | // Handle the corner case of the result being a rank 0 shaped type. Return an |
110 | // empty reassociation. |
111 | if (consumerReassociations.empty()) |
112 | return composedIndices; |
113 | |
114 | size_t consumerDims = std::accumulate( |
115 | first: consumerReassociations.begin(), last: consumerReassociations.end(), init: 0, |
116 | binary_op: [](size_t all, ReassociationIndicesRef indices) { |
117 | return all + indices.size(); |
118 | }); |
119 | if (producerReassociations.size() != consumerDims) |
120 | return std::nullopt; |
121 | |
122 | for (ReassociationIndicesRef consumerIndices : consumerReassociations) { |
123 | ReassociationIndices reassociations; |
124 | for (int64_t consumerIndex : consumerIndices) { |
125 | llvm::append_range(C&: reassociations, R: producerReassociations[consumerIndex]); |
126 | } |
127 | composedIndices.push_back(Elt: std::move(reassociations)); |
128 | } |
129 | return composedIndices; |
130 | } |
131 | |
132 | SmallVector<SmallVector<AffineExpr, 2>, 2> |
133 | mlir::convertReassociationIndicesToExprs( |
134 | MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) { |
135 | SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps; |
136 | for (const auto &indices : reassociationIndices) { |
137 | SmallVector<AffineExpr, 2> reassociationMap; |
138 | reassociationMap.reserve(N: indices.size()); |
139 | for (int64_t index : indices) |
140 | reassociationMap.push_back(Elt: mlir::getAffineDimExpr(position: index, context)); |
141 | reassociationMaps.push_back(Elt: std::move(reassociationMap)); |
142 | } |
143 | return reassociationMaps; |
144 | } |
145 | |
146 | template <typename AffineExprTy> |
147 | unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) { |
148 | unsigned pos = 0; |
149 | for (const auto &exprs : exprArrays) { |
150 | for (auto expr : exprs) { |
151 | expr.walk([&pos](AffineExpr e) { |
152 | if (auto d = dyn_cast<AffineExprTy>(e)) |
153 | pos = std::max(pos, d.getPosition()); |
154 | }); |
155 | } |
156 | } |
157 | return pos; |
158 | } |
159 | |
160 | ArrayAttr mlir::getReassociationIndicesAttribute( |
161 | OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) { |
162 | SmallVector<Attribute, 4> reassociationAttr = |
163 | llvm::to_vector<4>(Range: llvm::map_range( |
164 | C&: reassociation, F: [&](const ReassociationIndices &indices) -> Attribute { |
165 | return cast<Attribute>(Val: b.getI64ArrayAttr(indices)); |
166 | })); |
167 | return b.getArrayAttr(reassociationAttr); |
168 | } |
169 | |
170 | SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices( |
171 | OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) { |
172 | SmallVector<ReassociationIndices, 2> reassociationIndices; |
173 | for (const auto &exprs : reassociationExprs) { |
174 | ReassociationIndices indices; |
175 | indices.reserve(N: exprs.size()); |
176 | for (const auto &expr : exprs) |
177 | indices.push_back(Elt: cast<AffineDimExpr>(Val: expr).getPosition()); |
178 | reassociationIndices.push_back(Elt: indices); |
179 | } |
180 | return reassociationIndices; |
181 | } |
182 | |
183 | SmallVector<AffineMap, 4> |
184 | mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) { |
185 | unsigned maxDim = getMaxPosOfType<AffineDimExpr>(exprArrays: reassociation); |
186 | assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 && |
187 | "Expected symbol-less expressions" ); |
188 | SmallVector<AffineMap, 4> maps; |
189 | maps.reserve(N: reassociation.size()); |
190 | for (const auto &exprs : reassociation) { |
191 | assert(!exprs.empty()); |
192 | maps.push_back(Elt: AffineMap::get(dimCount: maxDim + 1, symbolCount: 0, results: exprs, context: exprs[0].getContext())); |
193 | } |
194 | return maps; |
195 | } |
196 | |
197 | bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation, |
198 | int *invalidIndex) { |
199 | if (reassociation.empty()) |
200 | return true; |
201 | unsigned nDims = reassociation[0].getNumDims(); |
202 | unsigned nextExpectedDim = 0; |
203 | for (const auto &it : llvm::enumerate(First&: reassociation)) { |
204 | auto m = it.value(); |
205 | if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { |
206 | if (invalidIndex) |
207 | *invalidIndex = it.index(); |
208 | return false; |
209 | } |
210 | for (auto e : m.getResults()) { |
211 | auto d = dyn_cast<AffineDimExpr>(Val&: e); |
212 | if (!d || d.getPosition() != nextExpectedDim++) { |
213 | if (invalidIndex) |
214 | *invalidIndex = it.index(); |
215 | return false; |
216 | } |
217 | } |
218 | } |
219 | if (nextExpectedDim != nDims) { |
220 | if (invalidIndex) |
221 | *invalidIndex = reassociation.size() - 1; |
222 | return false; |
223 | } |
224 | return true; |
225 | } |
226 | |
227 | LogicalResult mlir::reshapeLikeShapesAreCompatible( |
228 | function_ref<LogicalResult(const Twine &)> emitError, |
229 | ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape, |
230 | ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) { |
231 | unsigned expandedDimStart = 0; |
232 | for (const auto &map : llvm::enumerate(First&: reassociationMaps)) { |
233 | std::optional<int64_t> dynamicShape; |
234 | int64_t linearizedStaticShape = 1; |
235 | for (const auto &dim : llvm::enumerate( |
236 | First: expandedShape.slice(N: expandedDimStart, M: map.value().size()))) { |
237 | if (ShapedType::isDynamic(dim.value())) { |
238 | if (isExpandingReshape && dynamicShape) { |
239 | return emitError("invalid to have a single dimension (" + |
240 | Twine(map.index()) + |
241 | ") expanded into multiple dynamic dims (" + |
242 | Twine(expandedDimStart + dynamicShape.value()) + |
243 | "," + Twine(expandedDimStart + dim.index()) + ")" ); |
244 | } |
245 | dynamicShape = dim.index(); |
246 | } else { |
247 | linearizedStaticShape *= dim.value(); |
248 | } |
249 | } |
250 | if (dynamicShape) { |
251 | if (!ShapedType::isDynamic(collapsedShape[map.index()])) { |
252 | return emitError( |
253 | "expected dimension " + Twine(map.index()) + |
254 | " of collapsed type to be dynamic since one or more of the " |
255 | "corresponding dimensions in the expanded type is dynamic" ); |
256 | } |
257 | } else { |
258 | if (collapsedShape[map.index()] != linearizedStaticShape) { |
259 | return emitError("expected dimension " + Twine(map.index()) + |
260 | " of collapsed type to be static value of " + |
261 | Twine(linearizedStaticShape)); |
262 | } |
263 | } |
264 | expandedDimStart += map.value().size(); |
265 | } |
266 | return success(); |
267 | } |
268 | |
269 | bool mlir::hasNonIdentityLayout(Type type) { |
270 | if (auto memrefType = dyn_cast<MemRefType>(type)) |
271 | return !memrefType.getLayout().isIdentity(); |
272 | return false; |
273 | } |
274 | |
275 | llvm::SmallBitVector |
276 | mlir::getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape, |
277 | ArrayRef<Range> sliceParams) { |
278 | assert(sliceParams.size() == sliceInputShape.size() && |
279 | "only supports non rank-reducing case" ); |
280 | llvm::SmallBitVector mask(sliceInputShape.size()); |
281 | unsigned idx = 0; |
282 | for (const auto &[offset, size, stride] : sliceParams) { |
283 | std::optional<int64_t> offsetConst = getConstantIntValue(ofr: offset); |
284 | std::optional<int64_t> strideConst = getConstantIntValue(ofr: stride); |
285 | mask[idx] = !isEqualConstantIntOrValue(ofr1: size, ofr2: sliceInputShape[idx]) || |
286 | (!strideConst || *strideConst != 1) || |
287 | (!offsetConst || *offsetConst != 0); |
288 | idx++; |
289 | } |
290 | return mask; |
291 | } |
292 | |
293 | llvm::SmallBitVector mlir::getLinearizedDimensions( |
294 | ArrayRef<ReassociationIndices> reassociationIndices) { |
295 | llvm::SmallBitVector result(reassociationIndices.size()); |
296 | for (const auto &it : llvm::enumerate(First&: reassociationIndices)) |
297 | result[it.index()] = it.value().size() > 1; |
298 | return result; |
299 | } |
300 | |
301 | SmallVector<Range> SliceFromCollapseHelper::( |
302 | MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) { |
303 | unsigned loopIdx = 0; |
304 | auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1); |
305 | auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0); |
306 | SmallVector<Range> offsetsSizesAndStrides; |
307 | offsetsSizesAndStrides.reserve(N: collapseShapeInputShape.size()); |
308 | for (const auto &it : llvm::enumerate(First&: reassociationIndices)) { |
309 | // Case 1: Linearized dimensions that have also been sliced. These |
310 | // are size of 1 because we are iterating over these dimensions. The |
311 | // offsets are exactly the de-linearized multi-indices. |
312 | if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) { |
313 | llvm::append_range( |
314 | C&: offsetsSizesAndStrides, |
315 | R: llvm::map_range(C: multiIndices[loopIdx++], F: [&](Value v) -> Range { |
316 | return Range{getAsOpFoldResult(val: v), oneAttr, oneAttr}; |
317 | })); |
318 | continue; |
319 | } |
320 | |
321 | // Case 2: One or possibly multiple combined input dimensions, but we |
322 | // have proven that these are not sliced. In this case we just take |
323 | // the full extent of each dimension in the reassociation list. |
324 | if (linearizedDimensions[it.index()]) { |
325 | llvm::append_range( |
326 | C&: offsetsSizesAndStrides, |
327 | R: llvm::map_range(C&: it.value(), F: [&](int64_t idx) -> Range { |
328 | return {zeroAttr, collapseShapeInputShape[idx], oneAttr}; |
329 | })); |
330 | continue; |
331 | } |
332 | |
333 | // Case 3: A single index, but it may be sliced. |
334 | offsetsSizesAndStrides.push_back(Elt: sliceParams[it.index()]); |
335 | } |
336 | return offsetsSizesAndStrides; |
337 | } |
338 | |
339 | SmallVector<Range> |
340 | SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx, |
341 | ValueRange tileIndices) { |
342 | auto one = IntegerAttr::get(IndexType::get(ctx), 1); |
343 | auto zero = IntegerAttr::get(IndexType::get(ctx), 0); |
344 | SmallVector<Range> insertParams; |
345 | insertParams.reserve(N: linearizedDimensions.size()); |
346 | unsigned loopIdx = 0; |
347 | for (unsigned i = 0; i < linearizedDimensions.size(); i++) { |
348 | if (linearizedDimensions[i] && slicedDimensions[i]) { |
349 | insertParams.push_back(Elt: Range{tileIndices[loopIdx++], one, one}); |
350 | continue; |
351 | } |
352 | insertParams.push_back(Elt: Range{zero, sliceParams[i].size, one}); |
353 | } |
354 | return insertParams; |
355 | } |
356 | |
357 | /// Returns the index of the only non-unit dimension among `indices` of `shape`, |
358 | /// if such a dimension exists and `indices` has more than one element. |
359 | /// Otherwise, return std::nullopt. |
360 | static std::optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices, |
361 | ArrayRef<int64_t> shape) { |
362 | // Return false if more than one of the dimensions in this group are not 1. |
363 | std::optional<int64_t> dimIndex; |
364 | if (indices.size() < 2) |
365 | return std::nullopt; |
366 | for (int64_t idx : indices) { |
367 | if (shape[idx] != 1) { |
368 | if (dimIndex != std::nullopt) |
369 | return std::nullopt; |
370 | dimIndex = idx; |
371 | } |
372 | } |
373 | return dimIndex; |
374 | } |
375 | |
376 | // For each segment in the reassociation indices, check whether we can |
377 | // simplify that segment with a rank-reducing extract slice. We can do this if |
378 | // all but (exactly) one of the corresponding source dims is 1. |
379 | static SmallVector<std::optional<int64_t>> getCollapseShapeTrivialSegments( |
380 | RankedTensorType sourceType, |
381 | ArrayRef<ReassociationIndices> reassociationIndices) { |
382 | SmallVector<std::optional<int64_t>> trivialSegments; |
383 | for (const auto &indices : reassociationIndices) |
384 | trivialSegments.push_back( |
385 | Elt: getUniqueNonUnitDim(indices, sourceType.getShape())); |
386 | return trivialSegments; |
387 | } |
388 | |
389 | /// Returns true if any of the segments of the reassociation indices for a |
390 | /// collapsing reshape can be simplified using a rank-reducing slice. |
391 | static FailureOr<SmallVector<std::optional<int64_t>>> |
392 | canCollapseShapeBeSimplifiedByRankReducingSlice( |
393 | RankedTensorType sourceType, |
394 | ArrayRef<ReassociationIndices> reassociationIndices) { |
395 | SmallVector<std::optional<int64_t>> trivialSegments = |
396 | getCollapseShapeTrivialSegments(sourceType, reassociationIndices); |
397 | if (!llvm::any_of(Range&: trivialSegments, P: [](const std::optional<int64_t> &idx) { |
398 | return idx.has_value(); |
399 | })) |
400 | return failure(); |
401 | return trivialSegments; |
402 | } |
403 | |
404 | FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> |
405 | mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo( |
406 | RankedTensorType sourceType, |
407 | ArrayRef<ReassociationIndices> reassociationIndices) { |
408 | FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments = |
409 | canCollapseShapeBeSimplifiedByRankReducingSlice(sourceType, |
410 | reassociationIndices); |
411 | if (failed(result: trivialSegments)) |
412 | return failure(); |
413 | |
414 | // Create the expected result shape of the rank-reducing slice. |
415 | SmallVector<int64_t> sliceShape; |
416 | for (const auto &[nonUnitDim, indices] : |
417 | llvm::zip(*trivialSegments, reassociationIndices)) { |
418 | if (nonUnitDim) { |
419 | sliceShape.push_back(sourceType.getDimSize(*nonUnitDim)); |
420 | continue; |
421 | } |
422 | llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) { |
423 | return sourceType.getDimSize(idx); |
424 | })); |
425 | } |
426 | auto sliceType = |
427 | RankedTensorType::get(sliceShape, sourceType.getElementType()); |
428 | |
429 | // If the rank-reducing slice simplified every segment, then we are done. |
430 | if (sliceShape.size() == reassociationIndices.size()) |
431 | return CollapseShapeRankReducingSliceSimplificationInfo{sliceType, |
432 | std::nullopt}; |
433 | |
434 | // Otherwise, we need to create a new collapse_shape op for the segments that |
435 | // weren't covered by the slice. By design, the new reassociation indices has |
436 | // the same number of groups as the old reassociation indices. |
437 | SmallVector<ReassociationIndices> newReassociationIndices; |
438 | SmallVector<int64_t, 2> reassociation; |
439 | int64_t groupIdx = 0; |
440 | for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) { |
441 | reassociation.push_back(Elt: dimIdx); |
442 | if ((*trivialSegments)[groupIdx] || |
443 | reassociation.size() == reassociationIndices[groupIdx].size()) { |
444 | newReassociationIndices.push_back(Elt: reassociation); |
445 | reassociation.clear(); |
446 | groupIdx++; |
447 | } |
448 | } |
449 | |
450 | return CollapseShapeRankReducingSliceSimplificationInfo{ |
451 | sliceType, newReassociationIndices}; |
452 | } |
453 | |
454 | PackingMetadata mlir::computePackingMetadata(int64_t packedRank, |
455 | ArrayRef<int64_t> innerDimPos) { |
456 | PackingMetadata res; |
457 | res.insertPositions.reserve(N: innerDimPos.size()); |
458 | // The pack insert position is the position + the number of previously |
459 | // inserted positions + offset. |
460 | // The offset controls whether the packing dimension is the first or last. |
461 | // |
462 | // Example |
463 | // ======= |
464 | // Consider packing from a hypothetical ABCD layout to ABCDba whose |
465 | // pack.inner_dims is [1, 0]. The first step consists in undoing the |
466 | // permutation and producing AaBbCD. This is achieved purely by computing the |
467 | // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One |
468 | // possibility, is to produce insert positions [2, 0], this would result in an |
469 | // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert |
470 | // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1). |
471 | // The latter is what we expect from packing. |
472 | int64_t offset = 1; |
473 | for (int64_t pos : innerDimPos) { |
474 | int64_t numInsertedBefore = llvm::count_if( |
475 | Range&: innerDimPos, P: [&pos](int64_t pos2) { return pos > pos2; }); |
476 | res.insertPositions.push_back(Elt: pos + numInsertedBefore + offset); |
477 | } |
478 | |
479 | DenseSet<int64_t> posSet(res.insertPositions.begin(), |
480 | res.insertPositions.end()); |
481 | res.reassociations.reserve(N: packedRank); |
482 | for (int64_t i = 1; i <= packedRank; ++i) { |
483 | res.outerPositions.push_back(Elt: i - 1); |
484 | if (!posSet.contains(V: i)) { |
485 | res.reassociations.push_back(Elt: ReassociationIndices{i - 1}); |
486 | continue; |
487 | } |
488 | res.reassociations.push_back(Elt: ReassociationIndices{i - 1, i}); |
489 | ++i; |
490 | } |
491 | return res; |
492 | } |
493 | |