1 | //===- ExpandStridedMetadata.cpp - Simplify this operation -------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// The pass expands memref operations that modify the metadata of a memref |
10 | /// (sizes, offset, strides) into a sequence of easier to analyze constructs. |
11 | /// In particular, this pass transforms operations into explicit sequence of |
12 | /// operations that model the effect of this operation on the different |
13 | /// metadata. This pass uses affine constructs to materialize these effects. |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
17 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
19 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
20 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
21 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
22 | #include "mlir/IR/AffineMap.h" |
23 | #include "mlir/IR/BuiltinTypes.h" |
24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
25 | #include "llvm/ADT/STLExtras.h" |
26 | #include "llvm/ADT/SmallBitVector.h" |
27 | #include <optional> |
28 | |
29 | namespace mlir { |
30 | namespace memref { |
31 | #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA |
32 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
33 | } // namespace memref |
34 | } // namespace mlir |
35 | |
36 | using namespace mlir; |
37 | using namespace mlir::affine; |
38 | |
39 | namespace { |
40 | |
41 | struct StridedMetadata { |
42 | Value basePtr; |
43 | OpFoldResult offset; |
44 | SmallVector<OpFoldResult> sizes; |
45 | SmallVector<OpFoldResult> strides; |
46 | }; |
47 | |
48 | /// From `subview(memref, subOffset, subSizes, subStrides))` compute |
49 | /// |
50 | /// \verbatim |
51 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
52 | /// extract_strided_metadata(memref) |
53 | /// strides#i = baseStrides#i * subStrides#i |
54 | /// offset = baseOffset + sum(subOffset#i * baseStrides#i) |
55 | /// sizes = subSizes |
56 | /// \endverbatim |
57 | /// |
58 | /// and return {baseBuffer, offset, sizes, strides} |
59 | static FailureOr<StridedMetadata> |
60 | resolveSubviewStridedMetadata(RewriterBase &rewriter, |
61 | memref::SubViewOp subview) { |
62 | // Build a plain extract_strided_metadata(memref) from subview(memref). |
63 | Location origLoc = subview.getLoc(); |
64 | Value source = subview.getSource(); |
65 | auto sourceType = cast<MemRefType>(source.getType()); |
66 | unsigned sourceRank = sourceType.getRank(); |
67 | |
68 | auto = |
69 | rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source); |
70 | |
71 | auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType); |
72 | #ifndef NDEBUG |
73 | auto [resultStrides, resultOffset] = getStridesAndOffset(subview.getType()); |
74 | #endif // NDEBUG |
75 | |
76 | // Compute the new strides and offset from the base strides and offset: |
77 | // newStride#i = baseStride#i * subStride#i |
78 | // offset = baseOffset + sum(subOffsets#i * newStrides#i) |
79 | SmallVector<OpFoldResult> strides; |
80 | SmallVector<OpFoldResult> subStrides = subview.getMixedStrides(); |
81 | auto origStrides = newExtractStridedMetadata.getStrides(); |
82 | |
83 | // Hold the affine symbols and values for the computation of the offset. |
84 | SmallVector<OpFoldResult> values(2 * sourceRank + 1); |
85 | SmallVector<AffineExpr> symbols(2 * sourceRank + 1); |
86 | |
87 | bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols}); |
88 | AffineExpr expr = symbols.front(); |
89 | values[0] = ShapedType::isDynamic(sourceOffset) |
90 | ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) |
91 | : rewriter.getIndexAttr(sourceOffset); |
92 | SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets(); |
93 | |
94 | AffineExpr s0 = rewriter.getAffineSymbolExpr(position: 0); |
95 | AffineExpr s1 = rewriter.getAffineSymbolExpr(position: 1); |
96 | for (unsigned i = 0; i < sourceRank; ++i) { |
97 | // Compute the stride. |
98 | OpFoldResult origStride = |
99 | ShapedType::isDynamic(sourceStrides[i]) |
100 | ? origStrides[i] |
101 | : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i])); |
102 | strides.push_back(makeComposedFoldedAffineApply( |
103 | rewriter, origLoc, s0 * s1, {subStrides[i], origStride})); |
104 | |
105 | // Build up the computation of the offset. |
106 | unsigned baseIdxForDim = 1 + 2 * i; |
107 | unsigned subOffsetForDim = baseIdxForDim; |
108 | unsigned origStrideForDim = baseIdxForDim + 1; |
109 | expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; |
110 | values[subOffsetForDim] = subOffsets[i]; |
111 | values[origStrideForDim] = origStride; |
112 | } |
113 | |
114 | // Compute the offset. |
115 | OpFoldResult finalOffset = |
116 | makeComposedFoldedAffineApply(rewriter, origLoc, expr, values); |
117 | #ifndef NDEBUG |
118 | // Assert that the computed offset matches the offset of the result type of |
119 | // the subview op (if both are static). |
120 | std::optional<int64_t> computedOffset = getConstantIntValue(ofr: finalOffset); |
121 | if (computedOffset && !ShapedType::isDynamic(resultOffset)) |
122 | assert(*computedOffset == resultOffset && |
123 | "mismatch between computed offset and result type offset" ); |
124 | #endif // NDEBUG |
125 | |
126 | // The final result is <baseBuffer, offset, sizes, strides>. |
127 | // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all |
128 | // the values. |
129 | auto subType = cast<MemRefType>(subview.getType()); |
130 | unsigned subRank = subType.getRank(); |
131 | |
132 | // The sizes of the final type are defined directly by the input sizes of |
133 | // the subview. |
134 | // Moreover subviews can drop some dimensions, some strides and sizes may |
135 | // not end up in the final <base, offset, sizes, strides> value that we are |
136 | // replacing. |
137 | // Do the filtering here. |
138 | SmallVector<OpFoldResult> subSizes = subview.getMixedSizes(); |
139 | llvm::SmallBitVector droppedDims = subview.getDroppedDims(); |
140 | |
141 | SmallVector<OpFoldResult> finalSizes; |
142 | finalSizes.reserve(subRank); |
143 | |
144 | SmallVector<OpFoldResult> finalStrides; |
145 | finalStrides.reserve(subRank); |
146 | |
147 | #ifndef NDEBUG |
148 | // Iteration variable for result dimensions of the subview op. |
149 | int64_t j = 0; |
150 | #endif // NDEBUG |
151 | for (unsigned i = 0; i < sourceRank; ++i) { |
152 | if (droppedDims.test(Idx: i)) |
153 | continue; |
154 | |
155 | finalSizes.push_back(subSizes[i]); |
156 | finalStrides.push_back(strides[i]); |
157 | #ifndef NDEBUG |
158 | // Assert that the computed stride matches the stride of the result type of |
159 | // the subview op (if both are static). |
160 | std::optional<int64_t> computedStride = getConstantIntValue(strides[i]); |
161 | if (computedStride && !ShapedType::isDynamic(resultStrides[j])) |
162 | assert(*computedStride == resultStrides[j] && |
163 | "mismatch between computed stride and result type stride" ); |
164 | ++j; |
165 | #endif // NDEBUG |
166 | } |
167 | assert(finalSizes.size() == subRank && |
168 | "Should have populated all the values at this point" ); |
169 | return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset, |
170 | finalSizes, finalStrides}; |
171 | } |
172 | |
173 | /// Replace `dst = subview(memref, subOffset, subSizes, subStrides))` |
174 | /// With |
175 | /// |
176 | /// \verbatim |
177 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
178 | /// extract_strided_metadata(memref) |
179 | /// strides#i = baseStrides#i * subSizes#i |
180 | /// offset = baseOffset + sum(subOffset#i * baseStrides#i) |
181 | /// sizes = subSizes |
182 | /// dst = reinterpret_cast baseBuffer, offset, sizes, strides |
183 | /// \endverbatim |
184 | /// |
185 | /// In other words, get rid of the subview in that expression and canonicalize |
186 | /// on its effects on the offset, the sizes, and the strides using affine.apply. |
187 | struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> { |
188 | public: |
189 | using OpRewritePattern<memref::SubViewOp>::OpRewritePattern; |
190 | |
191 | LogicalResult matchAndRewrite(memref::SubViewOp subview, |
192 | PatternRewriter &rewriter) const override { |
193 | FailureOr<StridedMetadata> stridedMetadata = |
194 | resolveSubviewStridedMetadata(rewriter, subview); |
195 | if (failed(stridedMetadata)) { |
196 | return rewriter.notifyMatchFailure(subview, |
197 | "failed to resolve subview metadata" ); |
198 | } |
199 | |
200 | rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( |
201 | subview, subview.getType(), stridedMetadata->basePtr, |
202 | stridedMetadata->offset, stridedMetadata->sizes, |
203 | stridedMetadata->strides); |
204 | return success(); |
205 | } |
206 | }; |
207 | |
208 | /// Pattern to replace `extract_strided_metadata(subview)` |
209 | /// With |
210 | /// |
211 | /// \verbatim |
212 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
213 | /// extract_strided_metadata(memref) |
214 | /// strides#i = baseStrides#i * subSizes#i |
215 | /// offset = baseOffset + sum(subOffset#i * baseStrides#i) |
216 | /// sizes = subSizes |
217 | /// \verbatim |
218 | /// |
219 | /// with `baseBuffer`, `offset`, `sizes` and `strides` being |
220 | /// the replacements for the original `extract_strided_metadata`. |
221 | struct |
222 | : OpRewritePattern<memref::ExtractStridedMetadataOp> { |
223 | using OpRewritePattern::OpRewritePattern; |
224 | |
225 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
226 | PatternRewriter &rewriter) const override { |
227 | auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>(); |
228 | if (!subviewOp) |
229 | return failure(); |
230 | |
231 | FailureOr<StridedMetadata> stridedMetadata = |
232 | resolveSubviewStridedMetadata(rewriter, subviewOp); |
233 | if (failed(stridedMetadata)) { |
234 | return rewriter.notifyMatchFailure( |
235 | op, "failed to resolve metadata in terms of source subview op" ); |
236 | } |
237 | Location loc = subviewOp.getLoc(); |
238 | SmallVector<Value> results; |
239 | results.reserve(subviewOp.getType().getRank() * 2 + 2); |
240 | results.push_back(stridedMetadata->basePtr); |
241 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, |
242 | stridedMetadata->offset)); |
243 | results.append( |
244 | getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); |
245 | results.append(getValueOrCreateConstantIndexOp(rewriter, loc, |
246 | stridedMetadata->strides)); |
247 | rewriter.replaceOp(op, results); |
248 | |
249 | return success(); |
250 | } |
251 | }; |
252 | |
253 | /// Compute the expanded sizes of the given \p expandShape for the |
254 | /// \p groupId-th reassociation group. |
255 | /// \p origSizes hold the sizes of the source shape as values. |
256 | /// This is used to compute the new sizes in cases of dynamic shapes. |
257 | /// |
258 | /// sizes#i = |
259 | /// baseSizes#groupId / product(expandShapeSizes#j, |
260 | /// for j in group excluding reassIdx#i) |
261 | /// Where reassIdx#i is the reassociation index at index i in \p groupId. |
262 | /// |
263 | /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() |
264 | /// |
265 | /// TODO: Move this utility function directly within ExpandShapeOp. For now, |
266 | /// this is not possible because this function uses the Affine dialect and the |
267 | /// MemRef dialect cannot depend on the Affine dialect. |
268 | static SmallVector<OpFoldResult> |
269 | getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder, |
270 | ArrayRef<OpFoldResult> origSizes, unsigned groupId) { |
271 | SmallVector<int64_t, 2> reassocGroup = |
272 | expandShape.getReassociationIndices()[groupId]; |
273 | assert(!reassocGroup.empty() && |
274 | "Reassociation group should have at least one dimension" ); |
275 | |
276 | unsigned groupSize = reassocGroup.size(); |
277 | SmallVector<OpFoldResult> expandedSizes(groupSize); |
278 | |
279 | uint64_t productOfAllStaticSizes = 1; |
280 | std::optional<unsigned> dynSizeIdx; |
281 | MemRefType expandShapeType = expandShape.getResultType(); |
282 | |
283 | // Fill up all the statically known sizes. |
284 | for (unsigned i = 0; i < groupSize; ++i) { |
285 | uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); |
286 | if (ShapedType::isDynamic(dimSize)) { |
287 | assert(!dynSizeIdx && "There must be at most one dynamic size per group" ); |
288 | dynSizeIdx = i; |
289 | continue; |
290 | } |
291 | productOfAllStaticSizes *= dimSize; |
292 | expandedSizes[i] = builder.getIndexAttr(dimSize); |
293 | } |
294 | |
295 | // Compute the dynamic size using the original size and all the other known |
296 | // static sizes: |
297 | // expandSize = origSize / productOfAllStaticSizes. |
298 | if (dynSizeIdx) { |
299 | AffineExpr s0 = builder.getAffineSymbolExpr(position: 0); |
300 | expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply( |
301 | builder, expandShape.getLoc(), s0.floorDiv(v: productOfAllStaticSizes), |
302 | origSizes[groupId]); |
303 | } |
304 | |
305 | return expandedSizes; |
306 | } |
307 | |
308 | /// Compute the expanded strides of the given \p expandShape for the |
309 | /// \p groupId-th reassociation group. |
310 | /// \p origStrides and \p origSizes hold respectively the strides and sizes |
311 | /// of the source shape as values. |
312 | /// This is used to compute the strides in cases of dynamic shapes and/or |
313 | /// dynamic stride for this reassociation group. |
314 | /// |
315 | /// strides#i = |
316 | /// origStrides#reassDim * product(expandShapeSizes#j, for j in |
317 | /// reassIdx#i+1..reassIdx#i+group.size-1) |
318 | /// |
319 | /// Where reassIdx#i is the reassociation index for at index i in \p groupId |
320 | /// and expandShapeSizes#j is either: |
321 | /// - The constant size at dimension j, derived directly from the result type of |
322 | /// the expand_shape op, or |
323 | /// - An affine expression: baseSizes#reassDim / product of all constant sizes |
324 | /// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic |
325 | /// element.) |
326 | /// |
327 | /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() |
328 | /// |
329 | /// TODO: Move this utility function directly within ExpandShapeOp. For now, |
330 | /// this is not possible because this function uses the Affine dialect and the |
331 | /// MemRef dialect cannot depend on the Affine dialect. |
332 | SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape, |
333 | OpBuilder &builder, |
334 | ArrayRef<OpFoldResult> origSizes, |
335 | ArrayRef<OpFoldResult> origStrides, |
336 | unsigned groupId) { |
337 | SmallVector<int64_t, 2> reassocGroup = |
338 | expandShape.getReassociationIndices()[groupId]; |
339 | assert(!reassocGroup.empty() && |
340 | "Reassociation group should have at least one dimension" ); |
341 | |
342 | unsigned groupSize = reassocGroup.size(); |
343 | MemRefType expandShapeType = expandShape.getResultType(); |
344 | |
345 | std::optional<int64_t> dynSizeIdx; |
346 | |
347 | // Fill up the expanded strides, with the information we can deduce from the |
348 | // resulting shape. |
349 | uint64_t currentStride = 1; |
350 | SmallVector<OpFoldResult> expandedStrides(groupSize); |
351 | for (int i = groupSize - 1; i >= 0; --i) { |
352 | expandedStrides[i] = builder.getIndexAttr(currentStride); |
353 | uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); |
354 | if (ShapedType::isDynamic(dimSize)) { |
355 | assert(!dynSizeIdx && "There must be at most one dynamic size per group" ); |
356 | dynSizeIdx = i; |
357 | continue; |
358 | } |
359 | |
360 | currentStride *= dimSize; |
361 | } |
362 | |
363 | // Collect the statically known information about the original stride. |
364 | Value source = expandShape.getSrc(); |
365 | auto sourceType = cast<MemRefType>(source.getType()); |
366 | auto [strides, offset] = getStridesAndOffset(sourceType); |
367 | |
368 | OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) |
369 | ? origStrides[groupId] |
370 | : builder.getIndexAttr(strides[groupId]); |
371 | |
372 | // Apply the original stride to all the strides. |
373 | int64_t doneStrideIdx = 0; |
374 | // If we saw a dynamic dimension, we need to fix-up all the strides up to |
375 | // that dimension with the dynamic size. |
376 | if (dynSizeIdx) { |
377 | int64_t productOfAllStaticSizes = currentStride; |
378 | assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) && |
379 | "We shouldn't be able to change dynamicity" ); |
380 | OpFoldResult origSize = origSizes[groupId]; |
381 | |
382 | AffineExpr s0 = builder.getAffineSymbolExpr(position: 0); |
383 | AffineExpr s1 = builder.getAffineSymbolExpr(position: 1); |
384 | for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) { |
385 | int64_t baseExpandedStride = |
386 | cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>()) |
387 | .getInt(); |
388 | expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( |
389 | builder, expandShape.getLoc(), |
390 | (s0 * baseExpandedStride).floorDiv(v: productOfAllStaticSizes) * s1, |
391 | {origSize, origStride}); |
392 | } |
393 | } |
394 | |
395 | // Now apply the origStride to the remaining dimensions. |
396 | AffineExpr s0 = builder.getAffineSymbolExpr(position: 0); |
397 | for (; doneStrideIdx < groupSize; ++doneStrideIdx) { |
398 | int64_t baseExpandedStride = |
399 | cast<IntegerAttr>(expandedStrides[doneStrideIdx].get<Attribute>()) |
400 | .getInt(); |
401 | expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( |
402 | builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride}); |
403 | } |
404 | |
405 | return expandedStrides; |
406 | } |
407 | |
408 | /// Produce an OpFoldResult object with \p builder at \p loc representing |
409 | /// `prod(valueOrConstant#i, for i in {indices})`, |
410 | /// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false, |
411 | /// values[i] otherwise. |
412 | /// |
413 | /// \pre for all index in indices: index < values.size() |
414 | /// \pre for all index in indices: index < maybeConstants.size() |
415 | static OpFoldResult |
416 | getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc, |
417 | ArrayRef<int64_t> maybeConstants, |
418 | ArrayRef<OpFoldResult> values, |
419 | llvm::function_ref<bool(int64_t)> isDynamic) { |
420 | AffineExpr productOfValues = builder.getAffineConstantExpr(constant: 1); |
421 | SmallVector<OpFoldResult> inputValues; |
422 | unsigned numberOfSymbols = 0; |
423 | unsigned groupSize = indices.size(); |
424 | for (unsigned i = 0; i < groupSize; ++i) { |
425 | productOfValues = |
426 | productOfValues * builder.getAffineSymbolExpr(position: numberOfSymbols++); |
427 | unsigned srcIdx = indices[i]; |
428 | int64_t maybeConstant = maybeConstants[srcIdx]; |
429 | |
430 | inputValues.push_back(isDynamic(maybeConstant) |
431 | ? values[srcIdx] |
432 | : builder.getIndexAttr(maybeConstant)); |
433 | } |
434 | |
435 | return makeComposedFoldedAffineApply(builder, loc, productOfValues, |
436 | inputValues); |
437 | } |
438 | |
439 | /// Compute the collapsed size of the given \p collpaseShape for the |
440 | /// \p groupId-th reassociation group. |
441 | /// \p origSizes hold the sizes of the source shape as values. |
442 | /// This is used to compute the new sizes in cases of dynamic shapes. |
443 | /// |
444 | /// Conceptually this helper function computes: |
445 | /// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`. |
446 | /// |
447 | /// \post result.size() == 1, in other words, each group collapse to one |
448 | /// dimension. |
449 | /// |
450 | /// TODO: Move this utility function directly within CollapseShapeOp. For now, |
451 | /// this is not possible because this function uses the Affine dialect and the |
452 | /// MemRef dialect cannot depend on the Affine dialect. |
453 | static SmallVector<OpFoldResult> |
454 | getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder, |
455 | ArrayRef<OpFoldResult> origSizes, unsigned groupId) { |
456 | SmallVector<OpFoldResult> collapsedSize; |
457 | |
458 | MemRefType collapseShapeType = collapseShape.getResultType(); |
459 | |
460 | uint64_t size = collapseShapeType.getDimSize(groupId); |
461 | if (!ShapedType::isDynamic(size)) { |
462 | collapsedSize.push_back(builder.getIndexAttr(size)); |
463 | return collapsedSize; |
464 | } |
465 | |
466 | // We are dealing with a dynamic size. |
467 | // Build the affine expr of the product of the original sizes involved in that |
468 | // group. |
469 | Value source = collapseShape.getSrc(); |
470 | auto sourceType = cast<MemRefType>(source.getType()); |
471 | |
472 | SmallVector<int64_t, 2> reassocGroup = |
473 | collapseShape.getReassociationIndices()[groupId]; |
474 | |
475 | collapsedSize.push_back(getProductOfValues( |
476 | reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(), |
477 | origSizes, ShapedType::isDynamic)); |
478 | |
479 | return collapsedSize; |
480 | } |
481 | |
482 | /// Compute the collapsed stride of the given \p collpaseShape for the |
483 | /// \p groupId-th reassociation group. |
484 | /// \p origStrides and \p origSizes hold respectively the strides and sizes |
485 | /// of the source shape as values. |
486 | /// This is used to compute the strides in cases of dynamic shapes and/or |
487 | /// dynamic stride for this reassociation group. |
488 | /// |
489 | /// Conceptually this helper function returns the stride of the inner most |
490 | /// dimension of that group in the original shape. |
491 | /// |
492 | /// \post result.size() == 1, in other words, each group collapse to one |
493 | /// dimension. |
494 | static SmallVector<OpFoldResult> |
495 | getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, |
496 | ArrayRef<OpFoldResult> origSizes, |
497 | ArrayRef<OpFoldResult> origStrides, unsigned groupId) { |
498 | SmallVector<int64_t, 2> reassocGroup = |
499 | collapseShape.getReassociationIndices()[groupId]; |
500 | assert(!reassocGroup.empty() && |
501 | "Reassociation group should have at least one dimension" ); |
502 | |
503 | Value source = collapseShape.getSrc(); |
504 | auto sourceType = cast<MemRefType>(source.getType()); |
505 | |
506 | auto [strides, offset] = getStridesAndOffset(sourceType); |
507 | |
508 | SmallVector<OpFoldResult> groupStrides; |
509 | ArrayRef<int64_t> srcShape = sourceType.getShape(); |
510 | for (int64_t currentDim : reassocGroup) { |
511 | // Skip size-of-1 dimensions, since right now their strides may be |
512 | // meaningless. |
513 | // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless |
514 | // they are truly contiguous. When they are truly contiguous, we shouldn't |
515 | // need to skip them. |
516 | if (srcShape[currentDim] == 1) |
517 | continue; |
518 | |
519 | int64_t currentStride = strides[currentDim]; |
520 | groupStrides.push_back(ShapedType::isDynamic(currentStride) |
521 | ? origStrides[currentDim] |
522 | : builder.getIndexAttr(currentStride)); |
523 | } |
524 | if (groupStrides.empty()) { |
525 | // We're dealing with a 1x1x...x1 shape. The stride is meaningless, |
526 | // but we still have to make the type system happy. |
527 | MemRefType collapsedType = collapseShape.getResultType(); |
528 | auto [collapsedStrides, collapsedOffset] = |
529 | getStridesAndOffset(collapsedType); |
530 | int64_t finalStride = collapsedStrides[groupId]; |
531 | if (ShapedType::isDynamic(finalStride)) { |
532 | // Look for a dynamic stride. At this point we don't know which one is |
533 | // desired, but they are all equally good/bad. |
534 | for (int64_t currentDim : reassocGroup) { |
535 | assert(srcShape[currentDim] == 1 && |
536 | "We should be dealing with 1x1x...x1" ); |
537 | |
538 | if (ShapedType::isDynamic(strides[currentDim])) |
539 | return {origStrides[currentDim]}; |
540 | } |
541 | llvm_unreachable("We should have found a dynamic stride" ); |
542 | } |
543 | return {builder.getIndexAttr(finalStride)}; |
544 | } |
545 | |
546 | // For the general case, we just want the minimum stride |
547 | // since the collapsed dimensions are contiguous. |
548 | auto minMap = AffineMap::getMultiDimIdentityMap(numDims: groupStrides.size(), |
549 | context: builder.getContext()); |
550 | return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap, |
551 | groupStrides)}; |
552 | } |
553 | /// Replace `baseBuffer, offset, sizes, strides = |
554 | /// extract_strided_metadata(reshapeLike(memref))` |
555 | /// With |
556 | /// |
557 | /// \verbatim |
558 | /// baseBuffer, offset, baseSizes, baseStrides = |
559 | /// extract_strided_metadata(memref) |
560 | /// sizes = getReshapedSizes(reshapeLike) |
561 | /// strides = getReshapedStrides(reshapeLike) |
562 | /// \endverbatim |
563 | /// |
564 | /// |
565 | /// Notice that `baseBuffer` and `offset` are unchanged. |
566 | /// |
567 | /// In other words, get rid of the expand_shape in that expression and |
568 | /// materialize its effects on the sizes and the strides using affine apply. |
569 | template <typename ReassociativeReshapeLikeOp, |
570 | SmallVector<OpFoldResult> (*getReshapedSizes)( |
571 | ReassociativeReshapeLikeOp, OpBuilder &, |
572 | ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/), |
573 | SmallVector<OpFoldResult> (*getReshapedStrides)( |
574 | ReassociativeReshapeLikeOp, OpBuilder &, |
575 | ArrayRef<OpFoldResult> /*origSizes*/, |
576 | ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)> |
577 | struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> { |
578 | public: |
579 | using OpRewritePattern<ReassociativeReshapeLikeOp>::OpRewritePattern; |
580 | |
581 | LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape, |
582 | PatternRewriter &rewriter) const override { |
583 | // Build a plain extract_strided_metadata(memref) from |
584 | // extract_strided_metadata(reassociative_reshape_like(memref)). |
585 | Location origLoc = reshape.getLoc(); |
586 | Value source = reshape.getSrc(); |
587 | auto sourceType = cast<MemRefType>(source.getType()); |
588 | unsigned sourceRank = sourceType.getRank(); |
589 | |
590 | auto = |
591 | rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source); |
592 | |
593 | // Collect statically known information. |
594 | auto [strides, offset] = getStridesAndOffset(sourceType); |
595 | MemRefType reshapeType = reshape.getResultType(); |
596 | unsigned reshapeRank = reshapeType.getRank(); |
597 | |
598 | OpFoldResult offsetOfr = |
599 | ShapedType::isDynamic(offset) |
600 | ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) |
601 | : rewriter.getIndexAttr(offset); |
602 | |
603 | // Get the special case of 0-D out of the way. |
604 | if (sourceRank == 0) { |
605 | SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1)); |
606 | auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>( |
607 | origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(), |
608 | offsetOfr, /*sizes=*/ones, /*strides=*/ones); |
609 | rewriter.replaceOp(reshape, memrefDesc.getResult()); |
610 | return success(); |
611 | } |
612 | |
613 | SmallVector<OpFoldResult> finalSizes; |
614 | finalSizes.reserve(reshapeRank); |
615 | SmallVector<OpFoldResult> finalStrides; |
616 | finalStrides.reserve(reshapeRank); |
617 | |
618 | // Compute the reshaped strides and sizes from the base strides and sizes. |
619 | SmallVector<OpFoldResult> origSizes = |
620 | getAsOpFoldResult(newExtractStridedMetadata.getSizes()); |
621 | SmallVector<OpFoldResult> origStrides = |
622 | getAsOpFoldResult(newExtractStridedMetadata.getStrides()); |
623 | unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); |
624 | for (; idx != endIdx; ++idx) { |
625 | SmallVector<OpFoldResult> reshapedSizes = |
626 | getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); |
627 | SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides( |
628 | reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); |
629 | |
630 | unsigned groupSize = reshapedSizes.size(); |
631 | for (unsigned i = 0; i < groupSize; ++i) { |
632 | finalSizes.push_back(reshapedSizes[i]); |
633 | finalStrides.push_back(reshapedStrides[i]); |
634 | } |
635 | } |
636 | assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) || |
637 | (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) && |
638 | "We should have visited all the input dimensions" ); |
639 | assert(finalSizes.size() == reshapeRank && |
640 | "We should have populated all the values" ); |
641 | auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>( |
642 | origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(), |
643 | offsetOfr, finalSizes, finalStrides); |
644 | rewriter.replaceOp(reshape, memrefDesc.getResult()); |
645 | return success(); |
646 | } |
647 | }; |
648 | |
649 | /// Replace `base, offset, sizes, strides = |
650 | /// extract_strided_metadata(allocLikeOp)` |
651 | /// |
652 | /// With |
653 | /// |
654 | /// ``` |
655 | /// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy> |
656 | /// offset = 0 |
657 | /// sizes = allocSizes |
658 | /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) |
659 | /// ``` |
660 | /// |
661 | /// The transformation only applies if the allocLikeOp has been normalized. |
662 | /// In other words, the affine_map must be an identity. |
663 | template <typename AllocLikeOp> |
664 | struct |
665 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
666 | public: |
667 | using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern; |
668 | |
669 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
670 | PatternRewriter &rewriter) const override { |
671 | auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>(); |
672 | if (!allocLikeOp) |
673 | return failure(); |
674 | |
675 | auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType()); |
676 | if (!memRefType.getLayout().isIdentity()) |
677 | return rewriter.notifyMatchFailure( |
678 | allocLikeOp, "alloc-like operations should have been normalized" ); |
679 | |
680 | Location loc = op.getLoc(); |
681 | int rank = memRefType.getRank(); |
682 | |
683 | // Collect the sizes. |
684 | ValueRange dynamic = allocLikeOp.getDynamicSizes(); |
685 | SmallVector<OpFoldResult> sizes; |
686 | sizes.reserve(rank); |
687 | unsigned dynamicPos = 0; |
688 | for (int64_t size : memRefType.getShape()) { |
689 | if (ShapedType::isDynamic(size)) |
690 | sizes.push_back(dynamic[dynamicPos++]); |
691 | else |
692 | sizes.push_back(rewriter.getIndexAttr(size)); |
693 | } |
694 | |
695 | // Strides (just creates identity strides). |
696 | SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); |
697 | AffineExpr expr = rewriter.getAffineConstantExpr(constant: 1); |
698 | unsigned symbolNumber = 0; |
699 | for (int i = rank - 2; i >= 0; --i) { |
700 | expr = expr * rewriter.getAffineSymbolExpr(position: symbolNumber++); |
701 | assert(i + 1 + symbolNumber == sizes.size() && |
702 | "The ArrayRef should encompass the last #symbolNumber sizes" ); |
703 | ArrayRef<OpFoldResult> sizesInvolvedInStride(&sizes[i + 1], symbolNumber); |
704 | strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr, |
705 | sizesInvolvedInStride); |
706 | } |
707 | |
708 | // Put all the values together to replace the results. |
709 | SmallVector<Value> results; |
710 | results.reserve(rank * 2 + 2); |
711 | |
712 | auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType()); |
713 | int64_t offset = 0; |
714 | if (op.getBaseBuffer().use_empty()) { |
715 | results.push_back(nullptr); |
716 | } else { |
717 | if (allocLikeOp.getType() == baseBufferType) |
718 | results.push_back(allocLikeOp); |
719 | else |
720 | results.push_back(rewriter.create<memref::ReinterpretCastOp>( |
721 | loc, baseBufferType, allocLikeOp, offset, |
722 | /*sizes=*/ArrayRef<int64_t>(), |
723 | /*strides=*/ArrayRef<int64_t>())); |
724 | } |
725 | |
726 | // Offset. |
727 | results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset)); |
728 | |
729 | for (OpFoldResult size : sizes) |
730 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size)); |
731 | |
732 | for (OpFoldResult stride : strides) |
733 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride)); |
734 | |
735 | rewriter.replaceOp(op, results); |
736 | return success(); |
737 | } |
738 | }; |
739 | |
740 | /// Replace `base, offset, sizes, strides = |
741 | /// extract_strided_metadata(get_global)` |
742 | /// |
743 | /// With |
744 | /// |
745 | /// ``` |
746 | /// base = reinterpret_cast get_global to a flat memref<eltTy> |
747 | /// offset = 0 |
748 | /// sizes = allocSizes |
749 | /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) |
750 | /// ``` |
751 | /// |
752 | /// It is expected that the memref.get_global op has static shapes |
753 | /// and identity affine_map for the layout. |
754 | struct |
755 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
756 | public: |
757 | using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern; |
758 | |
759 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
760 | PatternRewriter &rewriter) const override { |
761 | auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>(); |
762 | if (!getGlobalOp) |
763 | return failure(); |
764 | |
765 | auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType()); |
766 | if (!memRefType.getLayout().isIdentity()) { |
767 | return rewriter.notifyMatchFailure( |
768 | getGlobalOp, |
769 | "get-global operation result should have been normalized" ); |
770 | } |
771 | |
772 | Location loc = op.getLoc(); |
773 | int rank = memRefType.getRank(); |
774 | |
775 | // Collect the sizes. |
776 | ArrayRef<int64_t> sizes = memRefType.getShape(); |
777 | assert(!llvm::any_of(sizes, ShapedType::isDynamic) && |
778 | "unexpected dynamic shape for result of `memref.get_global` op" ); |
779 | |
780 | // Strides (just creates identity strides). |
781 | SmallVector<int64_t> strides = computeSuffixProduct(sizes); |
782 | |
783 | // Put all the values together to replace the results. |
784 | SmallVector<Value> results; |
785 | results.reserve(rank * 2 + 2); |
786 | |
787 | auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType()); |
788 | int64_t offset = 0; |
789 | if (getGlobalOp.getType() == baseBufferType) |
790 | results.push_back(getGlobalOp); |
791 | else |
792 | results.push_back(rewriter.create<memref::ReinterpretCastOp>( |
793 | loc, baseBufferType, getGlobalOp, offset, |
794 | /*sizes=*/ArrayRef<int64_t>(), |
795 | /*strides=*/ArrayRef<int64_t>())); |
796 | |
797 | // Offset. |
798 | results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset)); |
799 | |
800 | for (auto size : sizes) |
801 | results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size)); |
802 | |
803 | for (auto stride : strides) |
804 | results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride)); |
805 | |
806 | rewriter.replaceOp(op, results); |
807 | return success(); |
808 | } |
809 | }; |
810 | |
811 | /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the |
812 | /// source of the ViewLikeOp. |
813 | class |
814 | : public OpRewritePattern<memref::ExtractAlignedPointerAsIndexOp> { |
815 | using OpRewritePattern::OpRewritePattern; |
816 | |
817 | LogicalResult |
818 | matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp , |
819 | PatternRewriter &rewriter) const override { |
820 | auto viewLikeOp = |
821 | extractOp.getSource().getDefiningOp<ViewLikeOpInterface>(); |
822 | if (!viewLikeOp) |
823 | return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source" ); |
824 | rewriter.modifyOpInPlace(extractOp, [&]() { |
825 | extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); |
826 | }); |
827 | return success(); |
828 | } |
829 | }; |
830 | |
831 | /// Replace `base, offset, sizes, strides = |
832 | /// extract_strided_metadata( |
833 | /// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))` |
834 | /// With |
835 | /// ``` |
836 | /// base, ... = extract_strided_metadata(src) |
837 | /// offset = srcOffset |
838 | /// sizes = srcSizes |
839 | /// strides = srcStrides |
840 | /// ``` |
841 | /// |
842 | /// In other words, consume the `reinterpret_cast` and apply its effects |
843 | /// on the offset, sizes, and strides. |
844 | class |
845 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
846 | using OpRewritePattern::OpRewritePattern; |
847 | |
848 | LogicalResult |
849 | matchAndRewrite(memref::ExtractStridedMetadataOp , |
850 | PatternRewriter &rewriter) const override { |
851 | auto reinterpretCastOp = extractStridedMetadataOp.getSource() |
852 | .getDefiningOp<memref::ReinterpretCastOp>(); |
853 | if (!reinterpretCastOp) |
854 | return failure(); |
855 | |
856 | Location loc = extractStridedMetadataOp.getLoc(); |
857 | // Check if the source is suitable for extract_strided_metadata. |
858 | SmallVector<Type> inferredReturnTypes; |
859 | if (failed(extractStridedMetadataOp.inferReturnTypes( |
860 | rewriter.getContext(), loc, {reinterpretCastOp.getSource()}, |
861 | /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, |
862 | inferredReturnTypes))) |
863 | return rewriter.notifyMatchFailure( |
864 | reinterpretCastOp, "reinterpret_cast source's type is incompatible" ); |
865 | |
866 | auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType()); |
867 | unsigned rank = memrefType.getRank(); |
868 | SmallVector<OpFoldResult> results; |
869 | results.resize_for_overwrite(rank * 2 + 2); |
870 | |
871 | auto = |
872 | rewriter.create<memref::ExtractStridedMetadataOp>( |
873 | loc, reinterpretCastOp.getSource()); |
874 | |
875 | // Register the base_buffer. |
876 | results[0] = newExtractStridedMetadata.getBaseBuffer(); |
877 | |
878 | // Register the new offset. |
879 | results[1] = getValueOrCreateConstantIndexOp( |
880 | rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]); |
881 | |
882 | const unsigned sizeStartIdx = 2; |
883 | const unsigned strideStartIdx = sizeStartIdx + rank; |
884 | |
885 | SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes(); |
886 | SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides(); |
887 | for (unsigned i = 0; i < rank; ++i) { |
888 | results[sizeStartIdx + i] = sizes[i]; |
889 | results[strideStartIdx + i] = strides[i]; |
890 | } |
891 | rewriter.replaceOp(extractStridedMetadataOp, |
892 | getValueOrCreateConstantIndexOp(rewriter, loc, results)); |
893 | return success(); |
894 | } |
895 | }; |
896 | |
897 | /// Replace `base, offset, sizes, strides = |
898 | /// extract_strided_metadata( |
899 | /// cast(src) to dstTy)` |
900 | /// With |
901 | /// ``` |
902 | /// base, ... = extract_strided_metadata(src) |
903 | /// offset = !dstTy.srcOffset.isDynamic() |
904 | /// ? dstTy.srcOffset |
905 | /// : extract_strided_metadata(src).offset |
906 | /// sizes = for each srcSize in dstTy.srcSizes: |
907 | /// !srcSize.isDynamic() |
908 | /// ? srcSize |
909 | // : extract_strided_metadata(src).sizes[i] |
910 | /// strides = for each srcStride in dstTy.srcStrides: |
911 | /// !srcStrides.isDynamic() |
912 | /// ? srcStrides |
913 | /// : extract_strided_metadata(src).strides[i] |
914 | /// ``` |
915 | /// |
916 | /// In other words, consume the `cast` and apply its effects |
917 | /// on the offset, sizes, and strides or compute them directly from `src`. |
918 | class |
919 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
920 | using OpRewritePattern::OpRewritePattern; |
921 | |
922 | LogicalResult |
923 | matchAndRewrite(memref::ExtractStridedMetadataOp , |
924 | PatternRewriter &rewriter) const override { |
925 | Value source = extractStridedMetadataOp.getSource(); |
926 | auto castOp = source.getDefiningOp<memref::CastOp>(); |
927 | if (!castOp) |
928 | return failure(); |
929 | |
930 | Location loc = extractStridedMetadataOp.getLoc(); |
931 | // Check if the source is suitable for extract_strided_metadata. |
932 | SmallVector<Type> inferredReturnTypes; |
933 | if (failed(extractStridedMetadataOp.inferReturnTypes( |
934 | rewriter.getContext(), loc, {castOp.getSource()}, |
935 | /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, |
936 | inferredReturnTypes))) |
937 | return rewriter.notifyMatchFailure(castOp, |
938 | "cast source's type is incompatible" ); |
939 | |
940 | auto memrefType = cast<MemRefType>(source.getType()); |
941 | unsigned rank = memrefType.getRank(); |
942 | SmallVector<OpFoldResult> results; |
943 | results.resize_for_overwrite(rank * 2 + 2); |
944 | |
945 | auto = |
946 | rewriter.create<memref::ExtractStridedMetadataOp>(loc, |
947 | castOp.getSource()); |
948 | |
949 | // Register the base_buffer. |
950 | results[0] = newExtractStridedMetadata.getBaseBuffer(); |
951 | |
952 | auto getConstantOrValue = [&rewriter](int64_t constant, |
953 | OpFoldResult ofr) -> OpFoldResult { |
954 | return !ShapedType::isDynamic(constant) |
955 | ? OpFoldResult(rewriter.getIndexAttr(constant)) |
956 | : ofr; |
957 | }; |
958 | |
959 | auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType); |
960 | assert(sourceStrides.size() == rank && "unexpected number of strides" ); |
961 | |
962 | // Register the new offset. |
963 | results[1] = |
964 | getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset()); |
965 | |
966 | const unsigned sizeStartIdx = 2; |
967 | const unsigned strideStartIdx = sizeStartIdx + rank; |
968 | ArrayRef<int64_t> sourceSizes = memrefType.getShape(); |
969 | |
970 | SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes(); |
971 | SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides(); |
972 | for (unsigned i = 0; i < rank; ++i) { |
973 | results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]); |
974 | results[strideStartIdx + i] = |
975 | getConstantOrValue(sourceStrides[i], strides[i]); |
976 | } |
977 | rewriter.replaceOp(extractStridedMetadataOp, |
978 | getValueOrCreateConstantIndexOp(rewriter, loc, results)); |
979 | return success(); |
980 | } |
981 | }; |
982 | |
983 | /// Replace `base, offset = |
984 | /// extract_strided_metadata(extract_strided_metadata(src)#0)` |
985 | /// With |
986 | /// ``` |
987 | /// base, ... = extract_strided_metadata(src) |
988 | /// offset = 0 |
989 | /// ``` |
990 | class |
991 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
992 | using OpRewritePattern::OpRewritePattern; |
993 | |
994 | LogicalResult |
995 | matchAndRewrite(memref::ExtractStridedMetadataOp , |
996 | PatternRewriter &rewriter) const override { |
997 | auto = |
998 | extractStridedMetadataOp.getSource() |
999 | .getDefiningOp<memref::ExtractStridedMetadataOp>(); |
1000 | if (!sourceExtractStridedMetadataOp) |
1001 | return failure(); |
1002 | Location loc = extractStridedMetadataOp.getLoc(); |
1003 | rewriter.replaceOp(extractStridedMetadataOp, |
1004 | {sourceExtractStridedMetadataOp.getBaseBuffer(), |
1005 | getValueOrCreateConstantIndexOp( |
1006 | rewriter, loc, rewriter.getIndexAttr(0))}); |
1007 | return success(); |
1008 | } |
1009 | }; |
1010 | } // namespace |
1011 | |
1012 | void memref::populateExpandStridedMetadataPatterns( |
1013 | RewritePatternSet &patterns) { |
1014 | patterns.add<SubviewFolder, |
1015 | ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes, |
1016 | getExpandedStrides>, |
1017 | ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize, |
1018 | getCollapsedStride>, |
1019 | ExtractStridedMetadataOpAllocFolder<memref::AllocOp>, |
1020 | ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>, |
1021 | ExtractStridedMetadataOpGetGlobalFolder, |
1022 | RewriteExtractAlignedPointerAsIndexOfViewLikeOp, |
1023 | ExtractStridedMetadataOpReinterpretCastFolder, |
1024 | ExtractStridedMetadataOpCastFolder, |
1025 | ExtractStridedMetadataOpExtractStridedMetadataFolder>( |
1026 | patterns.getContext()); |
1027 | } |
1028 | |
1029 | void memref::( |
1030 | RewritePatternSet &patterns) { |
1031 | patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>, |
1032 | ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>, |
1033 | ExtractStridedMetadataOpGetGlobalFolder, |
1034 | ExtractStridedMetadataOpSubviewFolder, |
1035 | RewriteExtractAlignedPointerAsIndexOfViewLikeOp, |
1036 | ExtractStridedMetadataOpReinterpretCastFolder, |
1037 | ExtractStridedMetadataOpCastFolder, |
1038 | ExtractStridedMetadataOpExtractStridedMetadataFolder>( |
1039 | arg: patterns.getContext()); |
1040 | } |
1041 | |
1042 | //===----------------------------------------------------------------------===// |
1043 | // Pass registration |
1044 | //===----------------------------------------------------------------------===// |
1045 | |
1046 | namespace { |
1047 | |
1048 | struct ExpandStridedMetadataPass final |
1049 | : public memref::impl::ExpandStridedMetadataBase< |
1050 | ExpandStridedMetadataPass> { |
1051 | void runOnOperation() override; |
1052 | }; |
1053 | |
1054 | } // namespace |
1055 | |
1056 | void ExpandStridedMetadataPass::runOnOperation() { |
1057 | RewritePatternSet patterns(&getContext()); |
1058 | memref::populateExpandStridedMetadataPatterns(patterns); |
1059 | (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); |
1060 | } |
1061 | |
1062 | std::unique_ptr<Pass> memref::createExpandStridedMetadataPass() { |
1063 | return std::make_unique<ExpandStridedMetadataPass>(); |
1064 | } |
1065 | |