| 1 | //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This pass legalizes vector operations so they can be lowered to ArmSME. |
| 10 | // |
| 11 | // Note: In the context of this pass 'tile' always refers to an SME tile. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
| 16 | #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
| 17 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
| 18 | #include "mlir/Dialect/ArmSME/Utils/Utils.h" |
| 19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 20 | #include "mlir/Dialect/Func/Transforms/FuncConversions.h" |
| 21 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
| 22 | #include "mlir/Dialect/Index/IR/IndexOps.h" |
| 23 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 25 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| 26 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 27 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 28 | #include "mlir/Transforms/DialectConversion.h" |
| 29 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 30 | |
| 31 | #define DEBUG_TYPE "arm-sme-vector-legalization" |
| 32 | |
| 33 | namespace mlir::arm_sme { |
| 34 | #define GEN_PASS_DEF_VECTORLEGALIZATION |
| 35 | #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
| 36 | } // namespace mlir::arm_sme |
| 37 | |
| 38 | using namespace mlir; |
| 39 | using namespace mlir::arm_sme; |
| 40 | |
| 41 | namespace { |
| 42 | |
| 43 | //===----------------------------------------------------------------------===// |
| 44 | // Decomposition of vector operations larger than an SME tile |
| 45 | //===----------------------------------------------------------------------===// |
| 46 | |
| 47 | // Common match failure reasons. |
| 48 | static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple( |
| 49 | "op vector size is not multiple of SME tiles" ); |
| 50 | static constexpr StringLiteral kMatchFailureUnsupportedMaskOp( |
| 51 | "op mask is unsupported for legalization/decomposition" ); |
| 52 | static constexpr StringLiteral |
| 53 | kMatchFailureNonPermutationMap("op affine map is not a permutation" ); |
| 54 | static constexpr StringLiteral kMatchFailureNotIllegalToLegal( |
| 55 | "expected transpose from illegal type to legal type" ); |
| 56 | |
| 57 | /// An SMESubTile represents a single SME-sized sub-tile from decomposing a |
| 58 | /// larger vector type. The (`row`, `col`) are the position of the tile in the |
| 59 | /// original vector type. For example for an [8]x[8] tile with four [4]x[4] |
| 60 | /// sub-tiles, we would have: |
| 61 | /// |
| 62 | /// 8 x vscale |
| 63 | /// ┌─────────────┬─────────────┐ |
| 64 | /// │(0,0) │(0,4) │ |
| 65 | /// │ │ │ |
| 66 | /// ├─────────────┼─────────────┤ 8 x vscale |
| 67 | /// │(4,0) │(4,4) │ |
| 68 | /// │ │ │ |
| 69 | /// └─────────────┴─────────────┘ |
| 70 | struct SMESubTile { |
| 71 | // Note: The units of (row, col) are vscale (as SME tiles are scalable). |
| 72 | int row{0}; |
| 73 | int col{0}; |
| 74 | // The SME tile type. |
| 75 | VectorType type; |
| 76 | }; |
| 77 | |
| 78 | /// Adds a constant elementwise scalable offset to `indices` (which are of equal |
| 79 | /// length). For example, in the 2D case this would return: |
| 80 | // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale } |
| 81 | SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder, |
| 82 | Location loc, |
| 83 | ValueRange indices, |
| 84 | ArrayRef<int> scalableOffsets) { |
| 85 | auto vscale = builder.create<vector::VectorScaleOp>(location: loc); |
| 86 | return llvm::map_to_vector( |
| 87 | C: llvm::zip_equal(t&: indices, u&: scalableOffsets), F: [&](auto pair) -> Value { |
| 88 | auto [index, base] = pair; |
| 89 | auto offset = builder.create<arith::MulIOp>( |
| 90 | loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale); |
| 91 | return builder.create<arith::AddIOp>(loc, index, offset); |
| 92 | }); |
| 93 | } |
| 94 | |
| 95 | /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to |
| 96 | /// indices for one of the SME sub-tiles it will decompose into. |
| 97 | /// |
| 98 | /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the |
| 99 | /// indices for each tile would need to be adjusted as follows: |
| 100 | /// |
| 101 | /// initial indices = [a,b], inital size = 8x8, target size = 4x4 |
| 102 | /// ┌─────────────┬─────────────┐ |
| 103 | /// │[a,b] │[a,b+4] │ |
| 104 | /// │ │ │ |
| 105 | /// ├─────────────┼─────────────┤ |
| 106 | /// │[a+4,b] │[a+4,b+4] │ |
| 107 | /// │ │ │ |
| 108 | /// └─────────────┴─────────────┘ |
| 109 | SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc, |
| 110 | ValueRange indices, |
| 111 | SMESubTile smeTile) { |
| 112 | return addConstantScalableOffset(builder, loc, indices, |
| 113 | scalableOffsets: {smeTile.row, smeTile.col}); |
| 114 | } |
| 115 | |
| 116 | /// Returns true if `mask` is generated by an operation that can be decomposed |
| 117 | /// for SME. Currently, that is just no mask, or vector.create_mask. |
| 118 | /// TODO: Add support for vector.constant_mask once required for SME. |
| 119 | bool isSupportedMaskOp(Value mask) { |
| 120 | return !mask || mask.getDefiningOp<vector::CreateMaskOp>(); |
| 121 | } |
| 122 | |
| 123 | /// Extracts a mask for an SME sub-tile from the mask of a larger vector type. |
| 124 | Value (OpBuilder &builder, Location loc, Value mask, |
| 125 | SMESubTile smeTile) { |
| 126 | assert(isSupportedMaskOp(mask)); |
| 127 | if (!mask) |
| 128 | return Value{}; |
| 129 | auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); |
| 130 | // The operands of `vector.create_mask` (from a 2D perspective) are the |
| 131 | // coordinates where the mask ends. So we subtract where this tile starts, |
| 132 | // from the mask operands to get the parameters for this sub-tile. |
| 133 | auto smeTileMaskDims = addConstantScalableOffset( |
| 134 | builder, loc, indices: createMask.getOperands(), scalableOffsets: {-smeTile.row, -smeTile.col}); |
| 135 | auto smeTileCreateMask = builder.create<vector::CreateMaskOp>( |
| 136 | location: loc, args: smeTile.type.clone(elementType: builder.getI1Type()), args&: smeTileMaskDims); |
| 137 | return smeTileCreateMask.getResult(); |
| 138 | } |
| 139 | |
| 140 | /// Constructs an iterator that returns each SME tile (with coordinates) |
| 141 | /// contained within a VectorType. For example, if decomposing an [8]x[8] into |
| 142 | /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0), |
| 143 | /// (4, 4). |
| 144 | auto decomposeToSMETiles(OpBuilder &builder, VectorType type, |
| 145 | VectorType smeTileType, |
| 146 | bool transposeIndices = false) { |
| 147 | return llvm::map_range( |
| 148 | C: StaticTileOffsetRange( |
| 149 | type.getShape(), |
| 150 | {std::min(a: type.getDimSize(idx: 0), b: smeTileType.getDimSize(idx: 0)), |
| 151 | std::min(a: type.getDimSize(idx: 1), b: smeTileType.getDimSize(idx: 1))}), |
| 152 | F: [=](auto indices) { |
| 153 | int row = int(indices[0]); |
| 154 | int col = int(indices[1]); |
| 155 | if (transposeIndices) |
| 156 | std::swap(a&: row, b&: col); |
| 157 | return SMESubTile{.row: row, .col: col, .type: smeTileType}; |
| 158 | }); |
| 159 | } |
| 160 | |
| 161 | /// Returns the number of SME tiles that fit into the (2D-scalable) vector type |
| 162 | /// `type`. |
| 163 | int getNumberOfSMETilesForVectorType(VectorType type) { |
| 164 | assert(isMultipleOfSMETileVectorType(type) && |
| 165 | "`type` not multiple of SME tiles" ); |
| 166 | int64_t vectorRows = type.getDimSize(idx: 0); |
| 167 | int64_t vectorCols = type.getDimSize(idx: 1); |
| 168 | auto elementType = type.getElementType(); |
| 169 | unsigned minNumElts = getSMETileSliceMinNumElts(type: elementType); |
| 170 | return (vectorRows * vectorCols) / (minNumElts * minNumElts); |
| 171 | } |
| 172 | |
| 173 | /// Legalize `arith.constant dense<value>` splat operations to fit within SME |
| 174 | /// tiles by decomposing them into tile-sized operations. |
| 175 | struct LegalizeArithConstantOpsByDecomposition |
| 176 | : public OpConversionPattern<arith::ConstantOp> { |
| 177 | using OpConversionPattern::OpConversionPattern; |
| 178 | |
| 179 | LogicalResult |
| 180 | matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, |
| 181 | ConversionPatternRewriter &rewriter) const override { |
| 182 | auto vectorType = dyn_cast<VectorType>(Val: constantOp.getType()); |
| 183 | auto denseAttr = dyn_cast<DenseElementsAttr>(Val: constantOp.getValueAttr()); |
| 184 | if (!vectorType || !denseAttr || !denseAttr.isSplat()) |
| 185 | return failure(); |
| 186 | |
| 187 | if (!isMultipleOfSMETileVectorType(vType: vectorType)) |
| 188 | return rewriter.notifyMatchFailure(arg&: constantOp, |
| 189 | msg: kMatchFailureNotSMETileTypeMultiple); |
| 190 | |
| 191 | auto smeTileType = getSMETileTypeForElement(elementType: vectorType.getElementType()); |
| 192 | auto tileCount = getNumberOfSMETilesForVectorType(type: vectorType); |
| 193 | auto tileSplat = rewriter.create<arith::ConstantOp>( |
| 194 | location: constantOp.getLoc(), args: denseAttr.resizeSplat(newType: smeTileType)); |
| 195 | SmallVector<Value> repl(tileCount, tileSplat); |
| 196 | rewriter.replaceOpWithMultiple(op: constantOp, newValues: {repl}); |
| 197 | |
| 198 | return success(); |
| 199 | } |
| 200 | }; |
| 201 | |
| 202 | /// Legalize `vector.outerproduct` operations to fit within SME tiles by |
| 203 | /// decomposing them into tile-sized operations. |
| 204 | struct LegalizeVectorOuterProductOpsByDecomposition |
| 205 | : public OpConversionPattern<vector::OuterProductOp> { |
| 206 | using OpConversionPattern::OpConversionPattern; |
| 207 | |
| 208 | LogicalResult |
| 209 | matchAndRewrite(vector::OuterProductOp outerProductOp, |
| 210 | OneToNOpAdaptor adaptor, |
| 211 | ConversionPatternRewriter &rewriter) const override { |
| 212 | auto vectorType = outerProductOp.getResultVectorType(); |
| 213 | if (!isMultipleOfSMETileVectorType(vType: vectorType)) |
| 214 | return rewriter.notifyMatchFailure(arg&: outerProductOp, |
| 215 | msg: kMatchFailureNotSMETileTypeMultiple); |
| 216 | |
| 217 | Value mask; |
| 218 | Operation *rootOp = outerProductOp; |
| 219 | auto loc = outerProductOp.getLoc(); |
| 220 | if (outerProductOp.isMasked()) { |
| 221 | auto maskOp = outerProductOp.getMaskingOp(); |
| 222 | mask = maskOp.getMask(); |
| 223 | rootOp = maskOp; |
| 224 | rewriter.setInsertionPoint(rootOp); |
| 225 | } |
| 226 | |
| 227 | if (!isSupportedMaskOp(mask)) |
| 228 | return rewriter.notifyMatchFailure(arg&: outerProductOp, |
| 229 | msg: kMatchFailureUnsupportedMaskOp); |
| 230 | |
| 231 | ValueRange accSMETiles = adaptor.getAcc(); |
| 232 | auto smeTileType = getSMETileTypeForElement(elementType: vectorType.getElementType()); |
| 233 | VectorType sliceType = VectorType::Builder(smeTileType).dropDim(pos: 0); |
| 234 | |
| 235 | SmallVector<Value> resultSMETiles; |
| 236 | for (auto [index, smeTile] : llvm::enumerate( |
| 237 | First: decomposeToSMETiles(builder&: rewriter, type: vectorType, smeTileType))) { |
| 238 | |
| 239 | auto smeMask = extractSMEMask(builder&: rewriter, loc, mask, smeTile); |
| 240 | auto lhs = rewriter.create<vector::ScalableExtractOp>( |
| 241 | location: loc, args&: sliceType, args: outerProductOp.getLhs(), args&: smeTile.row); |
| 242 | auto rhs = rewriter.create<vector::ScalableExtractOp>( |
| 243 | location: loc, args&: sliceType, args: outerProductOp.getRhs(), args&: smeTile.col); |
| 244 | auto smeOuterProduct = rewriter.create<vector::OuterProductOp>( |
| 245 | location: loc, args&: smeTileType, args&: lhs, args&: rhs, |
| 246 | args: !accSMETiles.empty() ? accSMETiles[index] : Value{}, |
| 247 | args: outerProductOp.getKind()); |
| 248 | |
| 249 | auto maskedOuterProduct = |
| 250 | vector::maskOperation(builder&: rewriter, maskableOp: smeOuterProduct, mask: smeMask); |
| 251 | resultSMETiles.push_back(Elt: maskedOuterProduct->getResult(idx: 0)); |
| 252 | } |
| 253 | |
| 254 | rewriter.replaceOpWithMultiple(op: rootOp, newValues: {resultSMETiles}); |
| 255 | return success(); |
| 256 | } |
| 257 | }; |
| 258 | |
| 259 | // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to |
| 260 | // get the help of the type conversion), but doing so results in the type |
| 261 | // conversion adding target materializations in the `vector.mask` region |
| 262 | // (invalid). This pattern matches on `vector.mask` then calls into the |
| 263 | // `vector.outerproduct` pattern to work around this issue. |
| 264 | struct LegalizeMaskedVectorOuterProductOpsByDecomposition |
| 265 | : public OpConversionPattern<vector::MaskOp> { |
| 266 | using OpConversionPattern::OpConversionPattern; |
| 267 | |
| 268 | LogicalResult |
| 269 | matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor, |
| 270 | ConversionPatternRewriter &rewriter) const override { |
| 271 | if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>( |
| 272 | Val: maskOp.getMaskableOp())) { |
| 273 | LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(), |
| 274 | getContext()); |
| 275 | return static_cast<RewritePattern &>(pattern).matchAndRewrite( |
| 276 | op: outerProductOp, rewriter); |
| 277 | } |
| 278 | return failure(); |
| 279 | } |
| 280 | }; |
| 281 | |
| 282 | /// Legalize `vector.transfer_read` operations to fit within SME tiles by |
| 283 | /// decomposing them into tile-sized operations. |
| 284 | struct LegalizeTransferReadOpsByDecomposition |
| 285 | : public OpConversionPattern<vector::TransferReadOp> { |
| 286 | using OpConversionPattern::OpConversionPattern; |
| 287 | |
| 288 | LogicalResult |
| 289 | matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor, |
| 290 | ConversionPatternRewriter &rewriter) const override { |
| 291 | auto vectorType = readOp.getVectorType(); |
| 292 | if (!isMultipleOfSMETileVectorType(vType: vectorType)) |
| 293 | return rewriter.notifyMatchFailure(arg&: readOp, |
| 294 | msg: kMatchFailureNotSMETileTypeMultiple); |
| 295 | |
| 296 | auto mask = readOp.getMask(); |
| 297 | if (!isSupportedMaskOp(mask)) |
| 298 | return rewriter.notifyMatchFailure(arg&: readOp, |
| 299 | msg: kMatchFailureUnsupportedMaskOp); |
| 300 | |
| 301 | auto permutationMap = readOp.getPermutationMap(); |
| 302 | if (!permutationMap.isPermutation()) |
| 303 | return rewriter.notifyMatchFailure(arg&: readOp, |
| 304 | msg: kMatchFailureNonPermutationMap); |
| 305 | |
| 306 | // Note: For 2D vector types the only non-identity permutation is a simple |
| 307 | // transpose [1, 0]. |
| 308 | bool transposed = !permutationMap.isIdentity(); |
| 309 | |
| 310 | auto loc = readOp.getLoc(); |
| 311 | auto smeTileType = getSMETileTypeForElement(elementType: vectorType.getElementType()); |
| 312 | |
| 313 | SmallVector<Value> resultSMETiles; |
| 314 | for (SMESubTile smeTile : |
| 315 | decomposeToSMETiles(builder&: rewriter, type: vectorType, smeTileType, transposeIndices: transposed)) { |
| 316 | auto smeMask = extractSMEMask(builder&: rewriter, loc, mask, smeTile); |
| 317 | auto smeRead = rewriter.create<vector::TransferReadOp>( |
| 318 | location: loc, args&: smeTileType, args: readOp.getBase(), |
| 319 | args: getSMESubTileIndices(builder&: rewriter, loc, indices: readOp.getIndices(), smeTile), |
| 320 | args: readOp.getPermutationMapAttr(), args: readOp.getPadding(), args&: smeMask, |
| 321 | args: readOp.getInBoundsAttr()); |
| 322 | resultSMETiles.push_back(Elt: smeRead); |
| 323 | } |
| 324 | |
| 325 | rewriter.replaceOpWithMultiple(op: readOp, newValues: {resultSMETiles}); |
| 326 | return success(); |
| 327 | } |
| 328 | }; |
| 329 | |
| 330 | /// Legalize `vector.transfer_write` operations to fit within SME tiles by |
| 331 | /// decomposing them into tile-sized operations. |
| 332 | struct LegalizeTransferWriteOpsByDecomposition |
| 333 | : public OpConversionPattern<vector::TransferWriteOp> { |
| 334 | using OpConversionPattern::OpConversionPattern; |
| 335 | |
| 336 | LogicalResult |
| 337 | matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, |
| 338 | ConversionPatternRewriter &rewriter) const override { |
| 339 | auto vectorType = writeOp.getVectorType(); |
| 340 | if (!isMultipleOfSMETileVectorType(vType: vectorType)) |
| 341 | return rewriter.notifyMatchFailure(arg&: writeOp, |
| 342 | msg: kMatchFailureNotSMETileTypeMultiple); |
| 343 | |
| 344 | auto mask = writeOp.getMask(); |
| 345 | if (!isSupportedMaskOp(mask)) |
| 346 | return rewriter.notifyMatchFailure(arg&: writeOp, |
| 347 | msg: kMatchFailureUnsupportedMaskOp); |
| 348 | |
| 349 | auto permutationMap = writeOp.getPermutationMap(); |
| 350 | if (!permutationMap.isPermutation()) |
| 351 | return rewriter.notifyMatchFailure(arg&: writeOp, |
| 352 | msg: kMatchFailureNonPermutationMap); |
| 353 | |
| 354 | // Note: For 2D vector types the only non-identity permutation is a simple |
| 355 | // transpose [1, 0]. |
| 356 | bool transposed = !permutationMap.isIdentity(); |
| 357 | |
| 358 | auto loc = writeOp.getLoc(); |
| 359 | auto smeTileType = getSMETileTypeForElement(elementType: vectorType.getElementType()); |
| 360 | auto inputSMETiles = adaptor.getValueToStore(); |
| 361 | |
| 362 | Value destTensorOrMemref = writeOp.getBase(); |
| 363 | for (auto [index, smeTile] : llvm::enumerate(First: decomposeToSMETiles( |
| 364 | builder&: rewriter, type: vectorType, smeTileType, transposeIndices: transposed))) { |
| 365 | auto smeMask = extractSMEMask(builder&: rewriter, loc, mask, smeTile); |
| 366 | auto smeWrite = rewriter.create<vector::TransferWriteOp>( |
| 367 | location: loc, args: inputSMETiles[index], args&: destTensorOrMemref, |
| 368 | args: getSMESubTileIndices(builder&: rewriter, loc, indices: writeOp.getIndices(), smeTile), |
| 369 | args: writeOp.getPermutationMapAttr(), args&: smeMask, args: writeOp.getInBoundsAttr()); |
| 370 | if (writeOp.hasPureTensorSemantics()) |
| 371 | destTensorOrMemref = smeWrite.getResult(); |
| 372 | } |
| 373 | |
| 374 | if (writeOp.hasPureTensorSemantics()) |
| 375 | rewriter.replaceOp(op: writeOp, newValues: destTensorOrMemref); |
| 376 | else |
| 377 | rewriter.eraseOp(op: writeOp); |
| 378 | |
| 379 | return success(); |
| 380 | } |
| 381 | }; |
| 382 | |
| 383 | /// Legalize a multi-tile transfer_write as a single store loop. This is done as |
| 384 | /// part of type decomposition as at this level we know each tile write is |
| 385 | /// disjoint, but that information is lost after decomposition (without analysis |
| 386 | /// to reconstruct it). |
| 387 | /// |
| 388 | /// Example (pseudo-MLIR): |
| 389 | /// |
| 390 | /// ``` |
| 391 | /// vector.transfer_write %vector, %dest[%y, %x], %mask |
| 392 | /// : vector<[16]x[8]xi16>, memref<?x?xi16> |
| 393 | /// ``` |
| 394 | /// Is rewritten to: |
| 395 | /// ``` |
| 396 | /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 { |
| 397 | /// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐ |
| 398 | /// : vector<[8]xi1> from vector<[16]x[8]xi1> | |
| 399 | /// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile |
| 400 | /// : vector<[8]xi16> from vector<[8]x[8]xi16> | |
| 401 | /// vector.transfer_write %upper_slice, | |
| 402 | /// %dest[%slice_idx + %y, %x], %upper_slice_mask | |
| 403 | /// : vector<[8]xi16>, memref<?x?xi16> ┘ |
| 404 | /// %lower_slice_idx = %slice_idx + %c8_vscale ─┐ |
| 405 | /// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] | |
| 406 | /// : vector<[8]xi1> from vector<[16]x[8]xi1> | |
| 407 | /// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower |
| 408 | /// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile |
| 409 | /// vector.transfer_write %lower_slice, | |
| 410 | /// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask | |
| 411 | /// : vector<[8]xi16>, memref<?x?xi16> ┘ |
| 412 | /// } |
| 413 | /// ``` |
| 414 | struct LegalizeMultiTileTransferWriteAsStoreLoop |
| 415 | : public OpConversionPattern<vector::TransferWriteOp> { |
| 416 | using OpConversionPattern::OpConversionPattern; |
| 417 | |
| 418 | LogicalResult |
| 419 | matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, |
| 420 | ConversionPatternRewriter &rewriter) const override { |
| 421 | if (writeOp.hasPureTensorSemantics()) |
| 422 | return rewriter.notifyMatchFailure( |
| 423 | arg&: writeOp, msg: "TODO: tensor semantics are unsupported" ); |
| 424 | |
| 425 | auto permutationMap = writeOp.getPermutationMap(); |
| 426 | if (!permutationMap.isPermutation()) |
| 427 | return rewriter.notifyMatchFailure(arg&: writeOp, |
| 428 | msg: kMatchFailureNonPermutationMap); |
| 429 | |
| 430 | bool transposed = !permutationMap.isIdentity(); |
| 431 | if (transposed) |
| 432 | return rewriter.notifyMatchFailure(arg&: writeOp, |
| 433 | msg: "TODO: transpose unsupported" ); |
| 434 | |
| 435 | auto vectorType = writeOp.getVectorType(); |
| 436 | if (!isMultipleOfSMETileVectorType(vType: vectorType)) |
| 437 | return rewriter.notifyMatchFailure(arg&: writeOp, |
| 438 | msg: kMatchFailureNotSMETileTypeMultiple); |
| 439 | |
| 440 | // Note: We also disallow masks where any dimension is > 16 because that |
| 441 | // prevents the masking from being lowered to use arm_sve.psel. |
| 442 | auto mask = writeOp.getMask(); |
| 443 | if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(idx: 0) > 16 || |
| 444 | vectorType.getDimSize(idx: 1) > 16))) |
| 445 | return rewriter.notifyMatchFailure(arg&: writeOp, |
| 446 | msg: kMatchFailureUnsupportedMaskOp); |
| 447 | |
| 448 | auto loc = writeOp.getLoc(); |
| 449 | auto createVscaleMultiple = |
| 450 | vector::makeVscaleConstantBuilder(rewriter, loc); |
| 451 | |
| 452 | // Get SME tile and slice types. |
| 453 | auto smeTileType = getSMETileTypeForElement(elementType: vectorType.getElementType()); |
| 454 | auto minTileSlices = smeTileType.getDimSize(idx: 0); |
| 455 | VectorType sliceMaskType = |
| 456 | VectorType::get(shape: minTileSlices, elementType: rewriter.getI1Type(), scalableDims: true); |
| 457 | |
| 458 | // Create loop over all tile slices. |
| 459 | auto lowerBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 460 | auto upperBound = createVscaleMultiple(minTileSlices); |
| 461 | auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 462 | auto storeLoop = |
| 463 | rewriter.create<scf::ForOp>(location: loc, args&: lowerBound, args&: upperBound, args&: step); |
| 464 | rewriter.setInsertionPointToStart(storeLoop.getBody()); |
| 465 | |
| 466 | // For each sub-tile of the multi-tile `vectorType`. |
| 467 | auto inputSMETiles = adaptor.getValueToStore(); |
| 468 | auto tileSliceIndex = storeLoop.getInductionVar(); |
| 469 | for (auto [index, smeTile] : llvm::enumerate( |
| 470 | First: decomposeToSMETiles(builder&: rewriter, type: vectorType, smeTileType))) { |
| 471 | // The coordinates of the tile within `vectorType`. |
| 472 | auto tileRow = createVscaleMultiple(smeTile.row); |
| 473 | auto tileCol = createVscaleMultiple(smeTile.col); |
| 474 | |
| 475 | // The current slice of `vectorType` we are processing. |
| 476 | auto sliceIndex = |
| 477 | rewriter.create<arith::AddIOp>(location: loc, args&: tileRow, args&: tileSliceIndex); |
| 478 | |
| 479 | // Where in the destination memref the current slice will be stored. |
| 480 | auto storeRow = rewriter.create<arith::AddIOp>(location: loc, args&: sliceIndex, |
| 481 | args: writeOp.getIndices()[0]); |
| 482 | auto storeCol = |
| 483 | rewriter.create<arith::AddIOp>(location: loc, args&: tileCol, args: writeOp.getIndices()[1]); |
| 484 | |
| 485 | // Extract the mask for the current slice. |
| 486 | Value sliceMask = nullptr; |
| 487 | if (mask) { |
| 488 | sliceMask = rewriter.create<vector::ExtractOp>( |
| 489 | location: loc, args&: mask, args: OpFoldResult(sliceIndex)); |
| 490 | if (sliceMaskType != sliceMask.getType()) |
| 491 | sliceMask = rewriter.create<vector::ScalableExtractOp>( |
| 492 | location: loc, args&: sliceMaskType, args&: sliceMask, args&: smeTile.col); |
| 493 | } |
| 494 | |
| 495 | // Extract and store the current slice. |
| 496 | Value tile = inputSMETiles[index]; |
| 497 | auto slice = |
| 498 | rewriter.create<vector::ExtractOp>(location: loc, args&: tile, args&: tileSliceIndex); |
| 499 | rewriter.create<vector::TransferWriteOp>( |
| 500 | location: loc, args&: slice, args: writeOp.getBase(), args: ValueRange{storeRow, storeCol}, |
| 501 | args: AffineMapAttr::get(value: writeOp.getPermutationMap().dropResult(pos: 0)), |
| 502 | args&: sliceMask, |
| 503 | args: rewriter.getBoolArrayAttr( |
| 504 | values: ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front())); |
| 505 | } |
| 506 | |
| 507 | rewriter.eraseOp(op: writeOp); |
| 508 | return success(); |
| 509 | } |
| 510 | }; |
| 511 | |
| 512 | //===----------------------------------------------------------------------===// |
| 513 | // ArmSME-specific fixup canonicalizations/folds |
| 514 | //===----------------------------------------------------------------------===// |
| 515 | |
| 516 | /// Folds an extract from a 3D `vector.create_mask` (which is a vector of |
| 517 | /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is |
| 518 | /// necessary for the mask to be lowered to ArmSME. |
| 519 | /// |
| 520 | /// Example: |
| 521 | /// |
| 522 | /// BEFORE: |
| 523 | /// ```mlir |
| 524 | /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> |
| 525 | /// %subMask = vector.extract %mask[2] |
| 526 | /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> |
| 527 | /// ``` |
| 528 | /// |
| 529 | /// AFTER: |
| 530 | /// ```mlir |
| 531 | /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index |
| 532 | /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index |
| 533 | /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> |
| 534 | /// ``` |
| 535 | struct |
| 536 | : public OpRewritePattern<vector::ExtractOp> { |
| 537 | using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; |
| 538 | |
| 539 | LogicalResult matchAndRewrite(vector::ExtractOp , |
| 540 | PatternRewriter &rewriter) const override { |
| 541 | auto loc = extractOp.getLoc(); |
| 542 | auto createMaskOp = |
| 543 | extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); |
| 544 | if (!createMaskOp) |
| 545 | return rewriter.notifyMatchFailure( |
| 546 | arg&: extractOp, msg: "extract not from vector.create_mask op" ); |
| 547 | |
| 548 | VectorType = |
| 549 | llvm::dyn_cast<VectorType>(Val: extractOp.getResult().getType()); |
| 550 | if (!extractedMaskType) |
| 551 | return rewriter.notifyMatchFailure(arg&: extractOp, |
| 552 | msg: "extracted type is not a vector type" ); |
| 553 | |
| 554 | auto numScalable = extractedMaskType.getNumScalableDims(); |
| 555 | if (numScalable != 2) |
| 556 | return rewriter.notifyMatchFailure( |
| 557 | arg&: extractOp, msg: "expected extracted type to be an SME-like mask" ); |
| 558 | |
| 559 | // TODO: Support multiple extraction indices. |
| 560 | if (extractOp.getStaticPosition().size() != 1) |
| 561 | return rewriter.notifyMatchFailure( |
| 562 | arg&: extractOp, msg: "only a single extraction index is supported" ); |
| 563 | |
| 564 | auto frontMaskDim = createMaskOp.getOperand(i: 0); |
| 565 | if (frontMaskDim.getDefiningOp<arith::ConstantOp>()) |
| 566 | return rewriter.notifyMatchFailure( |
| 567 | arg&: extractOp, |
| 568 | msg: "constant vector.create_masks dims should be folded elsewhere" ); |
| 569 | |
| 570 | auto zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 571 | auto = getValueOrCreateConstantIndexOp( |
| 572 | b&: rewriter, loc, ofr: extractOp.getMixedPosition()[0]); |
| 573 | auto = rewriter.create<arith::CmpIOp>( |
| 574 | location: loc, args: rewriter.getI1Type(), args: arith::CmpIPredicate::slt, args&: extractionIndex, |
| 575 | args&: frontMaskDim); |
| 576 | auto newMaskFrontDim = rewriter.create<arith::SelectOp>( |
| 577 | location: loc, args&: extractionInTrueRegion, args: createMaskOp.getOperand(i: 1), args&: zero); |
| 578 | |
| 579 | rewriter.replaceOpWithNewOp<vector::CreateMaskOp>( |
| 580 | op: extractOp, args&: extractedMaskType, |
| 581 | args: ValueRange{newMaskFrontDim, createMaskOp.getOperand(i: 2)}); |
| 582 | return success(); |
| 583 | } |
| 584 | }; |
| 585 | |
| 586 | /// A vector type where no fixed dimension comes after a scalable dimension. |
| 587 | bool isLegalVectorType(VectorType vType) { |
| 588 | bool seenFixedDim = false; |
| 589 | for (bool scalableFlag : llvm::reverse(C: vType.getScalableDims())) { |
| 590 | seenFixedDim |= !scalableFlag; |
| 591 | if (seenFixedDim && scalableFlag) |
| 592 | return false; |
| 593 | } |
| 594 | return true; |
| 595 | } |
| 596 | |
| 597 | /// Lifts an illegal vector.transpose and vector.transfer_read to a |
| 598 | /// memref.subview + memref.transpose, followed by a legal read. |
| 599 | /// |
| 600 | /// 'Illegal' here means a leading scalable dimension and a fixed trailing |
| 601 | /// dimension, which has no valid lowering. |
| 602 | /// |
| 603 | /// The memref.transpose is metadata-only transpose that produces a strided |
| 604 | /// memref, which eventually becomes a loop reading individual elements. |
| 605 | /// |
| 606 | /// Example: |
| 607 | /// |
| 608 | /// BEFORE: |
| 609 | /// ```mlir |
| 610 | /// %illegalRead = vector.transfer_read %memref[%a, %b] |
| 611 | /// : memref<?x?xf32>, vector<[8]x4xf32> |
| 612 | /// %legalType = vector.transpose %illegalRead, [1, 0] |
| 613 | /// : vector<[8]x4xf32> to vector<4x[8]xf32> |
| 614 | /// ``` |
| 615 | /// |
| 616 | /// AFTER: |
| 617 | /// ```mlir |
| 618 | /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1] |
| 619 | /// : memref<?x?xf32> to memref<?x?xf32> |
| 620 | /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0) |
| 621 | /// : memref<?x?xf32> to memref<?x?xf32> |
| 622 | /// %legalType = vector.transfer_read %transpose[%c0, %c0] |
| 623 | /// : memref<?x?xf32>, vector<4x[8]xf32> |
| 624 | /// ``` |
| 625 | struct LiftIllegalVectorTransposeToMemory |
| 626 | : public OpRewritePattern<vector::TransposeOp> { |
| 627 | using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; |
| 628 | |
| 629 | static Value getExtensionSource(Operation *op) { |
| 630 | if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(Val: op)) |
| 631 | return op->getOperand(idx: 0); |
| 632 | return {}; |
| 633 | } |
| 634 | |
| 635 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, |
| 636 | PatternRewriter &rewriter) const override { |
| 637 | auto sourceType = transposeOp.getSourceVectorType(); |
| 638 | auto resultType = transposeOp.getResultVectorType(); |
| 639 | if (isLegalVectorType(vType: sourceType) || !isLegalVectorType(vType: resultType)) |
| 640 | return rewriter.notifyMatchFailure(arg&: transposeOp, |
| 641 | msg: kMatchFailureNotIllegalToLegal); |
| 642 | |
| 643 | // Look through extend for transfer_read. |
| 644 | Value maybeRead = transposeOp.getVector(); |
| 645 | auto *transposeSourceOp = maybeRead.getDefiningOp(); |
| 646 | Operation *extendOp = nullptr; |
| 647 | if (Value extendSource = getExtensionSource(op: transposeSourceOp)) { |
| 648 | maybeRead = extendSource; |
| 649 | extendOp = transposeSourceOp; |
| 650 | } |
| 651 | |
| 652 | auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>(); |
| 653 | if (!illegalRead) |
| 654 | return rewriter.notifyMatchFailure( |
| 655 | arg&: transposeOp, |
| 656 | msg: "expected source to be (possibly extended) transfer_read" ); |
| 657 | |
| 658 | if (!illegalRead.getPermutationMap().isIdentity()) |
| 659 | return rewriter.notifyMatchFailure( |
| 660 | arg&: illegalRead, msg: "expected read to have identity permutation map" ); |
| 661 | |
| 662 | auto loc = transposeOp.getLoc(); |
| 663 | auto zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 664 | auto one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 665 | |
| 666 | // Create a subview that matches the size of the illegal read vector type. |
| 667 | auto readType = illegalRead.getVectorType(); |
| 668 | auto readSizes = llvm::map_to_vector( |
| 669 | C: llvm::zip_equal(t: readType.getShape(), u: readType.getScalableDims()), |
| 670 | F: [&](auto dim) -> Value { |
| 671 | auto [size, isScalable] = dim; |
| 672 | auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size); |
| 673 | if (!isScalable) |
| 674 | return dimSize; |
| 675 | auto vscale = rewriter.create<vector::VectorScaleOp>(location: loc); |
| 676 | return rewriter.create<arith::MulIOp>(loc, vscale, dimSize); |
| 677 | }); |
| 678 | SmallVector<Value> strides(readType.getRank(), Value(one)); |
| 679 | auto readSubview = rewriter.create<memref::SubViewOp>( |
| 680 | location: loc, args: illegalRead.getBase(), args: illegalRead.getIndices(), args&: readSizes, |
| 681 | args&: strides); |
| 682 | |
| 683 | // Apply the transpose to all values/attributes of the transfer_read: |
| 684 | // - The mask |
| 685 | Value mask = illegalRead.getMask(); |
| 686 | if (mask) { |
| 687 | // Note: The transpose for the mask should fold into the |
| 688 | // vector.create_mask/constant_mask op, which will then become legal. |
| 689 | mask = rewriter.create<vector::TransposeOp>(location: loc, args&: mask, |
| 690 | args: transposeOp.getPermutation()); |
| 691 | } |
| 692 | // - The source memref |
| 693 | mlir::AffineMap transposeMap = AffineMap::getPermutationMap( |
| 694 | permutation: transposeOp.getPermutation(), context: getContext()); |
| 695 | auto transposedSubview = rewriter.create<memref::TransposeOp>( |
| 696 | location: loc, args&: readSubview, args: AffineMapAttr::get(value: transposeMap)); |
| 697 | ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr(); |
| 698 | // - The `in_bounds` attribute |
| 699 | if (inBoundsAttr) { |
| 700 | SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(), |
| 701 | inBoundsAttr.end()); |
| 702 | applyPermutationToVector(inVec&: inBoundsValues, permutation: transposeOp.getPermutation()); |
| 703 | inBoundsAttr = rewriter.getArrayAttr(value: inBoundsValues); |
| 704 | } |
| 705 | |
| 706 | VectorType legalReadType = resultType.clone(elementType: readType.getElementType()); |
| 707 | // Note: The indices are all zero as the subview is already offset. |
| 708 | SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero); |
| 709 | auto legalRead = rewriter.create<vector::TransferReadOp>( |
| 710 | location: loc, args&: legalReadType, args&: transposedSubview, args&: readIndices, |
| 711 | args: illegalRead.getPermutationMapAttr(), args: illegalRead.getPadding(), args&: mask, |
| 712 | args&: inBoundsAttr); |
| 713 | |
| 714 | // Replace the transpose with the new read, extending the result if |
| 715 | // necessary. |
| 716 | rewriter.replaceOp(op: transposeOp, newOp: [&]() -> Operation * { |
| 717 | if (extendOp) |
| 718 | return rewriter.create(loc, opName: extendOp->getName().getIdentifier(), |
| 719 | operands: Value(legalRead), types: resultType); |
| 720 | return legalRead; |
| 721 | }()); |
| 722 | |
| 723 | return success(); |
| 724 | } |
| 725 | }; |
| 726 | |
| 727 | /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use |
| 728 | /// the ZA state. This workaround rewrite to support these transposes when ZA is |
| 729 | /// available. |
| 730 | /// |
| 731 | /// Example: |
| 732 | /// |
| 733 | /// BEFORE: |
| 734 | /// ```mlir |
| 735 | /// %transpose = vector.transpose %vec, [1, 0] |
| 736 | /// : vector<2x[4]xf32> to vector<[4]x2xf32> |
| 737 | /// vector.transfer_write %transpose, %dest[%y, %x] |
| 738 | /// : vector<[4]x2xf32>, memref<?x?xf32> |
| 739 | /// ``` |
| 740 | /// |
| 741 | /// AFTER: |
| 742 | /// ```mlir |
| 743 | /// %0 = arm_sme.get_tile : vector<[4]x[4]xf32> |
| 744 | /// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32> |
| 745 | /// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32> |
| 746 | /// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32> |
| 747 | /// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32> |
| 748 | /// %c4_vscale = arith.muli %vscale, %c4 : index |
| 749 | /// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1> |
| 750 | /// vector.transfer_write %4, %dest[%y, %x], %mask |
| 751 | /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} |
| 752 | /// : vector<[4]x[4]xf32>, memref<?x?xf32> |
| 753 | /// ``` |
| 754 | /// |
| 755 | /// Values larger than a single tile are supported via decomposition. |
| 756 | struct LowerIllegalTransposeStoreViaZA |
| 757 | : public OpRewritePattern<vector::TransferWriteOp> { |
| 758 | using OpRewritePattern::OpRewritePattern; |
| 759 | |
| 760 | LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, |
| 761 | PatternRewriter &rewriter) const override { |
| 762 | if (!isSupportedMaskOp(mask: writeOp.getMask())) |
| 763 | return rewriter.notifyMatchFailure(arg&: writeOp, |
| 764 | msg: kMatchFailureUnsupportedMaskOp); |
| 765 | |
| 766 | auto permutationMap = writeOp.getPermutationMap(); |
| 767 | if (!permutationMap.isIdentity()) |
| 768 | return rewriter.notifyMatchFailure(arg&: writeOp, |
| 769 | msg: kMatchFailureNonPermutationMap); |
| 770 | |
| 771 | auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>(); |
| 772 | if (!transposeOp) |
| 773 | return failure(); |
| 774 | |
| 775 | auto sourceType = transposeOp.getSourceVectorType(); |
| 776 | auto resultType = transposeOp.getResultVectorType(); |
| 777 | |
| 778 | if (resultType.getRank() != 2) |
| 779 | return rewriter.notifyMatchFailure(arg&: transposeOp, msg: "TransposeOp not rank 2" ); |
| 780 | |
| 781 | if (!isLegalVectorType(vType: sourceType) || isLegalVectorType(vType: resultType)) |
| 782 | return rewriter.notifyMatchFailure( |
| 783 | arg&: transposeOp, msg: "not illegal/unsupported SVE transpose" ); |
| 784 | |
| 785 | auto smeTileType = getSMETileTypeForElement(elementType: resultType.getElementType()); |
| 786 | VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(pos: 0); |
| 787 | |
| 788 | if (sourceType.getDimSize(idx: 0) <= 1 || |
| 789 | sourceType.getDimSize(idx: 1) % smeSliceType.getDimSize(idx: 0) != 0) |
| 790 | return rewriter.notifyMatchFailure(arg&: writeOp, msg: "unsupported source shape" ); |
| 791 | |
| 792 | auto loc = writeOp.getLoc(); |
| 793 | auto createVscaleMultiple = |
| 794 | vector::makeVscaleConstantBuilder(rewriter, loc); |
| 795 | |
| 796 | auto transposeMap = AffineMapAttr::get( |
| 797 | value: AffineMap::getPermutationMap(permutation: ArrayRef<int64_t>{1, 0}, context: getContext())); |
| 798 | |
| 799 | // Note: We need to use `get_tile` as there's no vector-level `undef`. |
| 800 | Value undefTile = rewriter.create<arm_sme::GetTileOp>(location: loc, args&: smeTileType); |
| 801 | Value destTensorOrMemref = writeOp.getBase(); |
| 802 | auto numSlicesPerTile = |
| 803 | std::min(a: sourceType.getDimSize(idx: 0), b: smeTileType.getDimSize(idx: 0)); |
| 804 | auto numSlices = |
| 805 | rewriter.create<arith::ConstantIndexOp>(location: loc, args&: numSlicesPerTile); |
| 806 | for (auto [index, smeTile] : llvm::enumerate( |
| 807 | First: decomposeToSMETiles(builder&: rewriter, type: sourceType, smeTileType))) { |
| 808 | // 1. _Deliberately_ drop a scalable dimension and insert a fixed number |
| 809 | // of slices from the source type into the SME tile. Without checking |
| 810 | // vscale (and emitting multiple implementations) we can't make use of the |
| 811 | // rows of the tile after 1*vscale rows. |
| 812 | Value tile = undefTile; |
| 813 | for (int d = 0; d < numSlicesPerTile; ++d) { |
| 814 | Value vector = rewriter.create<vector::ExtractOp>( |
| 815 | location: loc, args: transposeOp.getVector(), |
| 816 | args: rewriter.getIndexAttr(value: d + smeTile.row)); |
| 817 | if (vector.getType() != smeSliceType) { |
| 818 | vector = rewriter.create<vector::ScalableExtractOp>( |
| 819 | location: loc, args&: smeSliceType, args&: vector, args&: smeTile.col); |
| 820 | } |
| 821 | tile = rewriter.create<vector::InsertOp>(location: loc, args&: vector, args&: tile, args&: d); |
| 822 | } |
| 823 | |
| 824 | // 2. Transpose the tile position. |
| 825 | auto transposedRow = createVscaleMultiple(smeTile.col); |
| 826 | auto transposedCol = |
| 827 | rewriter.create<arith::ConstantIndexOp>(location: loc, args&: smeTile.row); |
| 828 | |
| 829 | // 3. Compute mask for tile store. |
| 830 | Value maskRows; |
| 831 | Value maskCols; |
| 832 | if (auto mask = writeOp.getMask()) { |
| 833 | auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); |
| 834 | maskRows = rewriter.create<arith::SubIOp>(location: loc, args: createMask.getOperand(i: 0), |
| 835 | args&: transposedRow); |
| 836 | maskCols = rewriter.create<arith::SubIOp>(location: loc, args: createMask.getOperand(i: 1), |
| 837 | args&: transposedCol); |
| 838 | maskCols = rewriter.create<index::MinSOp>(location: loc, args&: maskCols, args&: numSlices); |
| 839 | } else { |
| 840 | maskRows = createVscaleMultiple(smeTileType.getDimSize(idx: 0)); |
| 841 | maskCols = numSlices; |
| 842 | } |
| 843 | auto subMask = rewriter.create<vector::CreateMaskOp>( |
| 844 | location: loc, args: smeTileType.clone(elementType: rewriter.getI1Type()), |
| 845 | args: ValueRange{maskRows, maskCols}); |
| 846 | |
| 847 | // 4. Emit a transposed tile write. |
| 848 | auto writeIndices = writeOp.getIndices(); |
| 849 | Value destRow = |
| 850 | rewriter.create<arith::AddIOp>(location: loc, args&: transposedRow, args: writeIndices[0]); |
| 851 | Value destCol = |
| 852 | rewriter.create<arith::AddIOp>(location: loc, args&: transposedCol, args: writeIndices[1]); |
| 853 | auto smeWrite = rewriter.create<vector::TransferWriteOp>( |
| 854 | location: loc, args&: tile, args&: destTensorOrMemref, args: ValueRange{destRow, destCol}, |
| 855 | args&: transposeMap, args&: subMask, args: writeOp.getInBounds()); |
| 856 | |
| 857 | if (writeOp.hasPureTensorSemantics()) |
| 858 | destTensorOrMemref = smeWrite.getResult(); |
| 859 | } |
| 860 | |
| 861 | if (writeOp.hasPureTensorSemantics()) |
| 862 | rewriter.replaceOp(op: writeOp, newValues: destTensorOrMemref); |
| 863 | else |
| 864 | rewriter.eraseOp(op: writeOp); |
| 865 | |
| 866 | return success(); |
| 867 | } |
| 868 | }; |
| 869 | |
| 870 | /// Lower `vector.transfer_read` of a scalable column to `scf::for` |
| 871 | /// |
| 872 | /// Lowers a "read" of a scalable column from a MemRef for which there is no |
| 873 | /// hardware pperation that we could use to a loop over the rows to read and |
| 874 | /// loads one element at a time. |
| 875 | /// |
| 876 | /// BEFORE: |
| 877 | /// ``` |
| 878 | /// %res = vector.transfer_read %mem[%a, %b] (...) |
| 879 | /// : memref<?x?xf32>, vector<[4]x1xf32> |
| 880 | /// ``` |
| 881 | /// |
| 882 | /// AFTER: |
| 883 | /// ``` |
| 884 | /// %cst = arith.constant (...) : vector<[4]xf32> |
| 885 | /// %vscale = vector.vscale |
| 886 | /// %c4_vscale = arith.muli %vscale, %c4 : index |
| 887 | /// %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst) |
| 888 | /// -> (vector<[4]xf32>) { |
| 889 | /// |
| 890 | /// %load = memref.load %mem[%arg3 + %a, %b] : memref<?x?xf32> |
| 891 | /// %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32> |
| 892 | /// scf.yield %vec : vector<[4]xf32> |
| 893 | /// } |
| 894 | /// %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32> |
| 895 | /// ``` |
| 896 | /// |
| 897 | /// TODO: This transformation isn't specific to SME - move it to the SVE |
| 898 | /// dialect. |
| 899 | /// TODO: Check the in_bounds attribute and generate vector.maskedload if |
| 900 | /// required. |
| 901 | struct LowerColumnTransferReadToLoops |
| 902 | : public OpRewritePattern<vector::TransferReadOp> { |
| 903 | using OpRewritePattern::OpRewritePattern; |
| 904 | |
| 905 | LogicalResult matchAndRewrite(vector::TransferReadOp readOp, |
| 906 | PatternRewriter &rewriter) const override { |
| 907 | // NOTE: This is a fairly low-level transformation, so we shouldn't be |
| 908 | // adding support for Tensors without good rationale. |
| 909 | if (readOp.hasPureTensorSemantics()) |
| 910 | return rewriter.notifyMatchFailure( |
| 911 | arg&: readOp, msg: "Tensor semantics are unsupported (either bufferize or " |
| 912 | "extend this pattern)" ); |
| 913 | |
| 914 | auto resType = readOp.getVectorType(); |
| 915 | |
| 916 | if (resType.getRank() != 2) |
| 917 | return rewriter.notifyMatchFailure(arg&: readOp, |
| 918 | msg: "Only 2D vectors are supported!" ); |
| 919 | |
| 920 | if (resType.getShape()[1] != 1) |
| 921 | return rewriter.notifyMatchFailure( |
| 922 | arg&: readOp, msg: "The trailing output dim is != 1 (not supported ATM)" ); |
| 923 | |
| 924 | if (!resType.getScalableDims()[0] || resType.getScalableDims()[1]) |
| 925 | return rewriter.notifyMatchFailure( |
| 926 | arg&: readOp, msg: "Expected the leading dim to be scalable and the trailing " |
| 927 | "dim to be fixed." ); |
| 928 | |
| 929 | // Create new result type - similar to the original vector with the |
| 930 | // trailing unit dim collapsed. |
| 931 | int64_t numRows = resType.getShape()[0]; |
| 932 | VectorType newResType = VectorType::get(shape: numRows, elementType: resType.getElementType(), |
| 933 | /*scalableDims=*/{true}); |
| 934 | |
| 935 | // Create a loop over all rows and load one element at a time. |
| 936 | auto loc = readOp.getLoc(); |
| 937 | auto lowerBound = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 938 | auto createVscaleMultiple = |
| 939 | vector::makeVscaleConstantBuilder(rewriter, loc); |
| 940 | auto upperBound = createVscaleMultiple(numRows); |
| 941 | auto step = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
| 942 | Value init = rewriter.create<arith::ConstantOp>( |
| 943 | location: loc, args&: newResType, args: DenseElementsAttr::get(type: newResType, value: 0.0f)); |
| 944 | |
| 945 | scf::ForOp loadLoop; |
| 946 | { |
| 947 | OpBuilder::InsertionGuard g(rewriter); |
| 948 | loadLoop = rewriter.create<scf::ForOp>(location: loc, args&: lowerBound, args&: upperBound, args&: step, |
| 949 | args: ValueRange{init}); |
| 950 | rewriter.setInsertionPointToStart(loadLoop.getBody()); |
| 951 | |
| 952 | auto tileSliceIndex = loadLoop.getInductionVar(); |
| 953 | |
| 954 | auto idx0 = rewriter.create<arith::AddIOp>(location: loc, args&: tileSliceIndex, |
| 955 | args: readOp.getIndices()[0]); |
| 956 | auto idx1 = readOp.getIndices()[1]; |
| 957 | |
| 958 | Value scalar = rewriter.create<memref::LoadOp>( |
| 959 | location: loc, args: readOp.getBase(), args: SmallVector<Value>({idx0, idx1})); |
| 960 | |
| 961 | Operation *updateInit = rewriter.create<vector::InsertOp>( |
| 962 | location: loc, args&: scalar, args: loadLoop.getRegionIterArg(index: 0), args&: tileSliceIndex); |
| 963 | |
| 964 | rewriter.create<scf::YieldOp>(location: loc, args: updateInit->getResult(idx: 0)); |
| 965 | } |
| 966 | |
| 967 | // The read operation has been "legalized", but since the original result |
| 968 | // type was a 2D vector, we need to cast before returning the result. This |
| 969 | // ShapeCast should cancel-out with some other ShapeCast (i.e. it's a |
| 970 | // no-op). |
| 971 | auto sc = rewriter.create<vector::ShapeCastOp>( |
| 972 | location: loc, args: readOp.getResult().getType(), args: loadLoop.getResult(i: 0)); |
| 973 | |
| 974 | rewriter.replaceOp(op: readOp, newOp: sc); |
| 975 | |
| 976 | return success(); |
| 977 | } |
| 978 | }; |
| 979 | |
| 980 | struct VectorLegalizationPass |
| 981 | : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> { |
| 982 | void runOnOperation() override { |
| 983 | auto *context = &getContext(); |
| 984 | TypeConverter converter; |
| 985 | RewritePatternSet patterns(context); |
| 986 | converter.addConversion(callback: [](Type type) { return type; }); |
| 987 | converter.addConversion( |
| 988 | callback: [](VectorType vectorType, |
| 989 | SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> { |
| 990 | if (!isMultipleOfSMETileVectorType(vType: vectorType)) |
| 991 | return std::nullopt; |
| 992 | auto smeTileCount = getNumberOfSMETilesForVectorType(type: vectorType); |
| 993 | auto smeTileType = |
| 994 | getSMETileTypeForElement(elementType: vectorType.getElementType()); |
| 995 | types = SmallVector<Type>(smeTileCount, smeTileType); |
| 996 | return success(); |
| 997 | }); |
| 998 | |
| 999 | // Apply preprocessing patterns. |
| 1000 | RewritePatternSet rewritePatterns(context); |
| 1001 | rewritePatterns |
| 1002 | .add<FoldExtractFromVectorOfSMELikeCreateMasks, |
| 1003 | LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory, |
| 1004 | LowerIllegalTransposeStoreViaZA>(arg&: context); |
| 1005 | if (failed( |
| 1006 | Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(rewritePatterns)))) |
| 1007 | return signalPassFailure(); |
| 1008 | |
| 1009 | // Note: These two patterns are added with a high benefit to ensure: |
| 1010 | // - Masked outer products are handled before unmasked ones |
| 1011 | // - Multi-tile writes are lowered as a store loop (if possible) |
| 1012 | patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition, |
| 1013 | LegalizeMultiTileTransferWriteAsStoreLoop>(arg&: converter, args&: context, |
| 1014 | /*benefit=*/args: 1024); |
| 1015 | patterns.add<LegalizeArithConstantOpsByDecomposition, |
| 1016 | LegalizeVectorOuterProductOpsByDecomposition, |
| 1017 | LegalizeTransferReadOpsByDecomposition, |
| 1018 | LegalizeTransferWriteOpsByDecomposition>(arg&: converter, args&: context); |
| 1019 | populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
| 1020 | converter); |
| 1021 | populateCallOpTypeConversionPattern(patterns, converter); |
| 1022 | populateReturnOpTypeConversionPattern(patterns, converter); |
| 1023 | scf::populateSCFStructuralTypeConversions(typeConverter: converter, patterns); |
| 1024 | |
| 1025 | ConversionTarget target(getContext()); |
| 1026 | target.markUnknownOpDynamicallyLegal( |
| 1027 | fn: [&](Operation *op) { return converter.isLegal(op); }); |
| 1028 | target.addDynamicallyLegalOp<func::FuncOp>(callback: [&](func::FuncOp op) { |
| 1029 | return converter.isSignatureLegal(ty: op.getFunctionType()); |
| 1030 | }); |
| 1031 | if (failed(Result: applyPartialConversion(op: getOperation(), target, |
| 1032 | patterns: std::move(patterns)))) |
| 1033 | return signalPassFailure(); |
| 1034 | } |
| 1035 | }; |
| 1036 | |
| 1037 | } // namespace |
| 1038 | |
| 1039 | std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() { |
| 1040 | return std::make_unique<VectorLegalizationPass>(); |
| 1041 | } |
| 1042 | |