1//===- Utils.cpp - Utilities to support the Tensor dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Tensor dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Tensor/Utils/Utils.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/Utils/Utils.h"
17#include "mlir/Dialect/Utils/IndexingUtils.h"
18#include "mlir/Interfaces/ValueBoundsOpInterface.h"
19
20using namespace mlir;
21using namespace mlir::tensor;
22
23PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source,
24 Value pad, bool nofold, Location loc,
25 OpBuilder &b, ValueRange dynOutDims) {
26
27 // This assumption simplifies the following logic without limiting what's
28 // required _today_. If needed, we can relax it in the future.
29 assert(((resType.getNumDynamicDims() == dynOutDims.size()) ||
30 dynOutDims.empty()) &&
31 "Either none or all output dynamic dims must be specified!");
32
33 // Init "low" and "high" padding values ("low" is kept as is, "high" is
34 // computed below).
35 SmallVector<OpFoldResult> low(resType.getRank(), b.getIndexAttr(value: 0));
36 SmallVector<OpFoldResult> high(resType.getRank(), b.getIndexAttr(value: 0));
37
38 size_t outDimIdx = 0;
39
40 for (const auto [idx, val] : enumerate(First: resType.getShape())) {
41 bool isDimDynamic = ShapedType::isDynamic(dValue: val);
42 bool updatePadHigh = !isDimDynamic || !dynOutDims.empty();
43
44 // Keep the default padding width (i.e. "0") when the output dim is dynamic
45 // and no actual output sizes have been provided.
46 if (!updatePadHigh)
47 continue;
48
49 // Compute the padding width: resDim - sourceDim.
50 AffineExpr d0, d1;
51 bindDims(ctx: b.getContext(), exprs&: d0, exprs&: d1);
52 OpFoldResult sourceDim = tensor::getMixedSize(builder&: b, loc, value: source, dim: idx);
53 OpFoldResult outDim = isDimDynamic ? OpFoldResult(dynOutDims[outDimIdx++])
54 : OpFoldResult(b.getIndexAttr(value: val));
55
56 high[idx] = affine::makeComposedFoldedAffineApply(b, loc, expr: d0 - d1,
57 operands: {outDim, sourceDim});
58 }
59 return b.create<PadOp>(location: loc, args&: resType, args&: source, args&: low, args&: high, args&: pad, args&: nofold);
60}
61
62SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
63 Location loc,
64 Value rankedTensor) {
65 auto tensorTy = cast<RankedTensorType>(Val: rankedTensor.getType());
66 SmallVector<Value> dynamicDims;
67 for (const auto &en : llvm::enumerate(First: tensorTy.getShape())) {
68 if (en.value() == ShapedType::kDynamic)
69 dynamicDims.push_back(
70 Elt: b.create<tensor::DimOp>(location: loc, args&: rankedTensor, args: en.index()));
71 }
72 return dynamicDims;
73}
74
75FailureOr<RankedTensorType>
76mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
77 ArrayRef<int64_t> transposeVector) {
78 if (transposeVector.empty())
79 return rankedTensorType;
80
81 if (!isPermutationVector(interchange: transposeVector) ||
82 transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))
83 return failure();
84
85 SmallVector<int64_t> transposedShape(rankedTensorType.getShape());
86 applyPermutationToVector(inVec&: transposedShape, permutation: transposeVector);
87
88 using RTTBuilder = RankedTensorType::Builder;
89 RankedTensorType transposedTensorType =
90 RTTBuilder(rankedTensorType).setShape(transposedShape);
91 return transposedTensorType;
92}
93
94CollapseShapeOp
95mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value src,
96 const llvm::SmallBitVector &dropDims) {
97 auto srcType = cast<ShapedType>(Val: src.getType());
98 int64_t rank = srcType.getRank();
99 assert(rank == static_cast<int64_t>(dropDims.size()) &&
100 "dropDims dimension does not match src tensor rank");
101 assert(llvm::all_of(
102 dropDims.set_bits(),
103 [&](unsigned dim) { return srcType.getShape()[dim] == 1; }) &&
104 "Dropping non unit dimension");
105 // Computed reassociation map for the corresponding tensor.collapse_shape.
106 SmallVector<ReassociationIndices, 2> reassocMaps;
107 // Current reassociation group to add dropped dimension to.
108
109 int64_t nextDimToGroup = 0;
110 llvm::SmallBitVector keptDims(dropDims);
111 keptDims.flip();
112 int64_t lastSetBit = keptDims.find_last();
113 for (int64_t setBit : keptDims.set_bits()) {
114 // Group consecutive dropped dimension with the next non-dropped dimension.
115 // If this is the last set dimension, also group all subsequent dropped
116 // dimension, if any.
117 int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;
118 auto seq = llvm::seq_inclusive(Begin: nextDimToGroup, End: upTo);
119 reassocMaps.emplace_back(Args: llvm::make_range(x: seq.begin(), y: seq.end()));
120 nextDimToGroup = setBit + 1;
121 }
122 return b.create<tensor::CollapseShapeOp>(location: loc, args&: src, args&: reassocMaps);
123}
124
125bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
126 llvm::SmallBitVector droppedDims = op.getDroppedDims();
127 int64_t srcDim = 0;
128 RankedTensorType resultType = op.getDestType();
129 // Source dims and destination dims (apart from dropped dims) must have the
130 // same size.
131 for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
132 if (droppedDims.test(Idx: resultDim)) {
133 // InsertSlice may expand unit dimensions that result from inserting a
134 // size-1 slice into a non-size-1 result dimension.
135 if (resultType.getDimSize(idx: resultDim) != 1)
136 return false;
137 continue;
138 }
139 FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
140 var1: {op.getSource(), srcDim}, var2: {op.getResult(), resultDim});
141 if (failed(Result: equalDimSize) || !*equalDimSize)
142 return false;
143 ++srcDim;
144 }
145
146 return true;
147}
148
149bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
150 llvm::SmallBitVector droppedDims = op.getDroppedDims();
151 int64_t resultDim = 0;
152 // Source dims and result dims (apart from dropped dims) must have the same
153 // size.
154 RankedTensorType sourceType = op.getSourceType();
155 for (int64_t dim = 0, e = sourceType.getRank(); dim < e; ++dim) {
156 if (droppedDims.test(Idx: dim)) {
157 // ExtractSlice may drop unit dimensions that result from taking a size-1
158 // slice from a non-size-1 source dimension.
159 if (sourceType.getDimSize(idx: dim) != 1)
160 return false;
161 continue;
162 }
163 FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
164 var1: {op.getSource(), dim}, var2: {op.getResult(), resultDim});
165 if (failed(Result: equalDimSize) || !*equalDimSize)
166 return false;
167 ++resultDim;
168 }
169
170 return true;
171}
172

source code of mlir/lib/Dialect/Tensor/Utils/Utils.cpp