| 1 | //===- ExpandStridedMetadata.cpp - Simplify this operation -------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | /// The pass expands memref operations that modify the metadata of a memref |
| 10 | /// (sizes, offset, strides) into a sequence of easier to analyze constructs. |
| 11 | /// In particular, this pass transforms operations into explicit sequence of |
| 12 | /// operations that model the effect of this operation on the different |
| 13 | /// metadata. This pass uses affine constructs to materialize these effects. |
| 14 | //===----------------------------------------------------------------------===// |
| 15 | |
| 16 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 17 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 19 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
| 20 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
| 21 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 22 | #include "mlir/IR/AffineMap.h" |
| 23 | #include "mlir/IR/BuiltinTypes.h" |
| 24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 25 | #include "llvm/ADT/STLExtras.h" |
| 26 | #include "llvm/ADT/SmallBitVector.h" |
| 27 | #include <optional> |
| 28 | |
| 29 | namespace mlir { |
| 30 | namespace memref { |
| 31 | #define GEN_PASS_DEF_EXPANDSTRIDEDMETADATAPASS |
| 32 | #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" |
| 33 | } // namespace memref |
| 34 | } // namespace mlir |
| 35 | |
| 36 | using namespace mlir; |
| 37 | using namespace mlir::affine; |
| 38 | |
| 39 | namespace { |
| 40 | |
| 41 | struct StridedMetadata { |
| 42 | Value basePtr; |
| 43 | OpFoldResult offset; |
| 44 | SmallVector<OpFoldResult> sizes; |
| 45 | SmallVector<OpFoldResult> strides; |
| 46 | }; |
| 47 | |
| 48 | /// From `subview(memref, subOffset, subSizes, subStrides))` compute |
| 49 | /// |
| 50 | /// \verbatim |
| 51 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
| 52 | /// extract_strided_metadata(memref) |
| 53 | /// strides#i = baseStrides#i * subStrides#i |
| 54 | /// offset = baseOffset + sum(subOffset#i * baseStrides#i) |
| 55 | /// sizes = subSizes |
| 56 | /// \endverbatim |
| 57 | /// |
| 58 | /// and return {baseBuffer, offset, sizes, strides} |
| 59 | static FailureOr<StridedMetadata> |
| 60 | resolveSubviewStridedMetadata(RewriterBase &rewriter, |
| 61 | memref::SubViewOp subview) { |
| 62 | // Build a plain extract_strided_metadata(memref) from subview(memref). |
| 63 | Location origLoc = subview.getLoc(); |
| 64 | Value source = subview.getSource(); |
| 65 | auto sourceType = cast<MemRefType>(source.getType()); |
| 66 | unsigned sourceRank = sourceType.getRank(); |
| 67 | |
| 68 | auto = |
| 69 | rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source); |
| 70 | |
| 71 | auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset(); |
| 72 | #ifndef NDEBUG |
| 73 | auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset(); |
| 74 | #endif // NDEBUG |
| 75 | |
| 76 | // Compute the new strides and offset from the base strides and offset: |
| 77 | // newStride#i = baseStride#i * subStride#i |
| 78 | // offset = baseOffset + sum(subOffsets#i * newStrides#i) |
| 79 | SmallVector<OpFoldResult> strides; |
| 80 | SmallVector<OpFoldResult> subStrides = subview.getMixedStrides(); |
| 81 | auto origStrides = newExtractStridedMetadata.getStrides(); |
| 82 | |
| 83 | // Hold the affine symbols and values for the computation of the offset. |
| 84 | SmallVector<OpFoldResult> values(2 * sourceRank + 1); |
| 85 | SmallVector<AffineExpr> symbols(2 * sourceRank + 1); |
| 86 | |
| 87 | bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols}); |
| 88 | AffineExpr expr = symbols.front(); |
| 89 | values[0] = ShapedType::isDynamic(sourceOffset) |
| 90 | ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) |
| 91 | : rewriter.getIndexAttr(sourceOffset); |
| 92 | SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets(); |
| 93 | |
| 94 | AffineExpr s0 = rewriter.getAffineSymbolExpr(position: 0); |
| 95 | AffineExpr s1 = rewriter.getAffineSymbolExpr(position: 1); |
| 96 | for (unsigned i = 0; i < sourceRank; ++i) { |
| 97 | // Compute the stride. |
| 98 | OpFoldResult origStride = |
| 99 | ShapedType::isDynamic(sourceStrides[i]) |
| 100 | ? origStrides[i] |
| 101 | : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i])); |
| 102 | strides.push_back(makeComposedFoldedAffineApply( |
| 103 | rewriter, origLoc, s0 * s1, {subStrides[i], origStride})); |
| 104 | |
| 105 | // Build up the computation of the offset. |
| 106 | unsigned baseIdxForDim = 1 + 2 * i; |
| 107 | unsigned subOffsetForDim = baseIdxForDim; |
| 108 | unsigned origStrideForDim = baseIdxForDim + 1; |
| 109 | expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; |
| 110 | values[subOffsetForDim] = subOffsets[i]; |
| 111 | values[origStrideForDim] = origStride; |
| 112 | } |
| 113 | |
| 114 | // Compute the offset. |
| 115 | OpFoldResult finalOffset = |
| 116 | makeComposedFoldedAffineApply(rewriter, origLoc, expr, values); |
| 117 | #ifndef NDEBUG |
| 118 | // Assert that the computed offset matches the offset of the result type of |
| 119 | // the subview op (if both are static). |
| 120 | std::optional<int64_t> computedOffset = getConstantIntValue(ofr: finalOffset); |
| 121 | if (computedOffset && !ShapedType::isDynamic(resultOffset)) |
| 122 | assert(*computedOffset == resultOffset && |
| 123 | "mismatch between computed offset and result type offset" ); |
| 124 | #endif // NDEBUG |
| 125 | |
| 126 | // The final result is <baseBuffer, offset, sizes, strides>. |
| 127 | // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all |
| 128 | // the values. |
| 129 | auto subType = cast<MemRefType>(subview.getType()); |
| 130 | unsigned subRank = subType.getRank(); |
| 131 | |
| 132 | // The sizes of the final type are defined directly by the input sizes of |
| 133 | // the subview. |
| 134 | // Moreover subviews can drop some dimensions, some strides and sizes may |
| 135 | // not end up in the final <base, offset, sizes, strides> value that we are |
| 136 | // replacing. |
| 137 | // Do the filtering here. |
| 138 | SmallVector<OpFoldResult> subSizes = subview.getMixedSizes(); |
| 139 | llvm::SmallBitVector droppedDims = subview.getDroppedDims(); |
| 140 | |
| 141 | SmallVector<OpFoldResult> finalSizes; |
| 142 | finalSizes.reserve(subRank); |
| 143 | |
| 144 | SmallVector<OpFoldResult> finalStrides; |
| 145 | finalStrides.reserve(subRank); |
| 146 | |
| 147 | #ifndef NDEBUG |
| 148 | // Iteration variable for result dimensions of the subview op. |
| 149 | int64_t j = 0; |
| 150 | #endif // NDEBUG |
| 151 | for (unsigned i = 0; i < sourceRank; ++i) { |
| 152 | if (droppedDims.test(Idx: i)) |
| 153 | continue; |
| 154 | |
| 155 | finalSizes.push_back(subSizes[i]); |
| 156 | finalStrides.push_back(strides[i]); |
| 157 | #ifndef NDEBUG |
| 158 | // Assert that the computed stride matches the stride of the result type of |
| 159 | // the subview op (if both are static). |
| 160 | std::optional<int64_t> computedStride = getConstantIntValue(strides[i]); |
| 161 | if (computedStride && !ShapedType::isDynamic(resultStrides[j])) |
| 162 | assert(*computedStride == resultStrides[j] && |
| 163 | "mismatch between computed stride and result type stride" ); |
| 164 | ++j; |
| 165 | #endif // NDEBUG |
| 166 | } |
| 167 | assert(finalSizes.size() == subRank && |
| 168 | "Should have populated all the values at this point" ); |
| 169 | return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset, |
| 170 | finalSizes, finalStrides}; |
| 171 | } |
| 172 | |
| 173 | /// Replace `dst = subview(memref, subOffset, subSizes, subStrides))` |
| 174 | /// With |
| 175 | /// |
| 176 | /// \verbatim |
| 177 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
| 178 | /// extract_strided_metadata(memref) |
| 179 | /// strides#i = baseStrides#i * subSizes#i |
| 180 | /// offset = baseOffset + sum(subOffset#i * baseStrides#i) |
| 181 | /// sizes = subSizes |
| 182 | /// dst = reinterpret_cast baseBuffer, offset, sizes, strides |
| 183 | /// \endverbatim |
| 184 | /// |
| 185 | /// In other words, get rid of the subview in that expression and canonicalize |
| 186 | /// on its effects on the offset, the sizes, and the strides using affine.apply. |
| 187 | struct SubviewFolder : public OpRewritePattern<memref::SubViewOp> { |
| 188 | public: |
| 189 | using OpRewritePattern<memref::SubViewOp>::OpRewritePattern; |
| 190 | |
| 191 | LogicalResult matchAndRewrite(memref::SubViewOp subview, |
| 192 | PatternRewriter &rewriter) const override { |
| 193 | FailureOr<StridedMetadata> stridedMetadata = |
| 194 | resolveSubviewStridedMetadata(rewriter, subview); |
| 195 | if (failed(stridedMetadata)) { |
| 196 | return rewriter.notifyMatchFailure(subview, |
| 197 | "failed to resolve subview metadata" ); |
| 198 | } |
| 199 | |
| 200 | rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( |
| 201 | subview, subview.getType(), stridedMetadata->basePtr, |
| 202 | stridedMetadata->offset, stridedMetadata->sizes, |
| 203 | stridedMetadata->strides); |
| 204 | return success(); |
| 205 | } |
| 206 | }; |
| 207 | |
| 208 | /// Pattern to replace `extract_strided_metadata(subview)` |
| 209 | /// With |
| 210 | /// |
| 211 | /// \verbatim |
| 212 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
| 213 | /// extract_strided_metadata(memref) |
| 214 | /// strides#i = baseStrides#i * subSizes#i |
| 215 | /// offset = baseOffset + sum(subOffset#i * baseStrides#i) |
| 216 | /// sizes = subSizes |
| 217 | /// \verbatim |
| 218 | /// |
| 219 | /// with `baseBuffer`, `offset`, `sizes` and `strides` being |
| 220 | /// the replacements for the original `extract_strided_metadata`. |
| 221 | struct |
| 222 | : OpRewritePattern<memref::ExtractStridedMetadataOp> { |
| 223 | using OpRewritePattern::OpRewritePattern; |
| 224 | |
| 225 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
| 226 | PatternRewriter &rewriter) const override { |
| 227 | auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>(); |
| 228 | if (!subviewOp) |
| 229 | return failure(); |
| 230 | |
| 231 | FailureOr<StridedMetadata> stridedMetadata = |
| 232 | resolveSubviewStridedMetadata(rewriter, subviewOp); |
| 233 | if (failed(stridedMetadata)) { |
| 234 | return rewriter.notifyMatchFailure( |
| 235 | op, "failed to resolve metadata in terms of source subview op" ); |
| 236 | } |
| 237 | Location loc = subviewOp.getLoc(); |
| 238 | SmallVector<Value> results; |
| 239 | results.reserve(subviewOp.getType().getRank() * 2 + 2); |
| 240 | results.push_back(stridedMetadata->basePtr); |
| 241 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, |
| 242 | stridedMetadata->offset)); |
| 243 | results.append( |
| 244 | getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); |
| 245 | results.append(getValueOrCreateConstantIndexOp(rewriter, loc, |
| 246 | stridedMetadata->strides)); |
| 247 | rewriter.replaceOp(op, results); |
| 248 | |
| 249 | return success(); |
| 250 | } |
| 251 | }; |
| 252 | |
| 253 | /// Compute the expanded sizes of the given \p expandShape for the |
| 254 | /// \p groupId-th reassociation group. |
| 255 | /// \p origSizes hold the sizes of the source shape as values. |
| 256 | /// This is used to compute the new sizes in cases of dynamic shapes. |
| 257 | /// |
| 258 | /// sizes#i = |
| 259 | /// baseSizes#groupId / product(expandShapeSizes#j, |
| 260 | /// for j in group excluding reassIdx#i) |
| 261 | /// Where reassIdx#i is the reassociation index at index i in \p groupId. |
| 262 | /// |
| 263 | /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() |
| 264 | /// |
| 265 | /// TODO: Move this utility function directly within ExpandShapeOp. For now, |
| 266 | /// this is not possible because this function uses the Affine dialect and the |
| 267 | /// MemRef dialect cannot depend on the Affine dialect. |
| 268 | static SmallVector<OpFoldResult> |
| 269 | getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder, |
| 270 | ArrayRef<OpFoldResult> origSizes, unsigned groupId) { |
| 271 | SmallVector<int64_t, 2> reassocGroup = |
| 272 | expandShape.getReassociationIndices()[groupId]; |
| 273 | assert(!reassocGroup.empty() && |
| 274 | "Reassociation group should have at least one dimension" ); |
| 275 | |
| 276 | unsigned groupSize = reassocGroup.size(); |
| 277 | SmallVector<OpFoldResult> expandedSizes(groupSize); |
| 278 | |
| 279 | uint64_t productOfAllStaticSizes = 1; |
| 280 | std::optional<unsigned> dynSizeIdx; |
| 281 | MemRefType expandShapeType = expandShape.getResultType(); |
| 282 | |
| 283 | // Fill up all the statically known sizes. |
| 284 | for (unsigned i = 0; i < groupSize; ++i) { |
| 285 | uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); |
| 286 | if (ShapedType::isDynamic(dimSize)) { |
| 287 | assert(!dynSizeIdx && "There must be at most one dynamic size per group" ); |
| 288 | dynSizeIdx = i; |
| 289 | continue; |
| 290 | } |
| 291 | productOfAllStaticSizes *= dimSize; |
| 292 | expandedSizes[i] = builder.getIndexAttr(dimSize); |
| 293 | } |
| 294 | |
| 295 | // Compute the dynamic size using the original size and all the other known |
| 296 | // static sizes: |
| 297 | // expandSize = origSize / productOfAllStaticSizes. |
| 298 | if (dynSizeIdx) { |
| 299 | AffineExpr s0 = builder.getAffineSymbolExpr(position: 0); |
| 300 | expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply( |
| 301 | builder, expandShape.getLoc(), s0.floorDiv(v: productOfAllStaticSizes), |
| 302 | origSizes[groupId]); |
| 303 | } |
| 304 | |
| 305 | return expandedSizes; |
| 306 | } |
| 307 | |
| 308 | /// Compute the expanded strides of the given \p expandShape for the |
| 309 | /// \p groupId-th reassociation group. |
| 310 | /// \p origStrides and \p origSizes hold respectively the strides and sizes |
| 311 | /// of the source shape as values. |
| 312 | /// This is used to compute the strides in cases of dynamic shapes and/or |
| 313 | /// dynamic stride for this reassociation group. |
| 314 | /// |
| 315 | /// strides#i = |
| 316 | /// origStrides#reassDim * product(expandShapeSizes#j, for j in |
| 317 | /// reassIdx#i+1..reassIdx#i+group.size-1) |
| 318 | /// |
| 319 | /// Where reassIdx#i is the reassociation index for at index i in \p groupId |
| 320 | /// and expandShapeSizes#j is either: |
| 321 | /// - The constant size at dimension j, derived directly from the result type of |
| 322 | /// the expand_shape op, or |
| 323 | /// - An affine expression: baseSizes#reassDim / product of all constant sizes |
| 324 | /// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic |
| 325 | /// element.) |
| 326 | /// |
| 327 | /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() |
| 328 | /// |
| 329 | /// TODO: Move this utility function directly within ExpandShapeOp. For now, |
| 330 | /// this is not possible because this function uses the Affine dialect and the |
| 331 | /// MemRef dialect cannot depend on the Affine dialect. |
| 332 | SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape, |
| 333 | OpBuilder &builder, |
| 334 | ArrayRef<OpFoldResult> origSizes, |
| 335 | ArrayRef<OpFoldResult> origStrides, |
| 336 | unsigned groupId) { |
| 337 | SmallVector<int64_t, 2> reassocGroup = |
| 338 | expandShape.getReassociationIndices()[groupId]; |
| 339 | assert(!reassocGroup.empty() && |
| 340 | "Reassociation group should have at least one dimension" ); |
| 341 | |
| 342 | unsigned groupSize = reassocGroup.size(); |
| 343 | MemRefType expandShapeType = expandShape.getResultType(); |
| 344 | |
| 345 | std::optional<int64_t> dynSizeIdx; |
| 346 | |
| 347 | // Fill up the expanded strides, with the information we can deduce from the |
| 348 | // resulting shape. |
| 349 | uint64_t currentStride = 1; |
| 350 | SmallVector<OpFoldResult> expandedStrides(groupSize); |
| 351 | for (int i = groupSize - 1; i >= 0; --i) { |
| 352 | expandedStrides[i] = builder.getIndexAttr(currentStride); |
| 353 | uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); |
| 354 | if (ShapedType::isDynamic(dimSize)) { |
| 355 | assert(!dynSizeIdx && "There must be at most one dynamic size per group" ); |
| 356 | dynSizeIdx = i; |
| 357 | continue; |
| 358 | } |
| 359 | |
| 360 | currentStride *= dimSize; |
| 361 | } |
| 362 | |
| 363 | // Collect the statically known information about the original stride. |
| 364 | Value source = expandShape.getSrc(); |
| 365 | auto sourceType = cast<MemRefType>(source.getType()); |
| 366 | auto [strides, offset] = sourceType.getStridesAndOffset(); |
| 367 | |
| 368 | OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) |
| 369 | ? origStrides[groupId] |
| 370 | : builder.getIndexAttr(strides[groupId]); |
| 371 | |
| 372 | // Apply the original stride to all the strides. |
| 373 | int64_t doneStrideIdx = 0; |
| 374 | // If we saw a dynamic dimension, we need to fix-up all the strides up to |
| 375 | // that dimension with the dynamic size. |
| 376 | if (dynSizeIdx) { |
| 377 | int64_t productOfAllStaticSizes = currentStride; |
| 378 | assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) && |
| 379 | "We shouldn't be able to change dynamicity" ); |
| 380 | OpFoldResult origSize = origSizes[groupId]; |
| 381 | |
| 382 | AffineExpr s0 = builder.getAffineSymbolExpr(position: 0); |
| 383 | AffineExpr s1 = builder.getAffineSymbolExpr(position: 1); |
| 384 | for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) { |
| 385 | int64_t baseExpandedStride = |
| 386 | cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx])) |
| 387 | .getInt(); |
| 388 | expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( |
| 389 | builder, expandShape.getLoc(), |
| 390 | (s0 * baseExpandedStride).floorDiv(v: productOfAllStaticSizes) * s1, |
| 391 | {origSize, origStride}); |
| 392 | } |
| 393 | } |
| 394 | |
| 395 | // Now apply the origStride to the remaining dimensions. |
| 396 | AffineExpr s0 = builder.getAffineSymbolExpr(position: 0); |
| 397 | for (; doneStrideIdx < groupSize; ++doneStrideIdx) { |
| 398 | int64_t baseExpandedStride = |
| 399 | cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx])) |
| 400 | .getInt(); |
| 401 | expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( |
| 402 | builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride}); |
| 403 | } |
| 404 | |
| 405 | return expandedStrides; |
| 406 | } |
| 407 | |
| 408 | /// Produce an OpFoldResult object with \p builder at \p loc representing |
| 409 | /// `prod(valueOrConstant#i, for i in {indices})`, |
| 410 | /// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false, |
| 411 | /// values[i] otherwise. |
| 412 | /// |
| 413 | /// \pre for all index in indices: index < values.size() |
| 414 | /// \pre for all index in indices: index < maybeConstants.size() |
| 415 | static OpFoldResult |
| 416 | getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc, |
| 417 | ArrayRef<int64_t> maybeConstants, |
| 418 | ArrayRef<OpFoldResult> values, |
| 419 | llvm::function_ref<bool(int64_t)> isDynamic) { |
| 420 | AffineExpr productOfValues = builder.getAffineConstantExpr(constant: 1); |
| 421 | SmallVector<OpFoldResult> inputValues; |
| 422 | unsigned numberOfSymbols = 0; |
| 423 | unsigned groupSize = indices.size(); |
| 424 | for (unsigned i = 0; i < groupSize; ++i) { |
| 425 | productOfValues = |
| 426 | productOfValues * builder.getAffineSymbolExpr(position: numberOfSymbols++); |
| 427 | unsigned srcIdx = indices[i]; |
| 428 | int64_t maybeConstant = maybeConstants[srcIdx]; |
| 429 | |
| 430 | inputValues.push_back(isDynamic(maybeConstant) |
| 431 | ? values[srcIdx] |
| 432 | : builder.getIndexAttr(maybeConstant)); |
| 433 | } |
| 434 | |
| 435 | return makeComposedFoldedAffineApply(builder, loc, productOfValues, |
| 436 | inputValues); |
| 437 | } |
| 438 | |
| 439 | /// Compute the collapsed size of the given \p collpaseShape for the |
| 440 | /// \p groupId-th reassociation group. |
| 441 | /// \p origSizes hold the sizes of the source shape as values. |
| 442 | /// This is used to compute the new sizes in cases of dynamic shapes. |
| 443 | /// |
| 444 | /// Conceptually this helper function computes: |
| 445 | /// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`. |
| 446 | /// |
| 447 | /// \post result.size() == 1, in other words, each group collapse to one |
| 448 | /// dimension. |
| 449 | /// |
| 450 | /// TODO: Move this utility function directly within CollapseShapeOp. For now, |
| 451 | /// this is not possible because this function uses the Affine dialect and the |
| 452 | /// MemRef dialect cannot depend on the Affine dialect. |
| 453 | static SmallVector<OpFoldResult> |
| 454 | getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder, |
| 455 | ArrayRef<OpFoldResult> origSizes, unsigned groupId) { |
| 456 | SmallVector<OpFoldResult> collapsedSize; |
| 457 | |
| 458 | MemRefType collapseShapeType = collapseShape.getResultType(); |
| 459 | |
| 460 | uint64_t size = collapseShapeType.getDimSize(groupId); |
| 461 | if (!ShapedType::isDynamic(size)) { |
| 462 | collapsedSize.push_back(builder.getIndexAttr(size)); |
| 463 | return collapsedSize; |
| 464 | } |
| 465 | |
| 466 | // We are dealing with a dynamic size. |
| 467 | // Build the affine expr of the product of the original sizes involved in that |
| 468 | // group. |
| 469 | Value source = collapseShape.getSrc(); |
| 470 | auto sourceType = cast<MemRefType>(source.getType()); |
| 471 | |
| 472 | SmallVector<int64_t, 2> reassocGroup = |
| 473 | collapseShape.getReassociationIndices()[groupId]; |
| 474 | |
| 475 | collapsedSize.push_back(getProductOfValues( |
| 476 | reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(), |
| 477 | origSizes, ShapedType::isDynamic)); |
| 478 | |
| 479 | return collapsedSize; |
| 480 | } |
| 481 | |
| 482 | /// Compute the collapsed stride of the given \p collpaseShape for the |
| 483 | /// \p groupId-th reassociation group. |
| 484 | /// \p origStrides and \p origSizes hold respectively the strides and sizes |
| 485 | /// of the source shape as values. |
| 486 | /// This is used to compute the strides in cases of dynamic shapes and/or |
| 487 | /// dynamic stride for this reassociation group. |
| 488 | /// |
| 489 | /// Conceptually this helper function returns the stride of the inner most |
| 490 | /// dimension of that group in the original shape. |
| 491 | /// |
| 492 | /// \post result.size() == 1, in other words, each group collapse to one |
| 493 | /// dimension. |
| 494 | static SmallVector<OpFoldResult> |
| 495 | getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, |
| 496 | ArrayRef<OpFoldResult> origSizes, |
| 497 | ArrayRef<OpFoldResult> origStrides, unsigned groupId) { |
| 498 | SmallVector<int64_t, 2> reassocGroup = |
| 499 | collapseShape.getReassociationIndices()[groupId]; |
| 500 | assert(!reassocGroup.empty() && |
| 501 | "Reassociation group should have at least one dimension" ); |
| 502 | |
| 503 | Value source = collapseShape.getSrc(); |
| 504 | auto sourceType = cast<MemRefType>(source.getType()); |
| 505 | |
| 506 | auto [strides, offset] = sourceType.getStridesAndOffset(); |
| 507 | |
| 508 | ArrayRef<int64_t> srcShape = sourceType.getShape(); |
| 509 | |
| 510 | OpFoldResult lastValidStride = nullptr; |
| 511 | for (int64_t currentDim : reassocGroup) { |
| 512 | // Skip size-of-1 dimensions, since right now their strides may be |
| 513 | // meaningless. |
| 514 | // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless |
| 515 | // they are truly contiguous. When they are truly contiguous, we shouldn't |
| 516 | // need to skip them. |
| 517 | if (srcShape[currentDim] == 1) |
| 518 | continue; |
| 519 | |
| 520 | int64_t currentStride = strides[currentDim]; |
| 521 | lastValidStride = ShapedType::isDynamic(currentStride) |
| 522 | ? origStrides[currentDim] |
| 523 | : builder.getIndexAttr(currentStride); |
| 524 | } |
| 525 | if (!lastValidStride) { |
| 526 | // We're dealing with a 1x1x...x1 shape. The stride is meaningless, |
| 527 | // but we still have to make the type system happy. |
| 528 | MemRefType collapsedType = collapseShape.getResultType(); |
| 529 | auto [collapsedStrides, collapsedOffset] = |
| 530 | collapsedType.getStridesAndOffset(); |
| 531 | int64_t finalStride = collapsedStrides[groupId]; |
| 532 | if (ShapedType::isDynamic(finalStride)) { |
| 533 | // Look for a dynamic stride. At this point we don't know which one is |
| 534 | // desired, but they are all equally good/bad. |
| 535 | for (int64_t currentDim : reassocGroup) { |
| 536 | assert(srcShape[currentDim] == 1 && |
| 537 | "We should be dealing with 1x1x...x1" ); |
| 538 | |
| 539 | if (ShapedType::isDynamic(strides[currentDim])) |
| 540 | return {origStrides[currentDim]}; |
| 541 | } |
| 542 | llvm_unreachable("We should have found a dynamic stride" ); |
| 543 | } |
| 544 | return {builder.getIndexAttr(finalStride)}; |
| 545 | } |
| 546 | |
| 547 | return {lastValidStride}; |
| 548 | } |
| 549 | |
| 550 | /// From `reshape_like(memref, subSizes, subStrides))` compute |
| 551 | /// |
| 552 | /// \verbatim |
| 553 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
| 554 | /// extract_strided_metadata(memref) |
| 555 | /// strides#i = baseStrides#i * subStrides#i |
| 556 | /// sizes = subSizes |
| 557 | /// \endverbatim |
| 558 | /// |
| 559 | /// and return {baseBuffer, baseOffset, sizes, strides} |
| 560 | template <typename ReassociativeReshapeLikeOp> |
| 561 | static FailureOr<StridedMetadata> resolveReshapeStridedMetadata( |
| 562 | RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape, |
| 563 | function_ref<SmallVector<OpFoldResult>( |
| 564 | ReassociativeReshapeLikeOp, OpBuilder &, |
| 565 | ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)> |
| 566 | getReshapedSizes, |
| 567 | function_ref<SmallVector<OpFoldResult>( |
| 568 | ReassociativeReshapeLikeOp, OpBuilder &, |
| 569 | ArrayRef<OpFoldResult> /*origSizes*/, |
| 570 | ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)> |
| 571 | getReshapedStrides) { |
| 572 | // Build a plain extract_strided_metadata(memref) from |
| 573 | // extract_strided_metadata(reassociative_reshape_like(memref)). |
| 574 | Location origLoc = reshape.getLoc(); |
| 575 | Value source = reshape.getSrc(); |
| 576 | auto sourceType = cast<MemRefType>(source.getType()); |
| 577 | unsigned sourceRank = sourceType.getRank(); |
| 578 | |
| 579 | auto = |
| 580 | rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source); |
| 581 | |
| 582 | // Collect statically known information. |
| 583 | auto [strides, offset] = sourceType.getStridesAndOffset(); |
| 584 | MemRefType reshapeType = reshape.getResultType(); |
| 585 | unsigned reshapeRank = reshapeType.getRank(); |
| 586 | |
| 587 | OpFoldResult offsetOfr = |
| 588 | ShapedType::isDynamic(offset) |
| 589 | ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) |
| 590 | : rewriter.getIndexAttr(offset); |
| 591 | |
| 592 | // Get the special case of 0-D out of the way. |
| 593 | if (sourceRank == 0) { |
| 594 | SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1)); |
| 595 | return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr, |
| 596 | /*sizes=*/ones, /*strides=*/ones}; |
| 597 | } |
| 598 | |
| 599 | SmallVector<OpFoldResult> finalSizes; |
| 600 | finalSizes.reserve(reshapeRank); |
| 601 | SmallVector<OpFoldResult> finalStrides; |
| 602 | finalStrides.reserve(reshapeRank); |
| 603 | |
| 604 | // Compute the reshaped strides and sizes from the base strides and sizes. |
| 605 | SmallVector<OpFoldResult> origSizes = |
| 606 | getAsOpFoldResult(newExtractStridedMetadata.getSizes()); |
| 607 | SmallVector<OpFoldResult> origStrides = |
| 608 | getAsOpFoldResult(newExtractStridedMetadata.getStrides()); |
| 609 | unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); |
| 610 | for (; idx != endIdx; ++idx) { |
| 611 | SmallVector<OpFoldResult> reshapedSizes = |
| 612 | getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); |
| 613 | SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides( |
| 614 | reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); |
| 615 | |
| 616 | unsigned groupSize = reshapedSizes.size(); |
| 617 | for (unsigned i = 0; i < groupSize; ++i) { |
| 618 | finalSizes.push_back(reshapedSizes[i]); |
| 619 | finalStrides.push_back(reshapedStrides[i]); |
| 620 | } |
| 621 | } |
| 622 | assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) || |
| 623 | (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) && |
| 624 | "We should have visited all the input dimensions" ); |
| 625 | assert(finalSizes.size() == reshapeRank && |
| 626 | "We should have populated all the values" ); |
| 627 | |
| 628 | return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr, |
| 629 | finalSizes, finalStrides}; |
| 630 | } |
| 631 | |
| 632 | /// Replace `baseBuffer, offset, sizes, strides = |
| 633 | /// extract_strided_metadata(reshapeLike(memref))` |
| 634 | /// With |
| 635 | /// |
| 636 | /// \verbatim |
| 637 | /// baseBuffer, offset, baseSizes, baseStrides = |
| 638 | /// extract_strided_metadata(memref) |
| 639 | /// sizes = getReshapedSizes(reshapeLike) |
| 640 | /// strides = getReshapedStrides(reshapeLike) |
| 641 | /// \endverbatim |
| 642 | /// |
| 643 | /// |
| 644 | /// Notice that `baseBuffer` and `offset` are unchanged. |
| 645 | /// |
| 646 | /// In other words, get rid of the expand_shape in that expression and |
| 647 | /// materialize its effects on the sizes and the strides using affine apply. |
| 648 | template <typename ReassociativeReshapeLikeOp, |
| 649 | SmallVector<OpFoldResult> (*getReshapedSizes)( |
| 650 | ReassociativeReshapeLikeOp, OpBuilder &, |
| 651 | ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/), |
| 652 | SmallVector<OpFoldResult> (*getReshapedStrides)( |
| 653 | ReassociativeReshapeLikeOp, OpBuilder &, |
| 654 | ArrayRef<OpFoldResult> /*origSizes*/, |
| 655 | ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)> |
| 656 | struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> { |
| 657 | public: |
| 658 | using OpRewritePattern<ReassociativeReshapeLikeOp>::OpRewritePattern; |
| 659 | |
| 660 | LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape, |
| 661 | PatternRewriter &rewriter) const override { |
| 662 | FailureOr<StridedMetadata> stridedMetadata = |
| 663 | resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>( |
| 664 | rewriter, reshape, getReshapedSizes, getReshapedStrides); |
| 665 | if (failed(stridedMetadata)) { |
| 666 | return rewriter.notifyMatchFailure(reshape, |
| 667 | "failed to resolve reshape metadata" ); |
| 668 | } |
| 669 | |
| 670 | rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( |
| 671 | reshape, reshape.getType(), stridedMetadata->basePtr, |
| 672 | stridedMetadata->offset, stridedMetadata->sizes, |
| 673 | stridedMetadata->strides); |
| 674 | return success(); |
| 675 | } |
| 676 | }; |
| 677 | |
| 678 | /// Pattern to replace `extract_strided_metadata(collapse_shape)` |
| 679 | /// With |
| 680 | /// |
| 681 | /// \verbatim |
| 682 | /// baseBuffer, baseOffset, baseSizes, baseStrides = |
| 683 | /// extract_strided_metadata(memref) |
| 684 | /// strides#i = baseStrides#i * subSizes#i |
| 685 | /// offset = baseOffset + sum(subOffset#i * baseStrides#i) |
| 686 | /// sizes = subSizes |
| 687 | /// \verbatim |
| 688 | /// |
| 689 | /// with `baseBuffer`, `offset`, `sizes` and `strides` being |
| 690 | /// the replacements for the original `extract_strided_metadata`. |
| 691 | struct |
| 692 | : OpRewritePattern<memref::ExtractStridedMetadataOp> { |
| 693 | using OpRewritePattern::OpRewritePattern; |
| 694 | |
| 695 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
| 696 | PatternRewriter &rewriter) const override { |
| 697 | auto collapseShapeOp = |
| 698 | op.getSource().getDefiningOp<memref::CollapseShapeOp>(); |
| 699 | if (!collapseShapeOp) |
| 700 | return failure(); |
| 701 | |
| 702 | FailureOr<StridedMetadata> stridedMetadata = |
| 703 | resolveReshapeStridedMetadata<memref::CollapseShapeOp>( |
| 704 | rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride); |
| 705 | if (failed(stridedMetadata)) { |
| 706 | return rewriter.notifyMatchFailure( |
| 707 | op, |
| 708 | "failed to resolve metadata in terms of source collapse_shape op" ); |
| 709 | } |
| 710 | |
| 711 | Location loc = collapseShapeOp.getLoc(); |
| 712 | SmallVector<Value> results; |
| 713 | results.push_back(stridedMetadata->basePtr); |
| 714 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, |
| 715 | stridedMetadata->offset)); |
| 716 | results.append( |
| 717 | getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); |
| 718 | results.append(getValueOrCreateConstantIndexOp(rewriter, loc, |
| 719 | stridedMetadata->strides)); |
| 720 | rewriter.replaceOp(op, results); |
| 721 | return success(); |
| 722 | } |
| 723 | }; |
| 724 | |
| 725 | /// Pattern to replace `extract_strided_metadata(expand_shape)` |
| 726 | /// with the results of computing the sizes and strides on the expanded shape |
| 727 | /// and dividing up dimensions into static and dynamic parts as needed. |
| 728 | struct ExtractStridedMetadataOpExpandShapeFolder |
| 729 | : OpRewritePattern<memref::ExtractStridedMetadataOp> { |
| 730 | using OpRewritePattern::OpRewritePattern; |
| 731 | |
| 732 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
| 733 | PatternRewriter &rewriter) const override { |
| 734 | auto expandShapeOp = op.getSource().getDefiningOp<memref::ExpandShapeOp>(); |
| 735 | if (!expandShapeOp) |
| 736 | return failure(); |
| 737 | |
| 738 | FailureOr<StridedMetadata> stridedMetadata = |
| 739 | resolveReshapeStridedMetadata<memref::ExpandShapeOp>( |
| 740 | rewriter, expandShapeOp, getExpandedSizes, getExpandedStrides); |
| 741 | if (failed(stridedMetadata)) { |
| 742 | return rewriter.notifyMatchFailure( |
| 743 | op, "failed to resolve metadata in terms of source expand_shape op" ); |
| 744 | } |
| 745 | |
| 746 | Location loc = expandShapeOp.getLoc(); |
| 747 | SmallVector<Value> results; |
| 748 | results.push_back(stridedMetadata->basePtr); |
| 749 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, |
| 750 | stridedMetadata->offset)); |
| 751 | results.append( |
| 752 | getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); |
| 753 | results.append(getValueOrCreateConstantIndexOp(rewriter, loc, |
| 754 | stridedMetadata->strides)); |
| 755 | rewriter.replaceOp(op, results); |
| 756 | return success(); |
| 757 | } |
| 758 | }; |
| 759 | |
| 760 | /// Replace `base, offset, sizes, strides = |
| 761 | /// extract_strided_metadata(allocLikeOp)` |
| 762 | /// |
| 763 | /// With |
| 764 | /// |
| 765 | /// ``` |
| 766 | /// base = reinterpret_cast allocLikeOp(allocSizes) to a flat memref<eltTy> |
| 767 | /// offset = 0 |
| 768 | /// sizes = allocSizes |
| 769 | /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) |
| 770 | /// ``` |
| 771 | /// |
| 772 | /// The transformation only applies if the allocLikeOp has been normalized. |
| 773 | /// In other words, the affine_map must be an identity. |
| 774 | template <typename AllocLikeOp> |
| 775 | struct |
| 776 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
| 777 | public: |
| 778 | using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern; |
| 779 | |
| 780 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
| 781 | PatternRewriter &rewriter) const override { |
| 782 | auto allocLikeOp = op.getSource().getDefiningOp<AllocLikeOp>(); |
| 783 | if (!allocLikeOp) |
| 784 | return failure(); |
| 785 | |
| 786 | auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType()); |
| 787 | if (!memRefType.getLayout().isIdentity()) |
| 788 | return rewriter.notifyMatchFailure( |
| 789 | allocLikeOp, "alloc-like operations should have been normalized" ); |
| 790 | |
| 791 | Location loc = op.getLoc(); |
| 792 | int rank = memRefType.getRank(); |
| 793 | |
| 794 | // Collect the sizes. |
| 795 | ValueRange dynamic = allocLikeOp.getDynamicSizes(); |
| 796 | SmallVector<OpFoldResult> sizes; |
| 797 | sizes.reserve(rank); |
| 798 | unsigned dynamicPos = 0; |
| 799 | for (int64_t size : memRefType.getShape()) { |
| 800 | if (ShapedType::isDynamic(size)) |
| 801 | sizes.push_back(dynamic[dynamicPos++]); |
| 802 | else |
| 803 | sizes.push_back(rewriter.getIndexAttr(size)); |
| 804 | } |
| 805 | |
| 806 | // Strides (just creates identity strides). |
| 807 | SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); |
| 808 | AffineExpr expr = rewriter.getAffineConstantExpr(constant: 1); |
| 809 | unsigned symbolNumber = 0; |
| 810 | for (int i = rank - 2; i >= 0; --i) { |
| 811 | expr = expr * rewriter.getAffineSymbolExpr(position: symbolNumber++); |
| 812 | assert(i + 1 + symbolNumber == sizes.size() && |
| 813 | "The ArrayRef should encompass the last #symbolNumber sizes" ); |
| 814 | ArrayRef<OpFoldResult> sizesInvolvedInStride(&sizes[i + 1], symbolNumber); |
| 815 | strides[i] = makeComposedFoldedAffineApply(rewriter, loc, expr, |
| 816 | sizesInvolvedInStride); |
| 817 | } |
| 818 | |
| 819 | // Put all the values together to replace the results. |
| 820 | SmallVector<Value> results; |
| 821 | results.reserve(rank * 2 + 2); |
| 822 | |
| 823 | auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType()); |
| 824 | int64_t offset = 0; |
| 825 | if (op.getBaseBuffer().use_empty()) { |
| 826 | results.push_back(nullptr); |
| 827 | } else { |
| 828 | if (allocLikeOp.getType() == baseBufferType) |
| 829 | results.push_back(allocLikeOp); |
| 830 | else |
| 831 | results.push_back(rewriter.create<memref::ReinterpretCastOp>( |
| 832 | loc, baseBufferType, allocLikeOp, offset, |
| 833 | /*sizes=*/ArrayRef<int64_t>(), |
| 834 | /*strides=*/ArrayRef<int64_t>())); |
| 835 | } |
| 836 | |
| 837 | // Offset. |
| 838 | results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset)); |
| 839 | |
| 840 | for (OpFoldResult size : sizes) |
| 841 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size)); |
| 842 | |
| 843 | for (OpFoldResult stride : strides) |
| 844 | results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, stride)); |
| 845 | |
| 846 | rewriter.replaceOp(op, results); |
| 847 | return success(); |
| 848 | } |
| 849 | }; |
| 850 | |
| 851 | /// Replace `base, offset, sizes, strides = |
| 852 | /// extract_strided_metadata(get_global)` |
| 853 | /// |
| 854 | /// With |
| 855 | /// |
| 856 | /// ``` |
| 857 | /// base = reinterpret_cast get_global to a flat memref<eltTy> |
| 858 | /// offset = 0 |
| 859 | /// sizes = allocSizes |
| 860 | /// strides#i = prod(allocSizes#j, for j in {i+1..rank-1}) |
| 861 | /// ``` |
| 862 | /// |
| 863 | /// It is expected that the memref.get_global op has static shapes |
| 864 | /// and identity affine_map for the layout. |
| 865 | struct |
| 866 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
| 867 | public: |
| 868 | using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern; |
| 869 | |
| 870 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
| 871 | PatternRewriter &rewriter) const override { |
| 872 | auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>(); |
| 873 | if (!getGlobalOp) |
| 874 | return failure(); |
| 875 | |
| 876 | auto memRefType = cast<MemRefType>(getGlobalOp.getResult().getType()); |
| 877 | if (!memRefType.getLayout().isIdentity()) { |
| 878 | return rewriter.notifyMatchFailure( |
| 879 | getGlobalOp, |
| 880 | "get-global operation result should have been normalized" ); |
| 881 | } |
| 882 | |
| 883 | Location loc = op.getLoc(); |
| 884 | int rank = memRefType.getRank(); |
| 885 | |
| 886 | // Collect the sizes. |
| 887 | ArrayRef<int64_t> sizes = memRefType.getShape(); |
| 888 | assert(!llvm::any_of(sizes, ShapedType::isDynamic) && |
| 889 | "unexpected dynamic shape for result of `memref.get_global` op" ); |
| 890 | |
| 891 | // Strides (just creates identity strides). |
| 892 | SmallVector<int64_t> strides = computeSuffixProduct(sizes); |
| 893 | |
| 894 | // Put all the values together to replace the results. |
| 895 | SmallVector<Value> results; |
| 896 | results.reserve(rank * 2 + 2); |
| 897 | |
| 898 | auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType()); |
| 899 | int64_t offset = 0; |
| 900 | if (getGlobalOp.getType() == baseBufferType) |
| 901 | results.push_back(getGlobalOp); |
| 902 | else |
| 903 | results.push_back(rewriter.create<memref::ReinterpretCastOp>( |
| 904 | loc, baseBufferType, getGlobalOp, offset, |
| 905 | /*sizes=*/ArrayRef<int64_t>(), |
| 906 | /*strides=*/ArrayRef<int64_t>())); |
| 907 | |
| 908 | // Offset. |
| 909 | results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset)); |
| 910 | |
| 911 | for (auto size : sizes) |
| 912 | results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size)); |
| 913 | |
| 914 | for (auto stride : strides) |
| 915 | results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride)); |
| 916 | |
| 917 | rewriter.replaceOp(op, results); |
| 918 | return success(); |
| 919 | } |
| 920 | }; |
| 921 | |
| 922 | /// Pattern to replace `extract_strided_metadata(assume_alignment)` |
| 923 | /// |
| 924 | /// With |
| 925 | /// \verbatim |
| 926 | /// extract_strided_metadata(memref) |
| 927 | /// \endverbatim |
| 928 | /// |
| 929 | /// Since `assume_alignment` is a view-like op that does not modify the |
| 930 | /// underlying buffer, offset, sizes, or strides, extracting strided metadata |
| 931 | /// from its result is equivalent to extracting it from its source. This |
| 932 | /// canonicalization removes the unnecessary indirection. |
| 933 | struct |
| 934 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
| 935 | public: |
| 936 | using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern; |
| 937 | |
| 938 | LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, |
| 939 | PatternRewriter &rewriter) const override { |
| 940 | auto assumeAlignmentOp = |
| 941 | op.getSource().getDefiningOp<memref::AssumeAlignmentOp>(); |
| 942 | if (!assumeAlignmentOp) |
| 943 | return failure(); |
| 944 | |
| 945 | rewriter.replaceOpWithNewOp<memref::ExtractStridedMetadataOp>( |
| 946 | op, assumeAlignmentOp.getViewSource()); |
| 947 | return success(); |
| 948 | } |
| 949 | }; |
| 950 | |
| 951 | /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the |
| 952 | /// source of the ViewLikeOp. |
| 953 | class |
| 954 | : public OpRewritePattern<memref::ExtractAlignedPointerAsIndexOp> { |
| 955 | using OpRewritePattern::OpRewritePattern; |
| 956 | |
| 957 | LogicalResult |
| 958 | matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp , |
| 959 | PatternRewriter &rewriter) const override { |
| 960 | auto viewLikeOp = |
| 961 | extractOp.getSource().getDefiningOp<ViewLikeOpInterface>(); |
| 962 | if (!viewLikeOp) |
| 963 | return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source" ); |
| 964 | rewriter.modifyOpInPlace(extractOp, [&]() { |
| 965 | extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); |
| 966 | }); |
| 967 | return success(); |
| 968 | } |
| 969 | }; |
| 970 | |
| 971 | /// Replace `base, offset, sizes, strides = |
| 972 | /// extract_strided_metadata( |
| 973 | /// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))` |
| 974 | /// With |
| 975 | /// ``` |
| 976 | /// base, ... = extract_strided_metadata(src) |
| 977 | /// offset = srcOffset |
| 978 | /// sizes = srcSizes |
| 979 | /// strides = srcStrides |
| 980 | /// ``` |
| 981 | /// |
| 982 | /// In other words, consume the `reinterpret_cast` and apply its effects |
| 983 | /// on the offset, sizes, and strides. |
| 984 | class |
| 985 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
| 986 | using OpRewritePattern::OpRewritePattern; |
| 987 | |
| 988 | LogicalResult |
| 989 | matchAndRewrite(memref::ExtractStridedMetadataOp , |
| 990 | PatternRewriter &rewriter) const override { |
| 991 | auto reinterpretCastOp = extractStridedMetadataOp.getSource() |
| 992 | .getDefiningOp<memref::ReinterpretCastOp>(); |
| 993 | if (!reinterpretCastOp) |
| 994 | return failure(); |
| 995 | |
| 996 | Location loc = extractStridedMetadataOp.getLoc(); |
| 997 | // Check if the source is suitable for extract_strided_metadata. |
| 998 | SmallVector<Type> inferredReturnTypes; |
| 999 | if (failed(extractStridedMetadataOp.inferReturnTypes( |
| 1000 | rewriter.getContext(), loc, {reinterpretCastOp.getSource()}, |
| 1001 | /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, |
| 1002 | inferredReturnTypes))) |
| 1003 | return rewriter.notifyMatchFailure( |
| 1004 | reinterpretCastOp, "reinterpret_cast source's type is incompatible" ); |
| 1005 | |
| 1006 | auto memrefType = cast<MemRefType>(reinterpretCastOp.getResult().getType()); |
| 1007 | unsigned rank = memrefType.getRank(); |
| 1008 | SmallVector<OpFoldResult> results; |
| 1009 | results.resize_for_overwrite(rank * 2 + 2); |
| 1010 | |
| 1011 | auto = |
| 1012 | rewriter.create<memref::ExtractStridedMetadataOp>( |
| 1013 | loc, reinterpretCastOp.getSource()); |
| 1014 | |
| 1015 | // Register the base_buffer. |
| 1016 | results[0] = newExtractStridedMetadata.getBaseBuffer(); |
| 1017 | |
| 1018 | // Register the new offset. |
| 1019 | results[1] = getValueOrCreateConstantIndexOp( |
| 1020 | rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]); |
| 1021 | |
| 1022 | const unsigned sizeStartIdx = 2; |
| 1023 | const unsigned strideStartIdx = sizeStartIdx + rank; |
| 1024 | |
| 1025 | SmallVector<OpFoldResult> sizes = reinterpretCastOp.getMixedSizes(); |
| 1026 | SmallVector<OpFoldResult> strides = reinterpretCastOp.getMixedStrides(); |
| 1027 | for (unsigned i = 0; i < rank; ++i) { |
| 1028 | results[sizeStartIdx + i] = sizes[i]; |
| 1029 | results[strideStartIdx + i] = strides[i]; |
| 1030 | } |
| 1031 | rewriter.replaceOp(extractStridedMetadataOp, |
| 1032 | getValueOrCreateConstantIndexOp(rewriter, loc, results)); |
| 1033 | return success(); |
| 1034 | } |
| 1035 | }; |
| 1036 | |
| 1037 | /// Replace `base, offset, sizes, strides = |
| 1038 | /// extract_strided_metadata( |
| 1039 | /// cast(src) to dstTy)` |
| 1040 | /// With |
| 1041 | /// ``` |
| 1042 | /// base, ... = extract_strided_metadata(src) |
| 1043 | /// offset = !dstTy.srcOffset.isDynamic() |
| 1044 | /// ? dstTy.srcOffset |
| 1045 | /// : extract_strided_metadata(src).offset |
| 1046 | /// sizes = for each srcSize in dstTy.srcSizes: |
| 1047 | /// !srcSize.isDynamic() |
| 1048 | /// ? srcSize |
| 1049 | // : extract_strided_metadata(src).sizes[i] |
| 1050 | /// strides = for each srcStride in dstTy.srcStrides: |
| 1051 | /// !srcStrides.isDynamic() |
| 1052 | /// ? srcStrides |
| 1053 | /// : extract_strided_metadata(src).strides[i] |
| 1054 | /// ``` |
| 1055 | /// |
| 1056 | /// In other words, consume the `cast` and apply its effects |
| 1057 | /// on the offset, sizes, and strides or compute them directly from `src`. |
| 1058 | class |
| 1059 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
| 1060 | using OpRewritePattern::OpRewritePattern; |
| 1061 | |
| 1062 | LogicalResult |
| 1063 | matchAndRewrite(memref::ExtractStridedMetadataOp , |
| 1064 | PatternRewriter &rewriter) const override { |
| 1065 | Value source = extractStridedMetadataOp.getSource(); |
| 1066 | auto castOp = source.getDefiningOp<memref::CastOp>(); |
| 1067 | if (!castOp) |
| 1068 | return failure(); |
| 1069 | |
| 1070 | Location loc = extractStridedMetadataOp.getLoc(); |
| 1071 | // Check if the source is suitable for extract_strided_metadata. |
| 1072 | SmallVector<Type> inferredReturnTypes; |
| 1073 | if (failed(extractStridedMetadataOp.inferReturnTypes( |
| 1074 | rewriter.getContext(), loc, {castOp.getSource()}, |
| 1075 | /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, |
| 1076 | inferredReturnTypes))) |
| 1077 | return rewriter.notifyMatchFailure(castOp, |
| 1078 | "cast source's type is incompatible" ); |
| 1079 | |
| 1080 | auto memrefType = cast<MemRefType>(source.getType()); |
| 1081 | unsigned rank = memrefType.getRank(); |
| 1082 | SmallVector<OpFoldResult> results; |
| 1083 | results.resize_for_overwrite(rank * 2 + 2); |
| 1084 | |
| 1085 | auto = |
| 1086 | rewriter.create<memref::ExtractStridedMetadataOp>(loc, |
| 1087 | castOp.getSource()); |
| 1088 | |
| 1089 | // Register the base_buffer. |
| 1090 | results[0] = newExtractStridedMetadata.getBaseBuffer(); |
| 1091 | |
| 1092 | auto getConstantOrValue = [&rewriter](int64_t constant, |
| 1093 | OpFoldResult ofr) -> OpFoldResult { |
| 1094 | return !ShapedType::isDynamic(constant) |
| 1095 | ? OpFoldResult(rewriter.getIndexAttr(constant)) |
| 1096 | : ofr; |
| 1097 | }; |
| 1098 | |
| 1099 | auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset(); |
| 1100 | assert(sourceStrides.size() == rank && "unexpected number of strides" ); |
| 1101 | |
| 1102 | // Register the new offset. |
| 1103 | results[1] = |
| 1104 | getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset()); |
| 1105 | |
| 1106 | const unsigned sizeStartIdx = 2; |
| 1107 | const unsigned strideStartIdx = sizeStartIdx + rank; |
| 1108 | ArrayRef<int64_t> sourceSizes = memrefType.getShape(); |
| 1109 | |
| 1110 | SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes(); |
| 1111 | SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides(); |
| 1112 | for (unsigned i = 0; i < rank; ++i) { |
| 1113 | results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]); |
| 1114 | results[strideStartIdx + i] = |
| 1115 | getConstantOrValue(sourceStrides[i], strides[i]); |
| 1116 | } |
| 1117 | rewriter.replaceOp(extractStridedMetadataOp, |
| 1118 | getValueOrCreateConstantIndexOp(rewriter, loc, results)); |
| 1119 | return success(); |
| 1120 | } |
| 1121 | }; |
| 1122 | |
| 1123 | /// Replace `base, offset, sizes, strides = extract_strided_metadata( |
| 1124 | /// memory_space_cast(src) to dstTy)` |
| 1125 | /// with |
| 1126 | /// ``` |
| 1127 | /// oldBase, offset, sizes, strides = extract_strided_metadata(src) |
| 1128 | /// destBaseTy = type(oldBase) with memory space from destTy |
| 1129 | /// base = memory_space_cast(oldBase) to destBaseTy |
| 1130 | /// ``` |
| 1131 | /// |
| 1132 | /// In other words, propagate metadata extraction accross memory space casts. |
| 1133 | class |
| 1134 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
| 1135 | using OpRewritePattern::OpRewritePattern; |
| 1136 | |
| 1137 | LogicalResult |
| 1138 | matchAndRewrite(memref::ExtractStridedMetadataOp , |
| 1139 | PatternRewriter &rewriter) const override { |
| 1140 | Location loc = extractStridedMetadataOp.getLoc(); |
| 1141 | Value source = extractStridedMetadataOp.getSource(); |
| 1142 | auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>(); |
| 1143 | if (!memSpaceCastOp) |
| 1144 | return failure(); |
| 1145 | auto = |
| 1146 | rewriter.create<memref::ExtractStridedMetadataOp>( |
| 1147 | loc, memSpaceCastOp.getSource()); |
| 1148 | SmallVector<Value> results(newExtractStridedMetadata.getResults()); |
| 1149 | // As with most other strided metadata rewrite patterns, don't introduce |
| 1150 | // a use of the base pointer where non existed. This needs to happen here, |
| 1151 | // as opposed to in later dead-code elimination, because these patterns are |
| 1152 | // sometimes used during dialect conversion (see EmulateNarrowType, for |
| 1153 | // example), so adding spurious usages would cause a pre-legalization value |
| 1154 | // to be live that would be dead had this pattern not run. |
| 1155 | if (!extractStridedMetadataOp.getBaseBuffer().use_empty()) { |
| 1156 | auto baseBuffer = results[0]; |
| 1157 | auto baseBufferType = cast<MemRefType>(baseBuffer.getType()); |
| 1158 | MemRefType::Builder newTypeBuilder(baseBufferType); |
| 1159 | newTypeBuilder.setMemorySpace( |
| 1160 | memSpaceCastOp.getResult().getType().getMemorySpace()); |
| 1161 | results[0] = rewriter.create<memref::MemorySpaceCastOp>( |
| 1162 | loc, Type{newTypeBuilder}, baseBuffer); |
| 1163 | } else { |
| 1164 | results[0] = nullptr; |
| 1165 | } |
| 1166 | rewriter.replaceOp(extractStridedMetadataOp, results); |
| 1167 | return success(); |
| 1168 | } |
| 1169 | }; |
| 1170 | |
| 1171 | /// Replace `base, offset = |
| 1172 | /// extract_strided_metadata(extract_strided_metadata(src)#0)` |
| 1173 | /// With |
| 1174 | /// ``` |
| 1175 | /// base, ... = extract_strided_metadata(src) |
| 1176 | /// offset = 0 |
| 1177 | /// ``` |
| 1178 | class |
| 1179 | : public OpRewritePattern<memref::ExtractStridedMetadataOp> { |
| 1180 | using OpRewritePattern::OpRewritePattern; |
| 1181 | |
| 1182 | LogicalResult |
| 1183 | matchAndRewrite(memref::ExtractStridedMetadataOp , |
| 1184 | PatternRewriter &rewriter) const override { |
| 1185 | auto = |
| 1186 | extractStridedMetadataOp.getSource() |
| 1187 | .getDefiningOp<memref::ExtractStridedMetadataOp>(); |
| 1188 | if (!sourceExtractStridedMetadataOp) |
| 1189 | return failure(); |
| 1190 | Location loc = extractStridedMetadataOp.getLoc(); |
| 1191 | rewriter.replaceOp(extractStridedMetadataOp, |
| 1192 | {sourceExtractStridedMetadataOp.getBaseBuffer(), |
| 1193 | getValueOrCreateConstantIndexOp( |
| 1194 | rewriter, loc, rewriter.getIndexAttr(0))}); |
| 1195 | return success(); |
| 1196 | } |
| 1197 | }; |
| 1198 | } // namespace |
| 1199 | |
| 1200 | void memref::populateExpandStridedMetadataPatterns( |
| 1201 | RewritePatternSet &patterns) { |
| 1202 | patterns.add<SubviewFolder, |
| 1203 | ReshapeFolder<memref::ExpandShapeOp, getExpandedSizes, |
| 1204 | getExpandedStrides>, |
| 1205 | ReshapeFolder<memref::CollapseShapeOp, getCollapsedSize, |
| 1206 | getCollapsedStride>, |
| 1207 | ExtractStridedMetadataOpAllocFolder<memref::AllocOp>, |
| 1208 | ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>, |
| 1209 | ExtractStridedMetadataOpCollapseShapeFolder, |
| 1210 | ExtractStridedMetadataOpExpandShapeFolder, |
| 1211 | ExtractStridedMetadataOpGetGlobalFolder, |
| 1212 | RewriteExtractAlignedPointerAsIndexOfViewLikeOp, |
| 1213 | ExtractStridedMetadataOpReinterpretCastFolder, |
| 1214 | ExtractStridedMetadataOpSubviewFolder, |
| 1215 | ExtractStridedMetadataOpCastFolder, |
| 1216 | ExtractStridedMetadataOpMemorySpaceCastFolder, |
| 1217 | ExtractStridedMetadataOpAssumeAlignmentFolder, |
| 1218 | ExtractStridedMetadataOpExtractStridedMetadataFolder>( |
| 1219 | patterns.getContext()); |
| 1220 | } |
| 1221 | |
| 1222 | void memref::( |
| 1223 | RewritePatternSet &patterns) { |
| 1224 | patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>, |
| 1225 | ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>, |
| 1226 | ExtractStridedMetadataOpCollapseShapeFolder, |
| 1227 | ExtractStridedMetadataOpExpandShapeFolder, |
| 1228 | ExtractStridedMetadataOpGetGlobalFolder, |
| 1229 | ExtractStridedMetadataOpSubviewFolder, |
| 1230 | RewriteExtractAlignedPointerAsIndexOfViewLikeOp, |
| 1231 | ExtractStridedMetadataOpReinterpretCastFolder, |
| 1232 | ExtractStridedMetadataOpCastFolder, |
| 1233 | ExtractStridedMetadataOpMemorySpaceCastFolder, |
| 1234 | ExtractStridedMetadataOpAssumeAlignmentFolder, |
| 1235 | ExtractStridedMetadataOpExtractStridedMetadataFolder>( |
| 1236 | arg: patterns.getContext()); |
| 1237 | } |
| 1238 | |
| 1239 | //===----------------------------------------------------------------------===// |
| 1240 | // Pass registration |
| 1241 | //===----------------------------------------------------------------------===// |
| 1242 | |
| 1243 | namespace { |
| 1244 | |
| 1245 | struct ExpandStridedMetadataPass final |
| 1246 | : public memref::impl::ExpandStridedMetadataPassBase< |
| 1247 | ExpandStridedMetadataPass> { |
| 1248 | void runOnOperation() override; |
| 1249 | }; |
| 1250 | |
| 1251 | } // namespace |
| 1252 | |
| 1253 | void ExpandStridedMetadataPass::runOnOperation() { |
| 1254 | RewritePatternSet patterns(&getContext()); |
| 1255 | memref::populateExpandStridedMetadataPatterns(patterns); |
| 1256 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
| 1257 | } |
| 1258 | |