1//===- ExtractSliceFromReshapeUtils.cpp - Slice reshape rewrites ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements rewrites that replace slices of reshape results with
10// aggregated slices of the reshape source.
11//
12//===----------------------------------------------------------------------===//
13#include "mlir/Dialect/Affine/IR/AffineOps.h"
14#include "mlir/Dialect/Arith/Utils/Utils.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h"
17#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
18#include "mlir/Dialect/Utils/StaticValueUtils.h"
19#include "mlir/IR/BuiltinTypes.h"
20#include "mlir/IR/OpDefinition.h"
21#include "llvm/ADT/STLExtras.h"
22
23using namespace mlir;
24using namespace mlir::affine;
25using namespace mlir::tensor;
26
27/// A tuple that represents (dimension number, dimension value).
28using DimAndIndex = std::tuple<unsigned, Value>;
29
30/// Transform `dimAndIndex` from the output index space of a (non-rank-reducing)
31/// slice described by `sliceParams` into the input index space.
32static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc,
33 ArrayRef<Range> sliceParams,
34 const DimAndIndex &dimAndIndex) {
35 AffineExpr d0, s0, s1;
36 bindDims(ctx: b.getContext(), exprs&: d0);
37 bindSymbols(ctx: b.getContext(), exprs&: s0, exprs&: s1);
38 auto [dim, indexValue] = dimAndIndex;
39 assert(dim < sliceParams.size() && "slice should be non rank-reducing");
40 return std::make_pair(
41 x&: dim, y: affine::makeComposedAffineApply(
42 b, loc, e: s0 + d0 * s1,
43 operands: {indexValue, sliceParams[dim].offset, sliceParams[dim].stride}));
44}
45
46/// Transform `dimAndIndex` from the result tensor index space of a
47/// CollapseShapeOp to the source tensor index space.
48static ValueRange invertCollapseShapeIndexing(
49 OpBuilder &b, Location loc, ArrayRef<ReassociationIndices> reassociation,
50 ArrayRef<OpFoldResult> reshapeSourceShape, const DimAndIndex &dimAndIndex) {
51 const auto &[dim, indexValue] = dimAndIndex;
52 SmallVector<OpFoldResult> basis;
53 for (int64_t i : reassociation[dim])
54 basis.push_back(Elt: reshapeSourceShape[i]);
55 auto delinearized =
56 b.create<AffineDelinearizeIndexOp>(location: loc, args: indexValue, args&: basis);
57 return delinearized->getResults();
58}
59
60FailureOr<ExtractSliceFromCollapseHelper>
61tensor::ExtractSliceFromCollapseHelper::create(
62 OpBuilder &b, tensor::CollapseShapeOp collapseOp,
63 tensor::ExtractSliceOp extractOp) {
64 if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() !=
65 collapseOp)
66 return failure();
67 SmallVector<Range> ranges;
68 ranges.reserve(N: extractOp.getSourceType().getRank());
69 for (const auto &[o, s, st] :
70 llvm::zip(t: extractOp.getMixedOffsets(), u: extractOp.getMixedSizes(),
71 args: extractOp.getMixedStrides())) {
72 ranges.push_back(Elt: {.offset: o, .size: s, .stride: st});
73 }
74 return ExtractSliceFromCollapseHelper::create(b, op: collapseOp, sliceParams: ranges);
75}
76
77FailureOr<ExtractSliceFromCollapseHelper>
78tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b,
79 tensor::CollapseShapeOp op,
80 ArrayRef<Range> sliceParams) {
81 // Don't perform this pattern if the collapse op can be simplified by
82 // a rank-reducing extract slice.
83 if (succeeded(Result: mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
84 sourceType: op.getSrcType(), reassociationIndices: op.getReassociationIndices())))
85 return failure();
86
87 // Materialize the output shape of the collapse_shape operation. This will
88 // create IR describing the output shape in terms of the input shape.
89 ReifiedRankedShapedTypeDims reifiedShapes;
90 if (failed(Result: reifyResultShapes(b, op, reifiedReturnShapes&: reifiedShapes)))
91 return failure();
92 SmallVector<OpFoldResult> &collapseShapeOutputShape = reifiedShapes[0];
93 SmallVector<ReassociationIndices> reassociationIndices =
94 op.getReassociationIndices();
95
96 // Determine which of the CollapseShapeOp's result dimensions are sliced
97 // and/or linearized.
98 llvm::SmallBitVector linearizedDimensions =
99 getLinearizedDimensions(reassociationIndices);
100 llvm::SmallBitVector slicedDimensions =
101 getSlicedDimensions(sliceInputShape: collapseShapeOutputShape, sliceParams);
102
103 auto collapseShapeInputShape =
104 tensor::getMixedSizes(builder&: b, loc: op.getLoc(), value: op.getSrc());
105
106 SmallVector<Value> tileSizes;
107 for (unsigned i = 0; i < sliceParams.size(); i++) {
108 if (slicedDimensions[i] && linearizedDimensions[i])
109 tileSizes.push_back(
110 Elt: getValueOrCreateConstantIndexOp(b, loc: op.getLoc(), ofr: sliceParams[i].size));
111 }
112
113 return ExtractSliceFromCollapseHelper(
114 op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams,
115 linearizedDimensions, slicedDimensions, tileSizes);
116}
117
118std::pair<Value, SmallVector<Range>>
119tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody(
120 OpBuilder &builder, Location loc, ValueRange tileInductionVars) {
121 // Create the helper class for forming the slice parameters.
122 const SmallVector<ReassociationIndices> reassociationIndices =
123 collapseShapeOp.getReassociationIndices();
124 SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape,
125 collapseShapeOutputShape, sliceParams);
126
127 // Get the indices of the tiled dims (linearized by the collapse_shape
128 // and sliced by the extract_slice) invert the index spaces
129 // transformations.
130 SmallVector<ValueRange> multiIndices;
131 unsigned loopIdx = 0;
132 for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) {
133 if (linearizedDimensions[i] && slicedDimensions[i]) {
134 DimAndIndex tb =
135 invertSliceIndexing(b&: builder, loc, sliceParams,
136 dimAndIndex: std::make_tuple(args&: i, args: tileInductionVars[loopIdx++]));
137 multiIndices.push_back(Elt: invertCollapseShapeIndexing(
138 b&: builder, loc, reassociation: reassociationIndices, reshapeSourceShape: collapseShapeInputShape, dimAndIndex: tb));
139 }
140 }
141
142 SmallVector<Range> extractParams =
143 helper.getExtractSliceParams(ctx: builder.getContext(), multiIndices);
144
145 Value subTileResult = builder.create<tensor::ExtractSliceOp>(
146 location: loc, args: collapseShapeOp.getSrc(), args&: extractParams);
147
148 SmallVector<Range> insertParams =
149 helper.getInsertSliceParams(ctx: builder.getContext(), tileIndices: tileInductionVars);
150
151 // Collapse the dimensions of the source slice back down.
152 Value collapsedResult = builder.create<tensor::CollapseShapeOp>(
153 location: loc, args&: subTileResult, args: reassociationIndices);
154 return std::make_pair(x&: collapsedResult, y&: insertParams);
155}
156
157FailureOr<Operation *>
158tensor::simplifyCollapseShapeWithRankReducingExtractSlice(
159 tensor::CollapseShapeOp op, RewriterBase &rewriter) {
160 SmallVector<ReassociationIndices> reassociationIndices =
161 op.getReassociationIndices();
162 RankedTensorType sourceType = op.getSrcType();
163 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> info =
164 getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType,
165 reassociationIndices);
166 if (failed(Result: info))
167 return failure();
168
169 // Create the rank-reducing extract slice op.
170 auto zero = rewriter.getIndexAttr(value: 0);
171 auto one = rewriter.getIndexAttr(value: 1);
172 SmallVector<OpFoldResult> offsets(sourceType.getRank(), zero);
173 SmallVector<OpFoldResult> sizes =
174 tensor::getMixedSizes(builder&: rewriter, loc: op.getLoc(), value: op.getSrc());
175 SmallVector<OpFoldResult> strides(sourceType.getRank(), one);
176 auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(
177 location: op.getLoc(), args&: info->sliceResultType, args: op.getSrc(), args&: offsets, args&: sizes, args&: strides);
178
179 if (!info->newReassociationIndices.has_value()) {
180 rewriter.replaceOp(op, newValues: sliceOp.getResult());
181 return sliceOp.getOperation();
182 }
183
184 return rewriter
185 .replaceOpWithNewOp<tensor::CollapseShapeOp>(
186 op, args: sliceOp.getResult(), args&: *info->newReassociationIndices)
187 .getOperation();
188}
189

source code of mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp