1//===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
10
11#include "mlir/IR/AffineMap.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinTypeInterfaces.h"
14#include "llvm/ADT/ArrayRef.h"
15#include "llvm/ADT/SmallVector.h"
16
17#include <numeric>
18#include <optional>
19
20using namespace mlir;
21
22std::optional<SmallVector<ReassociationIndices>>
23mlir::getReassociationIndicesForReshape(ShapedType sourceType,
24 ShapedType targetType) {
25 if (sourceType.getRank() > targetType.getRank())
26 return getReassociationIndicesForCollapse(sourceShape: sourceType.getShape(),
27 targetShape: targetType.getShape());
28 if (sourceType.getRank() < targetType.getRank())
29 return getReassociationIndicesForCollapse(sourceShape: targetType.getShape(),
30 targetShape: sourceType.getShape());
31 return std::nullopt;
32}
33
34namespace {
35/// A simple struct to represent ReassociationIndices as an inclusive interval.
36/// It's designed to be feasibly minimal, so the call sites should manage the
37/// validity of the range manually.
38struct ReassociationIndexRange {
39 /// FIXME: Signed type is used for consistency with ReassociationIndices.
40 /// We should consider refactoring all reassociation utilities to use unsigned
41 /// types.
42 int64_t leftIdx = 0, rightIdx = 0;
43
44 /// Util for manual checks of the range's validity
45 LogicalResult verify() const {
46 return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
47 }
48
49 /// Checks range's containment within another range. Treats the edges
50 /// non-exclusively.
51 bool isInRange(const ReassociationIndexRange &outerRange) const {
52 return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
53 }
54
55 unsigned size() const {
56 assert(succeeded(verify()));
57 return rightIdx - leftIdx + 1;
58 }
59 bool containsSingleIndex() const { return size() == 1; }
60
61 /// Collects indices that do not overlap between this and another range.
62 ReassociationIndices
63 getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
64 if (rightIdx < rhs.leftIdx) {
65 // The intervals do not overlap - concatenate the indices from both.
66 auto jointFullIndices = getFullIndices();
67 jointFullIndices.append(RHS: rhs.getFullIndices());
68 return jointFullIndices;
69 }
70 ReassociationIndices result;
71 // Handle the chunk left of the overlapping range.
72 int64_t leftStart = std::min(a: leftIdx, b: rhs.leftIdx);
73 int64_t leftEnd = std::max(a: leftIdx, b: rhs.leftIdx);
74 llvm::append_range(C&: result, R: llvm::seq(Begin: leftStart, End: leftEnd));
75 // Handle the chunk right of the overlapping range. Symmetrically, we should
76 // skip the edge of the overlap AND include the rightmost index.
77 int64_t rightStart = std::min(a: rightIdx, b: rhs.rightIdx) + 1;
78 int64_t rightEnd = std::max(a: rightIdx, b: rhs.rightIdx);
79 if (rightStart < rightEnd)
80 llvm::append_range(C&: result, R: llvm::seq_inclusive(Begin: rightStart, End: rightEnd));
81 return result;
82 }
83
84 /// Converts the range into ReassociationIndices.
85 ReassociationIndices getFullIndices() const {
86 ReassociationIndices result;
87 for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
88 result.push_back(Elt: idx);
89 }
90 return result;
91 }
92};
93} // namespace
94
95/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
96/// sequence that can be collapsed into a dynamic dimension (at least one must
97/// be present in the source).
98/// By default, lazily returns once the first dynamic dimension has been found.
99/// Setting `matchGreedily` as `true` will also mark all subsequent
100/// source dimensions for collapsing into the target.
101static FailureOr<ReassociationIndexRange>
102findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
103 int64_t sourceStartIdx,
104 bool matchGreedily = false) {
105 const unsigned numSourceDims = sourceShape.size();
106 ReassociationIndexRange sourceShapeAsRange{.leftIdx: 0, .rightIdx: numSourceDims - 1};
107 std::optional<ReassociationIndexRange> resultRange = std::nullopt;
108
109 ReassociationIndexRange iterationRange{.leftIdx: sourceStartIdx, .rightIdx: sourceStartIdx};
110 for (; iterationRange.isInRange(outerRange: sourceShapeAsRange);
111 iterationRange.rightIdx++) {
112 int64_t sourceSize = sourceShape[iterationRange.rightIdx];
113 if (sourceSize == ShapedType::kDynamic) {
114 resultRange = iterationRange;
115 break;
116 }
117 }
118 if (!resultRange)
119 return failure();
120 if (matchGreedily)
121 resultRange->rightIdx = sourceShapeAsRange.rightIdx;
122 return *resultRange;
123}
124
125/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
126/// sequence of static dimensions such that their product matches `targetSize`.
127/// By default, lazily returns once the product matches the target size. Setting
128/// `matchGreedily` as `true` will append all neighboring unit dimensions
129/// (dimensions of 1) to the match.
130static FailureOr<ReassociationIndexRange>
131findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
132 int64_t sourceStartIdx, int64_t targetSize,
133 bool matchGreedily = false) {
134 const unsigned numSourceDims = sourceShape.size();
135 ReassociationIndexRange sourceShapeAsRange{.leftIdx: 0, .rightIdx: numSourceDims - 1};
136 std::optional<ReassociationIndexRange> resultRange = std::nullopt;
137
138 ReassociationIndexRange iterationRange{.leftIdx: sourceStartIdx, .rightIdx: sourceStartIdx};
139 int64_t prodOfCollapsedDims = 1;
140 while (iterationRange.isInRange(outerRange: sourceShapeAsRange)) {
141 int64_t sourceSize = sourceShape[iterationRange.rightIdx];
142 if (sourceSize == ShapedType::kDynamic) {
143 // Reassociation for a static dim cannot include a dynamic dim. Reset
144 // induction variables to essentially restart the loop from the next
145 // source dimension.
146 prodOfCollapsedDims = 1;
147 iterationRange = {.leftIdx: iterationRange.rightIdx + 1,
148 .rightIdx: iterationRange.rightIdx + 1};
149 continue;
150 }
151 prodOfCollapsedDims *= sourceSize;
152 // If the target size has been exceeded without matching, we need to shift
153 // the range start right. From the start of the range, roll back the
154 // multiplication until the target size exceeds the product again.
155 while (prodOfCollapsedDims > targetSize &&
156 !iterationRange.containsSingleIndex()) {
157 int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
158 prodOfCollapsedDims /= frontSourceSize;
159 // Shrink the range rightwards
160 iterationRange.leftIdx++;
161 }
162 // We could've reached the target size with the current dimension,
163 // also as a result of the above shift to right.
164 if (prodOfCollapsedDims == targetSize) {
165 resultRange = iterationRange;
166 break;
167 }
168 // Increment the iteration range
169 iterationRange.rightIdx++;
170 }
171 if (!resultRange)
172 return failure();
173 if (matchGreedily) {
174 // We now want to collect all unit dimensions directly after the target
175 // product match. Advance the iterator to avoid OOB when the product match
176 // happens at the last element.
177 iterationRange.rightIdx++;
178 while (iterationRange.isInRange(outerRange: sourceShapeAsRange) &&
179 sourceShape[iterationRange.rightIdx] == 1) {
180 resultRange = iterationRange;
181 iterationRange.rightIdx++;
182 }
183 }
184 return *resultRange;
185}
186
187/// Attempts to find a valid collapsing reassociation of `sourceShape` into
188/// `targetShape` through a simple traversal. If successful, an array of source
189/// index ranges is returned, correspondingly to each dimension in the target
190/// shape. The resulting indices shall fully cover the `sourceShape` without
191/// overlaps.
192///
193/// The algorithm is essentially a lazy one, searching for non-greedy matches -
194/// it will only yield a greedy match for the last target dimension.
195/// FIXME: The algorithm can only backtrack when it needs to append an offset
196/// for a static target dimension to the preceding dynamic one (this retains the
197/// linear complexity). As feasible, consider adding further backtracking
198/// routines to enable more reassociations, e.g.:
199/// - ?x2x?x2 into ?x2
200static FailureOr<SmallVector<ReassociationIndexRange>>
201findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
202 ArrayRef<int64_t> targetShape) {
203 unsigned numSourceDims = sourceShape.size(),
204 numTargetDims = targetShape.size();
205 assert(numSourceDims > numTargetDims);
206 ReassociationIndexRange sourceShapeAsRange{.leftIdx: 0, .rightIdx: numSourceDims - 1};
207
208 SmallVector<ReassociationIndexRange> reassocRanges;
209 reassocRanges.reserve(N: numTargetDims);
210 // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
211 // cases, e.g.:
212 // - ?x2x3x5 into ?x15
213 std::optional<int64_t> prevTargetSize = std::nullopt;
214 for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
215 targetDimIdx < numTargetDims; ++targetDimIdx) {
216 int64_t targetSize = targetShape[targetDimIdx];
217 // Simply check if there are any subsequent target dimensions left - if not,
218 // the match must be made greedily.
219 bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
220 FailureOr<ReassociationIndexRange> sourceRange;
221 if (targetSize == ShapedType::kDynamic) {
222 sourceRange = findReassociationRangeForDynamicDim(
223 sourceShape, sourceStartIdx: sourceDimIdx, matchGreedily: shouldMatchGreedily);
224 } else {
225 sourceRange = findReassociationRangeForSize(
226 sourceShape, sourceStartIdx: sourceDimIdx, targetSize, matchGreedily: shouldMatchGreedily);
227 }
228
229 // Run sanity checks on the returned index range.
230 if (failed(Result: sourceRange) || failed(Result: sourceRange->verify()) ||
231 !sourceRange->isInRange(outerRange: sourceShapeAsRange))
232 return failure();
233 if (sourceRange->leftIdx > sourceDimIdx) {
234 // If some source dimensions had to be skipped in order to find a match,
235 // they must be collapsed into the directly preceding dynamic dimension.
236 if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
237 return failure();
238 reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
239 }
240
241 // Store the gathered information as required for the next iteration.
242 prevTargetSize = targetSize;
243 sourceDimIdx = sourceRange->rightIdx + 1;
244 reassocRanges.push_back(Elt: *sourceRange);
245 }
246 // Fail if the source shape wasn't a full match for the target shape. We only
247 // need to check the last recorded index - any other gaps should have been
248 // mended by the main loop.
249 if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
250 return failure();
251 return reassocRanges;
252}
253
254/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
255/// the shapes right-to-left.
256static FailureOr<SmallVector<ReassociationIndexRange>>
257findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
258 ArrayRef<int64_t> targetShape,
259 bool iterateRightToLeft) {
260 if (!iterateRightToLeft)
261 return findReassociationRangesForCollapse(sourceShape, targetShape);
262 // NB: To iterate right-to-left, we currently reverse the shapes and then
263 // reverse the result back. The reversed shapes must not be temporary, as
264 // we're passing through an ArrayRef.
265 // FIXME: It would be preferable to avoid the expensive copies. At the moment,
266 // this approach is chosen for readability of the main implementation.
267 std::vector<int64_t> sourceToReverse = sourceShape.vec(),
268 targetToReverse = targetShape.vec();
269 std::reverse(first: sourceToReverse.begin(), last: sourceToReverse.end());
270 std::reverse(first: targetToReverse.begin(), last: targetToReverse.end());
271 auto invertedRanges =
272 findReassociationRangesForCollapse(sourceShape: sourceToReverse, targetShape: targetToReverse);
273 if (failed(Result: invertedRanges))
274 return failure();
275 SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
276 unsigned numSourceDims = sourceShape.size();
277 // We have received the ranges for inverted shapes. Now we have to invert
278 // the ranges back to correspond with the original source shape.
279 for (auto &range : rangesToInvert) {
280 int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
281 range.leftIdx = numSourceDims - 1 - invRightIdx;
282 range.rightIdx = numSourceDims - 1 - invLeftIdx;
283 }
284 // Also invert the ordering of the ranges to correspond with the original
285 // target shape.
286 std::reverse(first: rangesToInvert.begin(), last: rangesToInvert.end());
287 return rangesToInvert;
288}
289
290std::optional<SmallVector<ReassociationIndices>>
291mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
292 ArrayRef<int64_t> targetShape) {
293 unsigned numSourceDims = sourceShape.size(),
294 numTargetDims = targetShape.size();
295 // We're supposed to search for a collapsing reassociation. If the sizes
296 // match, there's no actual collapsing taking place - it's either a no-op or a
297 // `tensor.reshape`-style reassociation (that would be beyond the scope of
298 // this utility).
299 if (numSourceDims <= numTargetDims)
300 return std::nullopt;
301 // Early handling for scalar target types. We should report an invalid
302 // reassociation for non-unit static dimensions - no chance to collapse these
303 // into a scalar.
304 if (numTargetDims == 0) {
305 for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
306 ++sourceDimIdx) {
307 int64_t sourceSize = sourceShape[sourceDimIdx];
308 if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
309 return std::nullopt;
310 }
311 return SmallVector<ReassociationIndices>{};
312 }
313
314 // Collect source ranges by iterating over the target shape left-to-right.
315 FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
316 findReassociationRangesForCollapse(sourceShape, targetShape);
317 if (failed(Result: maybeForwardRanges))
318 return std::nullopt;
319 auto &ranges = *maybeForwardRanges;
320 // Now do the same in reverse. We need to get another valid reassociation
321 // through some other strategy, and then compare the results in order to
322 // disambiguate mixed subshapes, such as:
323 // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
324 // This leads us to lose some of the reassociation opportunities that can only
325 // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
326 // backtracking, the algorithm will fail right-to-left. However, this is the
327 // best way to preserve correctness.
328 FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
329 findReassociationRangesForCollapse(sourceShape, targetShape,
330 /*iterateRightToLeft=*/true);
331 if (failed(Result: maybeReverseRanges))
332 return std::nullopt;
333 auto &reverseRanges = *maybeReverseRanges;
334
335 if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
336 return std::nullopt;
337 // Now we can check for ambiguity of each target dimension's reassociation. If
338 // successful, we put the full indices into our result map for the target
339 // shape.
340 SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
341 for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
342 ++targetDimIdx) {
343 ReassociationIndexRange &range = ranges[targetDimIdx];
344 ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
345 // Get non-overlapping indices between the ranges
346 ReassociationIndices nonMatchingIndices =
347 range.getNonOverlappingIndicesWith(rhs&: reverseRange);
348 // Unit dimensions can be collapsed wherever - this is the only ambiguity
349 // that we allow.
350 for (int64_t sourceDimIdx : nonMatchingIndices) {
351 if (sourceShape[sourceDimIdx] != 1)
352 return std::nullopt;
353 }
354 reassociationMap[targetDimIdx] = range.getFullIndices();
355 }
356 return reassociationMap;
357}
358
359std::optional<SmallVector<ReassociationIndices>>
360mlir::composeReassociationIndices(
361 ArrayRef<ReassociationIndices> producerReassociations,
362 ArrayRef<ReassociationIndices> consumerReassociations,
363 MLIRContext *context) {
364 SmallVector<ReassociationIndices> composedIndices;
365 // Make the producer the larger sized vector. If they are of same size, the
366 // resulting reshape is not a supported reshape op.
367 if (producerReassociations.size() == consumerReassociations.size())
368 return std::nullopt;
369 if (producerReassociations.size() < consumerReassociations.size())
370 std::swap(a&: producerReassociations, b&: consumerReassociations);
371
372 // Handle the corner case of the result being a rank 0 shaped type. Return an
373 // empty reassociation.
374 if (consumerReassociations.empty())
375 return composedIndices;
376
377 size_t consumerDims = std::accumulate(
378 first: consumerReassociations.begin(), last: consumerReassociations.end(), init: 0,
379 binary_op: [](size_t all, ReassociationIndicesRef indices) {
380 return all + indices.size();
381 });
382 if (producerReassociations.size() != consumerDims)
383 return std::nullopt;
384
385 for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
386 ReassociationIndices reassociations;
387 for (int64_t consumerIndex : consumerIndices) {
388 llvm::append_range(C&: reassociations, R: producerReassociations[consumerIndex]);
389 }
390 composedIndices.push_back(Elt: std::move(reassociations));
391 }
392 return composedIndices;
393}
394
395SmallVector<SmallVector<AffineExpr, 2>, 2>
396mlir::convertReassociationIndicesToExprs(
397 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
398 SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
399 for (const auto &indices : reassociationIndices) {
400 SmallVector<AffineExpr, 2> reassociationMap;
401 reassociationMap.reserve(N: indices.size());
402 for (int64_t index : indices)
403 reassociationMap.push_back(Elt: mlir::getAffineDimExpr(position: index, context));
404 reassociationMaps.push_back(Elt: std::move(reassociationMap));
405 }
406 return reassociationMaps;
407}
408
409template <typename AffineExprTy>
410unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
411 unsigned pos = 0;
412 for (const auto &exprs : exprArrays) {
413 for (auto expr : exprs) {
414 expr.walk([&pos](AffineExpr e) {
415 if (auto d = dyn_cast<AffineExprTy>(e))
416 pos = std::max(pos, d.getPosition());
417 });
418 }
419 }
420 return pos;
421}
422
423ArrayAttr mlir::getReassociationIndicesAttribute(
424 Builder &b, ArrayRef<ReassociationIndices> reassociation) {
425 SmallVector<Attribute, 4> reassociationAttr =
426 llvm::to_vector<4>(Range: llvm::map_range(
427 C&: reassociation, F: [&](const ReassociationIndices &indices) -> Attribute {
428 return cast<Attribute>(Val: b.getI64ArrayAttr(values: indices));
429 }));
430 return b.getArrayAttr(value: reassociationAttr);
431}
432
433SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
434 ArrayRef<ReassociationExprs> reassociationExprs) {
435 SmallVector<ReassociationIndices, 2> reassociationIndices;
436 for (const auto &exprs : reassociationExprs) {
437 ReassociationIndices indices;
438 indices.reserve(N: exprs.size());
439 for (const auto &expr : exprs)
440 indices.push_back(Elt: cast<AffineDimExpr>(Val: expr).getPosition());
441 reassociationIndices.push_back(Elt: indices);
442 }
443 return reassociationIndices;
444}
445
446SmallVector<AffineMap, 4>
447mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
448 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(exprArrays: reassociation);
449 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
450 "Expected symbol-less expressions");
451 SmallVector<AffineMap, 4> maps;
452 maps.reserve(N: reassociation.size());
453 for (const auto &exprs : reassociation) {
454 assert(!exprs.empty());
455 maps.push_back(Elt: AffineMap::get(dimCount: maxDim + 1, symbolCount: 0, results: exprs, context: exprs[0].getContext()));
456 }
457 return maps;
458}
459
460bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
461 int *invalidIndex) {
462 if (reassociation.empty())
463 return true;
464 unsigned nDims = reassociation[0].getNumDims();
465 unsigned nextExpectedDim = 0;
466 for (const auto &it : llvm::enumerate(First&: reassociation)) {
467 auto m = it.value();
468 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
469 if (invalidIndex)
470 *invalidIndex = it.index();
471 return false;
472 }
473 for (auto e : m.getResults()) {
474 auto d = dyn_cast<AffineDimExpr>(Val&: e);
475 if (!d || d.getPosition() != nextExpectedDim++) {
476 if (invalidIndex)
477 *invalidIndex = it.index();
478 return false;
479 }
480 }
481 }
482 if (nextExpectedDim != nDims) {
483 if (invalidIndex)
484 *invalidIndex = reassociation.size() - 1;
485 return false;
486 }
487 return true;
488}
489
490LogicalResult mlir::reshapeLikeShapesAreCompatible(
491 function_ref<LogicalResult(const Twine &)> emitError,
492 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
493 ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
494 unsigned expandedDimStart = 0;
495 for (const auto &map : llvm::enumerate(First&: reassociationMaps)) {
496 bool foundDynamicShape = false;
497 int64_t linearizedStaticShape = 1;
498
499 for (const auto &dim : llvm::enumerate(
500 First: expandedShape.slice(N: expandedDimStart, M: map.value().size()))) {
501 if (ShapedType::isDynamic(dValue: dim.value()))
502 foundDynamicShape = true;
503 else
504 linearizedStaticShape *= dim.value();
505 }
506 if (foundDynamicShape) {
507 if (ShapedType::isStatic(dValue: collapsedShape[map.index()])) {
508 return emitError(
509 "expected dimension " + Twine(map.index()) +
510 " of collapsed type to be dynamic since one or more of the "
511 "corresponding dimensions in the expanded type is dynamic");
512 }
513 } else {
514 if (collapsedShape[map.index()] != linearizedStaticShape) {
515 return emitError("expected dimension " + Twine(map.index()) +
516 " of collapsed type to be static value of " +
517 Twine(linearizedStaticShape));
518 }
519 }
520 expandedDimStart += map.value().size();
521 }
522 return success();
523}
524
525bool mlir::hasNonIdentityLayout(Type type) {
526 if (auto memrefType = dyn_cast<MemRefType>(Val&: type))
527 return !memrefType.getLayout().isIdentity();
528 return false;
529}
530
531llvm::SmallBitVector
532mlir::getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
533 ArrayRef<Range> sliceParams) {
534 assert(sliceParams.size() == sliceInputShape.size() &&
535 "only supports non rank-reducing case");
536 llvm::SmallBitVector mask(sliceInputShape.size());
537 unsigned idx = 0;
538 for (const auto &[offset, size, stride] : sliceParams) {
539 std::optional<int64_t> offsetConst = getConstantIntValue(ofr: offset);
540 std::optional<int64_t> strideConst = getConstantIntValue(ofr: stride);
541 mask[idx] = !isEqualConstantIntOrValue(ofr1: size, ofr2: sliceInputShape[idx]) ||
542 (!strideConst || *strideConst != 1) ||
543 (!offsetConst || *offsetConst != 0);
544 idx++;
545 }
546 return mask;
547}
548
549llvm::SmallBitVector mlir::getLinearizedDimensions(
550 ArrayRef<ReassociationIndices> reassociationIndices) {
551 llvm::SmallBitVector result(reassociationIndices.size());
552 for (const auto &it : llvm::enumerate(First&: reassociationIndices))
553 result[it.index()] = it.value().size() > 1;
554 return result;
555}
556
557SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
558 MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) {
559 unsigned loopIdx = 0;
560 auto oneAttr = IntegerAttr::get(type: IndexType::get(context: ctx), value: 1);
561 auto zeroAttr = IntegerAttr::get(type: IndexType::get(context: ctx), value: 0);
562 SmallVector<Range> offsetsSizesAndStrides;
563 offsetsSizesAndStrides.reserve(N: collapseShapeInputShape.size());
564 for (const auto &it : llvm::enumerate(First&: reassociationIndices)) {
565 // Case 1: Linearized dimensions that have also been sliced. These
566 // are size of 1 because we are iterating over these dimensions. The
567 // offsets are exactly the de-linearized multi-indices.
568 if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
569 llvm::append_range(
570 C&: offsetsSizesAndStrides,
571 R: llvm::map_range(C: multiIndices[loopIdx++], F: [&](Value v) -> Range {
572 return Range{.offset: getAsOpFoldResult(val: v), .size: oneAttr, .stride: oneAttr};
573 }));
574 continue;
575 }
576
577 // Case 2: One or possibly multiple combined input dimensions, but we
578 // have proven that these are not sliced. In this case we just take
579 // the full extent of each dimension in the reassociation list.
580 if (linearizedDimensions[it.index()]) {
581 llvm::append_range(C&: offsetsSizesAndStrides,
582 R: llvm::map_range(C&: it.value(), F: [&](int64_t idx) -> Range {
583 return {.offset: zeroAttr, .size: collapseShapeInputShape[idx],
584 .stride: oneAttr};
585 }));
586 continue;
587 }
588
589 // Case 3: A single index, but it may be sliced.
590 offsetsSizesAndStrides.push_back(Elt: sliceParams[it.index()]);
591 }
592 return offsetsSizesAndStrides;
593}
594
595SmallVector<Range>
596SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,
597 ValueRange tileIndices) {
598 auto one = IntegerAttr::get(type: IndexType::get(context: ctx), value: 1);
599 auto zero = IntegerAttr::get(type: IndexType::get(context: ctx), value: 0);
600 SmallVector<Range> insertParams;
601 insertParams.reserve(N: linearizedDimensions.size());
602 unsigned loopIdx = 0;
603 for (unsigned i = 0; i < linearizedDimensions.size(); i++) {
604 if (linearizedDimensions[i] && slicedDimensions[i]) {
605 insertParams.push_back(Elt: Range{.offset: tileIndices[loopIdx++], .size: one, .stride: one});
606 continue;
607 }
608 insertParams.push_back(Elt: Range{.offset: zero, .size: sliceParams[i].size, .stride: one});
609 }
610 return insertParams;
611}
612
613/// Returns the index of the only non-unit dimension among `indices` of `shape`,
614/// if such a dimension exists and `indices` has more than one element.
615/// Otherwise, return std::nullopt.
616static std::optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices,
617 ArrayRef<int64_t> shape) {
618 // Return false if more than one of the dimensions in this group are not 1.
619 std::optional<int64_t> dimIndex;
620 if (indices.size() < 2)
621 return std::nullopt;
622 for (int64_t idx : indices) {
623 if (shape[idx] != 1) {
624 if (dimIndex != std::nullopt)
625 return std::nullopt;
626 dimIndex = idx;
627 }
628 }
629 return dimIndex;
630}
631
632// For each segment in the reassociation indices, check whether we can
633// simplify that segment with a rank-reducing extract slice. We can do this if
634// all but (exactly) one of the corresponding source dims is 1.
635static SmallVector<std::optional<int64_t>> getCollapseShapeTrivialSegments(
636 RankedTensorType sourceType,
637 ArrayRef<ReassociationIndices> reassociationIndices) {
638 SmallVector<std::optional<int64_t>> trivialSegments;
639 for (const auto &indices : reassociationIndices)
640 trivialSegments.push_back(
641 Elt: getUniqueNonUnitDim(indices, shape: sourceType.getShape()));
642 return trivialSegments;
643}
644
645/// Returns true if any of the segments of the reassociation indices for a
646/// collapsing reshape can be simplified using a rank-reducing slice.
647static FailureOr<SmallVector<std::optional<int64_t>>>
648canCollapseShapeBeSimplifiedByRankReducingSlice(
649 RankedTensorType sourceType,
650 ArrayRef<ReassociationIndices> reassociationIndices) {
651 SmallVector<std::optional<int64_t>> trivialSegments =
652 getCollapseShapeTrivialSegments(sourceType, reassociationIndices);
653 if (!llvm::any_of(Range&: trivialSegments, P: [](const std::optional<int64_t> &idx) {
654 return idx.has_value();
655 }))
656 return failure();
657 return trivialSegments;
658}
659
660FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
661mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
662 RankedTensorType sourceType,
663 ArrayRef<ReassociationIndices> reassociationIndices) {
664 FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =
665 canCollapseShapeBeSimplifiedByRankReducingSlice(sourceType,
666 reassociationIndices);
667 if (failed(Result: trivialSegments))
668 return failure();
669
670 // Create the expected result shape of the rank-reducing slice.
671 SmallVector<int64_t> sliceShape;
672 for (const auto &[nonUnitDim, indices] :
673 llvm::zip(t&: *trivialSegments, u&: reassociationIndices)) {
674 if (nonUnitDim) {
675 sliceShape.push_back(Elt: sourceType.getDimSize(idx: *nonUnitDim));
676 continue;
677 }
678 llvm::append_range(C&: sliceShape, R: llvm::map_range(C: indices, F: [&](int64_t idx) {
679 return sourceType.getDimSize(idx);
680 }));
681 }
682 auto sliceType =
683 RankedTensorType::get(shape: sliceShape, elementType: sourceType.getElementType());
684
685 // If the rank-reducing slice simplified every segment, then we are done.
686 if (sliceShape.size() == reassociationIndices.size())
687 return CollapseShapeRankReducingSliceSimplificationInfo{.sliceResultType: sliceType,
688 .newReassociationIndices: std::nullopt};
689
690 // Otherwise, we need to create a new collapse_shape op for the segments that
691 // weren't covered by the slice. By design, the new reassociation indices has
692 // the same number of groups as the old reassociation indices.
693 SmallVector<ReassociationIndices> newReassociationIndices;
694 SmallVector<int64_t, 2> reassociation;
695 int64_t groupIdx = 0;
696 for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
697 reassociation.push_back(Elt: dimIdx);
698 if ((*trivialSegments)[groupIdx] ||
699 reassociation.size() == reassociationIndices[groupIdx].size()) {
700 newReassociationIndices.push_back(Elt: reassociation);
701 reassociation.clear();
702 groupIdx++;
703 }
704 }
705
706 return CollapseShapeRankReducingSliceSimplificationInfo{
707 .sliceResultType: sliceType, .newReassociationIndices: newReassociationIndices};
708}
709
710PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
711 ArrayRef<int64_t> innerDimPos) {
712 PackingMetadata res;
713 res.insertPositions.reserve(N: innerDimPos.size());
714 // The pack insert position is the position + the number of previously
715 // inserted positions + offset.
716 // The offset controls whether the packing dimension is the first or last.
717 //
718 // Example
719 // =======
720 // Consider packing from a hypothetical ABCD layout to ABCDba whose
721 // pack.inner_dims is [1, 0]. The first step consists in undoing the
722 // permutation and producing AaBbCD. This is achieved purely by computing the
723 // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
724 // possibility, is to produce insert positions [2, 0], this would result in an
725 // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
726 // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
727 // The latter is what we expect from packing.
728 int64_t offset = 1;
729 for (int64_t pos : innerDimPos) {
730 int64_t numInsertedBefore = llvm::count_if(
731 Range&: innerDimPos, P: [&pos](int64_t pos2) { return pos > pos2; });
732 res.insertPositions.push_back(Elt: pos + numInsertedBefore + offset);
733 }
734
735 DenseSet<int64_t> posSet(res.insertPositions.begin(),
736 res.insertPositions.end());
737 res.reassociations.reserve(N: packedRank);
738 for (int64_t i = 1; i <= packedRank; ++i) {
739 res.outerPositions.push_back(Elt: i - 1);
740 if (!posSet.contains(V: i)) {
741 res.reassociations.push_back(Elt: ReassociationIndices{i - 1});
742 continue;
743 }
744 res.reassociations.push_back(Elt: ReassociationIndices{i - 1, i});
745 ++i;
746 }
747 return res;
748}
749
750OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
751 TensorType result,
752 std::optional<Attribute> cst) {
753 if (source && source.isSplat() && result.hasStaticShape() &&
754 (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
755 return source.resizeSplat(newType: result);
756
757 return {};
758}
759

source code of mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp