1//===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
10
11#include "mlir/IR/AffineMap.h"
12#include "mlir/IR/Builders.h"
13
14#include <numeric>
15#include <optional>
16
17using namespace mlir;
18
19std::optional<SmallVector<ReassociationIndices>>
20mlir::getReassociationIndicesForReshape(ShapedType sourceType,
21 ShapedType targetType) {
22 if (sourceType.getRank() > targetType.getRank())
23 return getReassociationIndicesForCollapse(sourceType.getShape(),
24 targetType.getShape());
25 if (sourceType.getRank() < targetType.getRank())
26 return getReassociationIndicesForCollapse(targetType.getShape(),
27 sourceType.getShape());
28 return std::nullopt;
29}
30
31std::optional<SmallVector<ReassociationIndices>>
32mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
33 ArrayRef<int64_t> targetShape) {
34 if (sourceShape.size() <= targetShape.size())
35 return std::nullopt;
36 unsigned sourceDim = 0;
37 SmallVector<ReassociationIndices> reassociationMap;
38 reassociationMap.reserve(N: targetShape.size());
39
40 ReassociationIndices currIndices;
41 int64_t prodOfCollapsedDims = 1;
42 while (sourceDim < sourceShape.size()) {
43 unsigned targetDim = reassociationMap.size();
44 // If we have mapped all the target dimensions stop and handle the remaining
45 // tail of size-1 dimensions explicitly.
46 if (targetDim == targetShape.size())
47 break;
48
49 int64_t currTargetShape = targetShape[targetDim];
50 while (sourceDim < (sourceShape.size() - 1) &&
51 sourceShape[sourceDim] != ShapedType::kDynamic &&
52 prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
53 prodOfCollapsedDims *= sourceShape[sourceDim];
54 currIndices.push_back(Elt: sourceDim++);
55 }
56
57 // If the current expanded dimension is dynamic, then the collapsed
58 // dimensions should also be dynamic and product of all previous unprocessed
59 // dimensions of the expanded shape should be 1.
60 if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61 (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
62 return std::nullopt;
63
64 // If the collapsed dim is dynamic, the current expanded dim should also
65 // be dynamic.
66 if (currTargetShape == ShapedType::kDynamic &&
67 sourceShape[sourceDim] != ShapedType::kDynamic)
68 return std::nullopt;
69
70 // For static shapes, if the product of dimensions of the expanded shape
71 // should match the collapsed dimension shape.
72 if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
73 return std::nullopt;
74
75 currIndices.push_back(Elt: sourceDim++);
76 reassociationMap.emplace_back(Args: ReassociationIndices{});
77 std::swap(LHS&: reassociationMap.back(), RHS&: currIndices);
78 prodOfCollapsedDims = 1;
79 }
80 // All the dimensions in the target must have been processed.
81 if (reassociationMap.size() != targetShape.size())
82 return std::nullopt;
83 // Process any remaining entries in the source shape. They all need to be
84 // 1 or dynamic.
85 for (; sourceDim < sourceShape.size(); sourceDim++) {
86 if (sourceShape[sourceDim] != ShapedType::kDynamic &&
87 sourceShape[sourceDim] != 1)
88 return std::nullopt;
89 // The map is empty when the target type is a scalar.
90 if (!reassociationMap.empty())
91 reassociationMap.back().push_back(Elt: sourceDim);
92 }
93 return reassociationMap;
94}
95
96std::optional<SmallVector<ReassociationIndices>>
97mlir::composeReassociationIndices(
98 ArrayRef<ReassociationIndices> producerReassociations,
99 ArrayRef<ReassociationIndices> consumerReassociations,
100 MLIRContext *context) {
101 SmallVector<ReassociationIndices> composedIndices;
102 // Make the producer the larger sized vector. If they are of same size, the
103 // resulting reshape is not a supported reshape op.
104 if (producerReassociations.size() == consumerReassociations.size())
105 return std::nullopt;
106 if (producerReassociations.size() < consumerReassociations.size())
107 std::swap(a&: producerReassociations, b&: consumerReassociations);
108
109 // Handle the corner case of the result being a rank 0 shaped type. Return an
110 // empty reassociation.
111 if (consumerReassociations.empty())
112 return composedIndices;
113
114 size_t consumerDims = std::accumulate(
115 first: consumerReassociations.begin(), last: consumerReassociations.end(), init: 0,
116 binary_op: [](size_t all, ReassociationIndicesRef indices) {
117 return all + indices.size();
118 });
119 if (producerReassociations.size() != consumerDims)
120 return std::nullopt;
121
122 for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
123 ReassociationIndices reassociations;
124 for (int64_t consumerIndex : consumerIndices) {
125 llvm::append_range(C&: reassociations, R: producerReassociations[consumerIndex]);
126 }
127 composedIndices.push_back(Elt: std::move(reassociations));
128 }
129 return composedIndices;
130}
131
132SmallVector<SmallVector<AffineExpr, 2>, 2>
133mlir::convertReassociationIndicesToExprs(
134 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
135 SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
136 for (const auto &indices : reassociationIndices) {
137 SmallVector<AffineExpr, 2> reassociationMap;
138 reassociationMap.reserve(N: indices.size());
139 for (int64_t index : indices)
140 reassociationMap.push_back(Elt: mlir::getAffineDimExpr(position: index, context));
141 reassociationMaps.push_back(Elt: std::move(reassociationMap));
142 }
143 return reassociationMaps;
144}
145
146template <typename AffineExprTy>
147unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) {
148 unsigned pos = 0;
149 for (const auto &exprs : exprArrays) {
150 for (auto expr : exprs) {
151 expr.walk([&pos](AffineExpr e) {
152 if (auto d = dyn_cast<AffineExprTy>(e))
153 pos = std::max(pos, d.getPosition());
154 });
155 }
156 }
157 return pos;
158}
159
160ArrayAttr mlir::getReassociationIndicesAttribute(
161 Builder &b, ArrayRef<ReassociationIndices> reassociation) {
162 SmallVector<Attribute, 4> reassociationAttr =
163 llvm::to_vector<4>(Range: llvm::map_range(
164 C&: reassociation, F: [&](const ReassociationIndices &indices) -> Attribute {
165 return cast<Attribute>(Val: b.getI64ArrayAttr(indices));
166 }));
167 return b.getArrayAttr(reassociationAttr);
168}
169
170SmallVector<ReassociationIndices, 2> mlir::convertReassociationMapsToIndices(
171 ArrayRef<ReassociationExprs> reassociationExprs) {
172 SmallVector<ReassociationIndices, 2> reassociationIndices;
173 for (const auto &exprs : reassociationExprs) {
174 ReassociationIndices indices;
175 indices.reserve(N: exprs.size());
176 for (const auto &expr : exprs)
177 indices.push_back(Elt: cast<AffineDimExpr>(Val: expr).getPosition());
178 reassociationIndices.push_back(Elt: indices);
179 }
180 return reassociationIndices;
181}
182
183SmallVector<AffineMap, 4>
184mlir::getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) {
185 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(exprArrays: reassociation);
186 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
187 "Expected symbol-less expressions");
188 SmallVector<AffineMap, 4> maps;
189 maps.reserve(N: reassociation.size());
190 for (const auto &exprs : reassociation) {
191 assert(!exprs.empty());
192 maps.push_back(Elt: AffineMap::get(dimCount: maxDim + 1, symbolCount: 0, results: exprs, context: exprs[0].getContext()));
193 }
194 return maps;
195}
196
197bool mlir::isReassociationValid(ArrayRef<AffineMap> reassociation,
198 int *invalidIndex) {
199 if (reassociation.empty())
200 return true;
201 unsigned nDims = reassociation[0].getNumDims();
202 unsigned nextExpectedDim = 0;
203 for (const auto &it : llvm::enumerate(First&: reassociation)) {
204 auto m = it.value();
205 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
206 if (invalidIndex)
207 *invalidIndex = it.index();
208 return false;
209 }
210 for (auto e : m.getResults()) {
211 auto d = dyn_cast<AffineDimExpr>(Val&: e);
212 if (!d || d.getPosition() != nextExpectedDim++) {
213 if (invalidIndex)
214 *invalidIndex = it.index();
215 return false;
216 }
217 }
218 }
219 if (nextExpectedDim != nDims) {
220 if (invalidIndex)
221 *invalidIndex = reassociation.size() - 1;
222 return false;
223 }
224 return true;
225}
226
227LogicalResult mlir::reshapeLikeShapesAreCompatible(
228 function_ref<LogicalResult(const Twine &)> emitError,
229 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
230 ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
231 unsigned expandedDimStart = 0;
232 for (const auto &map : llvm::enumerate(First&: reassociationMaps)) {
233 bool foundDynamicShape = false;
234 int64_t linearizedStaticShape = 1;
235
236 for (const auto &dim : llvm::enumerate(
237 First: expandedShape.slice(N: expandedDimStart, M: map.value().size()))) {
238 if (ShapedType::isDynamic(dim.value()))
239 foundDynamicShape = true;
240 else
241 linearizedStaticShape *= dim.value();
242 }
243 if (foundDynamicShape) {
244 if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
245 return emitError(
246 "expected dimension " + Twine(map.index()) +
247 " of collapsed type to be dynamic since one or more of the "
248 "corresponding dimensions in the expanded type is dynamic");
249 }
250 } else {
251 if (collapsedShape[map.index()] != linearizedStaticShape) {
252 return emitError("expected dimension " + Twine(map.index()) +
253 " of collapsed type to be static value of " +
254 Twine(linearizedStaticShape));
255 }
256 }
257 expandedDimStart += map.value().size();
258 }
259 return success();
260}
261
262bool mlir::hasNonIdentityLayout(Type type) {
263 if (auto memrefType = dyn_cast<MemRefType>(type))
264 return !memrefType.getLayout().isIdentity();
265 return false;
266}
267
268llvm::SmallBitVector
269mlir::getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
270 ArrayRef<Range> sliceParams) {
271 assert(sliceParams.size() == sliceInputShape.size() &&
272 "only supports non rank-reducing case");
273 llvm::SmallBitVector mask(sliceInputShape.size());
274 unsigned idx = 0;
275 for (const auto &[offset, size, stride] : sliceParams) {
276 std::optional<int64_t> offsetConst = getConstantIntValue(ofr: offset);
277 std::optional<int64_t> strideConst = getConstantIntValue(ofr: stride);
278 mask[idx] = !isEqualConstantIntOrValue(ofr1: size, ofr2: sliceInputShape[idx]) ||
279 (!strideConst || *strideConst != 1) ||
280 (!offsetConst || *offsetConst != 0);
281 idx++;
282 }
283 return mask;
284}
285
286llvm::SmallBitVector mlir::getLinearizedDimensions(
287 ArrayRef<ReassociationIndices> reassociationIndices) {
288 llvm::SmallBitVector result(reassociationIndices.size());
289 for (const auto &it : llvm::enumerate(First&: reassociationIndices))
290 result[it.index()] = it.value().size() > 1;
291 return result;
292}
293
294SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
295 MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) {
296 unsigned loopIdx = 0;
297 auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
298 auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
299 SmallVector<Range> offsetsSizesAndStrides;
300 offsetsSizesAndStrides.reserve(N: collapseShapeInputShape.size());
301 for (const auto &it : llvm::enumerate(First&: reassociationIndices)) {
302 // Case 1: Linearized dimensions that have also been sliced. These
303 // are size of 1 because we are iterating over these dimensions. The
304 // offsets are exactly the de-linearized multi-indices.
305 if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
306 llvm::append_range(
307 C&: offsetsSizesAndStrides,
308 R: llvm::map_range(C: multiIndices[loopIdx++], F: [&](Value v) -> Range {
309 return Range{getAsOpFoldResult(val: v), oneAttr, oneAttr};
310 }));
311 continue;
312 }
313
314 // Case 2: One or possibly multiple combined input dimensions, but we
315 // have proven that these are not sliced. In this case we just take
316 // the full extent of each dimension in the reassociation list.
317 if (linearizedDimensions[it.index()]) {
318 llvm::append_range(C&: offsetsSizesAndStrides,
319 R: llvm::map_range(C&: it.value(), F: [&](int64_t idx) -> Range {
320 return {zeroAttr, collapseShapeInputShape[idx],
321 oneAttr};
322 }));
323 continue;
324 }
325
326 // Case 3: A single index, but it may be sliced.
327 offsetsSizesAndStrides.push_back(Elt: sliceParams[it.index()]);
328 }
329 return offsetsSizesAndStrides;
330}
331
332SmallVector<Range>
333SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,
334 ValueRange tileIndices) {
335 auto one = IntegerAttr::get(IndexType::get(ctx), 1);
336 auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
337 SmallVector<Range> insertParams;
338 insertParams.reserve(N: linearizedDimensions.size());
339 unsigned loopIdx = 0;
340 for (unsigned i = 0; i < linearizedDimensions.size(); i++) {
341 if (linearizedDimensions[i] && slicedDimensions[i]) {
342 insertParams.push_back(Elt: Range{tileIndices[loopIdx++], one, one});
343 continue;
344 }
345 insertParams.push_back(Elt: Range{zero, sliceParams[i].size, one});
346 }
347 return insertParams;
348}
349
350/// Returns the index of the only non-unit dimension among `indices` of `shape`,
351/// if such a dimension exists and `indices` has more than one element.
352/// Otherwise, return std::nullopt.
353static std::optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices,
354 ArrayRef<int64_t> shape) {
355 // Return false if more than one of the dimensions in this group are not 1.
356 std::optional<int64_t> dimIndex;
357 if (indices.size() < 2)
358 return std::nullopt;
359 for (int64_t idx : indices) {
360 if (shape[idx] != 1) {
361 if (dimIndex != std::nullopt)
362 return std::nullopt;
363 dimIndex = idx;
364 }
365 }
366 return dimIndex;
367}
368
369// For each segment in the reassociation indices, check whether we can
370// simplify that segment with a rank-reducing extract slice. We can do this if
371// all but (exactly) one of the corresponding source dims is 1.
372static SmallVector<std::optional<int64_t>> getCollapseShapeTrivialSegments(
373 RankedTensorType sourceType,
374 ArrayRef<ReassociationIndices> reassociationIndices) {
375 SmallVector<std::optional<int64_t>> trivialSegments;
376 for (const auto &indices : reassociationIndices)
377 trivialSegments.push_back(
378 Elt: getUniqueNonUnitDim(indices, sourceType.getShape()));
379 return trivialSegments;
380}
381
382/// Returns true if any of the segments of the reassociation indices for a
383/// collapsing reshape can be simplified using a rank-reducing slice.
384static FailureOr<SmallVector<std::optional<int64_t>>>
385canCollapseShapeBeSimplifiedByRankReducingSlice(
386 RankedTensorType sourceType,
387 ArrayRef<ReassociationIndices> reassociationIndices) {
388 SmallVector<std::optional<int64_t>> trivialSegments =
389 getCollapseShapeTrivialSegments(sourceType, reassociationIndices);
390 if (!llvm::any_of(Range&: trivialSegments, P: [](const std::optional<int64_t> &idx) {
391 return idx.has_value();
392 }))
393 return failure();
394 return trivialSegments;
395}
396
397FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
398mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
399 RankedTensorType sourceType,
400 ArrayRef<ReassociationIndices> reassociationIndices) {
401 FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =
402 canCollapseShapeBeSimplifiedByRankReducingSlice(sourceType,
403 reassociationIndices);
404 if (failed(Result: trivialSegments))
405 return failure();
406
407 // Create the expected result shape of the rank-reducing slice.
408 SmallVector<int64_t> sliceShape;
409 for (const auto &[nonUnitDim, indices] :
410 llvm::zip(*trivialSegments, reassociationIndices)) {
411 if (nonUnitDim) {
412 sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));
413 continue;
414 }
415 llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {
416 return sourceType.getDimSize(idx);
417 }));
418 }
419 auto sliceType =
420 RankedTensorType::get(sliceShape, sourceType.getElementType());
421
422 // If the rank-reducing slice simplified every segment, then we are done.
423 if (sliceShape.size() == reassociationIndices.size())
424 return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,
425 std::nullopt};
426
427 // Otherwise, we need to create a new collapse_shape op for the segments that
428 // weren't covered by the slice. By design, the new reassociation indices has
429 // the same number of groups as the old reassociation indices.
430 SmallVector<ReassociationIndices> newReassociationIndices;
431 SmallVector<int64_t, 2> reassociation;
432 int64_t groupIdx = 0;
433 for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
434 reassociation.push_back(Elt: dimIdx);
435 if ((*trivialSegments)[groupIdx] ||
436 reassociation.size() == reassociationIndices[groupIdx].size()) {
437 newReassociationIndices.push_back(Elt: reassociation);
438 reassociation.clear();
439 groupIdx++;
440 }
441 }
442
443 return CollapseShapeRankReducingSliceSimplificationInfo{
444 sliceType, newReassociationIndices};
445}
446
447PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
448 ArrayRef<int64_t> innerDimPos) {
449 PackingMetadata res;
450 res.insertPositions.reserve(N: innerDimPos.size());
451 // The pack insert position is the position + the number of previously
452 // inserted positions + offset.
453 // The offset controls whether the packing dimension is the first or last.
454 //
455 // Example
456 // =======
457 // Consider packing from a hypothetical ABCD layout to ABCDba whose
458 // pack.inner_dims is [1, 0]. The first step consists in undoing the
459 // permutation and producing AaBbCD. This is achieved purely by computing the
460 // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
461 // possibility, is to produce insert positions [2, 0], this would result in an
462 // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
463 // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
464 // The latter is what we expect from packing.
465 int64_t offset = 1;
466 for (int64_t pos : innerDimPos) {
467 int64_t numInsertedBefore = llvm::count_if(
468 Range&: innerDimPos, P: [&pos](int64_t pos2) { return pos > pos2; });
469 res.insertPositions.push_back(Elt: pos + numInsertedBefore + offset);
470 }
471
472 DenseSet<int64_t> posSet(res.insertPositions.begin(),
473 res.insertPositions.end());
474 res.reassociations.reserve(N: packedRank);
475 for (int64_t i = 1; i <= packedRank; ++i) {
476 res.outerPositions.push_back(Elt: i - 1);
477 if (!posSet.contains(V: i)) {
478 res.reassociations.push_back(Elt: ReassociationIndices{i - 1});
479 continue;
480 }
481 res.reassociations.push_back(Elt: ReassociationIndices{i - 1, i});
482 ++i;
483 }
484 return res;
485}
486
487OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
488 TensorType result,
489 std::optional<Attribute> cst) {
490 if (source && source.isSplat() && result.hasStaticShape() &&
491 (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
492 return source.resizeSplat(result);
493
494 return {};
495}
496

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp