1//===- Utils.cpp - Utilities to support the Tensor dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Tensor dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Tensor/Utils/Utils.h"
14
15#include "mlir/Dialect/Affine/IR/AffineOps.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/Utils/IndexingUtils.h"
19#include "mlir/Interfaces/ValueBoundsOpInterface.h"
20
21using namespace mlir;
22using namespace mlir::tensor;
23
24PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
25 Value pad, bool nofold, Location loc,
26 OpBuilder &b) {
27 SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
28 SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
29 for (const auto &en : enumerate(type.getShape())) {
30 // Pad only the static dimensions of the result tensor type.
31 if (ShapedType::isDynamic(en.value()))
32 continue;
33 // Compute the padding width.
34 AffineExpr d0;
35 bindDims(b.getContext(), d0);
36 OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
37 high[en.index()] =
38 affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
39 }
40 return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
41}
42
43SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
44 Location loc,
45 Value rankedTensor) {
46 auto tensorTy = cast<RankedTensorType>(rankedTensor.getType());
47 SmallVector<Value> dynamicDims;
48 for (const auto &en : llvm::enumerate(tensorTy.getShape())) {
49 if (en.value() == ShapedType::kDynamic)
50 dynamicDims.push_back(
51 b.create<tensor::DimOp>(loc, rankedTensor, en.index()));
52 }
53 return dynamicDims;
54}
55
56FailureOr<RankedTensorType>
57mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
58 ArrayRef<int64_t> transposeVector) {
59 if (transposeVector.empty())
60 return rankedTensorType;
61
62 if (!isPermutationVector(interchange: transposeVector) ||
63 transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))
64 return failure();
65
66 SmallVector<int64_t> transposedShape(rankedTensorType.getShape().begin(),
67 rankedTensorType.getShape().end());
68 applyPermutationToVector(inVec&: transposedShape, permutation: transposeVector);
69
70 using RTTBuilder = RankedTensorType::Builder;
71 RankedTensorType transposedTensorType =
72 RTTBuilder(rankedTensorType).setShape(transposedShape);
73 return transposedTensorType;
74}
75/// The permutation can be obtained from two permutations:
76/// a) Compute the permutation vector to move the last `numPackedDims` into
77/// the `innerPosDims` of a shape of rank `rank`.
78/// b) Compute the permutation vector to move outer dims if the
79/// `outerPerm` parameter is not empty.
80/// Apply (b) permutation on (a) permutation to get the final permutation.
81static SmallVector<int64_t>
82computePackUnPackPerm(int64_t rank, ArrayRef<int64_t> &innerDimsPos,
83 ArrayRef<int64_t> &outerPerm,
84 PackingMetadata &packingMetadata) {
85 int64_t numPackedDims = innerDimsPos.size();
86 auto lastDims =
87 llvm::to_vector(Range: llvm::seq<int64_t>(Begin: rank - numPackedDims, End: rank));
88 packingMetadata = computePackingMetadata(packedRank: rank, innerDimPos: innerDimsPos);
89 SmallVector<int64_t> innerPositionsPerm =
90 computePermutationVector(permSize: rank, positions: lastDims, desiredPositions: packingMetadata.insertPositions);
91
92 SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
93 if (!outerPerm.empty())
94 applyPermutationToVector(inVec&: outerPos, permutation: outerPerm);
95 SmallVector<int64_t> outerPositionPerm =
96 computePermutationVector(permSize: rank, positions: packingMetadata.outerPositions, desiredPositions: outerPos);
97
98 SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
99 applyPermutationToVector(inVec&: packInverseDestPermutation, permutation: outerPositionPerm);
100 return packInverseDestPermutation;
101}
102
103/// Shell function to compute the Destination Permutation of PackOp
104/// This function uses the helper function `computePackUnPackPerm` to get
105/// the permutation vector. Only major difference between UnPack and Pack is
106/// that packOp uses destination rank whereas unpack Uses source rank.
107SmallVector<int64_t> mlir::tensor::getPackInverseDestPerm(PackOp packOp) {
108
109 PackingMetadata pMetadata;
110 int64_t packedRank = packOp.getDestType().getRank();
111 ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
112 ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
113 SmallVector<int64_t> packInvDestPerm =
114 computePackUnPackPerm(rank: packedRank, innerDimsPos&: innerDimPos, outerPerm, packingMetadata&: pMetadata);
115 return packInvDestPerm;
116}
117
118/// Shell function to compute the Source Permutation of unPackOp.
119/// This function, like the getPackInverseDestPerm uses the helper function
120/// computePackUnPackPerm` to get the permutation vector.
121/// Only major difference between UnPack and Pack is that packOp uses
122/// destination rank whereas unpack Uses source rank.
123SmallVector<int64_t> mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp) {
124 PackingMetadata metadata;
125 return mlir::tensor::getUnPackInverseSrcPerm(unpackOp, metadata);
126}
127
128/// Shell function to compute the Source rank permutation for unpackOp
129/// Unpack requires some packing metadata data information, so created
130/// another function where this value is passed by reference.
131SmallVector<int64_t>
132mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
133 PackingMetadata &metadata) {
134 int64_t unpackRank = unpackOp.getSourceType().getRank();
135 ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
136 ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
137 SmallVector<int64_t> unpackInvSrcPerm =
138 computePackUnPackPerm(rank: unpackRank, innerDimsPos&: innerDimPos, outerPerm, packingMetadata&: metadata);
139 return unpackInvSrcPerm;
140}
141
142bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
143 llvm::SmallBitVector droppedDims = op.getDroppedDims();
144 int64_t srcDim = 0;
145 RankedTensorType resultType = op.getDestType();
146 // Source dims and destination dims (apart from dropped dims) must have the
147 // same size.
148 for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
149 if (droppedDims.test(Idx: resultDim)) {
150 // InsertSlice may expand unit dimensions that result from inserting a
151 // size-1 slice into a non-size-1 result dimension.
152 if (resultType.getDimSize(resultDim) != 1)
153 return false;
154 continue;
155 }
156 FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
157 var1: {op.getSource(), srcDim}, var2: {op.getResult(), resultDim});
158 if (failed(result: equalDimSize) || !*equalDimSize)
159 return false;
160 ++srcDim;
161 }
162
163 return true;
164}
165
166bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
167 llvm::SmallBitVector droppedDims = op.getDroppedDims();
168 int64_t resultDim = 0;
169 // Source dims and result dims (apart from dropped dims) must have the same
170 // size.
171 RankedTensorType sourceType = op.getSourceType();
172 for (int64_t dim = 0, e = sourceType.getRank(); dim < e; ++dim) {
173 if (droppedDims.test(Idx: dim)) {
174 // ExtractSlice may drop unit dimensions that result from taking a size-1
175 // slice from a non-size-1 source dimension.
176 if (sourceType.getDimSize(dim) != 1)
177 return false;
178 continue;
179 }
180 FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
181 var1: {op.getSource(), dim}, var2: {op.getResult(), resultDim});
182 if (failed(result: equalDimSize) || !*equalDimSize)
183 return false;
184 ++resultDim;
185 }
186
187 return true;
188}
189

source code of mlir/lib/Dialect/Tensor/Utils/Utils.cpp