1 | //===- ExpandStridedMetadata.cpp - Simplify this operation -------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | /// The pass expands memref operations that modify the metadata of a memref |
10 | /// (sizes, offset, strides) into a sequence of easier to analyze constructs. |
11 | /// In particular, this pass transforms operations into explicit sequence of |
12 | /// operations that model the effect of this operation on the different |
13 | /// metadata. This pass uses affine constructs to materialize these effects. |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
17 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
19 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
20 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
21 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
22 | #include "mlir/IR/AffineMap.h" |
23 | #include "mlir/IR/BuiltinTypes.h" |
24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
25 | #include "llvm/ADT/STLExtras.h" |
26 | #include "llvm/ADT/SmallBitVector.h" |
27 | #include <optional> |
28 | |
29 | namespace mlir { |
30 | namespace memref { |
31 | #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATAPASS |
32 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
33 | } // namespace memref |
34 | } // namespace mlir |
35 | |
36 | using namespace mlir; |
37 | using namespace mlir::affine; |
38 | |
39 | namespace { |
40 | |
41 | struct StridedMetadata { |
42 | Value basePtr; |
43 | OpFoldResult offset; |
44 | SmallVector<OpFoldResult> sizes; |
45 | SmallVector<OpFoldResult> strides; |
46 | }; |
47 | |
48 | /// From `subview(memref, subOffset, subSizes, subStrides))` compute |
49 | /// |
50 | /// \verbatim |
51 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
52 | /// extract_strided_metadata(memref) |
53 | /// strides#i = baseStrides#i * subStrides#i |
54 | /// offset = baseOffset + sum(subOffset#i * baseStrides#i) |
55 | /// sizes = subSizes |
56 | /// \endverbatim |
57 | /// |
58 | /// and return {baseBuffer, offset, sizes, strides} |
59 | static FailureOr<StridedMetadata> |
60 | resolveSubviewStridedMetadata(RewriterBase &rewriter, |
61 | memref::SubViewOp subview) { |
62 | // Build a plain extract_strided_metadata(memref) from subview(memref). |
63 | Location origLoc = subview.getLoc(); |
64 | Value source = subview.getSource(); |
65 | auto sourceType = cast<MemRefType>(source.getType()); |
66 | unsigned sourceRank = sourceType.getRank(); |
67 | |
68 | auto newExtractStridedMetadata = |
69 | rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source); |
70 | |
71 | auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset(); |
72 | #ifndef NDEBUG |
73 | auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset(); |
74 | #endif // NDEBUG |
75 | |
76 | // Compute the new strides and offset from the base strides and offset: |
77 | // newStride#i = baseStride#i * subStride#i |
78 | // offset = baseOffset + sum(subOffsets#i * newStrides#i) |
79 | SmallVector<OpFoldResult> strides; |
80 | SmallVector<OpFoldResult> subStrides = subview.getMixedStrides(); |
81 | auto origStrides = newExtractStridedMetadata.getStrides(); |
82 | |
83 | // Hold the affine symbols and values for the computation of the offset. |
84 | SmallVector<OpFoldResult> values(2 * sourceRank + 1); |
85 | SmallVector<AffineExpr> symbols(2 * sourceRank + 1); |
86 | |
87 | bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols}); |
88 | AffineExpr expr = symbols.front(); |
89 | values[0] = ShapedType::isDynamic(sourceOffset) |
90 | ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) |
91 | : rewriter.getIndexAttr(sourceOffset); |
92 | SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets(); |
93 | |
94 | AffineExpr s0 = rewriter.getAffineSymbolExpr(position: 0); |
95 | AffineExpr s1 = rewriter.getAffineSymbolExpr(position: 1); |
96 | for (unsigned i = 0; i < sourceRank; ++i) { |
97 | // Compute the stride. |
98 | OpFoldResult origStride = |
99 | ShapedType::isDynamic(sourceStrides[i]) |
100 | ? origStrides[i] |
101 | : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i])); |
102 | strides.push_back(makeComposedFoldedAffineApply( |
103 | rewriter, origLoc, s0 * s1, {subStrides[i], origStride})); |
104 | |
105 | // Build up the computation of the offset. |
106 | unsigned baseIdxForDim = 1 + 2 * i; |
107 | unsigned subOffsetForDim = baseIdxForDim; |
108 | unsigned origStrideForDim = baseIdxForDim + 1; |
109 | expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; |
110 | values[subOffsetForDim] = subOffsets[i]; |
111 | values[origStrideForDim] = origStride; |
112 | } |
113 | |
114 | // Compute the offset. |
115 | OpFoldResult finalOffset = |
116 | makeComposedFoldedAffineApply(rewriter, origLoc, expr, values); |
117 | #ifndef NDEBUG |
118 | // Assert that the computed offset matches the offset of the result type of |
119 | // the subview op (if both are static). |
120 | std::optional<int64_t> computedOffset = getConstantIntValue(ofr: finalOffset); |
121 | if (computedOffset && !ShapedType::isDynamic(resultOffset)) |
122 | assert(*computedOffset == resultOffset && |
123 | "mismatch between computed offset and result type offset"); |
124 | #endif // NDEBUG |
125 | |
126 | // The final result is <baseBuffer, offset, sizes, strides>. |
127 | // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all |
128 | // the values. |
129 | auto subType = cast<MemRefType>(subview.getType()); |
130 | unsigned subRank = subType.getRank(); |
131 | |
132 | // The sizes of the final type are defined directly by the input sizes of |
133 | // the subview. |
134 | // Moreover subviews can drop some dimensions, some strides and sizes may |
135 | // not end up in the final <base, offset, sizes, strides> value that we are |
136 | // replacing. |
137 | // Do the filtering here. |
138 | SmallVector<OpFoldResult> subSizes = subview.getMixedSizes(); |
139 | llvm::SmallBitVector droppedDims = subview.getDroppedDims(); |
140 | |
141 | SmallVector<OpFoldResult> finalSizes; |
142 | finalSizes.reserve(subRank); |
143 | |
144 | SmallVector<OpFoldResult> finalStrides; |
145 | finalStrides.reserve(subRank); |
146 | |
147 | #ifndef NDEBUG |
148 | // Iteration variable for result dimensions of the subview op. |
149 | int64_t j = 0; |
150 | #endif // NDEBUG |
151 | for (unsigned i = 0; i < sourceRank; ++i) { |
152 | if (droppedDims.test(Idx: i)) |
153 | continue; |
154 | |
155 | finalSizes.push_back(subSizes[i]); |
156 | finalStrides.push_back(strides[i]); |
157 | #ifndef NDEBUG |
158 | // Assert that the computed stride matches the stride of the result type of |
159 | // the subview op (if both are static). |
160 | std::optional<int64_t> computedStride = getConstantIntValue(strides[i]); |
161 | if (computedStride && !ShapedType::isDynamic(resultStrides[j])) |
162 | assert(*computedStride == resultStrides[j] && |
163 | "mismatch between computed stride and result type stride"); |
164 | ++j; |
165 | #endif // NDEBUG |
166 | } |
167 | assert(finalSizes.size() == subRank && |
168 | "Should have populated all the values at this point"); |
169 | return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset, |
170 | finalSizes, finalStrides}; |
171 | } |
172 | |
173 | /// Replace `dst = subview(memref, subOffset, subSizes, subStrides))` |
174 | /// With |
175 | /// |
176 | /// \verbatim |
177 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
178 | /// extract_strided_metadata(memref) |
179 | /// strides#i = baseStrides#i * subSizes#i |
180 | /// offset = baseOffset + sum(subOffset#i * baseStrides#i) |
181 | /// sizes = subSizes |
182 | /// dst = reinterpret_cast baseBuffer, offset, sizes, strides |
183 | /// \endverbatim |
184 | /// |
185 | /// In other words, get rid of the subview in that expression and canonicalize |
186 | /// on its effects on the offset, the sizes, and the strides using affine.apply. |
187 | struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> { |
188 | public: |
189 | using OpRewritePattern<memref::SubViewOp>::OpRewritePattern; |
190 | |
191 | LogicalResult matchAndRewrite(memref::SubViewOp subview, |
192 | PatternRewriter &rewriter) const override { |
193 | FailureOr<StridedMetadata> stridedMetadata = |
194 | resolveSubviewStridedMetadata(rewriter, subview); |
195 | if (failed(stridedMetadata)) { |
196 | return rewriter.notifyMatchFailure(subview, |
197 | "failed to resolve subview metadata"); |
198 | } |
199 | |
200 | rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( |
201 | subview, subview.getType(), stridedMetadata->basePtr, |
202 | stridedMetadata->offset, stridedMetadata->sizes, |
203 | stridedMetadata->strides); |
204 | return success(); |
205 | } |
206 | }; |
207 | |
208 | /// Pattern to replace `extract_strided_metadata(subview)` |
209 | /// With |
210 | /// |
211 | /// \verbatim |
212 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
213 | /// extract_strided_metadata(memref) |
214 | /// strides#i = baseStrides#i * subSizes#i |
215 | /// offset = baseOffset + sum(subOffset#i * baseStrides#i) |
216 | /// sizes = subSizes |
217 | /// \verbatim |
218 | /// |
219 | /// with `baseBuffer`, `offset`, `sizes` and `strides` being |
220 | /// the replacements for the original `extract_strided_metadata`. |
221 | struct ExtractStridedMetadataOpSubviewFolder |
222 | : OpRewritePattern<memref::ExtractStridedMetadataOp> { |
223 | using OpRewritePattern::OpRewritePattern; |
224 | |
225 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
226 | PatternRewriter &rewriter) const override { |
227 | auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>(); |
228 | if (!subviewOp) |
229 | return failure(); |
230 | |
231 | FailureOr<StridedMetadata> stridedMetadata = |
232 | resolveSubviewStridedMetadata(rewriter, subviewOp); |
233 | if (failed(stridedMetadata)) { |
234 | return rewriter.notifyMatchFailure( |
235 | op, "failed to resolve metadata in terms of source subview op"); |
236 | } |
237 | Location loc = subviewOp.getLoc(); |
238 | SmallVector<Value> results; |
239 | results.reserve(subviewOp.getType().getRank() * 2 + 2); |
240 | results.push_back(stridedMetadata->basePtr); |
241 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, |
242 | stridedMetadata->offset)); |
243 | results.append( |
244 | getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); |
245 | results.append(getValueOrCreateConstantIndexOp(rewriter, loc, |
246 | stridedMetadata->strides)); |
247 | rewriter.replaceOp(op, results); |
248 | |
249 | return success(); |
250 | } |
251 | }; |
252 | |
253 | /// Compute the expanded sizes of the given \p expandShape for the |
254 | /// \p groupId-th reassociation group. |
255 | /// \p origSizes hold the sizes of the source shape as values. |
256 | /// This is used to compute the new sizes in cases of dynamic shapes. |
257 | /// |
258 | /// sizes#i = |
259 | /// baseSizes#groupId / product(expandShapeSizes#j, |
260 | /// for j in group excluding reassIdx#i) |
261 | /// Where reassIdx#i is the reassociation index at index i in \p groupId. |
262 | /// |
263 | /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() |
264 | /// |
265 | /// TODO: Move this utility function directly within ExpandShapeOp. For now, |
266 | /// this is not possible because this function uses the Affine dialect and the |
267 | /// MemRef dialect cannot depend on the Affine dialect. |
268 | static SmallVector<OpFoldResult> |
269 | getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder, |
270 | ArrayRef<OpFoldResult> origSizes, unsigned groupId) { |
271 | SmallVector<int64_t, 2> reassocGroup = |
272 | expandShape.getReassociationIndices()[groupId]; |
273 | assert(!reassocGroup.empty() && |
274 | "Reassociation group should have at least one dimension"); |
275 | |
276 | unsigned groupSize = reassocGroup.size(); |
277 | SmallVector<OpFoldResult> expandedSizes(groupSize); |
278 | |
279 | uint64_t productOfAllStaticSizes = 1; |
280 | std::optional<unsigned> dynSizeIdx; |
281 | MemRefType expandShapeType = expandShape.getResultType(); |
282 | |
283 | // Fill up all the statically known sizes. |
284 | for (unsigned i = 0; i < groupSize; ++i) { |
285 | uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); |
286 | if (ShapedType::isDynamic(dimSize)) { |
287 | assert(!dynSizeIdx && "There must be at most one dynamic size per group"); |
288 | dynSizeIdx = i; |
289 | continue; |
290 | } |
291 | productOfAllStaticSizes *= dimSize; |
292 | expandedSizes[i] = builder.getIndexAttr(dimSize); |
293 | } |
294 | |
295 | // Compute the dynamic size using the original size and all the other known |
296 | // static sizes: |
297 | // expandSize = origSize / productOfAllStaticSizes. |
298 | if (dynSizeIdx) { |
299 | AffineExpr s0 = builder.getAffineSymbolExpr(position: 0); |
300 | expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply( |
301 | builder, expandShape.getLoc(), s0.floorDiv(v: productOfAllStaticSizes), |
302 | origSizes[groupId]); |
303 | } |
304 | |
305 | return expandedSizes; |
306 | } |
307 | |
308 | /// Compute the expanded strides of the given \p expandShape for the |
309 | /// \p groupId-th reassociation group. |
310 | /// \p origStrides and \p origSizes hold respectively the strides and sizes |
311 | /// of the source shape as values. |
312 | /// This is used to compute the strides in cases of dynamic shapes and/or |
313 | /// dynamic stride for this reassociation group. |
314 | /// |
315 | /// strides#i = |
316 | /// origStrides#reassDim * product(expandShapeSizes#j, for j in |
317 | /// reassIdx#i+1..reassIdx#i+group.size-1) |
318 | /// |
319 | /// Where reassIdx#i is the reassociation index for at index i in \p groupId |
320 | /// and expandShapeSizes#j is either: |
321 | /// - The constant size at dimension j, derived directly from the result type of |
322 | /// the expand_shape op, or |
323 | /// - An affine expression: baseSizes#reassDim / product of all constant sizes |
324 | /// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic |
325 | /// element.) |
326 | /// |
327 | /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() |
328 | /// |
329 | /// TODO: Move this utility function directly within ExpandShapeOp. For now, |
330 | /// this is not possible because this function uses the Affine dialect and the |
331 | /// MemRef dialect cannot depend on the Affine dialect. |
332 | SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape, |
333 | OpBuilder &builder, |
334 | ArrayRef<OpFoldResult> origSizes, |
335 | ArrayRef<OpFoldResult> origStrides, |
336 | unsigned groupId) { |
337 | SmallVector<int64_t, 2> reassocGroup = |
338 | expandShape.getReassociationIndices()[groupId]; |
339 | assert(!reassocGroup.empty() && |
340 | "Reassociation group should have at least one dimension"); |
341 | |
342 | unsigned groupSize = reassocGroup.size(); |
343 | MemRefType expandShapeType = expandShape.getResultType(); |
344 | |
345 | std::optional<int64_t> dynSizeIdx; |
346 | |
347 | // Fill up the expanded strides, with the information we can deduce from the |
348 | // resulting shape. |
349 | uint64_t currentStride = 1; |
350 | SmallVector<OpFoldResult> expandedStrides(groupSize); |
351 | for (int i = groupSize - 1; i >= 0; --i) { |
352 | expandedStrides[i] = builder.getIndexAttr(currentStride); |
353 | uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); |
354 | if (ShapedType::isDynamic(dimSize)) { |
355 | assert(!dynSizeIdx && "There must be at most one dynamic size per group"); |
356 | dynSizeIdx = i; |
357 | continue; |
358 | } |
359 | |
360 | currentStride *= dimSize; |
361 | } |
362 | |
363 | // Collect the statically known information about the original stride. |
364 | Value source = expandShape.getSrc(); |
365 | auto sourceType = cast<MemRefType>(source.getType()); |
366 | auto [strides, offset] = sourceType.getStridesAndOffset(); |
367 | |
368 | OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) |
369 | ? origStrides[groupId] |
370 | : builder.getIndexAttr(strides[groupId]); |
371 | |
372 | // Apply the original stride to all the strides. |
373 | int64_t doneStrideIdx = 0; |
374 | // If we saw a dynamic dimension, we need to fix-up all the strides up to |
375 | // that dimension with the dynamic size. |
376 | if (dynSizeIdx) { |
377 | int64_t productOfAllStaticSizes = currentStride; |
378 | assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) && |
379 | "We shouldn't be able to change dynamicity"); |
380 | OpFoldResult origSize = origSizes[groupId]; |
381 | |
382 | AffineExpr s0 = builder.getAffineSymbolExpr(position: 0); |
383 | AffineExpr s1 = builder.getAffineSymbolExpr(position: 1); |
384 | for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) { |
385 | int64_t baseExpandedStride = |
386 | cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx])) |
387 | .getInt(); |
388 | expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( |
389 | builder, expandShape.getLoc(), |
390 | (s0 * baseExpandedStride).floorDiv(v: productOfAllStaticSizes) * s1, |
391 | {origSize, origStride}); |
392 | } |
393 | } |
394 | |
395 | // Now apply the origStride to the remaining dimensions. |
396 | AffineExpr s0 = builder.getAffineSymbolExpr(position: 0); |
397 | for (; doneStrideIdx < groupSize; ++doneStrideIdx) { |
398 | int64_t baseExpandedStride = |
399 | cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx])) |
400 | .getInt(); |
401 | expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( |
402 | builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride}); |
403 | } |
404 | |
405 | return expandedStrides; |
406 | } |
407 | |
408 | /// Produce an OpFoldResult object with \p builder at \p loc representing |
409 | /// `prod(valueOrConstant#i, for i in {indices})`, |
410 | /// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false, |
411 | /// values[i] otherwise. |
412 | /// |
413 | /// \pre for all index in indices: index < values.size() |
414 | /// \pre for all index in indices: index < maybeConstants.size() |
415 | static OpFoldResult |
416 | getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc, |
417 | ArrayRef<int64_t> maybeConstants, |
418 | ArrayRef<OpFoldResult> values, |
419 | llvm::function_ref<bool(int64_t)> isDynamic) { |
420 | AffineExpr productOfValues = builder.getAffineConstantExpr(constant: 1); |
421 | SmallVector<OpFoldResult> inputValues; |
422 | unsigned numberOfSymbols = 0; |
423 | unsigned groupSize = indices.size(); |
424 | for (unsigned i = 0; i < groupSize; ++i) { |
425 | productOfValues = |
426 | productOfValues * builder.getAffineSymbolExpr(position: numberOfSymbols++); |
427 | unsigned srcIdx = indices[i]; |
428 | int64_t maybeConstant = maybeConstants[srcIdx]; |
429 | |
430 | inputValues.push_back(isDynamic(maybeConstant) |
431 | ? values[srcIdx] |
432 | : builder.getIndexAttr(maybeConstant)); |
433 | } |
434 | |
435 | return makeComposedFoldedAffineApply(builder, loc, productOfValues, |
436 | inputValues); |
437 | } |
438 | |
439 | /// Compute the collapsed size of the given \p collpaseShape for the |
440 | /// \p groupId-th reassociation group. |
441 | /// \p origSizes hold the sizes of the source shape as values. |
442 | /// This is used to compute the new sizes in cases of dynamic shapes. |
443 | /// |
444 | /// Conceptually this helper function computes: |
445 | /// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`. |
446 | /// |
447 | /// \post result.size() == 1, in other words, each group collapse to one |
448 | /// dimension. |
449 | /// |
450 | /// TODO: Move this utility function directly within CollapseShapeOp. For now, |
451 | /// this is not possible because this function uses the Affine dialect and the |
452 | /// MemRef dialect cannot depend on the Affine dialect. |
453 | static SmallVector<OpFoldResult> |
454 | getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder, |
455 | ArrayRef<OpFoldResult> origSizes, unsigned groupId) { |
456 | SmallVector<OpFoldResult> collapsedSize; |
457 | |
458 | MemRefType collapseShapeType = collapseShape.getResultType(); |
459 | |
460 | uint64_t size = collapseShapeType.getDimSize(groupId); |
461 | if (!ShapedType::isDynamic(size)) { |
462 | collapsedSize.push_back(builder.getIndexAttr(size)); |
463 | return collapsedSize; |
464 | } |
465 | |
466 | // We are dealing with a dynamic size. |
467 | // Build the affine expr of the product of the original sizes involved in that |
468 | // group. |
469 | Value source = collapseShape.getSrc(); |
470 | auto sourceType = cast<MemRefType>(source.getType()); |
471 | |
472 | SmallVector<int64_t, 2> reassocGroup = |
473 | collapseShape.getReassociationIndices()[groupId]; |
474 | |
475 | collapsedSize.push_back(getProductOfValues( |
476 | reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(), |
477 | origSizes, ShapedType::isDynamic)); |
478 | |
479 | return collapsedSize; |
480 | } |
481 | |
482 | /// Compute the collapsed stride of the given \p collpaseShape for the |
483 | /// \p groupId-th reassociation group. |
484 | /// \p origStrides and \p origSizes hold respectively the strides and sizes |
485 | /// of the source shape as values. |
486 | /// This is used to compute the strides in cases of dynamic shapes and/or |
487 | /// dynamic stride for this reassociation group. |
488 | /// |
489 | /// Conceptually this helper function returns the stride of the inner most |
490 | /// dimension of that group in the original shape. |
491 | /// |
492 | /// \post result.size() == 1, in other words, each group collapse to one |
493 | /// dimension. |
494 | static SmallVector<OpFoldResult> |
495 | getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, |
496 | ArrayRef<OpFoldResult> origSizes, |
497 | ArrayRef<OpFoldResult> origStrides, unsigned groupId) { |
498 | SmallVector<int64_t, 2> reassocGroup = |
499 | collapseShape.getReassociationIndices()[groupId]; |
500 | assert(!reassocGroup.empty() && |
501 | "Reassociation group should have at least one dimension"); |
502 | |
503 | Value source = collapseShape.getSrc(); |
504 | auto sourceType = cast<MemRefType>(source.getType()); |
505 | |
506 | auto [strides, offset] = sourceType.getStridesAndOffset(); |
507 | |
508 | ArrayRef<int64_t> srcShape = sourceType.getShape(); |
509 | |
510 | OpFoldResult lastValidStride = nullptr; |
511 | for (int64_t currentDim : reassocGroup) { |
512 | // Skip size-of-1 dimensions, since right now their strides may be |
513 | // meaningless. |
514 | // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless |
515 | // they are truly contiguous. When they are truly contiguous, we shouldn't |
516 | // need to skip them. |
517 | if (srcShape[currentDim] == 1) |
518 | continue; |
519 | |
520 | int64_t currentStride = strides[currentDim]; |
521 | lastValidStride = ShapedType::isDynamic(currentStride) |
522 | ? origStrides[currentDim] |
523 | : builder.getIndexAttr(currentStride); |
524 | } |
525 | if (!lastValidStride) { |
526 | // We're dealing with a 1x1x...x1 shape. The stride is meaningless, |
527 | // but we still have to make the type system happy. |
528 | MemRefType collapsedType = collapseShape.getResultType(); |
529 | auto [collapsedStrides, collapsedOffset] = |
530 | collapsedType.getStridesAndOffset(); |
531 | int64_t finalStride = collapsedStrides[groupId]; |
532 | if (ShapedType::isDynamic(finalStride)) { |
533 | // Look for a dynamic stride. At this point we don't know which one is |
534 | // desired, but they are all equally good/bad. |
535 | for (int64_t currentDim : reassocGroup) { |
536 | assert(srcShape[currentDim] == 1 && |
537 | "We should be dealing with 1x1x...x1"); |
538 | |
539 | if (ShapedType::isDynamic(strides[currentDim])) |
540 | return {origStrides[currentDim]}; |
541 | } |
542 | llvm_unreachable("We should have found a dynamic stride"); |
543 | } |
544 | return {builder.getIndexAttr(finalStride)}; |
545 | } |
546 | |
547 | return {lastValidStride}; |
548 | } |
549 | |
550 | /// From `reshape_like(memref, subSizes, subStrides))` compute |
551 | /// |
552 | /// \verbatim |
553 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
554 | /// extract_strided_metadata(memref) |
555 | /// strides#i = baseStrides#i * subStrides#i |
556 | /// sizes = subSizes |
557 | /// \endverbatim |
558 | /// |
559 | /// and return {baseBuffer, baseOffset, sizes, strides} |
560 | template <typename ReassociativeReshapeLikeOp> |
561 | static FailureOr<StridedMetadata> resolveReshapeStridedMetadata( |
562 | RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape, |
563 | function_ref<SmallVector<OpFoldResult>( |
564 | ReassociativeReshapeLikeOp, OpBuilder &, |
565 | ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)> |
566 | getReshapedSizes, |
567 | function_ref<SmallVector<OpFoldResult>( |
568 | ReassociativeReshapeLikeOp, OpBuilder &, |
569 | ArrayRef<OpFoldResult> /*origSizes*/, |
570 | ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)> |
571 | getReshapedStrides) { |
572 | // Build a plain extract_strided_metadata(memref) from |
573 | // extract_strided_metadata(reassociative_reshape_like(memref)). |
574 | Location origLoc = reshape.getLoc(); |
575 | Value source = reshape.getSrc(); |
576 | auto sourceType = cast<MemRefType>(source.getType()); |
577 | unsigned sourceRank = sourceType.getRank(); |
578 | |
579 | auto newExtractStridedMetadata = |
580 | rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source); |
581 | |
582 | // Collect statically known information. |
583 | auto [strides, offset] = sourceType.getStridesAndOffset(); |
584 | MemRefType reshapeType = reshape.getResultType(); |
585 | unsigned reshapeRank = reshapeType.getRank(); |
586 | |
587 | OpFoldResult offsetOfr = |
588 | ShapedType::isDynamic(offset) |
589 | ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) |
590 | : rewriter.getIndexAttr(offset); |
591 | |
592 | // Get the special case of 0-D out of the way. |
593 | if (sourceRank == 0) { |
594 | SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1)); |
595 | return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr, |
596 | /*sizes=*/ones, /*strides=*/ones}; |
597 | } |
598 | |
599 | SmallVector<OpFoldResult> finalSizes; |
600 | finalSizes.reserve(reshapeRank); |
601 | SmallVector<OpFoldResult> finalStrides; |
602 | finalStrides.reserve(reshapeRank); |
603 | |
604 | // Compute the reshaped strides and sizes from the base strides and sizes. |
605 | SmallVector<OpFoldResult> origSizes = |
606 | getAsOpFoldResult(newExtractStridedMetadata.getSizes()); |
607 | SmallVector<OpFoldResult> origStrides = |
608 | getAsOpFoldResult(newExtractStridedMetadata.getStrides()); |
609 | unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); |
610 | for (; idx != endIdx; ++idx) { |
611 | SmallVector<OpFoldResult> reshapedSizes = |
612 | getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); |
613 | SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides( |
614 | reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); |
615 | |
616 | unsigned groupSize = reshapedSizes.size(); |
617 | for (unsigned i = 0; i < groupSize; ++i) { |
618 | finalSizes.push_back(reshapedSizes[i]); |
619 | finalStrides.push_back(reshapedStrides[i]); |
620 | } |
621 | } |
622 | assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) || |
623 | (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) && |
624 | "We should have visited all the input dimensions"); |
625 | assert(finalSizes.size() == reshapeRank && |
626 | "We should have populated all the values"); |
627 | |
628 | return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr, |
629 | finalSizes, finalStrides}; |
630 | } |
631 | |
632 | /// Replace `baseBuffer, offset, sizes, strides = |
633 | /// extract_strided_metadata(reshapeLike(memref))` |
634 | /// With |
635 | /// |
636 | /// \verbatim |
637 | /// baseBuffer, offset, baseSizes, baseStrides = |
638 | /// extract_strided_metadata(memref) |
639 | /// sizes = getReshapedSizes(reshapeLike) |
640 | /// strides = getReshapedStrides(reshapeLike) |
641 | /// \endverbatim |
642 | /// |
643 | /// |
644 | /// Notice that `baseBuffer` and `offset` are unchanged. |
645 | /// |
646 | /// In other words, get rid of the expand_shape in that expression and |
647 | /// materialize its effects on the sizes and the strides using affine apply. |
648 | template <typename ReassociativeReshapeLikeOp, |
649 | SmallVector<OpFoldResult> (*getReshapedSizes)( |
650 | ReassociativeReshapeLikeOp, OpBuilder &, |
651 | ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/), |
652 | SmallVector<OpFoldResult> (*getReshapedStrides)( |
653 | ReassociativeReshapeLikeOp, OpBuilder &, |
654 | ArrayRef<OpFoldResult> /*origSizes*/, |
655 | ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)> |
656 | struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> { |
657 | public: |
658 | using OpRewritePattern<ReassociativeReshapeLikeOp>::OpRewritePattern; |
659 | |
660 | LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape, |
661 | PatternRewriter &rewriter) const override { |
662 | FailureOr<StridedMetadata> stridedMetadata = |
663 | resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>( |
664 | rewriter, reshape, getReshapedSizes, getReshapedStrides); |
665 | if (failed(stridedMetadata)) { |
666 | return rewriter.notifyMatchFailure(reshape, |
667 | "failed to resolve reshape metadata"); |
668 | } |
669 | |
670 | rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( |
671 | reshape, reshape.getType(), stridedMetadata->basePtr, |
672 | stridedMetadata->offset, stridedMetadata->sizes, |
673 | stridedMetadata->strides); |
674 | return success(); |
675 | } |
676 | }; |
677 | |
678 | /// Pattern to replace `extract_strided_metadata(collapse_shape)` |
679 | /// With |
680 | /// |
681 | /// \verbatim |
682 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
683 | /// extract_strided_metadata(memref) |
684 | /// strides#i = baseStrides#i * subSizes#i |
685 | /// offset = baseOffset + sum(subOffset#i * baseStrides#i) |
686 | /// sizes = subSizes |
687 | /// \verbatim |
688 | /// |
689 | /// with `baseBuffer`, `offset`, `sizes` and `strides` being |
690 | /// the replacements for the original `extract_strided_metadata`. |
691 | struct ExtractStridedMetadataOpCollapseShapeFolder |
692 | : OpRewritePattern<memref::ExtractStridedMetadataOp> { |
693 | using OpRewritePattern::OpRewritePattern; |
694 | |
695 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
696 | PatternRewriter &rewriter) const override { |
697 | auto collapseShapeOp = |
698 | op.getSource().getDefiningOp<memref::CollapseShapeOp>(); |
699 | if (!collapseShapeOp) |
700 | return failure(); |
701 | |
702 | FailureOr<StridedMetadata> stridedMetadata = |
703 | resolveReshapeStridedMetadata<memref::CollapseShapeOp>( |
704 | rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride); |
705 | if (failed(stridedMetadata)) { |
706 | return rewriter.notifyMatchFailure( |
707 | op, |
708 | "failed to resolve metadata in terms of source collapse_shape op"); |
709 | } |
710 | |
711 | Location loc = collapseShapeOp.getLoc(); |
712 | SmallVector<Value> results; |
713 | results.push_back(stridedMetadata->basePtr); |
714 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, |
715 | stridedMetadata->offset)); |
716 | results.append( |
717 | getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); |
718 | results.append(getValueOrCreateConstantIndexOp(rewriter, loc, |
719 | stridedMetadata->strides)); |
720 | rewriter.replaceOp(op, results); |
721 | return success(); |
722 | } |
723 | }; |
724 | |
725 | /// Pattern to replace `extract_strided_metadata(expand_shape)` |
726 | /// with the results of computing the sizes and strides on the expanded shape |
727 | /// and dividing up dimensions into static and dynamic parts as needed. |
728 | struct ExtractStridedMetadataOpExpandShapeFolder |
729 | : OpRewritePattern<memref::ExtractStridedMetadataOp> { |
730 | using OpRewritePattern::OpRewritePattern; |
731 | |
732 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
733 | PatternRewriter &rewriter) const override { |
734 | auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>(); |
735 | if (!expandShapeOp) |
736 | return failure(); |
737 | |
738 | FailureOr<StridedMetadata> stridedMetadata = |
739 | resolveReshapeStridedMetadata<memref::ExpandShapeOp>( |
740 | rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides); |
741 | if (failed(stridedMetadata)) { |
742 | return rewriter.notifyMatchFailure( |
743 | op, "failed to resolve metadata in terms of source expand_shape op"); |
744 | } |
745 | |
746 | Location loc = expandShapeOp.getLoc(); |
747 | SmallVector<Value> results; |
748 | results.push_back(stridedMetadata->basePtr); |
749 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, |
750 | stridedMetadata->offset)); |
751 | results.append( |
752 | getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); |
753 | results.append(getValueOrCreateConstantIndexOp(rewriter, loc, |
754 | stridedMetadata->strides)); |
755 | rewriter.replaceOp(op, results); |
756 | return success(); |
757 | } |
758 | }; |
759 | |
760 | /// Replace `base, offset, sizes, strides = |
761 | /// extract_strided_metadata(allocLikeOp)` |
762 | /// |
763 | /// With |
764 | /// |
765 | /// ``` |
766 | /// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy> |
767 | /// offset = 0 |
768 | /// sizes = allocSizes |
769 | /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) |
770 | /// ``` |
771 | /// |
772 | /// The transformation only applies if the allocLikeOp has been normalized. |
773 | /// In other words, the affine_map must be an identity. |
774 | template <typename AllocLikeOp> |
775 | struct ExtractStridedMetadataOpAllocFolder |
776 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
777 | public: |
778 | using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern; |
779 | |
780 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
781 | PatternRewriter &rewriter) const override { |
782 | auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>(); |
783 | if (!allocLikeOp) |
784 | return failure(); |
785 | |
786 | auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType()); |
787 | if (!memRefType.getLayout().isIdentity()) |
788 | return rewriter.notifyMatchFailure( |
789 | allocLikeOp, "alloc-like operations should have been normalized"); |
790 | |
791 | Location loc = op.getLoc(); |
792 | int rank = memRefType.getRank(); |
793 | |
794 | // Collect the sizes. |
795 | ValueRange dynamic = allocLikeOp.getDynamicSizes(); |
796 | SmallVector<OpFoldResult> sizes; |
797 | sizes.reserve(rank); |
798 | unsigned dynamicPos = 0; |
799 | for (int64_t size : memRefType.getShape()) { |
800 | if (ShapedType::isDynamic(size)) |
801 | sizes.push_back(dynamic[dynamicPos++]); |
802 | else |
803 | sizes.push_back(rewriter.getIndexAttr(size)); |
804 | } |
805 | |
806 | // Strides (just creates identity strides). |
807 | SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); |
808 | AffineExpr expr = rewriter.getAffineConstantExpr(constant: 1); |
809 | unsigned symbolNumber = 0; |
810 | for (int i = rank - 2; i >= 0; --i) { |
811 | expr = expr * rewriter.getAffineSymbolExpr(position: symbolNumber++); |
812 | assert(i + 1 + symbolNumber == sizes.size() && |
813 | "The ArrayRef should encompass the last #symbolNumber sizes"); |
814 | ArrayRef<OpFoldResult> sizesInvolvedInStride(&sizes[i + 1], symbolNumber); |
815 | strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr, |
816 | sizesInvolvedInStride); |
817 | } |
818 | |
819 | // Put all the values together to replace the results. |
820 | SmallVector<Value> results; |
821 | results.reserve(rank * 2 + 2); |
822 | |
823 | auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType()); |
824 | int64_t offset = 0; |
825 | if (op.getBaseBuffer().use_empty()) { |
826 | results.push_back(nullptr); |
827 | } else { |
828 | if (allocLikeOp.getType() == baseBufferType) |
829 | results.push_back(allocLikeOp); |
830 | else |
831 | results.push_back(rewriter.create<memref::ReinterpretCastOp>( |
832 | loc, baseBufferType, allocLikeOp, offset, |
833 | /*sizes=*/ArrayRef<int64_t>(), |
834 | /*strides=*/ArrayRef<int64_t>())); |
835 | } |
836 | |
837 | // Offset. |
838 | results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset)); |
839 | |
840 | for (OpFoldResult size : sizes) |
841 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size)); |
842 | |
843 | for (OpFoldResult stride : strides) |
844 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride)); |
845 | |
846 | rewriter.replaceOp(op, results); |
847 | return success(); |
848 | } |
849 | }; |
850 | |
851 | /// Replace `base, offset, sizes, strides = |
852 | /// extract_strided_metadata(get_global)` |
853 | /// |
854 | /// With |
855 | /// |
856 | /// ``` |
857 | /// base = reinterpret_cast get_global to a flat memref<eltTy> |
858 | /// offset = 0 |
859 | /// sizes = allocSizes |
860 | /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) |
861 | /// ``` |
862 | /// |
863 | /// It is expected that the memref.get_global op has static shapes |
864 | /// and identity affine_map for the layout. |
865 | struct ExtractStridedMetadataOpGetGlobalFolder |
866 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
867 | public: |
868 | using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern; |
869 | |
870 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
871 | PatternRewriter &rewriter) const override { |
872 | auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>(); |
873 | if (!getGlobalOp) |
874 | return failure(); |
875 | |
876 | auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType()); |
877 | if (!memRefType.getLayout().isIdentity()) { |
878 | return rewriter.notifyMatchFailure( |
879 | getGlobalOp, |
880 | "get-global operation result should have been normalized"); |
881 | } |
882 | |
883 | Location loc = op.getLoc(); |
884 | int rank = memRefType.getRank(); |
885 | |
886 | // Collect the sizes. |
887 | ArrayRef<int64_t> sizes = memRefType.getShape(); |
888 | assert(!llvm::any_of(sizes, ShapedType::isDynamic) && |
889 | "unexpected dynamic shape for result of `memref.get_global` op"); |
890 | |
891 | // Strides (just creates identity strides). |
892 | SmallVector<int64_t> strides = computeSuffixProduct(sizes); |
893 | |
894 | // Put all the values together to replace the results. |
895 | SmallVector<Value> results; |
896 | results.reserve(rank * 2 + 2); |
897 | |
898 | auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType()); |
899 | int64_t offset = 0; |
900 | if (getGlobalOp.getType() == baseBufferType) |
901 | results.push_back(getGlobalOp); |
902 | else |
903 | results.push_back(rewriter.create<memref::ReinterpretCastOp>( |
904 | loc, baseBufferType, getGlobalOp, offset, |
905 | /*sizes=*/ArrayRef<int64_t>(), |
906 | /*strides=*/ArrayRef<int64_t>())); |
907 | |
908 | // Offset. |
909 | results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset)); |
910 | |
911 | for (auto size : sizes) |
912 | results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size)); |
913 | |
914 | for (auto stride : strides) |
915 | results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride)); |
916 | |
917 | rewriter.replaceOp(op, results); |
918 | return success(); |
919 | } |
920 | }; |
921 | |
922 | /// Pattern to replace `extract_strided_metadata(assume_alignment)` |
923 | /// |
924 | /// With |
925 | /// \verbatim |
926 | /// extract_strided_metadata(memref) |
927 | /// \endverbatim |
928 | /// |
929 | /// Since `assume_alignment` is a view-like op that does not modify the |
930 | /// underlying buffer, offset, sizes, or strides, extracting strided metadata |
931 | /// from its result is equivalent to extracting it from its source. This |
932 | /// canonicalization removes the unnecessary indirection. |
933 | struct ExtractStridedMetadataOpAssumeAlignmentFolder |
934 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
935 | public: |
936 | using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern; |
937 | |
938 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
939 | PatternRewriter &rewriter) const override { |
940 | auto assumeAlignmentOp = |
941 | op.getSource().getDefiningOp<memref::AssumeAlignmentOp>(); |
942 | if (!assumeAlignmentOp) |
943 | return failure(); |
944 | |
945 | rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>( |
946 | op, assumeAlignmentOp.getViewSource()); |
947 | return success(); |
948 | } |
949 | }; |
950 | |
951 | /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the |
952 | /// source of the ViewLikeOp. |
953 | class RewriteExtractAlignedPointerAsIndexOfViewLikeOp |
954 | : public OpRewritePattern<memref::ExtractAlignedPointerAsIndexOp> { |
955 | using OpRewritePattern::OpRewritePattern; |
956 | |
957 | LogicalResult |
958 | matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, |
959 | PatternRewriter &rewriter) const override { |
960 | auto viewLikeOp = |
961 | extractOp.getSource().getDefiningOp<ViewLikeOpInterface>(); |
962 | if (!viewLikeOp) |
963 | return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); |
964 | rewriter.modifyOpInPlace(extractOp, [&]() { |
965 | extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); |
966 | }); |
967 | return success(); |
968 | } |
969 | }; |
970 | |
971 | /// Replace `base, offset, sizes, strides = |
972 | /// extract_strided_metadata( |
973 | /// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))` |
974 | /// With |
975 | /// ``` |
976 | /// base, ... = extract_strided_metadata(src) |
977 | /// offset = srcOffset |
978 | /// sizes = srcSizes |
979 | /// strides = srcStrides |
980 | /// ``` |
981 | /// |
982 | /// In other words, consume the `reinterpret_cast` and apply its effects |
983 | /// on the offset, sizes, and strides. |
984 | class ExtractStridedMetadataOpReinterpretCastFolder |
985 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
986 | using OpRewritePattern::OpRewritePattern; |
987 | |
988 | LogicalResult |
989 | matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, |
990 | PatternRewriter &rewriter) const override { |
991 | auto reinterpretCastOp = extractStridedMetadataOp.getSource() |
992 | .getDefiningOp<memref::ReinterpretCastOp>(); |
993 | if (!reinterpretCastOp) |
994 | return failure(); |
995 | |
996 | Location loc = extractStridedMetadataOp.getLoc(); |
997 | // Check if the source is suitable for extract_strided_metadata. |
998 | SmallVector<Type> inferredReturnTypes; |
999 | if (failed(extractStridedMetadataOp.inferReturnTypes( |
1000 | rewriter.getContext(), loc, {reinterpretCastOp.getSource()}, |
1001 | /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, |
1002 | inferredReturnTypes))) |
1003 | return rewriter.notifyMatchFailure( |
1004 | reinterpretCastOp, "reinterpret_cast source's type is incompatible"); |
1005 | |
1006 | auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType()); |
1007 | unsigned rank = memrefType.getRank(); |
1008 | SmallVector<OpFoldResult> results; |
1009 | results.resize_for_overwrite(rank * 2 + 2); |
1010 | |
1011 | auto newExtractStridedMetadata = |
1012 | rewriter.create<memref::ExtractStridedMetadataOp>( |
1013 | loc, reinterpretCastOp.getSource()); |
1014 | |
1015 | // Register the base_buffer. |
1016 | results[0] = newExtractStridedMetadata.getBaseBuffer(); |
1017 | |
1018 | // Register the new offset. |
1019 | results[1] = getValueOrCreateConstantIndexOp( |
1020 | rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]); |
1021 | |
1022 | const unsigned sizeStartIdx = 2; |
1023 | const unsigned strideStartIdx = sizeStartIdx + rank; |
1024 | |
1025 | SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes(); |
1026 | SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides(); |
1027 | for (unsigned i = 0; i < rank; ++i) { |
1028 | results[sizeStartIdx + i] = sizes[i]; |
1029 | results[strideStartIdx + i] = strides[i]; |
1030 | } |
1031 | rewriter.replaceOp(extractStridedMetadataOp, |
1032 | getValueOrCreateConstantIndexOp(rewriter, loc, results)); |
1033 | return success(); |
1034 | } |
1035 | }; |
1036 | |
1037 | /// Replace `base, offset, sizes, strides = |
1038 | /// extract_strided_metadata( |
1039 | /// cast(src) to dstTy)` |
1040 | /// With |
1041 | /// ``` |
1042 | /// base, ... = extract_strided_metadata(src) |
1043 | /// offset = !dstTy.srcOffset.isDynamic() |
1044 | /// ? dstTy.srcOffset |
1045 | /// : extract_strided_metadata(src).offset |
1046 | /// sizes = for each srcSize in dstTy.srcSizes: |
1047 | /// !srcSize.isDynamic() |
1048 | /// ? srcSize |
1049 | // : extract_strided_metadata(src).sizes[i] |
1050 | /// strides = for each srcStride in dstTy.srcStrides: |
1051 | /// !srcStrides.isDynamic() |
1052 | /// ? srcStrides |
1053 | /// : extract_strided_metadata(src).strides[i] |
1054 | /// ``` |
1055 | /// |
1056 | /// In other words, consume the `cast` and apply its effects |
1057 | /// on the offset, sizes, and strides or compute them directly from `src`. |
1058 | class ExtractStridedMetadataOpCastFolder |
1059 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
1060 | using OpRewritePattern::OpRewritePattern; |
1061 | |
1062 | LogicalResult |
1063 | matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, |
1064 | PatternRewriter &rewriter) const override { |
1065 | Value source = extractStridedMetadataOp.getSource(); |
1066 | auto castOp = source.getDefiningOp<memref::CastOp>(); |
1067 | if (!castOp) |
1068 | return failure(); |
1069 | |
1070 | Location loc = extractStridedMetadataOp.getLoc(); |
1071 | // Check if the source is suitable for extract_strided_metadata. |
1072 | SmallVector<Type> inferredReturnTypes; |
1073 | if (failed(extractStridedMetadataOp.inferReturnTypes( |
1074 | rewriter.getContext(), loc, {castOp.getSource()}, |
1075 | /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, |
1076 | inferredReturnTypes))) |
1077 | return rewriter.notifyMatchFailure(castOp, |
1078 | "cast source's type is incompatible"); |
1079 | |
1080 | auto memrefType = cast<MemRefType>(source.getType()); |
1081 | unsigned rank = memrefType.getRank(); |
1082 | SmallVector<OpFoldResult> results; |
1083 | results.resize_for_overwrite(rank * 2 + 2); |
1084 | |
1085 | auto newExtractStridedMetadata = |
1086 | rewriter.create<memref::ExtractStridedMetadataOp>(loc, |
1087 | castOp.getSource()); |
1088 | |
1089 | // Register the base_buffer. |
1090 | results[0] = newExtractStridedMetadata.getBaseBuffer(); |
1091 | |
1092 | auto getConstantOrValue = [&rewriter](int64_t constant, |
1093 | OpFoldResult ofr) -> OpFoldResult { |
1094 | return !ShapedType::isDynamic(constant) |
1095 | ? OpFoldResult(rewriter.getIndexAttr(constant)) |
1096 | : ofr; |
1097 | }; |
1098 | |
1099 | auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset(); |
1100 | assert(sourceStrides.size() == rank && "unexpected number of strides"); |
1101 | |
1102 | // Register the new offset. |
1103 | results[1] = |
1104 | getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset()); |
1105 | |
1106 | const unsigned sizeStartIdx = 2; |
1107 | const unsigned strideStartIdx = sizeStartIdx + rank; |
1108 | ArrayRef<int64_t> sourceSizes = memrefType.getShape(); |
1109 | |
1110 | SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes(); |
1111 | SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides(); |
1112 | for (unsigned i = 0; i < rank; ++i) { |
1113 | results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]); |
1114 | results[strideStartIdx + i] = |
1115 | getConstantOrValue(sourceStrides[i], strides[i]); |
1116 | } |
1117 | rewriter.replaceOp(extractStridedMetadataOp, |
1118 | getValueOrCreateConstantIndexOp(rewriter, loc, results)); |
1119 | return success(); |
1120 | } |
1121 | }; |
1122 | |
1123 | /// Replace `base, offset, sizes, strides = extract_strided_metadata( |
1124 | /// memory_space_cast(src) to dstTy)` |
1125 | /// with |
1126 | /// ``` |
1127 | /// oldBase, offset, sizes, strides = extract_strided_metadata(src) |
1128 | /// destBaseTy = type(oldBase) with memory space from destTy |
1129 | /// base = memory_space_cast(oldBase) to destBaseTy |
1130 | /// ``` |
1131 | /// |
1132 | /// In other words, propagate metadata extraction accross memory space casts. |
1133 | class ExtractStridedMetadataOpMemorySpaceCastFolder |
1134 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
1135 | using OpRewritePattern::OpRewritePattern; |
1136 | |
1137 | LogicalResult |
1138 | matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, |
1139 | PatternRewriter &rewriter) const override { |
1140 | Location loc = extractStridedMetadataOp.getLoc(); |
1141 | Value source = extractStridedMetadataOp.getSource(); |
1142 | auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>(); |
1143 | if (!memSpaceCastOp) |
1144 | return failure(); |
1145 | auto newExtractStridedMetadata = |
1146 | rewriter.create<memref::ExtractStridedMetadataOp>( |
1147 | loc, memSpaceCastOp.getSource()); |
1148 | SmallVector<Value> results(newExtractStridedMetadata.getResults()); |
1149 | // As with most other strided metadata rewrite patterns, don't introduce |
1150 | // a use of the base pointer where non existed. This needs to happen here, |
1151 | // as opposed to in later dead-code elimination, because these patterns are |
1152 | // sometimes used during dialect conversion (see EmulateNarrowType, for |
1153 | // example), so adding spurious usages would cause a pre-legalization value |
1154 | // to be live that would be dead had this pattern not run. |
1155 | if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) { |
1156 | auto baseBuffer = results[0]; |
1157 | auto baseBufferType = cast<MemRefType>(baseBuffer.getType()); |
1158 | MemRefType::Builder newTypeBuilder(baseBufferType); |
1159 | newTypeBuilder.setMemorySpace( |
1160 | memSpaceCastOp.getResult().getType().getMemorySpace()); |
1161 | results[0] = rewriter.create<memref::MemorySpaceCastOp>( |
1162 | loc, Type{newTypeBuilder}, baseBuffer); |
1163 | } else { |
1164 | results[0] = nullptr; |
1165 | } |
1166 | rewriter.replaceOp(extractStridedMetadataOp, results); |
1167 | return success(); |
1168 | } |
1169 | }; |
1170 | |
1171 | /// Replace `base, offset = |
1172 | /// extract_strided_metadata(extract_strided_metadata(src)#0)` |
1173 | /// With |
1174 | /// ``` |
1175 | /// base, ... = extract_strided_metadata(src) |
1176 | /// offset = 0 |
1177 | /// ``` |
1178 | class ExtractStridedMetadataOpExtractStridedMetadataFolder |
1179 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
1180 | using OpRewritePattern::OpRewritePattern; |
1181 | |
1182 | LogicalResult |
1183 | matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, |
1184 | PatternRewriter &rewriter) const override { |
1185 | auto sourceExtractStridedMetadataOp = |
1186 | extractStridedMetadataOp.getSource() |
1187 | .getDefiningOp<memref::ExtractStridedMetadataOp>(); |
1188 | if (!sourceExtractStridedMetadataOp) |
1189 | return failure(); |
1190 | Location loc = extractStridedMetadataOp.getLoc(); |
1191 | rewriter.replaceOp(extractStridedMetadataOp, |
1192 | {sourceExtractStridedMetadataOp.getBaseBuffer(), |
1193 | getValueOrCreateConstantIndexOp( |
1194 | rewriter, loc, rewriter.getIndexAttr(0))}); |
1195 | return success(); |
1196 | } |
1197 | }; |
1198 | } // namespace |
1199 | |
1200 | void memref::populateExpandStridedMetadataPatterns( |
1201 | RewritePatternSet &patterns) { |
1202 | patterns.add<SubviewFolder, |
1203 | ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes, |
1204 | getExpandedStrides>, |
1205 | ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize, |
1206 | getCollapsedStride>, |
1207 | ExtractStridedMetadataOpAllocFolder<memref::AllocOp>, |
1208 | ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>, |
1209 | ExtractStridedMetadataOpCollapseShapeFolder, |
1210 | ExtractStridedMetadataOpExpandShapeFolder, |
1211 | ExtractStridedMetadataOpGetGlobalFolder, |
1212 | RewriteExtractAlignedPointerAsIndexOfViewLikeOp, |
1213 | ExtractStridedMetadataOpReinterpretCastFolder, |
1214 | ExtractStridedMetadataOpSubviewFolder, |
1215 | ExtractStridedMetadataOpCastFolder, |
1216 | ExtractStridedMetadataOpMemorySpaceCastFolder, |
1217 | ExtractStridedMetadataOpAssumeAlignmentFolder, |
1218 | ExtractStridedMetadataOpExtractStridedMetadataFolder>( |
1219 | patterns.getContext()); |
1220 | } |
1221 | |
1222 | void memref::populateResolveExtractStridedMetadataPatterns( |
1223 | RewritePatternSet &patterns) { |
1224 | patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>, |
1225 | ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>, |
1226 | ExtractStridedMetadataOpCollapseShapeFolder, |
1227 | ExtractStridedMetadataOpExpandShapeFolder, |
1228 | ExtractStridedMetadataOpGetGlobalFolder, |
1229 | ExtractStridedMetadataOpSubviewFolder, |
1230 | RewriteExtractAlignedPointerAsIndexOfViewLikeOp, |
1231 | ExtractStridedMetadataOpReinterpretCastFolder, |
1232 | ExtractStridedMetadataOpCastFolder, |
1233 | ExtractStridedMetadataOpMemorySpaceCastFolder, |
1234 | ExtractStridedMetadataOpAssumeAlignmentFolder, |
1235 | ExtractStridedMetadataOpExtractStridedMetadataFolder>( |
1236 | arg: patterns.getContext()); |
1237 | } |
1238 | |
1239 | //===----------------------------------------------------------------------===// |
1240 | // Pass registration |
1241 | //===----------------------------------------------------------------------===// |
1242 | |
1243 | namespace { |
1244 | |
1245 | struct ExpandStridedMetadataPass final |
1246 | : public memref::impl::ExpandStridedMetadataPassBase< |
1247 | ExpandStridedMetadataPass> { |
1248 | void runOnOperation() override; |
1249 | }; |
1250 | |
1251 | } // namespace |
1252 | |
1253 | void ExpandStridedMetadataPass::runOnOperation() { |
1254 | RewritePatternSet patterns(&getContext()); |
1255 | memref::populateExpandStridedMetadataPatterns(patterns); |
1256 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
1257 | } |
1258 |
Definitions
- StridedMetadata
- resolveSubviewStridedMetadata
- SubviewFolder
- matchAndRewrite
- ExtractStridedMetadataOpSubviewFolder
- matchAndRewrite
- getExpandedSizes
- getExpandedStrides
- getProductOfValues
- getCollapsedSize
- getCollapsedStride
- resolveReshapeStridedMetadata
- ReshapeFolder
- matchAndRewrite
- ExtractStridedMetadataOpCollapseShapeFolder
- matchAndRewrite
- ExtractStridedMetadataOpExpandShapeFolder
- matchAndRewrite
- ExtractStridedMetadataOpAllocFolder
- matchAndRewrite
- ExtractStridedMetadataOpGetGlobalFolder
- matchAndRewrite
- ExtractStridedMetadataOpAssumeAlignmentFolder
- matchAndRewrite
- RewriteExtractAlignedPointerAsIndexOfViewLikeOp
- matchAndRewrite
- ExtractStridedMetadataOpReinterpretCastFolder
- matchAndRewrite
- ExtractStridedMetadataOpCastFolder
- matchAndRewrite
- ExtractStridedMetadataOpMemorySpaceCastFolder
- matchAndRewrite
- ExtractStridedMetadataOpExtractStridedMetadataFolder
- matchAndRewrite
- populateExpandStridedMetadataPatterns
- populateResolveExtractStridedMetadataPatterns
- ExpandStridedMetadataPass
Improve your Profiling and Debugging skills
Find out more