1//===- ExtractSliceFromReshapeUtils.cpp - Slice reshape rewrites ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewrites that replace slices of reshape results with
10// aggregated slices of the reshape source.
11//
12//===----------------------------------------------------------------------===//
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/Dialect/Arith/Utils/Utils.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
18#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
19#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
20#include "mlir/Dialect/Utils/StaticValueUtils.h"
21#include "mlir/IR/BuiltinTypes.h"
22#include "mlir/IR/OpDefinition.h"
23#include "llvm/ADT/STLExtras.h"
24
25using namespace mlir;
26using namespace mlir::affine;
27using namespace mlir::tensor;
28
29/// A tuple that represents (dimension number, dimension value).
30using DimAndIndex = std::tuple<unsigned, Value>;
31
32/// Transform `dimAndIndex` from the output index space of a (non-rank-reducing)
33/// slice described by `sliceParams` into the input index space.
34static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc,
35 ArrayRef<Range> sliceParams,
36 const DimAndIndex &dimAndIndex) {
37 AffineExpr d0, s0, s1;
38 bindDims(ctx: b.getContext(), exprs&: d0);
39 bindSymbols(ctx: b.getContext(), exprs&: s0, exprs&: s1);
40 auto [dim, indexValue] = dimAndIndex;
41 assert(dim < sliceParams.size() && "slice should be non rank-reducing");
42 return std::make_pair(
43 dim, affine::makeComposedAffineApply(
44 b, loc, s0 + d0 * s1,
45 {indexValue, sliceParams[dim].offset, sliceParams[dim].stride}));
46}
47
48/// Transform `dimAndIndex` from the result tensor index space of a
49/// CollapseShapeOp to the source tensor index space.
50static ValueRange invertCollapseShapeIndexing(
51 OpBuilder &b, Location loc, ArrayRef<ReassociationIndices> reassociation,
52 ArrayRef<OpFoldResult> reshapeSourceShape, const DimAndIndex &dimAndIndex) {
53 const auto &[dim, indexValue] = dimAndIndex;
54 SmallVector<OpFoldResult> basis;
55 for (int64_t i : reassociation[dim])
56 basis.push_back(Elt: reshapeSourceShape[i]);
57 auto delinearized =
58 b.create<AffineDelinearizeIndexOp>(loc, indexValue, basis);
59 return delinearized->getResults();
60}
61
62FailureOr<ExtractSliceFromCollapseHelper>
63tensor::ExtractSliceFromCollapseHelper::create(
64 OpBuilder &b, tensor::CollapseShapeOp collapseOp,
65 tensor::ExtractSliceOp extractOp) {
66 if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
67 collapseOp)
68 return failure();
69 SmallVector<Range> ranges;
70 ranges.reserve(N: extractOp.getSourceType().getRank());
71 for (const auto &[o, s, st] :
72 llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(),
73 extractOp.getMixedStrides())) {
74 ranges.push_back({o, s, st});
75 }
76 return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges);
77}
78
79FailureOr<ExtractSliceFromCollapseHelper>
80tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
81 tensor::CollapseShapeOp op,
82 ArrayRef<Range> sliceParams) {
83 // Don't perform this pattern if the collapse op can be simplified by
84 // a rank-reducing extract slice.
85 if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
86 sourceType: op.getSrcType(), reassociationIndices: op.getReassociationIndices())))
87 return failure();
88
89 // Materialize the output shape of the collapse_shape operation. This will
90 // create IR describing the output shape in terms of the input shape.
91 ReifiedRankedShapedTypeDims reifiedShapes;
92 if (failed(reifyResultShapes(b, op, reifiedShapes)))
93 return failure();
94 SmallVector<OpFoldResult> &collapseShapeOutputShape = reifiedShapes[0];
95 SmallVector<ReassociationIndices> reassociationIndices =
96 op.getReassociationIndices();
97
98 // Determine which of the CollapseShapeOp's result dimensions are sliced
99 // and/or linearized.
100 llvm::SmallBitVector linearizedDimensions =
101 getLinearizedDimensions(reassociationIndices);
102 llvm::SmallBitVector slicedDimensions =
103 getSlicedDimensions(sliceInputShape: collapseShapeOutputShape, sliceParams);
104
105 auto collapseShapeInputShape =
106 tensor::getMixedSizes(builder&: b, loc: op.getLoc(), value: op.getSrc());
107
108 SmallVector<Value> tileSizes;
109 for (unsigned i = 0; i < sliceParams.size(); i++) {
110 if (slicedDimensions[i] && linearizedDimensions[i])
111 tileSizes.push_back(
112 Elt: getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size));
113 }
114
115 return ExtractSliceFromCollapseHelper(
116 op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
117 linearizedDimensions, slicedDimensions, tileSizes);
118}
119
120std::pair<Value, SmallVector<Range>>
121tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
122 OpBuilder &builder, Location loc, ValueRange tileInductionVars) {
123 // Create the helper class for forming the slice parameters.
124 const SmallVector<ReassociationIndices> reassociationIndices =
125 collapseShapeOp.getReassociationIndices();
126 SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
127 collapseShapeOutputShape, sliceParams);
128
129 // Get the indices of the tiled dims (linearized by the collapse_shape
130 // and sliced by the extract_slice) invert the index spaces
131 // transformations.
132 SmallVector<ValueRange> multiIndices;
133 unsigned loopIdx = 0;
134 for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
135 if (linearizedDimensions[i] && slicedDimensions[i]) {
136 DimAndIndex tb =
137 invertSliceIndexing(b&: builder, loc, sliceParams,
138 dimAndIndex: std::make_tuple(args&: i, args: tileInductionVars[loopIdx++]));
139 multiIndices.push_back(Elt: invertCollapseShapeIndexing(
140 b&: builder, loc, reassociation: reassociationIndices, reshapeSourceShape: collapseShapeInputShape, dimAndIndex: tb));
141 }
142 }
143
144 SmallVector<Range> extractParams =
145 helper.getExtractSliceParams(ctx: builder.getContext(), multiIndices);
146
147 Value subTileResult = builder.create<tensor::ExtractSliceOp>(
148 loc, collapseShapeOp.getSrc(), extractParams);
149
150 SmallVector<Range> insertParams =
151 helper.getInsertSliceParams(ctx: builder.getContext(), tileIndices: tileInductionVars);
152
153 // Collapse the dimensions of the source slice back down.
154 Value collapsedResult = builder.create<tensor::CollapseShapeOp>(
155 loc, subTileResult, reassociationIndices);
156 return std::make_pair(x&: collapsedResult, y&: insertParams);
157}
158
159FailureOr<Operation *>
160tensor::simplifyCollapseShapeWithRankReducingExtractSlice(
161 tensor::CollapseShapeOp op, RewriterBase &rewriter) {
162 SmallVector<ReassociationIndices> reassociationIndices =
163 op.getReassociationIndices();
164 RankedTensorType sourceType = op.getSrcType();
165 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> info =
166 getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType,
167 reassociationIndices);
168 if (failed(info))
169 return failure();
170
171 // Create the rank-reducing extract slice op.
172 auto zero = rewriter.getIndexAttr(0);
173 auto one = rewriter.getIndexAttr(1);
174 SmallVector<OpFoldResult> offsets(sourceType.getRank(), zero);
175 SmallVector<OpFoldResult> sizes =
176 tensor::getMixedSizes(builder&: rewriter, loc: op.getLoc(), value: op.getSrc());
177 SmallVector<OpFoldResult> strides(sourceType.getRank(), one);
178 auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(
179 op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, strides);
180
181 if (!info->newReassociationIndices.has_value()) {
182 rewriter.replaceOp(op, sliceOp.getResult());
183 return sliceOp.getOperation();
184 }
185
186 return rewriter
187 .replaceOpWithNewOp<tensor::CollapseShapeOp>(
188 op, sliceOp.getResult(), *info->newReassociationIndices)
189 .getOperation();
190}
191

source code of mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp