| 1 | //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This pass legalizes vector operations so they can be lowered to ArmSME. |
| 10 | // |
| 11 | // Note: In the context of this pass 'tile' always refers to an SME tile. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 16 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
| 17 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
| 18 | #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
| 19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 20 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
| 21 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
| 22 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
| 23 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 25 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| 26 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 27 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 28 | #include "mlir/Transforms/DialectConversion.h" |
| 29 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 30 | |
| 31 | #define DEBUG_TYPE "arm-sme-vector-legalization" |
| 32 | |
| 33 | namespace mlir::arm_sme { |
| 34 | #define GEN_PASS_DEF_VECTORLEGALIZATION |
| 35 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
| 36 | } // namespace mlir::arm_sme |
| 37 | |
| 38 | using namespace mlir; |
| 39 | using namespace mlir::arm_sme; |
| 40 | |
| 41 | namespace { |
| 42 | |
| 43 | //===----------------------------------------------------------------------===// |
| 44 | // Decomposition of vector operations larger than an SME tile |
| 45 | //===----------------------------------------------------------------------===// |
| 46 | |
| 47 | // Common match failure reasons. |
| 48 | static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple( |
| 49 | "op vector size is not multiple of SME tiles" ); |
| 50 | static constexpr StringLiteral kMatchFailureUnsupportedMaskOp( |
| 51 | "op mask is unsupported for legalization/decomposition" ); |
| 52 | static constexpr StringLiteral |
| 53 | kMatchFailureNonPermutationMap("op affine map is not a permutation" ); |
| 54 | static constexpr StringLiteral kMatchFailureNotIllegalToLegal( |
| 55 | "expected transpose from illegal type to legal type" ); |
| 56 | |
| 57 | /// An SMESubTile represents a single SME-sized sub-tile from decomposing a |
| 58 | /// larger vector type. The (`row`, `col`) are the position of the tile in the |
| 59 | /// original vector type. For example for an [8]x[8] tile with four [4]x[4] |
| 60 | /// sub-tiles, we would have: |
| 61 | /// |
| 62 | /// 8 x vscale |
| 63 | /// ┌─────────────┬─────────────┐ |
| 64 | /// │(0,0) │(0,4) │ |
| 65 | /// │ │ │ |
| 66 | /// ├─────────────┼─────────────┤ 8 x vscale |
| 67 | /// │(4,0) │(4,4) │ |
| 68 | /// │ │ │ |
| 69 | /// └─────────────┴─────────────┘ |
| 70 | struct SMESubTile { |
| 71 | // Note: The units of (row, col) are vscale (as SME tiles are scalable). |
| 72 | int row{0}; |
| 73 | int col{0}; |
| 74 | // The SME tile type. |
| 75 | VectorType type; |
| 76 | }; |
| 77 | |
| 78 | /// Adds a constant elementwise scalable offset to `indices` (which are of equal |
| 79 | /// length). For example, in the 2D case this would return: |
| 80 | // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale } |
| 81 | SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder, |
| 82 | Location loc, |
| 83 | ValueRange indices, |
| 84 | ArrayRef<int> scalableOffsets) { |
| 85 | auto vscale = builder.create<vector::VectorScaleOp>(loc); |
| 86 | return llvm::map_to_vector( |
| 87 | C: llvm::zip_equal(t&: indices, u&: scalableOffsets), F: [&](auto pair) -> Value { |
| 88 | auto [index, base] = pair; |
| 89 | auto offset = builder.create<arith::MulIOp>( |
| 90 | loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale); |
| 91 | return builder.create<arith::AddIOp>(loc, index, offset); |
| 92 | }); |
| 93 | } |
| 94 | |
| 95 | /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to |
| 96 | /// indices for one of the SME sub-tiles it will decompose into. |
| 97 | /// |
| 98 | /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the |
| 99 | /// indices for each tile would need to be adjusted as follows: |
| 100 | /// |
| 101 | /// initial indices = [a,b], inital size = 8x8, target size = 4x4 |
| 102 | /// ┌─────────────┬─────────────┐ |
| 103 | /// │[a,b] │[a,b+4] │ |
| 104 | /// │ │ │ |
| 105 | /// ├─────────────┼─────────────┤ |
| 106 | /// │[a+4,b] │[a+4,b+4] │ |
| 107 | /// │ │ │ |
| 108 | /// └─────────────┴─────────────┘ |
| 109 | SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc, |
| 110 | ValueRange indices, |
| 111 | SMESubTile smeTile) { |
| 112 | return addConstantScalableOffset(builder, loc, indices, |
| 113 | scalableOffsets: {smeTile.row, smeTile.col}); |
| 114 | } |
| 115 | |
| 116 | /// Returns true if `mask` is generated by an operation that can be decomposed |
| 117 | /// for SME. Currently, that is just no mask, or vector.create_mask. |
| 118 | /// TODO: Add support for vector.constant_mask once required for SME. |
| 119 | bool isSupportedMaskOp(Value mask) { |
| 120 | return !mask || mask.getDefiningOp<vector::CreateMaskOp>(); |
| 121 | } |
| 122 | |
| 123 | /// Extracts a mask for an SME sub-tile from the mask of a larger vector type. |
| 124 | Value (OpBuilder &builder, Location loc, Value mask, |
| 125 | SMESubTile smeTile) { |
| 126 | assert(isSupportedMaskOp(mask)); |
| 127 | if (!mask) |
| 128 | return Value{}; |
| 129 | auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); |
| 130 | // The operands of `vector.create_mask` (from a 2D perspective) are the |
| 131 | // coordinates where the mask ends. So we subtract where this tile starts, |
| 132 | // from the mask operands to get the parameters for this sub-tile. |
| 133 | auto smeTileMaskDims = addConstantScalableOffset( |
| 134 | builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col}); |
| 135 | auto smeTileCreateMask = builder.create<vector::CreateMaskOp>( |
| 136 | loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims); |
| 137 | return smeTileCreateMask.getResult(); |
| 138 | } |
| 139 | |
| 140 | /// Constructs an iterator that returns each SME tile (with coordinates) |
| 141 | /// contained within a VectorType. For example, if decomposing an [8]x[8] into |
| 142 | /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0), |
| 143 | /// (4, 4). |
| 144 | auto decomposeToSMETiles(OpBuilder &builder, VectorType type, |
| 145 | VectorType smeTileType, |
| 146 | bool transposeIndices = false) { |
| 147 | return llvm::map_range( |
| 148 | StaticTileOffsetRange( |
| 149 | type.getShape(), |
| 150 | {std::min(type.getDimSize(0), smeTileType.getDimSize(0)), |
| 151 | std::min(type.getDimSize(1), smeTileType.getDimSize(1))}), |
| 152 | [=](auto indices) { |
| 153 | int row = int(indices[0]); |
| 154 | int col = int(indices[1]); |
| 155 | if (transposeIndices) |
| 156 | std::swap(a&: row, b&: col); |
| 157 | return SMESubTile{row, col, smeTileType}; |
| 158 | }); |
| 159 | } |
| 160 | |
| 161 | /// Returns the number of SME tiles that fit into the (2D-scalable) vector type |
| 162 | /// `type`. |
| 163 | int getNumberOfSMETilesForVectorType(VectorType type) { |
| 164 | assert(isMultipleOfSMETileVectorType(type) && |
| 165 | "`type` not multiple of SME tiles" ); |
| 166 | int64_t vectorRows = type.getDimSize(0); |
| 167 | int64_t vectorCols = type.getDimSize(1); |
| 168 | auto elementType = type.getElementType(); |
| 169 | unsigned minNumElts = getSMETileSliceMinNumElts(elementType); |
| 170 | return (vectorRows * vectorCols) / (minNumElts * minNumElts); |
| 171 | } |
| 172 | |
| 173 | /// Legalize `arith.constant dense<value>` splat operations to fit within SME |
| 174 | /// tiles by decomposing them into tile-sized operations. |
| 175 | struct LegalizeArithConstantOpsByDecomposition |
| 176 | : public OpConversionPattern<arith::ConstantOp> { |
| 177 | using OpConversionPattern::OpConversionPattern; |
| 178 | |
| 179 | LogicalResult |
| 180 | matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, |
| 181 | ConversionPatternRewriter &rewriter) const override { |
| 182 | auto vectorType = dyn_cast<VectorType>(constantOp.getType()); |
| 183 | auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); |
| 184 | if (!vectorType || !denseAttr || !denseAttr.isSplat()) |
| 185 | return failure(); |
| 186 | |
| 187 | if (!isMultipleOfSMETileVectorType(vectorType)) |
| 188 | return rewriter.notifyMatchFailure(constantOp, |
| 189 | kMatchFailureNotSMETileTypeMultiple); |
| 190 | |
| 191 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
| 192 | auto tileCount = getNumberOfSMETilesForVectorType(vectorType); |
| 193 | auto tileSplat = rewriter.create<arith::ConstantOp>( |
| 194 | constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); |
| 195 | SmallVector<Value> repl(tileCount, tileSplat); |
| 196 | rewriter.replaceOpWithMultiple(constantOp, {repl}); |
| 197 | |
| 198 | return success(); |
| 199 | } |
| 200 | }; |
| 201 | |
| 202 | /// Legalize `vector.outerproduct` operations to fit within SME tiles by |
| 203 | /// decomposing them into tile-sized operations. |
| 204 | struct LegalizeVectorOuterProductOpsByDecomposition |
| 205 | : public OpConversionPattern<vector::OuterProductOp> { |
| 206 | using OpConversionPattern::OpConversionPattern; |
| 207 | |
| 208 | LogicalResult |
| 209 | matchAndRewrite(vector::OuterProductOp outerProductOp, |
| 210 | OneToNOpAdaptor adaptor, |
| 211 | ConversionPatternRewriter &rewriter) const override { |
| 212 | auto vectorType = outerProductOp.getResultVectorType(); |
| 213 | if (!isMultipleOfSMETileVectorType(vectorType)) |
| 214 | return rewriter.notifyMatchFailure(outerProductOp, |
| 215 | kMatchFailureNotSMETileTypeMultiple); |
| 216 | |
| 217 | Value mask; |
| 218 | Operation *rootOp = outerProductOp; |
| 219 | auto loc = outerProductOp.getLoc(); |
| 220 | if (outerProductOp.isMasked()) { |
| 221 | auto maskOp = outerProductOp.getMaskingOp(); |
| 222 | mask = maskOp.getMask(); |
| 223 | rootOp = maskOp; |
| 224 | rewriter.setInsertionPoint(rootOp); |
| 225 | } |
| 226 | |
| 227 | if (!isSupportedMaskOp(mask)) |
| 228 | return rewriter.notifyMatchFailure(outerProductOp, |
| 229 | kMatchFailureUnsupportedMaskOp); |
| 230 | |
| 231 | ValueRange accSMETiles = adaptor.getAcc(); |
| 232 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
| 233 | VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0); |
| 234 | |
| 235 | SmallVector<Value> resultSMETiles; |
| 236 | for (auto [index, smeTile] : llvm::enumerate( |
| 237 | decomposeToSMETiles(rewriter, vectorType, smeTileType))) { |
| 238 | |
| 239 | auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); |
| 240 | auto lhs = rewriter.create<vector::ScalableExtractOp>( |
| 241 | loc, sliceType, outerProductOp.getLhs(), smeTile.row); |
| 242 | auto rhs = rewriter.create<vector::ScalableExtractOp>( |
| 243 | loc, sliceType, outerProductOp.getRhs(), smeTile.col); |
| 244 | auto smeOuterProduct = rewriter.create<vector::OuterProductOp>( |
| 245 | loc, smeTileType, lhs, rhs, |
| 246 | !accSMETiles.empty() ? accSMETiles[index] : Value{}, |
| 247 | outerProductOp.getKind()); |
| 248 | |
| 249 | auto maskedOuterProduct = |
| 250 | vector::maskOperation(rewriter, smeOuterProduct, smeMask); |
| 251 | resultSMETiles.push_back(maskedOuterProduct->getResult(0)); |
| 252 | } |
| 253 | |
| 254 | rewriter.replaceOpWithMultiple(op: rootOp, newValues: {resultSMETiles}); |
| 255 | return success(); |
| 256 | } |
| 257 | }; |
| 258 | |
| 259 | // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to |
| 260 | // get the help of the type conversion), but doing so results in the type |
| 261 | // conversion adding target materializations in the `vector.mask` region |
| 262 | // (invalid). This pattern matches on `vector.mask` then calls into the |
| 263 | // `vector.outerproduct` pattern to work around this issue. |
| 264 | struct LegalizeMaskedVectorOuterProductOpsByDecomposition |
| 265 | : public OpConversionPattern<vector::MaskOp> { |
| 266 | using OpConversionPattern::OpConversionPattern; |
| 267 | |
| 268 | LogicalResult |
| 269 | matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor, |
| 270 | ConversionPatternRewriter &rewriter) const override { |
| 271 | if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>( |
| 272 | maskOp.getMaskableOp())) { |
| 273 | LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(), |
| 274 | getContext()); |
| 275 | return static_cast<RewritePattern &>(pattern).matchAndRewrite( |
| 276 | outerProductOp, rewriter); |
| 277 | } |
| 278 | return failure(); |
| 279 | } |
| 280 | }; |
| 281 | |
| 282 | /// Legalize `vector.transfer_read` operations to fit within SME tiles by |
| 283 | /// decomposing them into tile-sized operations. |
| 284 | struct LegalizeTransferReadOpsByDecomposition |
| 285 | : public OpConversionPattern<vector::TransferReadOp> { |
| 286 | using OpConversionPattern::OpConversionPattern; |
| 287 | |
| 288 | LogicalResult |
| 289 | matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor, |
| 290 | ConversionPatternRewriter &rewriter) const override { |
| 291 | auto vectorType = readOp.getVectorType(); |
| 292 | if (!isMultipleOfSMETileVectorType(vectorType)) |
| 293 | return rewriter.notifyMatchFailure(readOp, |
| 294 | kMatchFailureNotSMETileTypeMultiple); |
| 295 | |
| 296 | auto mask = readOp.getMask(); |
| 297 | if (!isSupportedMaskOp(mask)) |
| 298 | return rewriter.notifyMatchFailure(readOp, |
| 299 | kMatchFailureUnsupportedMaskOp); |
| 300 | |
| 301 | auto permutationMap = readOp.getPermutationMap(); |
| 302 | if (!permutationMap.isPermutation()) |
| 303 | return rewriter.notifyMatchFailure(readOp, |
| 304 | kMatchFailureNonPermutationMap); |
| 305 | |
| 306 | // Note: For 2D vector types the only non-identity permutation is a simple |
| 307 | // transpose [1, 0]. |
| 308 | bool transposed = !permutationMap.isIdentity(); |
| 309 | |
| 310 | auto loc = readOp.getLoc(); |
| 311 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
| 312 | |
| 313 | SmallVector<Value> resultSMETiles; |
| 314 | for (SMESubTile smeTile : |
| 315 | decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) { |
| 316 | auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); |
| 317 | auto smeRead = rewriter.create<vector::TransferReadOp>( |
| 318 | loc, smeTileType, readOp.getBase(), |
| 319 | getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile), |
| 320 | readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask, |
| 321 | readOp.getInBoundsAttr()); |
| 322 | resultSMETiles.push_back(smeRead); |
| 323 | } |
| 324 | |
| 325 | rewriter.replaceOpWithMultiple(readOp, {resultSMETiles}); |
| 326 | return success(); |
| 327 | } |
| 328 | }; |
| 329 | |
| 330 | /// Legalize `vector.transfer_write` operations to fit within SME tiles by |
| 331 | /// decomposing them into tile-sized operations. |
| 332 | struct LegalizeTransferWriteOpsByDecomposition |
| 333 | : public OpConversionPattern<vector::TransferWriteOp> { |
| 334 | using OpConversionPattern::OpConversionPattern; |
| 335 | |
| 336 | LogicalResult |
| 337 | matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, |
| 338 | ConversionPatternRewriter &rewriter) const override { |
| 339 | auto vectorType = writeOp.getVectorType(); |
| 340 | if (!isMultipleOfSMETileVectorType(vectorType)) |
| 341 | return rewriter.notifyMatchFailure(writeOp, |
| 342 | kMatchFailureNotSMETileTypeMultiple); |
| 343 | |
| 344 | auto mask = writeOp.getMask(); |
| 345 | if (!isSupportedMaskOp(mask)) |
| 346 | return rewriter.notifyMatchFailure(writeOp, |
| 347 | kMatchFailureUnsupportedMaskOp); |
| 348 | |
| 349 | auto permutationMap = writeOp.getPermutationMap(); |
| 350 | if (!permutationMap.isPermutation()) |
| 351 | return rewriter.notifyMatchFailure(writeOp, |
| 352 | kMatchFailureNonPermutationMap); |
| 353 | |
| 354 | // Note: For 2D vector types the only non-identity permutation is a simple |
| 355 | // transpose [1, 0]. |
| 356 | bool transposed = !permutationMap.isIdentity(); |
| 357 | |
| 358 | auto loc = writeOp.getLoc(); |
| 359 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
| 360 | auto inputSMETiles = adaptor.getValueToStore(); |
| 361 | |
| 362 | Value destTensorOrMemref = writeOp.getBase(); |
| 363 | for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles( |
| 364 | rewriter, vectorType, smeTileType, transposed))) { |
| 365 | auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); |
| 366 | auto smeWrite = rewriter.create<vector::TransferWriteOp>( |
| 367 | loc, inputSMETiles[index], destTensorOrMemref, |
| 368 | getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile), |
| 369 | writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr()); |
| 370 | if (writeOp.hasPureTensorSemantics()) |
| 371 | destTensorOrMemref = smeWrite.getResult(); |
| 372 | } |
| 373 | |
| 374 | if (writeOp.hasPureTensorSemantics()) |
| 375 | rewriter.replaceOp(writeOp, destTensorOrMemref); |
| 376 | else |
| 377 | rewriter.eraseOp(op: writeOp); |
| 378 | |
| 379 | return success(); |
| 380 | } |
| 381 | }; |
| 382 | |
| 383 | /// Legalize a multi-tile transfer_write as a single store loop. This is done as |
| 384 | /// part of type decomposition as at this level we know each tile write is |
| 385 | /// disjoint, but that information is lost after decomposition (without analysis |
| 386 | /// to reconstruct it). |
| 387 | /// |
| 388 | /// Example (pseudo-MLIR): |
| 389 | /// |
| 390 | /// ``` |
| 391 | /// vector.transfer_write %vector, %dest[%y, %x], %mask |
| 392 | /// : vector<[16]x[8]xi16>, memref<?x?xi16> |
| 393 | /// ``` |
| 394 | /// Is rewritten to: |
| 395 | /// ``` |
| 396 | /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 { |
| 397 | /// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐ |
| 398 | /// : vector<[8]xi1> from vector<[16]x[8]xi1> | |
| 399 | /// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile |
| 400 | /// : vector<[8]xi16> from vector<[8]x[8]xi16> | |
| 401 | /// vector.transfer_write %upper_slice, | |
| 402 | /// %dest[%slice_idx + %y, %x], %upper_slice_mask | |
| 403 | /// : vector<[8]xi16>, memref<?x?xi16> ┘ |
| 404 | /// %lower_slice_idx = %slice_idx + %c8_vscale ─┐ |
| 405 | /// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] | |
| 406 | /// : vector<[8]xi1> from vector<[16]x[8]xi1> | |
| 407 | /// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower |
| 408 | /// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile |
| 409 | /// vector.transfer_write %lower_slice, | |
| 410 | /// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask | |
| 411 | /// : vector<[8]xi16>, memref<?x?xi16> ┘ |
| 412 | /// } |
| 413 | /// ``` |
| 414 | struct LegalizeMultiTileTransferWriteAsStoreLoop |
| 415 | : public OpConversionPattern<vector::TransferWriteOp> { |
| 416 | using OpConversionPattern::OpConversionPattern; |
| 417 | |
| 418 | LogicalResult |
| 419 | matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, |
| 420 | ConversionPatternRewriter &rewriter) const override { |
| 421 | if (writeOp.hasPureTensorSemantics()) |
| 422 | return rewriter.notifyMatchFailure( |
| 423 | writeOp, "TODO: tensor semantics are unsupported" ); |
| 424 | |
| 425 | auto permutationMap = writeOp.getPermutationMap(); |
| 426 | if (!permutationMap.isPermutation()) |
| 427 | return rewriter.notifyMatchFailure(writeOp, |
| 428 | kMatchFailureNonPermutationMap); |
| 429 | |
| 430 | bool transposed = !permutationMap.isIdentity(); |
| 431 | if (transposed) |
| 432 | return rewriter.notifyMatchFailure(writeOp, |
| 433 | "TODO: transpose unsupported" ); |
| 434 | |
| 435 | auto vectorType = writeOp.getVectorType(); |
| 436 | if (!isMultipleOfSMETileVectorType(vectorType)) |
| 437 | return rewriter.notifyMatchFailure(writeOp, |
| 438 | kMatchFailureNotSMETileTypeMultiple); |
| 439 | |
| 440 | // Note: We also disallow masks where any dimension is > 16 because that |
| 441 | // prevents the masking from being lowered to use arm_sve.psel. |
| 442 | auto mask = writeOp.getMask(); |
| 443 | if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 || |
| 444 | vectorType.getDimSize(1) > 16))) |
| 445 | return rewriter.notifyMatchFailure(writeOp, |
| 446 | kMatchFailureUnsupportedMaskOp); |
| 447 | |
| 448 | auto loc = writeOp.getLoc(); |
| 449 | auto createVscaleMultiple = |
| 450 | vector::makeVscaleConstantBuilder(rewriter, loc: loc); |
| 451 | |
| 452 | // Get SME tile and slice types. |
| 453 | auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); |
| 454 | auto minTileSlices = smeTileType.getDimSize(0); |
| 455 | VectorType sliceMaskType = |
| 456 | VectorType::get(minTileSlices, rewriter.getI1Type(), true); |
| 457 | |
| 458 | // Create loop over all tile slices. |
| 459 | auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| 460 | auto upperBound = createVscaleMultiple(minTileSlices); |
| 461 | auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
| 462 | auto storeLoop = |
| 463 | rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); |
| 464 | rewriter.setInsertionPointToStart(storeLoop.getBody()); |
| 465 | |
| 466 | // For each sub-tile of the multi-tile `vectorType`. |
| 467 | auto inputSMETiles = adaptor.getValueToStore(); |
| 468 | auto tileSliceIndex = storeLoop.getInductionVar(); |
| 469 | for (auto [index, smeTile] : llvm::enumerate( |
| 470 | decomposeToSMETiles(rewriter, vectorType, smeTileType))) { |
| 471 | // The coordinates of the tile within `vectorType`. |
| 472 | auto tileRow = createVscaleMultiple(smeTile.row); |
| 473 | auto tileCol = createVscaleMultiple(smeTile.col); |
| 474 | |
| 475 | // The current slice of `vectorType` we are processing. |
| 476 | auto sliceIndex = |
| 477 | rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex); |
| 478 | |
| 479 | // Where in the destination memref the current slice will be stored. |
| 480 | auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex, |
| 481 | writeOp.getIndices()[0]); |
| 482 | auto storeCol = |
| 483 | rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]); |
| 484 | |
| 485 | // Extract the mask for the current slice. |
| 486 | Value sliceMask = nullptr; |
| 487 | if (mask) { |
| 488 | sliceMask = rewriter.create<vector::ExtractOp>( |
| 489 | loc, mask, OpFoldResult(sliceIndex)); |
| 490 | if (sliceMaskType != sliceMask.getType()) |
| 491 | sliceMask = rewriter.create<vector::ScalableExtractOp>( |
| 492 | loc, sliceMaskType, sliceMask, smeTile.col); |
| 493 | } |
| 494 | |
| 495 | // Extract and store the current slice. |
| 496 | Value tile = inputSMETiles[index]; |
| 497 | auto slice = |
| 498 | rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex); |
| 499 | rewriter.create<vector::TransferWriteOp>( |
| 500 | loc, slice, writeOp.getBase(), ValueRange{storeRow, storeCol}, |
| 501 | AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)), |
| 502 | sliceMask, |
| 503 | rewriter.getBoolArrayAttr( |
| 504 | ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front())); |
| 505 | } |
| 506 | |
| 507 | rewriter.eraseOp(op: writeOp); |
| 508 | return success(); |
| 509 | } |
| 510 | }; |
| 511 | |
| 512 | //===----------------------------------------------------------------------===// |
| 513 | // ArmSME-specific fixup canonicalizations/folds |
| 514 | //===----------------------------------------------------------------------===// |
| 515 | |
| 516 | /// Folds an extract from a 3D `vector.create_mask` (which is a vector of |
| 517 | /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is |
| 518 | /// necessary for the mask to be lowered to ArmSME. |
| 519 | /// |
| 520 | /// Example: |
| 521 | /// |
| 522 | /// BEFORE: |
| 523 | /// ```mlir |
| 524 | /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> |
| 525 | /// %subMask = vector.extract %mask[2] |
| 526 | /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> |
| 527 | /// ``` |
| 528 | /// |
| 529 | /// AFTER: |
| 530 | /// ```mlir |
| 531 | /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index |
| 532 | /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index |
| 533 | /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> |
| 534 | /// ``` |
| 535 | struct |
| 536 | : public OpRewritePattern<vector::ExtractOp> { |
| 537 | using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; |
| 538 | |
| 539 | LogicalResult matchAndRewrite(vector::ExtractOp , |
| 540 | PatternRewriter &rewriter) const override { |
| 541 | auto loc = extractOp.getLoc(); |
| 542 | auto createMaskOp = |
| 543 | extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); |
| 544 | if (!createMaskOp) |
| 545 | return rewriter.notifyMatchFailure( |
| 546 | extractOp, "extract not from vector.create_mask op" ); |
| 547 | |
| 548 | VectorType = |
| 549 | llvm::dyn_cast<VectorType>(extractOp.getResult().getType()); |
| 550 | if (!extractedMaskType) |
| 551 | return rewriter.notifyMatchFailure(extractOp, |
| 552 | "extracted type is not a vector type" ); |
| 553 | |
| 554 | auto numScalable = extractedMaskType.getNumScalableDims(); |
| 555 | if (numScalable != 2) |
| 556 | return rewriter.notifyMatchFailure( |
| 557 | extractOp, "expected extracted type to be an SME-like mask" ); |
| 558 | |
| 559 | // TODO: Support multiple extraction indices. |
| 560 | if (extractOp.getStaticPosition().size() != 1) |
| 561 | return rewriter.notifyMatchFailure( |
| 562 | extractOp, "only a single extraction index is supported" ); |
| 563 | |
| 564 | auto frontMaskDim = createMaskOp.getOperand(0); |
| 565 | if (frontMaskDim.getDefiningOp<arith::ConstantOp>()) |
| 566 | return rewriter.notifyMatchFailure( |
| 567 | extractOp, |
| 568 | "constant vector.create_masks dims should be folded elsewhere" ); |
| 569 | |
| 570 | auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| 571 | auto = getValueOrCreateConstantIndexOp( |
| 572 | rewriter, loc, extractOp.getMixedPosition()[0]); |
| 573 | auto = rewriter.create<arith::CmpIOp>( |
| 574 | loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex, |
| 575 | frontMaskDim); |
| 576 | auto newMaskFrontDim = rewriter.create<arith::SelectOp>( |
| 577 | loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero); |
| 578 | |
| 579 | rewriter.replaceOpWithNewOp<vector::CreateMaskOp>( |
| 580 | extractOp, extractedMaskType, |
| 581 | ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)}); |
| 582 | return success(); |
| 583 | } |
| 584 | }; |
| 585 | |
| 586 | /// A vector type where no fixed dimension comes after a scalable dimension. |
| 587 | bool isLegalVectorType(VectorType vType) { |
| 588 | bool seenFixedDim = false; |
| 589 | for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) { |
| 590 | seenFixedDim |= !scalableFlag; |
| 591 | if (seenFixedDim && scalableFlag) |
| 592 | return false; |
| 593 | } |
| 594 | return true; |
| 595 | } |
| 596 | |
| 597 | /// Lifts an illegal vector.transpose and vector.transfer_read to a |
| 598 | /// memref.subview + memref.transpose, followed by a legal read. |
| 599 | /// |
| 600 | /// 'Illegal' here means a leading scalable dimension and a fixed trailing |
| 601 | /// dimension, which has no valid lowering. |
| 602 | /// |
| 603 | /// The memref.transpose is metadata-only transpose that produces a strided |
| 604 | /// memref, which eventually becomes a loop reading individual elements. |
| 605 | /// |
| 606 | /// Example: |
| 607 | /// |
| 608 | /// BEFORE: |
| 609 | /// ```mlir |
| 610 | /// %illegalRead = vector.transfer_read %memref[%a, %b] |
| 611 | /// : memref<?x?xf32>, vector<[8]x4xf32> |
| 612 | /// %legalType = vector.transpose %illegalRead, [1, 0] |
| 613 | /// : vector<[8]x4xf32> to vector<4x[8]xf32> |
| 614 | /// ``` |
| 615 | /// |
| 616 | /// AFTER: |
| 617 | /// ```mlir |
| 618 | /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1] |
| 619 | /// : memref<?x?xf32> to memref<?x?xf32> |
| 620 | /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0) |
| 621 | /// : memref<?x?xf32> to memref<?x?xf32> |
| 622 | /// %legalType = vector.transfer_read %transpose[%c0, %c0] |
| 623 | /// : memref<?x?xf32>, vector<4x[8]xf32> |
| 624 | /// ``` |
| 625 | struct LiftIllegalVectorTransposeToMemory |
| 626 | : public OpRewritePattern<vector::TransposeOp> { |
| 627 | using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; |
| 628 | |
| 629 | static Value getExtensionSource(Operation *op) { |
| 630 | if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op)) |
| 631 | return op->getOperand(idx: 0); |
| 632 | return {}; |
| 633 | } |
| 634 | |
| 635 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
| 636 | PatternRewriter &rewriter) const override { |
| 637 | auto sourceType = transposeOp.getSourceVectorType(); |
| 638 | auto resultType = transposeOp.getResultVectorType(); |
| 639 | if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) |
| 640 | return rewriter.notifyMatchFailure(transposeOp, |
| 641 | kMatchFailureNotIllegalToLegal); |
| 642 | |
| 643 | // Look through extend for transfer_read. |
| 644 | Value maybeRead = transposeOp.getVector(); |
| 645 | auto *transposeSourceOp = maybeRead.getDefiningOp(); |
| 646 | Operation *extendOp = nullptr; |
| 647 | if (Value extendSource = getExtensionSource(op: transposeSourceOp)) { |
| 648 | maybeRead = extendSource; |
| 649 | extendOp = transposeSourceOp; |
| 650 | } |
| 651 | |
| 652 | auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>(); |
| 653 | if (!illegalRead) |
| 654 | return rewriter.notifyMatchFailure( |
| 655 | transposeOp, |
| 656 | "expected source to be (possibly extended) transfer_read" ); |
| 657 | |
| 658 | if (!illegalRead.getPermutationMap().isIdentity()) |
| 659 | return rewriter.notifyMatchFailure( |
| 660 | illegalRead, "expected read to have identity permutation map" ); |
| 661 | |
| 662 | auto loc = transposeOp.getLoc(); |
| 663 | auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
| 664 | auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
| 665 | |
| 666 | // Create a subview that matches the size of the illegal read vector type. |
| 667 | auto readType = illegalRead.getVectorType(); |
| 668 | auto readSizes = llvm::map_to_vector( |
| 669 | llvm::zip_equal(readType.getShape(), readType.getScalableDims()), |
| 670 | [&](auto dim) -> Value { |
| 671 | auto [size, isScalable] = dim; |
| 672 | auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size); |
| 673 | if (!isScalable) |
| 674 | return dimSize; |
| 675 | auto vscale = rewriter.create<vector::VectorScaleOp>(loc); |
| 676 | return rewriter.create<arith::MulIOp>(loc, vscale, dimSize); |
| 677 | }); |
| 678 | SmallVector<Value> strides(readType.getRank(), Value(one)); |
| 679 | auto readSubview = rewriter.create<memref::SubViewOp>( |
| 680 | loc, illegalRead.getBase(), illegalRead.getIndices(), readSizes, |
| 681 | strides); |
| 682 | |
| 683 | // Apply the transpose to all values/attributes of the transfer_read: |
| 684 | // - The mask |
| 685 | Value mask = illegalRead.getMask(); |
| 686 | if (mask) { |
| 687 | // Note: The transpose for the mask should fold into the |
| 688 | // vector.create_mask/constant_mask op, which will then become legal. |
| 689 | mask = rewriter.create<vector::TransposeOp>(loc, mask, |
| 690 | transposeOp.getPermutation()); |
| 691 | } |
| 692 | // - The source memref |
| 693 | mlir::AffineMap transposeMap = AffineMap::getPermutationMap( |
| 694 | transposeOp.getPermutation(), getContext()); |
| 695 | auto transposedSubview = rewriter.create<memref::TransposeOp>( |
| 696 | loc, readSubview, AffineMapAttr::get(transposeMap)); |
| 697 | ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr(); |
| 698 | // - The `in_bounds` attribute |
| 699 | if (inBoundsAttr) { |
| 700 | SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(), |
| 701 | inBoundsAttr.end()); |
| 702 | applyPermutationToVector(inBoundsValues, transposeOp.getPermutation()); |
| 703 | inBoundsAttr = rewriter.getArrayAttr(inBoundsValues); |
| 704 | } |
| 705 | |
| 706 | VectorType legalReadType = resultType.clone(readType.getElementType()); |
| 707 | // Note: The indices are all zero as the subview is already offset. |
| 708 | SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero); |
| 709 | auto legalRead = rewriter.create<vector::TransferReadOp>( |
| 710 | loc, legalReadType, transposedSubview, readIndices, |
| 711 | illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask, |
| 712 | inBoundsAttr); |
| 713 | |
| 714 | // Replace the transpose with the new read, extending the result if |
| 715 | // necessary. |
| 716 | rewriter.replaceOp(transposeOp, [&]() -> Operation * { |
| 717 | if (extendOp) |
| 718 | return rewriter.create(loc, extendOp->getName().getIdentifier(), |
| 719 | Value(legalRead), resultType); |
| 720 | return legalRead; |
| 721 | }()); |
| 722 | |
| 723 | return success(); |
| 724 | } |
| 725 | }; |
| 726 | |
| 727 | /// A rewrite to turn unit dim transpose-like vector.shape_casts into |
| 728 | /// vector.transposes. The shape_cast has to be from an illegal vector type to a |
| 729 | /// legal one (as defined by isLegalVectorType). |
| 730 | /// |
| 731 | /// The reasoning for this is if we've got to this pass and we still have |
| 732 | /// shape_casts of illegal types, then they likely will not cancel out. Turning |
| 733 | /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to |
| 734 | /// eliminate them. |
| 735 | /// |
| 736 | /// Example: |
| 737 | /// |
| 738 | /// BEFORE: |
| 739 | /// ```mlir |
| 740 | /// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32> |
| 741 | /// ``` |
| 742 | /// |
| 743 | /// AFTER: |
| 744 | /// ```mlir |
| 745 | /// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> |
| 746 | /// ``` |
| 747 | struct ConvertIllegalShapeCastOpsToTransposes |
| 748 | : public OpRewritePattern<vector::ShapeCastOp> { |
| 749 | using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; |
| 750 | |
| 751 | LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, |
| 752 | PatternRewriter &rewriter) const override { |
| 753 | auto sourceType = shapeCastOp.getSourceVectorType(); |
| 754 | auto resultType = shapeCastOp.getResultVectorType(); |
| 755 | if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) |
| 756 | return rewriter.notifyMatchFailure(shapeCastOp, |
| 757 | kMatchFailureNotIllegalToLegal); |
| 758 | |
| 759 | // Note: If we know that `sourceType` is an illegal vector type (and 2D) |
| 760 | // then dim 0 is scalable and dim 1 is fixed. |
| 761 | if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1) |
| 762 | return rewriter.notifyMatchFailure( |
| 763 | shapeCastOp, "expected source to be a 2D scalable vector with a " |
| 764 | "trailing unit dim" ); |
| 765 | |
| 766 | auto loc = shapeCastOp.getLoc(); |
| 767 | auto transpose = rewriter.create<vector::TransposeOp>( |
| 768 | loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0}); |
| 769 | |
| 770 | if (resultType.getRank() == 1) |
| 771 | rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType, |
| 772 | transpose); |
| 773 | else |
| 774 | rewriter.replaceOp(shapeCastOp, transpose); |
| 775 | |
| 776 | return success(); |
| 777 | } |
| 778 | }; |
| 779 | |
| 780 | /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use |
| 781 | /// the ZA state. This workaround rewrite to support these transposes when ZA is |
| 782 | /// available. |
| 783 | /// |
| 784 | /// Example: |
| 785 | /// |
| 786 | /// BEFORE: |
| 787 | /// ```mlir |
| 788 | /// %transpose = vector.transpose %vec, [1, 0] |
| 789 | /// : vector<2x[4]xf32> to vector<[4]x2xf32> |
| 790 | /// vector.transfer_write %transpose, %dest[%y, %x] |
| 791 | /// : vector<[4]x2xf32>, memref<?x?xf32> |
| 792 | /// ``` |
| 793 | /// |
| 794 | /// AFTER: |
| 795 | /// ```mlir |
| 796 | /// %0 = arm_sme.get_tile : vector<[4]x[4]xf32> |
| 797 | /// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32> |
| 798 | /// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32> |
| 799 | /// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32> |
| 800 | /// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32> |
| 801 | /// %c4_vscale = arith.muli %vscale, %c4 : index |
| 802 | /// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1> |
| 803 | /// vector.transfer_write %4, %dest[%y, %x], %mask |
| 804 | /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} |
| 805 | /// : vector<[4]x[4]xf32>, memref<?x?xf32> |
| 806 | /// ``` |
| 807 | /// |
| 808 | /// Values larger than a single tile are supported via decomposition. |
| 809 | struct LowerIllegalTransposeStoreViaZA |
| 810 | : public OpRewritePattern<vector::TransferWriteOp> { |
| 811 | using OpRewritePattern::OpRewritePattern; |
| 812 | |
| 813 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
| 814 | PatternRewriter &rewriter) const override { |
| 815 | if (!isSupportedMaskOp(writeOp.getMask())) |
| 816 | return rewriter.notifyMatchFailure(writeOp, |
| 817 | kMatchFailureUnsupportedMaskOp); |
| 818 | |
| 819 | auto permutationMap = writeOp.getPermutationMap(); |
| 820 | if (!permutationMap.isIdentity()) |
| 821 | return rewriter.notifyMatchFailure(writeOp, |
| 822 | kMatchFailureNonPermutationMap); |
| 823 | |
| 824 | auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>(); |
| 825 | if (!transposeOp) |
| 826 | return failure(); |
| 827 | |
| 828 | auto sourceType = transposeOp.getSourceVectorType(); |
| 829 | auto resultType = transposeOp.getResultVectorType(); |
| 830 | |
| 831 | if (resultType.getRank() != 2) |
| 832 | return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2" ); |
| 833 | |
| 834 | if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType)) |
| 835 | return rewriter.notifyMatchFailure( |
| 836 | transposeOp, "not illegal/unsupported SVE transpose" ); |
| 837 | |
| 838 | auto smeTileType = getSMETileTypeForElement(resultType.getElementType()); |
| 839 | VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0); |
| 840 | |
| 841 | if (sourceType.getDimSize(0) <= 1 || |
| 842 | sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0) |
| 843 | return rewriter.notifyMatchFailure(writeOp, "unsupported source shape" ); |
| 844 | |
| 845 | auto loc = writeOp.getLoc(); |
| 846 | auto createVscaleMultiple = |
| 847 | vector::makeVscaleConstantBuilder(rewriter, loc: loc); |
| 848 | |
| 849 | auto transposeMap = AffineMapAttr::get( |
| 850 | AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext())); |
| 851 | |
| 852 | // Note: We need to use `get_tile` as there's no vector-level `undef`. |
| 853 | Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType); |
| 854 | Value destTensorOrMemref = writeOp.getBase(); |
| 855 | auto numSlicesPerTile = |
| 856 | std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0)); |
| 857 | auto numSlices = |
| 858 | rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile); |
| 859 | for (auto [index, smeTile] : llvm::enumerate( |
| 860 | decomposeToSMETiles(rewriter, sourceType, smeTileType))) { |
| 861 | // 1. _Deliberately_ drop a scalable dimension and insert a fixed number |
| 862 | // of slices from the source type into the SME tile. Without checking |
| 863 | // vscale (and emitting multiple implementations) we can't make use of the |
| 864 | // rows of the tile after 1*vscale rows. |
| 865 | Value tile = undefTile; |
| 866 | for (int d = 0; d < numSlicesPerTile; ++d) { |
| 867 | Value vector = rewriter.create<vector::ExtractOp>( |
| 868 | loc, transposeOp.getVector(), |
| 869 | rewriter.getIndexAttr(d + smeTile.row)); |
| 870 | if (vector.getType() != smeSliceType) { |
| 871 | vector = rewriter.create<vector::ScalableExtractOp>( |
| 872 | loc, smeSliceType, vector, smeTile.col); |
| 873 | } |
| 874 | tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d); |
| 875 | } |
| 876 | |
| 877 | // 2. Transpose the tile position. |
| 878 | auto transposedRow = createVscaleMultiple(smeTile.col); |
| 879 | auto transposedCol = |
| 880 | rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row); |
| 881 | |
| 882 | // 3. Compute mask for tile store. |
| 883 | Value maskRows; |
| 884 | Value maskCols; |
| 885 | if (auto mask = writeOp.getMask()) { |
| 886 | auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); |
| 887 | maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0), |
| 888 | transposedRow); |
| 889 | maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1), |
| 890 | transposedCol); |
| 891 | maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices); |
| 892 | } else { |
| 893 | maskRows = createVscaleMultiple(smeTileType.getDimSize(0)); |
| 894 | maskCols = numSlices; |
| 895 | } |
| 896 | auto subMask = rewriter.create<vector::CreateMaskOp>( |
| 897 | loc, smeTileType.clone(rewriter.getI1Type()), |
| 898 | ValueRange{maskRows, maskCols}); |
| 899 | |
| 900 | // 4. Emit a transposed tile write. |
| 901 | auto writeIndices = writeOp.getIndices(); |
| 902 | Value destRow = |
| 903 | rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]); |
| 904 | Value destCol = |
| 905 | rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]); |
| 906 | auto smeWrite = rewriter.create<vector::TransferWriteOp>( |
| 907 | loc, tile, destTensorOrMemref, ValueRange{destRow, destCol}, |
| 908 | transposeMap, subMask, writeOp.getInBounds()); |
| 909 | |
| 910 | if (writeOp.hasPureTensorSemantics()) |
| 911 | destTensorOrMemref = smeWrite.getResult(); |
| 912 | } |
| 913 | |
| 914 | if (writeOp.hasPureTensorSemantics()) |
| 915 | rewriter.replaceOp(writeOp, destTensorOrMemref); |
| 916 | else |
| 917 | rewriter.eraseOp(op: writeOp); |
| 918 | |
| 919 | return success(); |
| 920 | } |
| 921 | }; |
| 922 | |
| 923 | struct VectorLegalizationPass |
| 924 | : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> { |
| 925 | void runOnOperation() override { |
| 926 | auto *context = &getContext(); |
| 927 | TypeConverter converter; |
| 928 | RewritePatternSet patterns(context); |
| 929 | converter.addConversion(callback: [](Type type) { return type; }); |
| 930 | converter.addConversion( |
| 931 | callback: [](VectorType vectorType, |
| 932 | SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> { |
| 933 | if (!isMultipleOfSMETileVectorType(vectorType)) |
| 934 | return std::nullopt; |
| 935 | auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType); |
| 936 | auto smeTileType = |
| 937 | getSMETileTypeForElement(vectorType.getElementType()); |
| 938 | types = SmallVector<Type>(smeTileCount, smeTileType); |
| 939 | return success(); |
| 940 | }); |
| 941 | |
| 942 | // Apply preprocessing patterns. |
| 943 | RewritePatternSet rewritePatterns(context); |
| 944 | rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks, |
| 945 | LiftIllegalVectorTransposeToMemory, |
| 946 | ConvertIllegalShapeCastOpsToTransposes, |
| 947 | LowerIllegalTransposeStoreViaZA>(context); |
| 948 | if (failed( |
| 949 | applyPatternsGreedily(getOperation(), std::move(rewritePatterns)))) |
| 950 | return signalPassFailure(); |
| 951 | |
| 952 | // Note: These two patterns are added with a high benefit to ensure: |
| 953 | // - Masked outer products are handled before unmasked ones |
| 954 | // - Multi-tile writes are lowered as a store loop (if possible) |
| 955 | patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition, |
| 956 | LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context, |
| 957 | /*benefit=*/1024); |
| 958 | patterns.add<LegalizeArithConstantOpsByDecomposition, |
| 959 | LegalizeVectorOuterProductOpsByDecomposition, |
| 960 | LegalizeTransferReadOpsByDecomposition, |
| 961 | LegalizeTransferWriteOpsByDecomposition>(converter, context); |
| 962 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
| 963 | converter); |
| 964 | populateCallOpTypeConversionPattern(patterns, converter); |
| 965 | populateReturnOpTypeConversionPattern(patterns, converter); |
| 966 | scf::populateSCFStructuralTypeConversions(typeConverter: converter, patterns); |
| 967 | |
| 968 | ConversionTarget target(getContext()); |
| 969 | target.markUnknownOpDynamicallyLegal( |
| 970 | fn: [&](Operation *op) { return converter.isLegal(op); }); |
| 971 | target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
| 972 | return converter.isSignatureLegal(op.getFunctionType()); |
| 973 | }); |
| 974 | if (failed(applyPartialConversion(getOperation(), target, |
| 975 | std::move(patterns)))) |
| 976 | return signalPassFailure(); |
| 977 | } |
| 978 | }; |
| 979 | |
| 980 | } // namespace |
| 981 | |
| 982 | std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() { |
| 983 | return std::make_unique<VectorLegalizationPass>(); |
| 984 | } |
| 985 | |