1//===- ExpandStridedMetadata.cpp - Simplify this operation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// The pass expands memref operations that modify the metadata of a memref
10/// (sizes, offset, strides) into a sequence of easier to analyze constructs.
11/// In particular, this pass transforms operations into explicit sequence of
12/// operations that model the effect of this operation on the different
13/// metadata. This pass uses affine constructs to materialize these effects.
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/MemRef/Transforms/Passes.h"
20#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/IR/AffineMap.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include <optional>
28
29namespace mlir {
30namespace memref {
31#define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA
32#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
33} // namespace memref
34} // namespace mlir
35
36using namespace mlir;
37using namespace mlir::affine;
38
39namespace {
40
41struct StridedMetadata {
42 Value basePtr;
43 OpFoldResult offset;
44 SmallVector<OpFoldResult> sizes;
45 SmallVector<OpFoldResult> strides;
46};
47
48/// From `subview(memref, subOffset, subSizes, subStrides))` compute
49///
50/// \verbatim
51/// baseBuffer, baseOffset, baseSizes, baseStrides =
52/// extract_strided_metadata(memref)
53/// strides#i = baseStrides#i * subStrides#i
54/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
55/// sizes = subSizes
56/// \endverbatim
57///
58/// and return {baseBuffer, offset, sizes, strides}
59static FailureOr<StridedMetadata>
60resolveSubviewStridedMetadata(RewriterBase &rewriter,
61 memref::SubViewOp subview) {
62 // Build a plain extract_strided_metadata(memref) from subview(memref).
63 Location origLoc = subview.getLoc();
64 Value source = subview.getSource();
65 auto sourceType = cast<MemRefType>(source.getType());
66 unsigned sourceRank = sourceType.getRank();
67
68 auto newExtractStridedMetadata =
69 rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
70
71 auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
72#ifndef NDEBUG
73 auto [resultStrides, resultOffset] = getStridesAndOffset(subview.getType());
74#endif // NDEBUG
75
76 // Compute the new strides and offset from the base strides and offset:
77 // newStride#i = baseStride#i * subStride#i
78 // offset = baseOffset + sum(subOffsets#i * newStrides#i)
79 SmallVector<OpFoldResult> strides;
80 SmallVector<OpFoldResult> subStrides = subview.getMixedStrides();
81 auto origStrides = newExtractStridedMetadata.getStrides();
82
83 // Hold the affine symbols and values for the computation of the offset.
84 SmallVector<OpFoldResult> values(2 * sourceRank + 1);
85 SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
86
87 bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
88 AffineExpr expr = symbols.front();
89 values[0] = ShapedType::isDynamic(sourceOffset)
90 ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
91 : rewriter.getIndexAttr(sourceOffset);
92 SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets();
93
94 AffineExpr s0 = rewriter.getAffineSymbolExpr(position: 0);
95 AffineExpr s1 = rewriter.getAffineSymbolExpr(position: 1);
96 for (unsigned i = 0; i < sourceRank; ++i) {
97 // Compute the stride.
98 OpFoldResult origStride =
99 ShapedType::isDynamic(sourceStrides[i])
100 ? origStrides[i]
101 : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i]));
102 strides.push_back(makeComposedFoldedAffineApply(
103 rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
104
105 // Build up the computation of the offset.
106 unsigned baseIdxForDim = 1 + 2 * i;
107 unsigned subOffsetForDim = baseIdxForDim;
108 unsigned origStrideForDim = baseIdxForDim + 1;
109 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110 values[subOffsetForDim] = subOffsets[i];
111 values[origStrideForDim] = origStride;
112 }
113
114 // Compute the offset.
115 OpFoldResult finalOffset =
116 makeComposedFoldedAffineApply(rewriter, origLoc, expr, values);
117#ifndef NDEBUG
118 // Assert that the computed offset matches the offset of the result type of
119 // the subview op (if both are static).
120 std::optional<int64_t> computedOffset = getConstantIntValue(ofr: finalOffset);
121 if (computedOffset && !ShapedType::isDynamic(resultOffset))
122 assert(*computedOffset == resultOffset &&
123 "mismatch between computed offset and result type offset");
124#endif // NDEBUG
125
126 // The final result is <baseBuffer, offset, sizes, strides>.
127 // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
128 // the values.
129 auto subType = cast<MemRefType>(subview.getType());
130 unsigned subRank = subType.getRank();
131
132 // The sizes of the final type are defined directly by the input sizes of
133 // the subview.
134 // Moreover subviews can drop some dimensions, some strides and sizes may
135 // not end up in the final <base, offset, sizes, strides> value that we are
136 // replacing.
137 // Do the filtering here.
138 SmallVector<OpFoldResult> subSizes = subview.getMixedSizes();
139 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
140
141 SmallVector<OpFoldResult> finalSizes;
142 finalSizes.reserve(subRank);
143
144 SmallVector<OpFoldResult> finalStrides;
145 finalStrides.reserve(subRank);
146
147#ifndef NDEBUG
148 // Iteration variable for result dimensions of the subview op.
149 int64_t j = 0;
150#endif // NDEBUG
151 for (unsigned i = 0; i < sourceRank; ++i) {
152 if (droppedDims.test(Idx: i))
153 continue;
154
155 finalSizes.push_back(subSizes[i]);
156 finalStrides.push_back(strides[i]);
157#ifndef NDEBUG
158 // Assert that the computed stride matches the stride of the result type of
159 // the subview op (if both are static).
160 std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
161 if (computedStride && !ShapedType::isDynamic(resultStrides[j]))
162 assert(*computedStride == resultStrides[j] &&
163 "mismatch between computed stride and result type stride");
164 ++j;
165#endif // NDEBUG
166 }
167 assert(finalSizes.size() == subRank &&
168 "Should have populated all the values at this point");
169 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170 finalSizes, finalStrides};
171}
172
173/// Replace `dst = subview(memref, subOffset, subSizes, subStrides))`
174/// With
175///
176/// \verbatim
177/// baseBuffer, baseOffset, baseSizes, baseStrides =
178/// extract_strided_metadata(memref)
179/// strides#i = baseStrides#i * subSizes#i
180/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
181/// sizes = subSizes
182/// dst = reinterpret_cast baseBuffer, offset, sizes, strides
183/// \endverbatim
184///
185/// In other words, get rid of the subview in that expression and canonicalize
186/// on its effects on the offset, the sizes, and the strides using affine.apply.
187struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {
188public:
189 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
190
191 LogicalResult matchAndRewrite(memref::SubViewOp subview,
192 PatternRewriter &rewriter) const override {
193 FailureOr<StridedMetadata> stridedMetadata =
194 resolveSubviewStridedMetadata(rewriter, subview);
195 if (failed(stridedMetadata)) {
196 return rewriter.notifyMatchFailure(subview,
197 "failed to resolve subview metadata");
198 }
199
200 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
201 subview, subview.getType(), stridedMetadata->basePtr,
202 stridedMetadata->offset, stridedMetadata->sizes,
203 stridedMetadata->strides);
204 return success();
205 }
206};
207
208/// Pattern to replace `extract_strided_metadata(subview)`
209/// With
210///
211/// \verbatim
212/// baseBuffer, baseOffset, baseSizes, baseStrides =
213/// extract_strided_metadata(memref)
214/// strides#i = baseStrides#i * subSizes#i
215/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
216/// sizes = subSizes
217/// \verbatim
218///
219/// with `baseBuffer`, `offset`, `sizes` and `strides` being
220/// the replacements for the original `extract_strided_metadata`.
221struct ExtractStridedMetadataOpSubviewFolder
222 : OpRewritePattern<memref::ExtractStridedMetadataOp> {
223 using OpRewritePattern::OpRewritePattern;
224
225 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
226 PatternRewriter &rewriter) const override {
227 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
228 if (!subviewOp)
229 return failure();
230
231 FailureOr<StridedMetadata> stridedMetadata =
232 resolveSubviewStridedMetadata(rewriter, subviewOp);
233 if (failed(stridedMetadata)) {
234 return rewriter.notifyMatchFailure(
235 op, "failed to resolve metadata in terms of source subview op");
236 }
237 Location loc = subviewOp.getLoc();
238 SmallVector<Value> results;
239 results.reserve(subviewOp.getType().getRank() * 2 + 2);
240 results.push_back(stridedMetadata->basePtr);
241 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
242 stridedMetadata->offset));
243 results.append(
244 getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
245 results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
246 stridedMetadata->strides));
247 rewriter.replaceOp(op, results);
248
249 return success();
250 }
251};
252
253/// Compute the expanded sizes of the given \p expandShape for the
254/// \p groupId-th reassociation group.
255/// \p origSizes hold the sizes of the source shape as values.
256/// This is used to compute the new sizes in cases of dynamic shapes.
257///
258/// sizes#i =
259/// baseSizes#groupId / product(expandShapeSizes#j,
260/// for j in group excluding reassIdx#i)
261/// Where reassIdx#i is the reassociation index at index i in \p groupId.
262///
263/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
264///
265/// TODO: Move this utility function directly within ExpandShapeOp. For now,
266/// this is not possible because this function uses the Affine dialect and the
267/// MemRef dialect cannot depend on the Affine dialect.
268static SmallVector<OpFoldResult>
269getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
270 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
271 SmallVector<int64_t, 2> reassocGroup =
272 expandShape.getReassociationIndices()[groupId];
273 assert(!reassocGroup.empty() &&
274 "Reassociation group should have at least one dimension");
275
276 unsigned groupSize = reassocGroup.size();
277 SmallVector<OpFoldResult> expandedSizes(groupSize);
278
279 uint64_t productOfAllStaticSizes = 1;
280 std::optional<unsigned> dynSizeIdx;
281 MemRefType expandShapeType = expandShape.getResultType();
282
283 // Fill up all the statically known sizes.
284 for (unsigned i = 0; i < groupSize; ++i) {
285 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
286 if (ShapedType::isDynamic(dimSize)) {
287 assert(!dynSizeIdx && "There must be at most one dynamic size per group");
288 dynSizeIdx = i;
289 continue;
290 }
291 productOfAllStaticSizes *= dimSize;
292 expandedSizes[i] = builder.getIndexAttr(dimSize);
293 }
294
295 // Compute the dynamic size using the original size and all the other known
296 // static sizes:
297 // expandSize = origSize / productOfAllStaticSizes.
298 if (dynSizeIdx) {
299 AffineExpr s0 = builder.getAffineSymbolExpr(position: 0);
300 expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
301 builder, expandShape.getLoc(), s0.floorDiv(v: productOfAllStaticSizes),
302 origSizes[groupId]);
303 }
304
305 return expandedSizes;
306}
307
308/// Compute the expanded strides of the given \p expandShape for the
309/// \p groupId-th reassociation group.
310/// \p origStrides and \p origSizes hold respectively the strides and sizes
311/// of the source shape as values.
312/// This is used to compute the strides in cases of dynamic shapes and/or
313/// dynamic stride for this reassociation group.
314///
315/// strides#i =
316/// origStrides#reassDim * product(expandShapeSizes#j, for j in
317/// reassIdx#i+1..reassIdx#i+group.size-1)
318///
319/// Where reassIdx#i is the reassociation index for at index i in \p groupId
320/// and expandShapeSizes#j is either:
321/// - The constant size at dimension j, derived directly from the result type of
322/// the expand_shape op, or
323/// - An affine expression: baseSizes#reassDim / product of all constant sizes
324/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
325/// element.)
326///
327/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
328///
329/// TODO: Move this utility function directly within ExpandShapeOp. For now,
330/// this is not possible because this function uses the Affine dialect and the
331/// MemRef dialect cannot depend on the Affine dialect.
332SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
333 OpBuilder &builder,
334 ArrayRef<OpFoldResult> origSizes,
335 ArrayRef<OpFoldResult> origStrides,
336 unsigned groupId) {
337 SmallVector<int64_t, 2> reassocGroup =
338 expandShape.getReassociationIndices()[groupId];
339 assert(!reassocGroup.empty() &&
340 "Reassociation group should have at least one dimension");
341
342 unsigned groupSize = reassocGroup.size();
343 MemRefType expandShapeType = expandShape.getResultType();
344
345 std::optional<int64_t> dynSizeIdx;
346
347 // Fill up the expanded strides, with the information we can deduce from the
348 // resulting shape.
349 uint64_t currentStride = 1;
350 SmallVector<OpFoldResult> expandedStrides(groupSize);
351 for (int i = groupSize - 1; i >= 0; --i) {
352 expandedStrides[i] = builder.getIndexAttr(currentStride);
353 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354 if (ShapedType::isDynamic(dimSize)) {
355 assert(!dynSizeIdx && "There must be at most one dynamic size per group");
356 dynSizeIdx = i;
357 continue;
358 }
359
360 currentStride *= dimSize;
361 }
362
363 // Collect the statically known information about the original stride.
364 Value source = expandShape.getSrc();
365 auto sourceType = cast<MemRefType>(source.getType());
366 auto [strides, offset] = getStridesAndOffset(sourceType);
367
368 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369 ? origStrides[groupId]
370 : builder.getIndexAttr(strides[groupId]);
371
372 // Apply the original stride to all the strides.
373 int64_t doneStrideIdx = 0;
374 // If we saw a dynamic dimension, we need to fix-up all the strides up to
375 // that dimension with the dynamic size.
376 if (dynSizeIdx) {
377 int64_t productOfAllStaticSizes = currentStride;
378 assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379 "We shouldn't be able to change dynamicity");
380 OpFoldResult origSize = origSizes[groupId];
381
382 AffineExpr s0 = builder.getAffineSymbolExpr(position: 0);
383 AffineExpr s1 = builder.getAffineSymbolExpr(position: 1);
384 for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385 int64_t baseExpandedStride =
386 cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
387 .getInt();
388 expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
389 builder, expandShape.getLoc(),
390 (s0 * baseExpandedStride).floorDiv(v: productOfAllStaticSizes) * s1,
391 {origSize, origStride});
392 }
393 }
394
395 // Now apply the origStride to the remaining dimensions.
396 AffineExpr s0 = builder.getAffineSymbolExpr(position: 0);
397 for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398 int64_t baseExpandedStride =
399 cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>())
400 .getInt();
401 expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
402 builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
403 }
404
405 return expandedStrides;
406}
407
408/// Produce an OpFoldResult object with \p builder at \p loc representing
409/// `prod(valueOrConstant#i, for i in {indices})`,
410/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
411/// values[i] otherwise.
412///
413/// \pre for all index in indices: index < values.size()
414/// \pre for all index in indices: index < maybeConstants.size()
415static OpFoldResult
416getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
417 ArrayRef<int64_t> maybeConstants,
418 ArrayRef<OpFoldResult> values,
419 llvm::function_ref<bool(int64_t)> isDynamic) {
420 AffineExpr productOfValues = builder.getAffineConstantExpr(constant: 1);
421 SmallVector<OpFoldResult> inputValues;
422 unsigned numberOfSymbols = 0;
423 unsigned groupSize = indices.size();
424 for (unsigned i = 0; i < groupSize; ++i) {
425 productOfValues =
426 productOfValues * builder.getAffineSymbolExpr(position: numberOfSymbols++);
427 unsigned srcIdx = indices[i];
428 int64_t maybeConstant = maybeConstants[srcIdx];
429
430 inputValues.push_back(isDynamic(maybeConstant)
431 ? values[srcIdx]
432 : builder.getIndexAttr(maybeConstant));
433 }
434
435 return makeComposedFoldedAffineApply(builder, loc, productOfValues,
436 inputValues);
437}
438
439/// Compute the collapsed size of the given \p collpaseShape for the
440/// \p groupId-th reassociation group.
441/// \p origSizes hold the sizes of the source shape as values.
442/// This is used to compute the new sizes in cases of dynamic shapes.
443///
444/// Conceptually this helper function computes:
445/// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`.
446///
447/// \post result.size() == 1, in other words, each group collapse to one
448/// dimension.
449///
450/// TODO: Move this utility function directly within CollapseShapeOp. For now,
451/// this is not possible because this function uses the Affine dialect and the
452/// MemRef dialect cannot depend on the Affine dialect.
453static SmallVector<OpFoldResult>
454getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
455 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
456 SmallVector<OpFoldResult> collapsedSize;
457
458 MemRefType collapseShapeType = collapseShape.getResultType();
459
460 uint64_t size = collapseShapeType.getDimSize(groupId);
461 if (!ShapedType::isDynamic(size)) {
462 collapsedSize.push_back(builder.getIndexAttr(size));
463 return collapsedSize;
464 }
465
466 // We are dealing with a dynamic size.
467 // Build the affine expr of the product of the original sizes involved in that
468 // group.
469 Value source = collapseShape.getSrc();
470 auto sourceType = cast<MemRefType>(source.getType());
471
472 SmallVector<int64_t, 2> reassocGroup =
473 collapseShape.getReassociationIndices()[groupId];
474
475 collapsedSize.push_back(getProductOfValues(
476 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
477 origSizes, ShapedType::isDynamic));
478
479 return collapsedSize;
480}
481
482/// Compute the collapsed stride of the given \p collpaseShape for the
483/// \p groupId-th reassociation group.
484/// \p origStrides and \p origSizes hold respectively the strides and sizes
485/// of the source shape as values.
486/// This is used to compute the strides in cases of dynamic shapes and/or
487/// dynamic stride for this reassociation group.
488///
489/// Conceptually this helper function returns the stride of the inner most
490/// dimension of that group in the original shape.
491///
492/// \post result.size() == 1, in other words, each group collapse to one
493/// dimension.
494static SmallVector<OpFoldResult>
495getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
496 ArrayRef<OpFoldResult> origSizes,
497 ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
498 SmallVector<int64_t, 2> reassocGroup =
499 collapseShape.getReassociationIndices()[groupId];
500 assert(!reassocGroup.empty() &&
501 "Reassociation group should have at least one dimension");
502
503 Value source = collapseShape.getSrc();
504 auto sourceType = cast<MemRefType>(source.getType());
505
506 auto [strides, offset] = getStridesAndOffset(sourceType);
507
508 SmallVector<OpFoldResult> groupStrides;
509 ArrayRef<int64_t> srcShape = sourceType.getShape();
510 for (int64_t currentDim : reassocGroup) {
511 // Skip size-of-1 dimensions, since right now their strides may be
512 // meaningless.
513 // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless
514 // they are truly contiguous. When they are truly contiguous, we shouldn't
515 // need to skip them.
516 if (srcShape[currentDim] == 1)
517 continue;
518
519 int64_t currentStride = strides[currentDim];
520 groupStrides.push_back(ShapedType::isDynamic(currentStride)
521 ? origStrides[currentDim]
522 : builder.getIndexAttr(currentStride));
523 }
524 if (groupStrides.empty()) {
525 // We're dealing with a 1x1x...x1 shape. The stride is meaningless,
526 // but we still have to make the type system happy.
527 MemRefType collapsedType = collapseShape.getResultType();
528 auto [collapsedStrides, collapsedOffset] =
529 getStridesAndOffset(collapsedType);
530 int64_t finalStride = collapsedStrides[groupId];
531 if (ShapedType::isDynamic(finalStride)) {
532 // Look for a dynamic stride. At this point we don't know which one is
533 // desired, but they are all equally good/bad.
534 for (int64_t currentDim : reassocGroup) {
535 assert(srcShape[currentDim] == 1 &&
536 "We should be dealing with 1x1x...x1");
537
538 if (ShapedType::isDynamic(strides[currentDim]))
539 return {origStrides[currentDim]};
540 }
541 llvm_unreachable("We should have found a dynamic stride");
542 }
543 return {builder.getIndexAttr(finalStride)};
544 }
545
546 // For the general case, we just want the minimum stride
547 // since the collapsed dimensions are contiguous.
548 auto minMap = AffineMap::getMultiDimIdentityMap(numDims: groupStrides.size(),
549 context: builder.getContext());
550 return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
551 groupStrides)};
552}
553/// Replace `baseBuffer, offset, sizes, strides =
554/// extract_strided_metadata(reshapeLike(memref))`
555/// With
556///
557/// \verbatim
558/// baseBuffer, offset, baseSizes, baseStrides =
559/// extract_strided_metadata(memref)
560/// sizes = getReshapedSizes(reshapeLike)
561/// strides = getReshapedStrides(reshapeLike)
562/// \endverbatim
563///
564///
565/// Notice that `baseBuffer` and `offset` are unchanged.
566///
567/// In other words, get rid of the expand_shape in that expression and
568/// materialize its effects on the sizes and the strides using affine apply.
569template <typename ReassociativeReshapeLikeOp,
570 SmallVector<OpFoldResult> (*getReshapedSizes)(
571 ReassociativeReshapeLikeOp, OpBuilder &,
572 ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
573 SmallVector<OpFoldResult> (*getReshapedStrides)(
574 ReassociativeReshapeLikeOp, OpBuilder &,
575 ArrayRef<OpFoldResult> /*origSizes*/,
576 ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
577struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
578public:
579 using OpRewritePattern<ReassociativeReshapeLikeOp>::OpRewritePattern;
580
581 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
582 PatternRewriter &rewriter) const override {
583 // Build a plain extract_strided_metadata(memref) from
584 // extract_strided_metadata(reassociative_reshape_like(memref)).
585 Location origLoc = reshape.getLoc();
586 Value source = reshape.getSrc();
587 auto sourceType = cast<MemRefType>(source.getType());
588 unsigned sourceRank = sourceType.getRank();
589
590 auto newExtractStridedMetadata =
591 rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
592
593 // Collect statically known information.
594 auto [strides, offset] = getStridesAndOffset(sourceType);
595 MemRefType reshapeType = reshape.getResultType();
596 unsigned reshapeRank = reshapeType.getRank();
597
598 OpFoldResult offsetOfr =
599 ShapedType::isDynamic(offset)
600 ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
601 : rewriter.getIndexAttr(offset);
602
603 // Get the special case of 0-D out of the way.
604 if (sourceRank == 0) {
605 SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
606 auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
607 origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
608 offsetOfr, /*sizes=*/ones, /*strides=*/ones);
609 rewriter.replaceOp(reshape, memrefDesc.getResult());
610 return success();
611 }
612
613 SmallVector<OpFoldResult> finalSizes;
614 finalSizes.reserve(reshapeRank);
615 SmallVector<OpFoldResult> finalStrides;
616 finalStrides.reserve(reshapeRank);
617
618 // Compute the reshaped strides and sizes from the base strides and sizes.
619 SmallVector<OpFoldResult> origSizes =
620 getAsOpFoldResult(newExtractStridedMetadata.getSizes());
621 SmallVector<OpFoldResult> origStrides =
622 getAsOpFoldResult(newExtractStridedMetadata.getStrides());
623 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
624 for (; idx != endIdx; ++idx) {
625 SmallVector<OpFoldResult> reshapedSizes =
626 getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
627 SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
628 reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
629
630 unsigned groupSize = reshapedSizes.size();
631 for (unsigned i = 0; i < groupSize; ++i) {
632 finalSizes.push_back(reshapedSizes[i]);
633 finalStrides.push_back(reshapedStrides[i]);
634 }
635 }
636 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
637 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
638 "We should have visited all the input dimensions");
639 assert(finalSizes.size() == reshapeRank &&
640 "We should have populated all the values");
641 auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
642 origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
643 offsetOfr, finalSizes, finalStrides);
644 rewriter.replaceOp(reshape, memrefDesc.getResult());
645 return success();
646 }
647};
648
649/// Replace `base, offset, sizes, strides =
650/// extract_strided_metadata(allocLikeOp)`
651///
652/// With
653///
654/// ```
655/// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy>
656/// offset = 0
657/// sizes = allocSizes
658/// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
659/// ```
660///
661/// The transformation only applies if the allocLikeOp has been normalized.
662/// In other words, the affine_map must be an identity.
663template <typename AllocLikeOp>
664struct ExtractStridedMetadataOpAllocFolder
665 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
666public:
667 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
668
669 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
670 PatternRewriter &rewriter) const override {
671 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
672 if (!allocLikeOp)
673 return failure();
674
675 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
676 if (!memRefType.getLayout().isIdentity())
677 return rewriter.notifyMatchFailure(
678 allocLikeOp, "alloc-like operations should have been normalized");
679
680 Location loc = op.getLoc();
681 int rank = memRefType.getRank();
682
683 // Collect the sizes.
684 ValueRange dynamic = allocLikeOp.getDynamicSizes();
685 SmallVector<OpFoldResult> sizes;
686 sizes.reserve(rank);
687 unsigned dynamicPos = 0;
688 for (int64_t size : memRefType.getShape()) {
689 if (ShapedType::isDynamic(size))
690 sizes.push_back(dynamic[dynamicPos++]);
691 else
692 sizes.push_back(rewriter.getIndexAttr(size));
693 }
694
695 // Strides (just creates identity strides).
696 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
697 AffineExpr expr = rewriter.getAffineConstantExpr(constant: 1);
698 unsigned symbolNumber = 0;
699 for (int i = rank - 2; i >= 0; --i) {
700 expr = expr * rewriter.getAffineSymbolExpr(position: symbolNumber++);
701 assert(i + 1 + symbolNumber == sizes.size() &&
702 "The ArrayRef should encompass the last #symbolNumber sizes");
703 ArrayRef<OpFoldResult> sizesInvolvedInStride(&sizes[i + 1], symbolNumber);
704 strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr,
705 sizesInvolvedInStride);
706 }
707
708 // Put all the values together to replace the results.
709 SmallVector<Value> results;
710 results.reserve(rank * 2 + 2);
711
712 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
713 int64_t offset = 0;
714 if (op.getBaseBuffer().use_empty()) {
715 results.push_back(nullptr);
716 } else {
717 if (allocLikeOp.getType() == baseBufferType)
718 results.push_back(allocLikeOp);
719 else
720 results.push_back(rewriter.create<memref::ReinterpretCastOp>(
721 loc, baseBufferType, allocLikeOp, offset,
722 /*sizes=*/ArrayRef<int64_t>(),
723 /*strides=*/ArrayRef<int64_t>()));
724 }
725
726 // Offset.
727 results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
728
729 for (OpFoldResult size : sizes)
730 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size));
731
732 for (OpFoldResult stride : strides)
733 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride));
734
735 rewriter.replaceOp(op, results);
736 return success();
737 }
738};
739
740/// Replace `base, offset, sizes, strides =
741/// extract_strided_metadata(get_global)`
742///
743/// With
744///
745/// ```
746/// base = reinterpret_cast get_global to a flat memref<eltTy>
747/// offset = 0
748/// sizes = allocSizes
749/// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
750/// ```
751///
752/// It is expected that the memref.get_global op has static shapes
753/// and identity affine_map for the layout.
754struct ExtractStridedMetadataOpGetGlobalFolder
755 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
756public:
757 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
758
759 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
760 PatternRewriter &rewriter) const override {
761 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
762 if (!getGlobalOp)
763 return failure();
764
765 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
766 if (!memRefType.getLayout().isIdentity()) {
767 return rewriter.notifyMatchFailure(
768 getGlobalOp,
769 "get-global operation result should have been normalized");
770 }
771
772 Location loc = op.getLoc();
773 int rank = memRefType.getRank();
774
775 // Collect the sizes.
776 ArrayRef<int64_t> sizes = memRefType.getShape();
777 assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
778 "unexpected dynamic shape for result of `memref.get_global` op");
779
780 // Strides (just creates identity strides).
781 SmallVector<int64_t> strides = computeSuffixProduct(sizes);
782
783 // Put all the values together to replace the results.
784 SmallVector<Value> results;
785 results.reserve(rank * 2 + 2);
786
787 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
788 int64_t offset = 0;
789 if (getGlobalOp.getType() == baseBufferType)
790 results.push_back(getGlobalOp);
791 else
792 results.push_back(rewriter.create<memref::ReinterpretCastOp>(
793 loc, baseBufferType, getGlobalOp, offset,
794 /*sizes=*/ArrayRef<int64_t>(),
795 /*strides=*/ArrayRef<int64_t>()));
796
797 // Offset.
798 results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
799
800 for (auto size : sizes)
801 results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size));
802
803 for (auto stride : strides)
804 results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride));
805
806 rewriter.replaceOp(op, results);
807 return success();
808 }
809};
810
811/// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
812/// source of the ViewLikeOp.
813class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
814 : public OpRewritePattern<memref::ExtractAlignedPointerAsIndexOp> {
815 using OpRewritePattern::OpRewritePattern;
816
817 LogicalResult
818 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
819 PatternRewriter &rewriter) const override {
820 auto viewLikeOp =
821 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
822 if (!viewLikeOp)
823 return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
824 rewriter.modifyOpInPlace(extractOp, [&]() {
825 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
826 });
827 return success();
828 }
829};
830
831/// Replace `base, offset, sizes, strides =
832/// extract_strided_metadata(
833/// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))`
834/// With
835/// ```
836/// base, ... = extract_strided_metadata(src)
837/// offset = srcOffset
838/// sizes = srcSizes
839/// strides = srcStrides
840/// ```
841///
842/// In other words, consume the `reinterpret_cast` and apply its effects
843/// on the offset, sizes, and strides.
844class ExtractStridedMetadataOpReinterpretCastFolder
845 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
846 using OpRewritePattern::OpRewritePattern;
847
848 LogicalResult
849 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
850 PatternRewriter &rewriter) const override {
851 auto reinterpretCastOp = extractStridedMetadataOp.getSource()
852 .getDefiningOp<memref::ReinterpretCastOp>();
853 if (!reinterpretCastOp)
854 return failure();
855
856 Location loc = extractStridedMetadataOp.getLoc();
857 // Check if the source is suitable for extract_strided_metadata.
858 SmallVector<Type> inferredReturnTypes;
859 if (failed(extractStridedMetadataOp.inferReturnTypes(
860 rewriter.getContext(), loc, {reinterpretCastOp.getSource()},
861 /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
862 inferredReturnTypes)))
863 return rewriter.notifyMatchFailure(
864 reinterpretCastOp, "reinterpret_cast source's type is incompatible");
865
866 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
867 unsigned rank = memrefType.getRank();
868 SmallVector<OpFoldResult> results;
869 results.resize_for_overwrite(rank * 2 + 2);
870
871 auto newExtractStridedMetadata =
872 rewriter.create<memref::ExtractStridedMetadataOp>(
873 loc, reinterpretCastOp.getSource());
874
875 // Register the base_buffer.
876 results[0] = newExtractStridedMetadata.getBaseBuffer();
877
878 // Register the new offset.
879 results[1] = getValueOrCreateConstantIndexOp(
880 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
881
882 const unsigned sizeStartIdx = 2;
883 const unsigned strideStartIdx = sizeStartIdx + rank;
884
885 SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes();
886 SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides();
887 for (unsigned i = 0; i < rank; ++i) {
888 results[sizeStartIdx + i] = sizes[i];
889 results[strideStartIdx + i] = strides[i];
890 }
891 rewriter.replaceOp(extractStridedMetadataOp,
892 getValueOrCreateConstantIndexOp(rewriter, loc, results));
893 return success();
894 }
895};
896
897/// Replace `base, offset, sizes, strides =
898/// extract_strided_metadata(
899/// cast(src) to dstTy)`
900/// With
901/// ```
902/// base, ... = extract_strided_metadata(src)
903/// offset = !dstTy.srcOffset.isDynamic()
904/// ? dstTy.srcOffset
905/// : extract_strided_metadata(src).offset
906/// sizes = for each srcSize in dstTy.srcSizes:
907/// !srcSize.isDynamic()
908/// ? srcSize
909// : extract_strided_metadata(src).sizes[i]
910/// strides = for each srcStride in dstTy.srcStrides:
911/// !srcStrides.isDynamic()
912/// ? srcStrides
913/// : extract_strided_metadata(src).strides[i]
914/// ```
915///
916/// In other words, consume the `cast` and apply its effects
917/// on the offset, sizes, and strides or compute them directly from `src`.
918class ExtractStridedMetadataOpCastFolder
919 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
920 using OpRewritePattern::OpRewritePattern;
921
922 LogicalResult
923 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
924 PatternRewriter &rewriter) const override {
925 Value source = extractStridedMetadataOp.getSource();
926 auto castOp = source.getDefiningOp<memref::CastOp>();
927 if (!castOp)
928 return failure();
929
930 Location loc = extractStridedMetadataOp.getLoc();
931 // Check if the source is suitable for extract_strided_metadata.
932 SmallVector<Type> inferredReturnTypes;
933 if (failed(extractStridedMetadataOp.inferReturnTypes(
934 rewriter.getContext(), loc, {castOp.getSource()},
935 /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
936 inferredReturnTypes)))
937 return rewriter.notifyMatchFailure(castOp,
938 "cast source's type is incompatible");
939
940 auto memrefType = cast<MemRefType>(source.getType());
941 unsigned rank = memrefType.getRank();
942 SmallVector<OpFoldResult> results;
943 results.resize_for_overwrite(rank * 2 + 2);
944
945 auto newExtractStridedMetadata =
946 rewriter.create<memref::ExtractStridedMetadataOp>(loc,
947 castOp.getSource());
948
949 // Register the base_buffer.
950 results[0] = newExtractStridedMetadata.getBaseBuffer();
951
952 auto getConstantOrValue = [&rewriter](int64_t constant,
953 OpFoldResult ofr) -> OpFoldResult {
954 return !ShapedType::isDynamic(constant)
955 ? OpFoldResult(rewriter.getIndexAttr(constant))
956 : ofr;
957 };
958
959 auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType);
960 assert(sourceStrides.size() == rank && "unexpected number of strides");
961
962 // Register the new offset.
963 results[1] =
964 getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
965
966 const unsigned sizeStartIdx = 2;
967 const unsigned strideStartIdx = sizeStartIdx + rank;
968 ArrayRef<int64_t> sourceSizes = memrefType.getShape();
969
970 SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
971 SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
972 for (unsigned i = 0; i < rank; ++i) {
973 results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
974 results[strideStartIdx + i] =
975 getConstantOrValue(sourceStrides[i], strides[i]);
976 }
977 rewriter.replaceOp(extractStridedMetadataOp,
978 getValueOrCreateConstantIndexOp(rewriter, loc, results));
979 return success();
980 }
981};
982
983/// Replace `base, offset =
984/// extract_strided_metadata(extract_strided_metadata(src)#0)`
985/// With
986/// ```
987/// base, ... = extract_strided_metadata(src)
988/// offset = 0
989/// ```
990class ExtractStridedMetadataOpExtractStridedMetadataFolder
991 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
992 using OpRewritePattern::OpRewritePattern;
993
994 LogicalResult
995 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
996 PatternRewriter &rewriter) const override {
997 auto sourceExtractStridedMetadataOp =
998 extractStridedMetadataOp.getSource()
999 .getDefiningOp<memref::ExtractStridedMetadataOp>();
1000 if (!sourceExtractStridedMetadataOp)
1001 return failure();
1002 Location loc = extractStridedMetadataOp.getLoc();
1003 rewriter.replaceOp(extractStridedMetadataOp,
1004 {sourceExtractStridedMetadataOp.getBaseBuffer(),
1005 getValueOrCreateConstantIndexOp(
1006 rewriter, loc, rewriter.getIndexAttr(0))});
1007 return success();
1008 }
1009};
1010} // namespace
1011
1012void memref::populateExpandStridedMetadataPatterns(
1013 RewritePatternSet &patterns) {
1014 patterns.add<SubviewFolder,
1015 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1016 getExpandedStrides>,
1017 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1018 getCollapsedStride>,
1019 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1020 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1021 ExtractStridedMetadataOpGetGlobalFolder,
1022 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1023 ExtractStridedMetadataOpReinterpretCastFolder,
1024 ExtractStridedMetadataOpCastFolder,
1025 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1026 patterns.getContext());
1027}
1028
1029void memref::populateResolveExtractStridedMetadataPatterns(
1030 RewritePatternSet &patterns) {
1031 patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1032 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1033 ExtractStridedMetadataOpGetGlobalFolder,
1034 ExtractStridedMetadataOpSubviewFolder,
1035 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1036 ExtractStridedMetadataOpReinterpretCastFolder,
1037 ExtractStridedMetadataOpCastFolder,
1038 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1039 arg: patterns.getContext());
1040}
1041
1042//===----------------------------------------------------------------------===//
1043// Pass registration
1044//===----------------------------------------------------------------------===//
1045
1046namespace {
1047
1048struct ExpandStridedMetadataPass final
1049 : public memref::impl::ExpandStridedMetadataBase<
1050 ExpandStridedMetadataPass> {
1051 void runOnOperation() override;
1052};
1053
1054} // namespace
1055
1056void ExpandStridedMetadataPass::runOnOperation() {
1057 RewritePatternSet patterns(&getContext());
1058 memref::populateExpandStridedMetadataPatterns(patterns);
1059 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1060}
1061
1062std::unique_ptr<Pass> memref::createExpandStridedMetadataPass() {
1063 return std::make_unique<ExpandStridedMetadataPass>();
1064}
1065

source code of mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp