1//===- ExpandStridedMetadata.cpp - Simplify this operation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// The pass expands memref operations that modify the metadata of a memref
10/// (sizes, offset, strides) into a sequence of easier to analyze constructs.
11/// In particular, this pass transforms operations into explicit sequence of
12/// operations that model the effect of this operation on the different
13/// metadata. This pass uses affine constructs to materialize these effects.
14//===----------------------------------------------------------------------===//
15
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/MemRef/Transforms/Passes.h"
20#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
21#include "mlir/Dialect/Utils/IndexingUtils.h"
22#include "mlir/IR/AffineMap.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include <optional>
28
29namespace mlir {
30namespace memref {
31#define GEN_PASS_DEF_EXPANDSTRIDEDMETADATAPASS
32#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
33} // namespace memref
34} // namespace mlir
35
36using namespace mlir;
37using namespace mlir::affine;
38
39namespace {
40
41struct StridedMetadata {
42 Value basePtr;
43 OpFoldResult offset;
44 SmallVector<OpFoldResult> sizes;
45 SmallVector<OpFoldResult> strides;
46};
47
48/// From `subview(memref, subOffset, subSizes, subStrides))` compute
49///
50/// \verbatim
51/// baseBuffer, baseOffset, baseSizes, baseStrides =
52/// extract_strided_metadata(memref)
53/// strides#i = baseStrides#i * subStrides#i
54/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
55/// sizes = subSizes
56/// \endverbatim
57///
58/// and return {baseBuffer, offset, sizes, strides}
59static FailureOr<StridedMetadata>
60resolveSubviewStridedMetadata(RewriterBase &rewriter,
61 memref::SubViewOp subview) {
62 // Build a plain extract_strided_metadata(memref) from subview(memref).
63 Location origLoc = subview.getLoc();
64 Value source = subview.getSource();
65 auto sourceType = cast<MemRefType>(source.getType());
66 unsigned sourceRank = sourceType.getRank();
67
68 auto newExtractStridedMetadata =
69 rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
70
71 auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
72#ifndef NDEBUG
73 auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset();
74#endif // NDEBUG
75
76 // Compute the new strides and offset from the base strides and offset:
77 // newStride#i = baseStride#i * subStride#i
78 // offset = baseOffset + sum(subOffsets#i * newStrides#i)
79 SmallVector<OpFoldResult> strides;
80 SmallVector<OpFoldResult> subStrides = subview.getMixedStrides();
81 auto origStrides = newExtractStridedMetadata.getStrides();
82
83 // Hold the affine symbols and values for the computation of the offset.
84 SmallVector<OpFoldResult> values(2 * sourceRank + 1);
85 SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
86
87 bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
88 AffineExpr expr = symbols.front();
89 values[0] = ShapedType::isDynamic(sourceOffset)
90 ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
91 : rewriter.getIndexAttr(sourceOffset);
92 SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets();
93
94 AffineExpr s0 = rewriter.getAffineSymbolExpr(position: 0);
95 AffineExpr s1 = rewriter.getAffineSymbolExpr(position: 1);
96 for (unsigned i = 0; i < sourceRank; ++i) {
97 // Compute the stride.
98 OpFoldResult origStride =
99 ShapedType::isDynamic(sourceStrides[i])
100 ? origStrides[i]
101 : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i]));
102 strides.push_back(makeComposedFoldedAffineApply(
103 rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
104
105 // Build up the computation of the offset.
106 unsigned baseIdxForDim = 1 + 2 * i;
107 unsigned subOffsetForDim = baseIdxForDim;
108 unsigned origStrideForDim = baseIdxForDim + 1;
109 expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
110 values[subOffsetForDim] = subOffsets[i];
111 values[origStrideForDim] = origStride;
112 }
113
114 // Compute the offset.
115 OpFoldResult finalOffset =
116 makeComposedFoldedAffineApply(rewriter, origLoc, expr, values);
117#ifndef NDEBUG
118 // Assert that the computed offset matches the offset of the result type of
119 // the subview op (if both are static).
120 std::optional<int64_t> computedOffset = getConstantIntValue(ofr: finalOffset);
121 if (computedOffset && !ShapedType::isDynamic(resultOffset))
122 assert(*computedOffset == resultOffset &&
123 "mismatch between computed offset and result type offset");
124#endif // NDEBUG
125
126 // The final result is <baseBuffer, offset, sizes, strides>.
127 // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
128 // the values.
129 auto subType = cast<MemRefType>(subview.getType());
130 unsigned subRank = subType.getRank();
131
132 // The sizes of the final type are defined directly by the input sizes of
133 // the subview.
134 // Moreover subviews can drop some dimensions, some strides and sizes may
135 // not end up in the final <base, offset, sizes, strides> value that we are
136 // replacing.
137 // Do the filtering here.
138 SmallVector<OpFoldResult> subSizes = subview.getMixedSizes();
139 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
140
141 SmallVector<OpFoldResult> finalSizes;
142 finalSizes.reserve(subRank);
143
144 SmallVector<OpFoldResult> finalStrides;
145 finalStrides.reserve(subRank);
146
147#ifndef NDEBUG
148 // Iteration variable for result dimensions of the subview op.
149 int64_t j = 0;
150#endif // NDEBUG
151 for (unsigned i = 0; i < sourceRank; ++i) {
152 if (droppedDims.test(Idx: i))
153 continue;
154
155 finalSizes.push_back(subSizes[i]);
156 finalStrides.push_back(strides[i]);
157#ifndef NDEBUG
158 // Assert that the computed stride matches the stride of the result type of
159 // the subview op (if both are static).
160 std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
161 if (computedStride && !ShapedType::isDynamic(resultStrides[j]))
162 assert(*computedStride == resultStrides[j] &&
163 "mismatch between computed stride and result type stride");
164 ++j;
165#endif // NDEBUG
166 }
167 assert(finalSizes.size() == subRank &&
168 "Should have populated all the values at this point");
169 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
170 finalSizes, finalStrides};
171}
172
173/// Replace `dst = subview(memref, subOffset, subSizes, subStrides))`
174/// With
175///
176/// \verbatim
177/// baseBuffer, baseOffset, baseSizes, baseStrides =
178/// extract_strided_metadata(memref)
179/// strides#i = baseStrides#i * subSizes#i
180/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
181/// sizes = subSizes
182/// dst = reinterpret_cast baseBuffer, offset, sizes, strides
183/// \endverbatim
184///
185/// In other words, get rid of the subview in that expression and canonicalize
186/// on its effects on the offset, the sizes, and the strides using affine.apply.
187struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> {
188public:
189 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
190
191 LogicalResult matchAndRewrite(memref::SubViewOp subview,
192 PatternRewriter &rewriter) const override {
193 FailureOr<StridedMetadata> stridedMetadata =
194 resolveSubviewStridedMetadata(rewriter, subview);
195 if (failed(stridedMetadata)) {
196 return rewriter.notifyMatchFailure(subview,
197 "failed to resolve subview metadata");
198 }
199
200 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
201 subview, subview.getType(), stridedMetadata->basePtr,
202 stridedMetadata->offset, stridedMetadata->sizes,
203 stridedMetadata->strides);
204 return success();
205 }
206};
207
208/// Pattern to replace `extract_strided_metadata(subview)`
209/// With
210///
211/// \verbatim
212/// baseBuffer, baseOffset, baseSizes, baseStrides =
213/// extract_strided_metadata(memref)
214/// strides#i = baseStrides#i * subSizes#i
215/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
216/// sizes = subSizes
217/// \verbatim
218///
219/// with `baseBuffer`, `offset`, `sizes` and `strides` being
220/// the replacements for the original `extract_strided_metadata`.
221struct ExtractStridedMetadataOpSubviewFolder
222 : OpRewritePattern<memref::ExtractStridedMetadataOp> {
223 using OpRewritePattern::OpRewritePattern;
224
225 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
226 PatternRewriter &rewriter) const override {
227 auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
228 if (!subviewOp)
229 return failure();
230
231 FailureOr<StridedMetadata> stridedMetadata =
232 resolveSubviewStridedMetadata(rewriter, subviewOp);
233 if (failed(stridedMetadata)) {
234 return rewriter.notifyMatchFailure(
235 op, "failed to resolve metadata in terms of source subview op");
236 }
237 Location loc = subviewOp.getLoc();
238 SmallVector<Value> results;
239 results.reserve(subviewOp.getType().getRank() * 2 + 2);
240 results.push_back(stridedMetadata->basePtr);
241 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
242 stridedMetadata->offset));
243 results.append(
244 getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
245 results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
246 stridedMetadata->strides));
247 rewriter.replaceOp(op, results);
248
249 return success();
250 }
251};
252
253/// Compute the expanded sizes of the given \p expandShape for the
254/// \p groupId-th reassociation group.
255/// \p origSizes hold the sizes of the source shape as values.
256/// This is used to compute the new sizes in cases of dynamic shapes.
257///
258/// sizes#i =
259/// baseSizes#groupId / product(expandShapeSizes#j,
260/// for j in group excluding reassIdx#i)
261/// Where reassIdx#i is the reassociation index at index i in \p groupId.
262///
263/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
264///
265/// TODO: Move this utility function directly within ExpandShapeOp. For now,
266/// this is not possible because this function uses the Affine dialect and the
267/// MemRef dialect cannot depend on the Affine dialect.
268static SmallVector<OpFoldResult>
269getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
270 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
271 SmallVector<int64_t, 2> reassocGroup =
272 expandShape.getReassociationIndices()[groupId];
273 assert(!reassocGroup.empty() &&
274 "Reassociation group should have at least one dimension");
275
276 unsigned groupSize = reassocGroup.size();
277 SmallVector<OpFoldResult> expandedSizes(groupSize);
278
279 uint64_t productOfAllStaticSizes = 1;
280 std::optional<unsigned> dynSizeIdx;
281 MemRefType expandShapeType = expandShape.getResultType();
282
283 // Fill up all the statically known sizes.
284 for (unsigned i = 0; i < groupSize; ++i) {
285 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
286 if (ShapedType::isDynamic(dimSize)) {
287 assert(!dynSizeIdx && "There must be at most one dynamic size per group");
288 dynSizeIdx = i;
289 continue;
290 }
291 productOfAllStaticSizes *= dimSize;
292 expandedSizes[i] = builder.getIndexAttr(dimSize);
293 }
294
295 // Compute the dynamic size using the original size and all the other known
296 // static sizes:
297 // expandSize = origSize / productOfAllStaticSizes.
298 if (dynSizeIdx) {
299 AffineExpr s0 = builder.getAffineSymbolExpr(position: 0);
300 expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
301 builder, expandShape.getLoc(), s0.floorDiv(v: productOfAllStaticSizes),
302 origSizes[groupId]);
303 }
304
305 return expandedSizes;
306}
307
308/// Compute the expanded strides of the given \p expandShape for the
309/// \p groupId-th reassociation group.
310/// \p origStrides and \p origSizes hold respectively the strides and sizes
311/// of the source shape as values.
312/// This is used to compute the strides in cases of dynamic shapes and/or
313/// dynamic stride for this reassociation group.
314///
315/// strides#i =
316/// origStrides#reassDim * product(expandShapeSizes#j, for j in
317/// reassIdx#i+1..reassIdx#i+group.size-1)
318///
319/// Where reassIdx#i is the reassociation index for at index i in \p groupId
320/// and expandShapeSizes#j is either:
321/// - The constant size at dimension j, derived directly from the result type of
322/// the expand_shape op, or
323/// - An affine expression: baseSizes#reassDim / product of all constant sizes
324/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
325/// element.)
326///
327/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
328///
329/// TODO: Move this utility function directly within ExpandShapeOp. For now,
330/// this is not possible because this function uses the Affine dialect and the
331/// MemRef dialect cannot depend on the Affine dialect.
332SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
333 OpBuilder &builder,
334 ArrayRef<OpFoldResult> origSizes,
335 ArrayRef<OpFoldResult> origStrides,
336 unsigned groupId) {
337 SmallVector<int64_t, 2> reassocGroup =
338 expandShape.getReassociationIndices()[groupId];
339 assert(!reassocGroup.empty() &&
340 "Reassociation group should have at least one dimension");
341
342 unsigned groupSize = reassocGroup.size();
343 MemRefType expandShapeType = expandShape.getResultType();
344
345 std::optional<int64_t> dynSizeIdx;
346
347 // Fill up the expanded strides, with the information we can deduce from the
348 // resulting shape.
349 uint64_t currentStride = 1;
350 SmallVector<OpFoldResult> expandedStrides(groupSize);
351 for (int i = groupSize - 1; i >= 0; --i) {
352 expandedStrides[i] = builder.getIndexAttr(currentStride);
353 uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
354 if (ShapedType::isDynamic(dimSize)) {
355 assert(!dynSizeIdx && "There must be at most one dynamic size per group");
356 dynSizeIdx = i;
357 continue;
358 }
359
360 currentStride *= dimSize;
361 }
362
363 // Collect the statically known information about the original stride.
364 Value source = expandShape.getSrc();
365 auto sourceType = cast<MemRefType>(source.getType());
366 auto [strides, offset] = sourceType.getStridesAndOffset();
367
368 OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
369 ? origStrides[groupId]
370 : builder.getIndexAttr(strides[groupId]);
371
372 // Apply the original stride to all the strides.
373 int64_t doneStrideIdx = 0;
374 // If we saw a dynamic dimension, we need to fix-up all the strides up to
375 // that dimension with the dynamic size.
376 if (dynSizeIdx) {
377 int64_t productOfAllStaticSizes = currentStride;
378 assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
379 "We shouldn't be able to change dynamicity");
380 OpFoldResult origSize = origSizes[groupId];
381
382 AffineExpr s0 = builder.getAffineSymbolExpr(position: 0);
383 AffineExpr s1 = builder.getAffineSymbolExpr(position: 1);
384 for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
385 int64_t baseExpandedStride =
386 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
387 .getInt();
388 expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
389 builder, expandShape.getLoc(),
390 (s0 * baseExpandedStride).floorDiv(v: productOfAllStaticSizes) * s1,
391 {origSize, origStride});
392 }
393 }
394
395 // Now apply the origStride to the remaining dimensions.
396 AffineExpr s0 = builder.getAffineSymbolExpr(position: 0);
397 for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
398 int64_t baseExpandedStride =
399 cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
400 .getInt();
401 expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
402 builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
403 }
404
405 return expandedStrides;
406}
407
408/// Produce an OpFoldResult object with \p builder at \p loc representing
409/// `prod(valueOrConstant#i, for i in {indices})`,
410/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
411/// values[i] otherwise.
412///
413/// \pre for all index in indices: index < values.size()
414/// \pre for all index in indices: index < maybeConstants.size()
415static OpFoldResult
416getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
417 ArrayRef<int64_t> maybeConstants,
418 ArrayRef<OpFoldResult> values,
419 llvm::function_ref<bool(int64_t)> isDynamic) {
420 AffineExpr productOfValues = builder.getAffineConstantExpr(constant: 1);
421 SmallVector<OpFoldResult> inputValues;
422 unsigned numberOfSymbols = 0;
423 unsigned groupSize = indices.size();
424 for (unsigned i = 0; i < groupSize; ++i) {
425 productOfValues =
426 productOfValues * builder.getAffineSymbolExpr(position: numberOfSymbols++);
427 unsigned srcIdx = indices[i];
428 int64_t maybeConstant = maybeConstants[srcIdx];
429
430 inputValues.push_back(isDynamic(maybeConstant)
431 ? values[srcIdx]
432 : builder.getIndexAttr(maybeConstant));
433 }
434
435 return makeComposedFoldedAffineApply(builder, loc, productOfValues,
436 inputValues);
437}
438
439/// Compute the collapsed size of the given \p collpaseShape for the
440/// \p groupId-th reassociation group.
441/// \p origSizes hold the sizes of the source shape as values.
442/// This is used to compute the new sizes in cases of dynamic shapes.
443///
444/// Conceptually this helper function computes:
445/// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`.
446///
447/// \post result.size() == 1, in other words, each group collapse to one
448/// dimension.
449///
450/// TODO: Move this utility function directly within CollapseShapeOp. For now,
451/// this is not possible because this function uses the Affine dialect and the
452/// MemRef dialect cannot depend on the Affine dialect.
453static SmallVector<OpFoldResult>
454getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
455 ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
456 SmallVector<OpFoldResult> collapsedSize;
457
458 MemRefType collapseShapeType = collapseShape.getResultType();
459
460 uint64_t size = collapseShapeType.getDimSize(groupId);
461 if (!ShapedType::isDynamic(size)) {
462 collapsedSize.push_back(builder.getIndexAttr(size));
463 return collapsedSize;
464 }
465
466 // We are dealing with a dynamic size.
467 // Build the affine expr of the product of the original sizes involved in that
468 // group.
469 Value source = collapseShape.getSrc();
470 auto sourceType = cast<MemRefType>(source.getType());
471
472 SmallVector<int64_t, 2> reassocGroup =
473 collapseShape.getReassociationIndices()[groupId];
474
475 collapsedSize.push_back(getProductOfValues(
476 reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(),
477 origSizes, ShapedType::isDynamic));
478
479 return collapsedSize;
480}
481
482/// Compute the collapsed stride of the given \p collpaseShape for the
483/// \p groupId-th reassociation group.
484/// \p origStrides and \p origSizes hold respectively the strides and sizes
485/// of the source shape as values.
486/// This is used to compute the strides in cases of dynamic shapes and/or
487/// dynamic stride for this reassociation group.
488///
489/// Conceptually this helper function returns the stride of the inner most
490/// dimension of that group in the original shape.
491///
492/// \post result.size() == 1, in other words, each group collapse to one
493/// dimension.
494static SmallVector<OpFoldResult>
495getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
496 ArrayRef<OpFoldResult> origSizes,
497 ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
498 SmallVector<int64_t, 2> reassocGroup =
499 collapseShape.getReassociationIndices()[groupId];
500 assert(!reassocGroup.empty() &&
501 "Reassociation group should have at least one dimension");
502
503 Value source = collapseShape.getSrc();
504 auto sourceType = cast<MemRefType>(source.getType());
505
506 auto [strides, offset] = sourceType.getStridesAndOffset();
507
508 ArrayRef<int64_t> srcShape = sourceType.getShape();
509
510 OpFoldResult lastValidStride = nullptr;
511 for (int64_t currentDim : reassocGroup) {
512 // Skip size-of-1 dimensions, since right now their strides may be
513 // meaningless.
514 // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless
515 // they are truly contiguous. When they are truly contiguous, we shouldn't
516 // need to skip them.
517 if (srcShape[currentDim] == 1)
518 continue;
519
520 int64_t currentStride = strides[currentDim];
521 lastValidStride = ShapedType::isDynamic(currentStride)
522 ? origStrides[currentDim]
523 : builder.getIndexAttr(currentStride);
524 }
525 if (!lastValidStride) {
526 // We're dealing with a 1x1x...x1 shape. The stride is meaningless,
527 // but we still have to make the type system happy.
528 MemRefType collapsedType = collapseShape.getResultType();
529 auto [collapsedStrides, collapsedOffset] =
530 collapsedType.getStridesAndOffset();
531 int64_t finalStride = collapsedStrides[groupId];
532 if (ShapedType::isDynamic(finalStride)) {
533 // Look for a dynamic stride. At this point we don't know which one is
534 // desired, but they are all equally good/bad.
535 for (int64_t currentDim : reassocGroup) {
536 assert(srcShape[currentDim] == 1 &&
537 "We should be dealing with 1x1x...x1");
538
539 if (ShapedType::isDynamic(strides[currentDim]))
540 return {origStrides[currentDim]};
541 }
542 llvm_unreachable("We should have found a dynamic stride");
543 }
544 return {builder.getIndexAttr(finalStride)};
545 }
546
547 return {lastValidStride};
548}
549
550/// From `reshape_like(memref, subSizes, subStrides))` compute
551///
552/// \verbatim
553/// baseBuffer, baseOffset, baseSizes, baseStrides =
554/// extract_strided_metadata(memref)
555/// strides#i = baseStrides#i * subStrides#i
556/// sizes = subSizes
557/// \endverbatim
558///
559/// and return {baseBuffer, baseOffset, sizes, strides}
560template <typename ReassociativeReshapeLikeOp>
561static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
562 RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
563 function_ref<SmallVector<OpFoldResult>(
564 ReassociativeReshapeLikeOp, OpBuilder &,
565 ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)>
566 getReshapedSizes,
567 function_ref<SmallVector<OpFoldResult>(
568 ReassociativeReshapeLikeOp, OpBuilder &,
569 ArrayRef<OpFoldResult> /*origSizes*/,
570 ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
571 getReshapedStrides) {
572 // Build a plain extract_strided_metadata(memref) from
573 // extract_strided_metadata(reassociative_reshape_like(memref)).
574 Location origLoc = reshape.getLoc();
575 Value source = reshape.getSrc();
576 auto sourceType = cast<MemRefType>(source.getType());
577 unsigned sourceRank = sourceType.getRank();
578
579 auto newExtractStridedMetadata =
580 rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
581
582 // Collect statically known information.
583 auto [strides, offset] = sourceType.getStridesAndOffset();
584 MemRefType reshapeType = reshape.getResultType();
585 unsigned reshapeRank = reshapeType.getRank();
586
587 OpFoldResult offsetOfr =
588 ShapedType::isDynamic(offset)
589 ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
590 : rewriter.getIndexAttr(offset);
591
592 // Get the special case of 0-D out of the way.
593 if (sourceRank == 0) {
594 SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
595 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
596 /*sizes=*/ones, /*strides=*/ones};
597 }
598
599 SmallVector<OpFoldResult> finalSizes;
600 finalSizes.reserve(reshapeRank);
601 SmallVector<OpFoldResult> finalStrides;
602 finalStrides.reserve(reshapeRank);
603
604 // Compute the reshaped strides and sizes from the base strides and sizes.
605 SmallVector<OpFoldResult> origSizes =
606 getAsOpFoldResult(newExtractStridedMetadata.getSizes());
607 SmallVector<OpFoldResult> origStrides =
608 getAsOpFoldResult(newExtractStridedMetadata.getStrides());
609 unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
610 for (; idx != endIdx; ++idx) {
611 SmallVector<OpFoldResult> reshapedSizes =
612 getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
613 SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
614 reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
615
616 unsigned groupSize = reshapedSizes.size();
617 for (unsigned i = 0; i < groupSize; ++i) {
618 finalSizes.push_back(reshapedSizes[i]);
619 finalStrides.push_back(reshapedStrides[i]);
620 }
621 }
622 assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
623 (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
624 "We should have visited all the input dimensions");
625 assert(finalSizes.size() == reshapeRank &&
626 "We should have populated all the values");
627
628 return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
629 finalSizes, finalStrides};
630}
631
632/// Replace `baseBuffer, offset, sizes, strides =
633/// extract_strided_metadata(reshapeLike(memref))`
634/// With
635///
636/// \verbatim
637/// baseBuffer, offset, baseSizes, baseStrides =
638/// extract_strided_metadata(memref)
639/// sizes = getReshapedSizes(reshapeLike)
640/// strides = getReshapedStrides(reshapeLike)
641/// \endverbatim
642///
643///
644/// Notice that `baseBuffer` and `offset` are unchanged.
645///
646/// In other words, get rid of the expand_shape in that expression and
647/// materialize its effects on the sizes and the strides using affine apply.
648template <typename ReassociativeReshapeLikeOp,
649 SmallVector<OpFoldResult> (*getReshapedSizes)(
650 ReassociativeReshapeLikeOp, OpBuilder &,
651 ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
652 SmallVector<OpFoldResult> (*getReshapedStrides)(
653 ReassociativeReshapeLikeOp, OpBuilder &,
654 ArrayRef<OpFoldResult> /*origSizes*/,
655 ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
656struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
657public:
658 using OpRewritePattern<ReassociativeReshapeLikeOp>::OpRewritePattern;
659
660 LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
661 PatternRewriter &rewriter) const override {
662 FailureOr<StridedMetadata> stridedMetadata =
663 resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
664 rewriter, reshape, getReshapedSizes, getReshapedStrides);
665 if (failed(stridedMetadata)) {
666 return rewriter.notifyMatchFailure(reshape,
667 "failed to resolve reshape metadata");
668 }
669
670 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
671 reshape, reshape.getType(), stridedMetadata->basePtr,
672 stridedMetadata->offset, stridedMetadata->sizes,
673 stridedMetadata->strides);
674 return success();
675 }
676};
677
678/// Pattern to replace `extract_strided_metadata(collapse_shape)`
679/// With
680///
681/// \verbatim
682/// baseBuffer, baseOffset, baseSizes, baseStrides =
683/// extract_strided_metadata(memref)
684/// strides#i = baseStrides#i * subSizes#i
685/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
686/// sizes = subSizes
687/// \verbatim
688///
689/// with `baseBuffer`, `offset`, `sizes` and `strides` being
690/// the replacements for the original `extract_strided_metadata`.
691struct ExtractStridedMetadataOpCollapseShapeFolder
692 : OpRewritePattern<memref::ExtractStridedMetadataOp> {
693 using OpRewritePattern::OpRewritePattern;
694
695 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
696 PatternRewriter &rewriter) const override {
697 auto collapseShapeOp =
698 op.getSource().getDefiningOp<memref::CollapseShapeOp>();
699 if (!collapseShapeOp)
700 return failure();
701
702 FailureOr<StridedMetadata> stridedMetadata =
703 resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
704 rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
705 if (failed(stridedMetadata)) {
706 return rewriter.notifyMatchFailure(
707 op,
708 "failed to resolve metadata in terms of source collapse_shape op");
709 }
710
711 Location loc = collapseShapeOp.getLoc();
712 SmallVector<Value> results;
713 results.push_back(stridedMetadata->basePtr);
714 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
715 stridedMetadata->offset));
716 results.append(
717 getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
718 results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
719 stridedMetadata->strides));
720 rewriter.replaceOp(op, results);
721 return success();
722 }
723};
724
725/// Pattern to replace `extract_strided_metadata(expand_shape)`
726/// with the results of computing the sizes and strides on the expanded shape
727/// and dividing up dimensions into static and dynamic parts as needed.
728struct ExtractStridedMetadataOpExpandShapeFolder
729 : OpRewritePattern<memref::ExtractStridedMetadataOp> {
730 using OpRewritePattern::OpRewritePattern;
731
732 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
733 PatternRewriter &rewriter) const override {
734 auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
735 if (!expandShapeOp)
736 return failure();
737
738 FailureOr<StridedMetadata> stridedMetadata =
739 resolveReshapeStridedMetadata<memref::ExpandShapeOp>(
740 rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides);
741 if (failed(stridedMetadata)) {
742 return rewriter.notifyMatchFailure(
743 op, "failed to resolve metadata in terms of source expand_shape op");
744 }
745
746 Location loc = expandShapeOp.getLoc();
747 SmallVector<Value> results;
748 results.push_back(stridedMetadata->basePtr);
749 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
750 stridedMetadata->offset));
751 results.append(
752 getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
753 results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
754 stridedMetadata->strides));
755 rewriter.replaceOp(op, results);
756 return success();
757 }
758};
759
760/// Replace `base, offset, sizes, strides =
761/// extract_strided_metadata(allocLikeOp)`
762///
763/// With
764///
765/// ```
766/// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy>
767/// offset = 0
768/// sizes = allocSizes
769/// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
770/// ```
771///
772/// The transformation only applies if the allocLikeOp has been normalized.
773/// In other words, the affine_map must be an identity.
774template <typename AllocLikeOp>
775struct ExtractStridedMetadataOpAllocFolder
776 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
777public:
778 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
779
780 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
781 PatternRewriter &rewriter) const override {
782 auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>();
783 if (!allocLikeOp)
784 return failure();
785
786 auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
787 if (!memRefType.getLayout().isIdentity())
788 return rewriter.notifyMatchFailure(
789 allocLikeOp, "alloc-like operations should have been normalized");
790
791 Location loc = op.getLoc();
792 int rank = memRefType.getRank();
793
794 // Collect the sizes.
795 ValueRange dynamic = allocLikeOp.getDynamicSizes();
796 SmallVector<OpFoldResult> sizes;
797 sizes.reserve(rank);
798 unsigned dynamicPos = 0;
799 for (int64_t size : memRefType.getShape()) {
800 if (ShapedType::isDynamic(size))
801 sizes.push_back(dynamic[dynamicPos++]);
802 else
803 sizes.push_back(rewriter.getIndexAttr(size));
804 }
805
806 // Strides (just creates identity strides).
807 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
808 AffineExpr expr = rewriter.getAffineConstantExpr(constant: 1);
809 unsigned symbolNumber = 0;
810 for (int i = rank - 2; i >= 0; --i) {
811 expr = expr * rewriter.getAffineSymbolExpr(position: symbolNumber++);
812 assert(i + 1 + symbolNumber == sizes.size() &&
813 "The ArrayRef should encompass the last #symbolNumber sizes");
814 ArrayRef<OpFoldResult> sizesInvolvedInStride(&sizes[i + 1], symbolNumber);
815 strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr,
816 sizesInvolvedInStride);
817 }
818
819 // Put all the values together to replace the results.
820 SmallVector<Value> results;
821 results.reserve(rank * 2 + 2);
822
823 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
824 int64_t offset = 0;
825 if (op.getBaseBuffer().use_empty()) {
826 results.push_back(nullptr);
827 } else {
828 if (allocLikeOp.getType() == baseBufferType)
829 results.push_back(allocLikeOp);
830 else
831 results.push_back(rewriter.create<memref::ReinterpretCastOp>(
832 loc, baseBufferType, allocLikeOp, offset,
833 /*sizes=*/ArrayRef<int64_t>(),
834 /*strides=*/ArrayRef<int64_t>()));
835 }
836
837 // Offset.
838 results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
839
840 for (OpFoldResult size : sizes)
841 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size));
842
843 for (OpFoldResult stride : strides)
844 results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride));
845
846 rewriter.replaceOp(op, results);
847 return success();
848 }
849};
850
851/// Replace `base, offset, sizes, strides =
852/// extract_strided_metadata(get_global)`
853///
854/// With
855///
856/// ```
857/// base = reinterpret_cast get_global to a flat memref<eltTy>
858/// offset = 0
859/// sizes = allocSizes
860/// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
861/// ```
862///
863/// It is expected that the memref.get_global op has static shapes
864/// and identity affine_map for the layout.
865struct ExtractStridedMetadataOpGetGlobalFolder
866 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
867public:
868 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
869
870 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
871 PatternRewriter &rewriter) const override {
872 auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
873 if (!getGlobalOp)
874 return failure();
875
876 auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType());
877 if (!memRefType.getLayout().isIdentity()) {
878 return rewriter.notifyMatchFailure(
879 getGlobalOp,
880 "get-global operation result should have been normalized");
881 }
882
883 Location loc = op.getLoc();
884 int rank = memRefType.getRank();
885
886 // Collect the sizes.
887 ArrayRef<int64_t> sizes = memRefType.getShape();
888 assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
889 "unexpected dynamic shape for result of `memref.get_global` op");
890
891 // Strides (just creates identity strides).
892 SmallVector<int64_t> strides = computeSuffixProduct(sizes);
893
894 // Put all the values together to replace the results.
895 SmallVector<Value> results;
896 results.reserve(rank * 2 + 2);
897
898 auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
899 int64_t offset = 0;
900 if (getGlobalOp.getType() == baseBufferType)
901 results.push_back(getGlobalOp);
902 else
903 results.push_back(rewriter.create<memref::ReinterpretCastOp>(
904 loc, baseBufferType, getGlobalOp, offset,
905 /*sizes=*/ArrayRef<int64_t>(),
906 /*strides=*/ArrayRef<int64_t>()));
907
908 // Offset.
909 results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
910
911 for (auto size : sizes)
912 results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size));
913
914 for (auto stride : strides)
915 results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride));
916
917 rewriter.replaceOp(op, results);
918 return success();
919 }
920};
921
922/// Pattern to replace `extract_strided_metadata(assume_alignment)`
923///
924/// With
925/// \verbatim
926/// extract_strided_metadata(memref)
927/// \endverbatim
928///
929/// Since `assume_alignment` is a view-like op that does not modify the
930/// underlying buffer, offset, sizes, or strides, extracting strided metadata
931/// from its result is equivalent to extracting it from its source. This
932/// canonicalization removes the unnecessary indirection.
933struct ExtractStridedMetadataOpAssumeAlignmentFolder
934 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
935public:
936 using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
937
938 LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
939 PatternRewriter &rewriter) const override {
940 auto assumeAlignmentOp =
941 op.getSource().getDefiningOp<memref::AssumeAlignmentOp>();
942 if (!assumeAlignmentOp)
943 return failure();
944
945 rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>(
946 op, assumeAlignmentOp.getViewSource());
947 return success();
948 }
949};
950
951/// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
952/// source of the ViewLikeOp.
953class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
954 : public OpRewritePattern<memref::ExtractAlignedPointerAsIndexOp> {
955 using OpRewritePattern::OpRewritePattern;
956
957 LogicalResult
958 matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
959 PatternRewriter &rewriter) const override {
960 auto viewLikeOp =
961 extractOp.getSource().getDefiningOp<ViewLikeOpInterface>();
962 if (!viewLikeOp)
963 return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source");
964 rewriter.modifyOpInPlace(extractOp, [&]() {
965 extractOp.getSourceMutable().assign(viewLikeOp.getViewSource());
966 });
967 return success();
968 }
969};
970
971/// Replace `base, offset, sizes, strides =
972/// extract_strided_metadata(
973/// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))`
974/// With
975/// ```
976/// base, ... = extract_strided_metadata(src)
977/// offset = srcOffset
978/// sizes = srcSizes
979/// strides = srcStrides
980/// ```
981///
982/// In other words, consume the `reinterpret_cast` and apply its effects
983/// on the offset, sizes, and strides.
984class ExtractStridedMetadataOpReinterpretCastFolder
985 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
986 using OpRewritePattern::OpRewritePattern;
987
988 LogicalResult
989 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
990 PatternRewriter &rewriter) const override {
991 auto reinterpretCastOp = extractStridedMetadataOp.getSource()
992 .getDefiningOp<memref::ReinterpretCastOp>();
993 if (!reinterpretCastOp)
994 return failure();
995
996 Location loc = extractStridedMetadataOp.getLoc();
997 // Check if the source is suitable for extract_strided_metadata.
998 SmallVector<Type> inferredReturnTypes;
999 if (failed(extractStridedMetadataOp.inferReturnTypes(
1000 rewriter.getContext(), loc, {reinterpretCastOp.getSource()},
1001 /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
1002 inferredReturnTypes)))
1003 return rewriter.notifyMatchFailure(
1004 reinterpretCastOp, "reinterpret_cast source's type is incompatible");
1005
1006 auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType());
1007 unsigned rank = memrefType.getRank();
1008 SmallVector<OpFoldResult> results;
1009 results.resize_for_overwrite(rank * 2 + 2);
1010
1011 auto newExtractStridedMetadata =
1012 rewriter.create<memref::ExtractStridedMetadataOp>(
1013 loc, reinterpretCastOp.getSource());
1014
1015 // Register the base_buffer.
1016 results[0] = newExtractStridedMetadata.getBaseBuffer();
1017
1018 // Register the new offset.
1019 results[1] = getValueOrCreateConstantIndexOp(
1020 rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]);
1021
1022 const unsigned sizeStartIdx = 2;
1023 const unsigned strideStartIdx = sizeStartIdx + rank;
1024
1025 SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes();
1026 SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides();
1027 for (unsigned i = 0; i < rank; ++i) {
1028 results[sizeStartIdx + i] = sizes[i];
1029 results[strideStartIdx + i] = strides[i];
1030 }
1031 rewriter.replaceOp(extractStridedMetadataOp,
1032 getValueOrCreateConstantIndexOp(rewriter, loc, results));
1033 return success();
1034 }
1035};
1036
1037/// Replace `base, offset, sizes, strides =
1038/// extract_strided_metadata(
1039/// cast(src) to dstTy)`
1040/// With
1041/// ```
1042/// base, ... = extract_strided_metadata(src)
1043/// offset = !dstTy.srcOffset.isDynamic()
1044/// ? dstTy.srcOffset
1045/// : extract_strided_metadata(src).offset
1046/// sizes = for each srcSize in dstTy.srcSizes:
1047/// !srcSize.isDynamic()
1048/// ? srcSize
1049// : extract_strided_metadata(src).sizes[i]
1050/// strides = for each srcStride in dstTy.srcStrides:
1051/// !srcStrides.isDynamic()
1052/// ? srcStrides
1053/// : extract_strided_metadata(src).strides[i]
1054/// ```
1055///
1056/// In other words, consume the `cast` and apply its effects
1057/// on the offset, sizes, and strides or compute them directly from `src`.
1058class ExtractStridedMetadataOpCastFolder
1059 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1060 using OpRewritePattern::OpRewritePattern;
1061
1062 LogicalResult
1063 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1064 PatternRewriter &rewriter) const override {
1065 Value source = extractStridedMetadataOp.getSource();
1066 auto castOp = source.getDefiningOp<memref::CastOp>();
1067 if (!castOp)
1068 return failure();
1069
1070 Location loc = extractStridedMetadataOp.getLoc();
1071 // Check if the source is suitable for extract_strided_metadata.
1072 SmallVector<Type> inferredReturnTypes;
1073 if (failed(extractStridedMetadataOp.inferReturnTypes(
1074 rewriter.getContext(), loc, {castOp.getSource()},
1075 /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{},
1076 inferredReturnTypes)))
1077 return rewriter.notifyMatchFailure(castOp,
1078 "cast source's type is incompatible");
1079
1080 auto memrefType = cast<MemRefType>(source.getType());
1081 unsigned rank = memrefType.getRank();
1082 SmallVector<OpFoldResult> results;
1083 results.resize_for_overwrite(rank * 2 + 2);
1084
1085 auto newExtractStridedMetadata =
1086 rewriter.create<memref::ExtractStridedMetadataOp>(loc,
1087 castOp.getSource());
1088
1089 // Register the base_buffer.
1090 results[0] = newExtractStridedMetadata.getBaseBuffer();
1091
1092 auto getConstantOrValue = [&rewriter](int64_t constant,
1093 OpFoldResult ofr) -> OpFoldResult {
1094 return !ShapedType::isDynamic(constant)
1095 ? OpFoldResult(rewriter.getIndexAttr(constant))
1096 : ofr;
1097 };
1098
1099 auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
1100 assert(sourceStrides.size() == rank && "unexpected number of strides");
1101
1102 // Register the new offset.
1103 results[1] =
1104 getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset());
1105
1106 const unsigned sizeStartIdx = 2;
1107 const unsigned strideStartIdx = sizeStartIdx + rank;
1108 ArrayRef<int64_t> sourceSizes = memrefType.getShape();
1109
1110 SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes();
1111 SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides();
1112 for (unsigned i = 0; i < rank; ++i) {
1113 results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]);
1114 results[strideStartIdx + i] =
1115 getConstantOrValue(sourceStrides[i], strides[i]);
1116 }
1117 rewriter.replaceOp(extractStridedMetadataOp,
1118 getValueOrCreateConstantIndexOp(rewriter, loc, results));
1119 return success();
1120 }
1121};
1122
1123/// Replace `base, offset, sizes, strides = extract_strided_metadata(
1124/// memory_space_cast(src) to dstTy)`
1125/// with
1126/// ```
1127/// oldBase, offset, sizes, strides = extract_strided_metadata(src)
1128/// destBaseTy = type(oldBase) with memory space from destTy
1129/// base = memory_space_cast(oldBase) to destBaseTy
1130/// ```
1131///
1132/// In other words, propagate metadata extraction accross memory space casts.
1133class ExtractStridedMetadataOpMemorySpaceCastFolder
1134 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1135 using OpRewritePattern::OpRewritePattern;
1136
1137 LogicalResult
1138 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1139 PatternRewriter &rewriter) const override {
1140 Location loc = extractStridedMetadataOp.getLoc();
1141 Value source = extractStridedMetadataOp.getSource();
1142 auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>();
1143 if (!memSpaceCastOp)
1144 return failure();
1145 auto newExtractStridedMetadata =
1146 rewriter.create<memref::ExtractStridedMetadataOp>(
1147 loc, memSpaceCastOp.getSource());
1148 SmallVector<Value> results(newExtractStridedMetadata.getResults());
1149 // As with most other strided metadata rewrite patterns, don't introduce
1150 // a use of the base pointer where non existed. This needs to happen here,
1151 // as opposed to in later dead-code elimination, because these patterns are
1152 // sometimes used during dialect conversion (see EmulateNarrowType, for
1153 // example), so adding spurious usages would cause a pre-legalization value
1154 // to be live that would be dead had this pattern not run.
1155 if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) {
1156 auto baseBuffer = results[0];
1157 auto baseBufferType = cast<MemRefType>(baseBuffer.getType());
1158 MemRefType::Builder newTypeBuilder(baseBufferType);
1159 newTypeBuilder.setMemorySpace(
1160 memSpaceCastOp.getResult().getType().getMemorySpace());
1161 results[0] = rewriter.create<memref::MemorySpaceCastOp>(
1162 loc, Type{newTypeBuilder}, baseBuffer);
1163 } else {
1164 results[0] = nullptr;
1165 }
1166 rewriter.replaceOp(extractStridedMetadataOp, results);
1167 return success();
1168 }
1169};
1170
1171/// Replace `base, offset =
1172/// extract_strided_metadata(extract_strided_metadata(src)#0)`
1173/// With
1174/// ```
1175/// base, ... = extract_strided_metadata(src)
1176/// offset = 0
1177/// ```
1178class ExtractStridedMetadataOpExtractStridedMetadataFolder
1179 : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
1180 using OpRewritePattern::OpRewritePattern;
1181
1182 LogicalResult
1183 matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp,
1184 PatternRewriter &rewriter) const override {
1185 auto sourceExtractStridedMetadataOp =
1186 extractStridedMetadataOp.getSource()
1187 .getDefiningOp<memref::ExtractStridedMetadataOp>();
1188 if (!sourceExtractStridedMetadataOp)
1189 return failure();
1190 Location loc = extractStridedMetadataOp.getLoc();
1191 rewriter.replaceOp(extractStridedMetadataOp,
1192 {sourceExtractStridedMetadataOp.getBaseBuffer(),
1193 getValueOrCreateConstantIndexOp(
1194 rewriter, loc, rewriter.getIndexAttr(0))});
1195 return success();
1196 }
1197};
1198} // namespace
1199
1200void memref::populateExpandStridedMetadataPatterns(
1201 RewritePatternSet &patterns) {
1202 patterns.add<SubviewFolder,
1203 ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes,
1204 getExpandedStrides>,
1205 ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize,
1206 getCollapsedStride>,
1207 ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1208 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1209 ExtractStridedMetadataOpCollapseShapeFolder,
1210 ExtractStridedMetadataOpExpandShapeFolder,
1211 ExtractStridedMetadataOpGetGlobalFolder,
1212 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1213 ExtractStridedMetadataOpReinterpretCastFolder,
1214 ExtractStridedMetadataOpSubviewFolder,
1215 ExtractStridedMetadataOpCastFolder,
1216 ExtractStridedMetadataOpMemorySpaceCastFolder,
1217 ExtractStridedMetadataOpAssumeAlignmentFolder,
1218 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1219 patterns.getContext());
1220}
1221
1222void memref::populateResolveExtractStridedMetadataPatterns(
1223 RewritePatternSet &patterns) {
1224 patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1225 ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1226 ExtractStridedMetadataOpCollapseShapeFolder,
1227 ExtractStridedMetadataOpExpandShapeFolder,
1228 ExtractStridedMetadataOpGetGlobalFolder,
1229 ExtractStridedMetadataOpSubviewFolder,
1230 RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1231 ExtractStridedMetadataOpReinterpretCastFolder,
1232 ExtractStridedMetadataOpCastFolder,
1233 ExtractStridedMetadataOpMemorySpaceCastFolder,
1234 ExtractStridedMetadataOpAssumeAlignmentFolder,
1235 ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1236 arg: patterns.getContext());
1237}
1238
1239//===----------------------------------------------------------------------===//
1240// Pass registration
1241//===----------------------------------------------------------------------===//
1242
1243namespace {
1244
1245struct ExpandStridedMetadataPass final
1246 : public memref::impl::ExpandStridedMetadataPassBase<
1247 ExpandStridedMetadataPass> {
1248 void runOnOperation() override;
1249};
1250
1251} // namespace
1252
1253void ExpandStridedMetadataPass::runOnOperation() {
1254 RewritePatternSet patterns(&getContext());
1255 memref::populateExpandStridedMetadataPatterns(patterns);
1256 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
1257}
1258

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp