| 1 | //===- Vectorization.cpp - Implementation of linalg Vectorization ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements the linalg dialect Vectorization transformations. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | #include "mlir/Dialect/Affine/Utils.h" |
| 13 | |
| 14 | #include "mlir/Analysis/SliceAnalysis.h" |
| 15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 18 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 19 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 20 | #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 21 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 22 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 23 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| 24 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 25 | #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" |
| 26 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 27 | #include "mlir/IR/AffineExpr.h" |
| 28 | #include "mlir/IR/Builders.h" |
| 29 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
| 30 | #include "mlir/IR/BuiltinTypes.h" |
| 31 | #include "mlir/IR/OpDefinition.h" |
| 32 | #include "mlir/IR/PatternMatch.h" |
| 33 | #include "mlir/IR/Value.h" |
| 34 | #include "mlir/Support/LLVM.h" |
| 35 | #include "mlir/Transforms/RegionUtils.h" |
| 36 | #include "llvm/ADT/STLExtras.h" |
| 37 | #include "llvm/ADT/Sequence.h" |
| 38 | #include "llvm/ADT/SmallVector.h" |
| 39 | #include "llvm/ADT/TypeSwitch.h" |
| 40 | #include "llvm/Support/Debug.h" |
| 41 | #include "llvm/Support/MathExtras.h" |
| 42 | #include "llvm/Support/raw_ostream.h" |
| 43 | #include <optional> |
| 44 | |
| 45 | using namespace mlir; |
| 46 | using namespace mlir::linalg; |
| 47 | |
| 48 | #define DEBUG_TYPE "linalg-vectorization" |
| 49 | |
| 50 | #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") |
| 51 | #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") |
| 52 | |
| 53 | /// Try to vectorize `convOp` as a convolution. |
| 54 | static FailureOr<Operation *> |
| 55 | vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, |
| 56 | ArrayRef<int64_t> inputVecSizes = {}, |
| 57 | ArrayRef<bool> inputVecScalableFlags = {}, |
| 58 | bool flatten1DDepthwiseConv = false); |
| 59 | |
| 60 | /// Vectorize tensor::InsertSliceOp with: |
| 61 | /// * vector::TransferReadOp + vector::TransferWriteOp |
| 62 | /// The vector sizes are either: |
| 63 | /// * user-provided in `inputVectorSizes`, or |
| 64 | /// * inferred from the static dims in the input and output tensors. |
| 65 | /// Bails out if: |
| 66 | /// * vector sizes are not user-provided, and |
| 67 | /// * at least one dim is dynamic (in both the input and output tensors). |
| 68 | /// |
| 69 | /// Before: |
| 70 | /// !t_in_type = tensor<1x2x3xf32> |
| 71 | /// !t_out_type = tensor<9x8x7x1x2x3xf32> |
| 72 | /// !v_type = vector<1x2x3xf32> |
| 73 | /// %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type |
| 74 | /// into !t_out_type |
| 75 | /// After: |
| 76 | /// %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type |
| 77 | /// %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type |
| 78 | static LogicalResult |
| 79 | vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, |
| 80 | ArrayRef<int64_t> inputVectorSizes, |
| 81 | SmallVectorImpl<Value> &newResults); |
| 82 | |
| 83 | /// Returns the effective Pad value for the input op, provided it's a scalar. |
| 84 | /// |
| 85 | /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If |
| 86 | /// this Op performs padding, retrieve the padding value provided that it's |
| 87 | /// a scalar and static/fixed for all the padded values. Returns an empty value |
| 88 | /// otherwise. |
| 89 | static Value getStaticPadVal(Operation *op); |
| 90 | |
| 91 | /// Return the unique instance of OpType in `block` if it is indeed unique. |
| 92 | /// Return null if none or more than 1 instances exist. |
| 93 | template <typename OpType> |
| 94 | static OpType getSingleOpOfType(Block &block) { |
| 95 | OpType res; |
| 96 | block.walk([&](OpType op) { |
| 97 | if (res) { |
| 98 | res = nullptr; |
| 99 | return WalkResult::interrupt(); |
| 100 | } |
| 101 | res = op; |
| 102 | return WalkResult::advance(); |
| 103 | }); |
| 104 | return res; |
| 105 | } |
| 106 | |
| 107 | /// Helper function to extract the input slices after filter is unrolled along |
| 108 | /// kw. |
| 109 | static SmallVector<Value> |
| 110 | (RewriterBase &rewriter, Location loc, Value input, |
| 111 | int64_t nSize, int64_t wSize, int64_t cSize, |
| 112 | int64_t kwSize, int strideW, int dilationW, |
| 113 | int64_t wSizeStep, bool isSingleChanneled) { |
| 114 | SmallVector<Value> result; |
| 115 | if (isSingleChanneled) { |
| 116 | // Extract input slice of size {wSizeStep} @ [w + kw] for non-channeled |
| 117 | // convolution. |
| 118 | SmallVector<int64_t> sizes = {wSizeStep}; |
| 119 | SmallVector<int64_t> strides = {1}; |
| 120 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 121 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 122 | result.push_back(Elt: rewriter.create<vector::ExtractStridedSliceOp>( |
| 123 | location: loc, args&: input, /*offsets=*/args: ArrayRef<int64_t>{w + kw}, args&: sizes, args&: strides)); |
| 124 | } |
| 125 | } |
| 126 | } else { |
| 127 | // Extract lhs slice of size {n, wSizeStep, c} @ [0, sw * w + dw * kw, 0] |
| 128 | // for channeled convolution. |
| 129 | SmallVector<int64_t> sizes = {nSize, wSizeStep, cSize}; |
| 130 | SmallVector<int64_t> strides = {1, 1, 1}; |
| 131 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 132 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 133 | result.push_back(Elt: rewriter.create<vector::ExtractStridedSliceOp>( |
| 134 | location: loc, args&: input, |
| 135 | /*offsets=*/args: ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0}, |
| 136 | args&: sizes, args&: strides)); |
| 137 | } |
| 138 | } |
| 139 | } |
| 140 | return result; |
| 141 | } |
| 142 | |
| 143 | /// Helper function to extract the filter slices after filter is unrolled along |
| 144 | /// kw. |
| 145 | static SmallVector<Value> (RewriterBase &rewriter, |
| 146 | Location loc, Value filter, |
| 147 | int64_t kwSize) { |
| 148 | SmallVector<Value> result; |
| 149 | // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for |
| 150 | // non-chanelled convolution] @ [kw]. |
| 151 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 152 | result.push_back(Elt: rewriter.create<vector::ExtractOp>( |
| 153 | location: loc, args&: filter, /*offsets=*/args: ArrayRef<int64_t>{kw})); |
| 154 | } |
| 155 | return result; |
| 156 | } |
| 157 | |
| 158 | /// Helper function to extract the result slices after filter is unrolled along |
| 159 | /// kw. |
| 160 | static SmallVector<Value> |
| 161 | (RewriterBase &rewriter, Location loc, Value res, |
| 162 | int64_t nSize, int64_t wSize, int64_t fSize, |
| 163 | int64_t wSizeStep, bool isSingleChanneled) { |
| 164 | SmallVector<Value> result; |
| 165 | if (isSingleChanneled) { |
| 166 | // Extract res slice: {wSizeStep} @ [w] for non-channeled convolution. |
| 167 | SmallVector<int64_t> sizes = {wSizeStep}; |
| 168 | SmallVector<int64_t> strides = {1}; |
| 169 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 170 | result.push_back(Elt: rewriter.create<vector::ExtractStridedSliceOp>( |
| 171 | location: loc, args&: res, /*offsets=*/args: ArrayRef<int64_t>{w}, args&: sizes, args&: strides)); |
| 172 | } |
| 173 | } else { |
| 174 | // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled |
| 175 | // convolution. |
| 176 | SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize}; |
| 177 | SmallVector<int64_t> strides = {1, 1, 1}; |
| 178 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 179 | result.push_back(Elt: rewriter.create<vector::ExtractStridedSliceOp>( |
| 180 | location: loc, args&: res, /*offsets=*/args: ArrayRef<int64_t>{0, w, 0}, args&: sizes, args&: strides)); |
| 181 | } |
| 182 | } |
| 183 | return result; |
| 184 | } |
| 185 | |
| 186 | /// Helper function to insert the computed result slices. |
| 187 | static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, |
| 188 | Value res, int64_t wSize, int64_t wSizeStep, |
| 189 | SmallVectorImpl<Value> &resVals, |
| 190 | bool isSingleChanneled) { |
| 191 | |
| 192 | if (isSingleChanneled) { |
| 193 | // Write back res slice: {wSizeStep} @ [w] for non-channeled convolution. |
| 194 | // This does not depend on kw. |
| 195 | SmallVector<int64_t> strides = {1}; |
| 196 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 197 | res = rewriter.create<vector::InsertStridedSliceOp>( |
| 198 | location: loc, args&: resVals[w], args&: res, /*offsets=*/args: ArrayRef<int64_t>{w}, args&: strides); |
| 199 | } |
| 200 | } else { |
| 201 | // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled |
| 202 | // convolution. This does not depend on kw. |
| 203 | SmallVector<int64_t> strides = {1, 1, 1}; |
| 204 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 205 | res = rewriter.create<vector::InsertStridedSliceOp>( |
| 206 | location: loc, args&: resVals[w], args&: res, /*offsets=*/args: ArrayRef<int64_t>{0, w, 0}, |
| 207 | args&: strides); |
| 208 | } |
| 209 | } |
| 210 | return res; |
| 211 | } |
| 212 | |
| 213 | /// Contains the vectorization state and related methods used across the |
| 214 | /// vectorization process of a given operation. |
| 215 | struct VectorizationState { |
| 216 | VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {} |
| 217 | |
| 218 | /// Initializes the vectorization state, including the computation of the |
| 219 | /// canonical vector shape for vectorization. |
| 220 | LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp, |
| 221 | ArrayRef<int64_t> inputVectorSizes, |
| 222 | ArrayRef<bool> inputScalableVecDims); |
| 223 | |
| 224 | /// Returns the canonical vector shape used to vectorize the iteration space. |
| 225 | ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; } |
| 226 | |
| 227 | /// Returns the vector dimensions that are scalable in the canonical vector |
| 228 | /// shape. |
| 229 | ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; } |
| 230 | |
| 231 | /// Returns a vector type of the provided `elementType` with the canonical |
| 232 | /// vector shape and the corresponding fixed/scalable dimensions bit. If |
| 233 | /// `dimPermutation` is provided, the canonical vector dimensions are permuted |
| 234 | /// accordingly. |
| 235 | VectorType getCanonicalVecType( |
| 236 | Type elementType, |
| 237 | std::optional<AffineMap> dimPermutation = std::nullopt) const { |
| 238 | SmallVector<int64_t> vectorShape; |
| 239 | SmallVector<bool> scalableDims; |
| 240 | if (dimPermutation.has_value()) { |
| 241 | vectorShape = |
| 242 | applyPermutationMap<int64_t>(map: *dimPermutation, source: canonicalVecShape); |
| 243 | scalableDims = |
| 244 | applyPermutationMap<bool>(map: *dimPermutation, source: scalableVecDims); |
| 245 | } else { |
| 246 | vectorShape.append(in_start: canonicalVecShape.begin(), in_end: canonicalVecShape.end()); |
| 247 | scalableDims.append(in_start: scalableVecDims.begin(), in_end: scalableVecDims.end()); |
| 248 | } |
| 249 | |
| 250 | return VectorType::get(shape: vectorShape, elementType, scalableDims); |
| 251 | } |
| 252 | |
| 253 | /// Masks an operation with the canonical vector mask if the operation needs |
| 254 | /// masking. Returns the masked operation or the original operation if masking |
| 255 | /// is not needed. If provided, the canonical mask for this operation is |
| 256 | /// permuted using `maybeIndexingMap`. |
| 257 | Operation * |
| 258 | maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, |
| 259 | std::optional<AffineMap> maybeIndexingMap = std::nullopt); |
| 260 | |
| 261 | private: |
| 262 | /// Initializes the iteration space static sizes using the Linalg op |
| 263 | /// information. This may become more complicated in the future. |
| 264 | void initIterSpaceStaticSizes(LinalgOp linalgOp) { |
| 265 | iterSpaceStaticSizes.append(RHS: linalgOp.getStaticLoopRanges()); |
| 266 | } |
| 267 | |
| 268 | /// Generates 'arith.constant' and 'tensor/memref.dim' operations for |
| 269 | /// all the static and dynamic dimensions of the iteration space to be |
| 270 | /// vectorized and store them in `iterSpaceValueSizes`. |
| 271 | LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter, |
| 272 | LinalgOp linalgOp); |
| 273 | |
| 274 | /// Create or retrieve an existing mask value to mask `opToMask` in the |
| 275 | /// canonical vector iteration space. If `maybeMaskingMap` the mask is |
| 276 | /// permuted using that permutation map. If a new mask is created, it will be |
| 277 | /// cached for future users. |
| 278 | Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask, |
| 279 | LinalgOp linalgOp, |
| 280 | std::optional<AffineMap> maybeMaskingMap); |
| 281 | |
| 282 | /// Check whether this permutation map can be used for masking. At the |
| 283 | /// moment we only make sure that there are no broadcast dimensions, but this |
| 284 | /// might change if indexing maps evolve. |
| 285 | bool isValidMaskingMap(AffineMap maskingMap) { |
| 286 | return maskingMap.getBroadcastDims().size() == 0; |
| 287 | } |
| 288 | |
| 289 | /// Turn the input indexing map into a valid masking map. |
| 290 | /// |
| 291 | /// The input indexing map may contain "zero" results, e.g.: |
| 292 | /// (d0, d1, d2, d3) -> (d2, d1, d0, 0) |
| 293 | /// Applying such maps to canonical vector shapes like this one: |
| 294 | /// (1, 16, 16, 4) |
| 295 | /// would yield an invalid vector shape like this: |
| 296 | /// (16, 16, 1, 0) |
| 297 | /// Instead, drop the broadcasting dims that make no sense for masking perm. |
| 298 | /// maps: |
| 299 | /// (d0, d1, d2, d3) -> (d2, d1, d0) |
| 300 | /// This way, the corresponding vector/mask type will be: |
| 301 | /// vector<16x16x1xty> |
| 302 | /// rather than this invalid Vector type: |
| 303 | /// vector<16x16x1x0xty> |
| 304 | AffineMap getMaskingMapFromIndexingMap(AffineMap &indexingMap) { |
| 305 | return indexingMap.dropZeroResults(); |
| 306 | } |
| 307 | |
| 308 | // Holds the compile-time static sizes of the iteration space to vectorize. |
| 309 | // Dynamic dimensions are represented using ShapedType::kDynamic. |
| 310 | SmallVector<int64_t> iterSpaceStaticSizes; |
| 311 | |
| 312 | /// Holds the value sizes of the iteration space to vectorize. Static |
| 313 | /// dimensions are represented by 'arith.constant' and dynamic |
| 314 | /// dimensions by 'tensor/memref.dim'. |
| 315 | SmallVector<Value> iterSpaceValueSizes; |
| 316 | |
| 317 | /// Holds the canonical vector shape used to vectorize the iteration space. |
| 318 | SmallVector<int64_t> canonicalVecShape; |
| 319 | |
| 320 | /// Holds the vector dimensions that are scalable in the canonical vector |
| 321 | /// shape. |
| 322 | SmallVector<bool> scalableVecDims; |
| 323 | |
| 324 | /// Holds the active masks for permutations of the canonical vector iteration |
| 325 | /// space. |
| 326 | DenseMap<AffineMap, Value> activeMaskCache; |
| 327 | |
| 328 | /// Global vectorization guard for the incoming rewriter. It's initialized |
| 329 | /// when the vectorization state is initialized. |
| 330 | OpBuilder::InsertionGuard rewriterGuard; |
| 331 | }; |
| 332 | |
| 333 | LogicalResult |
| 334 | VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, |
| 335 | LinalgOp linalgOp) { |
| 336 | // TODO: Support 0-d vectors. |
| 337 | for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { |
| 338 | if (ShapedType::isStatic(dValue: iterSpaceStaticSizes[vecDim])) { |
| 339 | // Create constant index op for static dimensions. |
| 340 | iterSpaceValueSizes.push_back(Elt: rewriter.create<arith::ConstantIndexOp>( |
| 341 | location: linalgOp.getLoc(), args&: iterSpaceStaticSizes[vecDim])); |
| 342 | continue; |
| 343 | } |
| 344 | |
| 345 | // Find an operand defined on this dimension of the iteration space to |
| 346 | // extract the runtime dimension size. |
| 347 | Value operand; |
| 348 | unsigned operandDimPos; |
| 349 | if (failed(Result: linalgOp.mapIterationSpaceDimToOperandDim(dimPos: vecDim, operand, |
| 350 | operandDimPos))) |
| 351 | return failure(); |
| 352 | |
| 353 | Value dynamicDim = linalgOp.hasPureTensorSemantics() |
| 354 | ? (Value)rewriter.create<tensor::DimOp>( |
| 355 | location: linalgOp.getLoc(), args&: operand, args&: operandDimPos) |
| 356 | : (Value)rewriter.create<memref::DimOp>( |
| 357 | location: linalgOp.getLoc(), args&: operand, args&: operandDimPos); |
| 358 | iterSpaceValueSizes.push_back(Elt: dynamicDim); |
| 359 | } |
| 360 | |
| 361 | return success(); |
| 362 | } |
| 363 | |
| 364 | /// Initializes the vectorization state, including the computation of the |
| 365 | /// canonical vector shape for vectorization. |
| 366 | // TODO: Move this to the constructor when we can remove the failure cases. |
| 367 | LogicalResult |
| 368 | VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp, |
| 369 | ArrayRef<int64_t> inputVectorSizes, |
| 370 | ArrayRef<bool> inputScalableVecDims) { |
| 371 | // Initialize the insertion point. |
| 372 | rewriter.setInsertionPoint(linalgOp); |
| 373 | |
| 374 | if (!inputVectorSizes.empty()) { |
| 375 | // Get the canonical vector shape from the input vector sizes provided. This |
| 376 | // path should be taken to vectorize code with dynamic shapes and when using |
| 377 | // vector sizes greater than the iteration space sizes. |
| 378 | canonicalVecShape.append(in_start: inputVectorSizes.begin(), in_end: inputVectorSizes.end()); |
| 379 | scalableVecDims.append(in_start: inputScalableVecDims.begin(), |
| 380 | in_end: inputScalableVecDims.end()); |
| 381 | } else { |
| 382 | // Compute the canonical vector shape from the operation shape. If there are |
| 383 | // dynamic shapes, the operation won't be vectorized. We assume all the |
| 384 | // vector dimensions are fixed. |
| 385 | canonicalVecShape = linalgOp.getStaticLoopRanges(); |
| 386 | scalableVecDims.append(NumInputs: linalgOp.getNumLoops(), Elt: false); |
| 387 | } |
| 388 | |
| 389 | LDBG("Canonical vector shape: " ); |
| 390 | LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs())); |
| 391 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 392 | LDBG("Scalable vector dims: " ); |
| 393 | LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs())); |
| 394 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 395 | |
| 396 | if (ShapedType::isDynamicShape(dSizes: canonicalVecShape)) |
| 397 | return failure(); |
| 398 | |
| 399 | // Initialize iteration space static sizes. |
| 400 | initIterSpaceStaticSizes(linalgOp); |
| 401 | |
| 402 | // Generate 'arith.constant' and 'tensor/memref.dim' operations for |
| 403 | // all the static and dynamic dimensions of the iteration space, needed to |
| 404 | // compute a mask during vectorization. |
| 405 | if (failed(Result: precomputeIterSpaceValueSizes(rewriter, linalgOp))) |
| 406 | return failure(); |
| 407 | |
| 408 | return success(); |
| 409 | } |
| 410 | |
| 411 | /// Create or retrieve an existing mask value to mask `opToMask` in the |
| 412 | /// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted |
| 413 | /// using that permutation map. If a new mask is created, it will be cached for |
| 414 | /// future users. |
| 415 | Value VectorizationState::getOrCreateMaskFor( |
| 416 | RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, |
| 417 | std::optional<AffineMap> maybeMaskingMap) { |
| 418 | |
| 419 | assert((!maybeMaskingMap || isValidMaskingMap(*maybeMaskingMap)) && |
| 420 | "Ill-formed masking map." ); |
| 421 | |
| 422 | // No mask is needed if the operation is not maskable. |
| 423 | auto maskableOp = dyn_cast<vector::MaskableOpInterface>(Val: opToMask); |
| 424 | if (!maskableOp) |
| 425 | return Value(); |
| 426 | |
| 427 | assert(!maskableOp.isMasked() && |
| 428 | "Masking an operation that is already masked" ); |
| 429 | |
| 430 | // If no masking map was provided, use an identity map with the loop dims. |
| 431 | assert((!maybeMaskingMap || *maybeMaskingMap) && |
| 432 | "Unexpected null mask permutation map" ); |
| 433 | AffineMap maskingMap = |
| 434 | maybeMaskingMap ? *maybeMaskingMap |
| 435 | : AffineMap::getMultiDimIdentityMap( |
| 436 | numDims: linalgOp.getNumLoops(), context: rewriter.getContext()); |
| 437 | |
| 438 | LDBG("Masking map: " << maskingMap << "\n" ); |
| 439 | |
| 440 | // Return the active mask for the masking map of this operation if it was |
| 441 | // already created. |
| 442 | auto activeMaskIt = activeMaskCache.find(Val: maskingMap); |
| 443 | if (activeMaskIt != activeMaskCache.end()) { |
| 444 | Value mask = activeMaskIt->second; |
| 445 | LDBG("Reusing mask: " << mask << "\n" ); |
| 446 | return mask; |
| 447 | } |
| 448 | |
| 449 | // Compute permuted projection of the iteration space to be masked and the |
| 450 | // corresponding mask shape. If the resulting iteration space dimensions are |
| 451 | // static and identical to the mask shape, masking is not needed for this |
| 452 | // operation. |
| 453 | // TODO: Improve this check. Only projected permutation indexing maps are |
| 454 | // supported. |
| 455 | SmallVector<int64_t> permutedStaticSizes = |
| 456 | applyPermutationMap<int64_t>(map: maskingMap, source: iterSpaceStaticSizes); |
| 457 | auto maskType = getCanonicalVecType(elementType: rewriter.getI1Type(), dimPermutation: maskingMap); |
| 458 | auto maskShape = maskType.getShape(); |
| 459 | |
| 460 | LDBG("Mask shape: " ); |
| 461 | LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs())); |
| 462 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 463 | |
| 464 | if (permutedStaticSizes == maskShape) { |
| 465 | LDBG("Masking is not needed for masking map: " << maskingMap << "\n" ); |
| 466 | activeMaskCache[maskingMap] = Value(); |
| 467 | return Value(); |
| 468 | } |
| 469 | |
| 470 | // Permute the iteration space value sizes to compute the mask upper bounds. |
| 471 | SmallVector<Value> upperBounds = |
| 472 | applyPermutationMap(map: maskingMap, source: ArrayRef<Value>(iterSpaceValueSizes)); |
| 473 | assert(!maskShape.empty() && !upperBounds.empty() && |
| 474 | "Masked 0-d vectors are not supported yet" ); |
| 475 | |
| 476 | // Create the mask based on the dimension values. |
| 477 | Value mask = rewriter.create<vector::CreateMaskOp>(location: linalgOp.getLoc(), |
| 478 | args&: maskType, args&: upperBounds); |
| 479 | LDBG("Creating new mask: " << mask << "\n" ); |
| 480 | activeMaskCache[maskingMap] = mask; |
| 481 | return mask; |
| 482 | } |
| 483 | |
| 484 | Operation * |
| 485 | VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, |
| 486 | LinalgOp linalgOp, |
| 487 | std::optional<AffineMap> maybeIndexingMap) { |
| 488 | LDBG("Trying to mask: " << *opToMask << "\n" ); |
| 489 | |
| 490 | std::optional<AffineMap> maybeMaskingMap = std::nullopt; |
| 491 | if (maybeIndexingMap) |
| 492 | maybeMaskingMap = getMaskingMapFromIndexingMap(indexingMap&: *maybeIndexingMap); |
| 493 | |
| 494 | // Create or retrieve mask for this operation. |
| 495 | Value mask = |
| 496 | getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap); |
| 497 | |
| 498 | if (!mask) { |
| 499 | LDBG("No mask required\n" ); |
| 500 | return opToMask; |
| 501 | } |
| 502 | |
| 503 | // Wrap the operation with a new `vector.mask` and update D-U chain. |
| 504 | assert(opToMask && "Expected a valid operation to mask" ); |
| 505 | auto maskOp = cast<vector::MaskOp>( |
| 506 | Val: mlir::vector::maskOperation(builder&: rewriter, maskableOp: opToMask, mask)); |
| 507 | Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back(); |
| 508 | |
| 509 | for (auto [resIdx, resVal] : llvm::enumerate(First: opToMask->getResults())) |
| 510 | rewriter.replaceAllUsesExcept(from: resVal, to: maskOp.getResult(i: resIdx), |
| 511 | exceptedUser: maskOpTerminator); |
| 512 | |
| 513 | LDBG("Masked operation: " << *maskOp << "\n" ); |
| 514 | return maskOp; |
| 515 | } |
| 516 | |
| 517 | /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a |
| 518 | /// projectedPermutation, compress the unused dimensions to serve as a |
| 519 | /// permutation_map for a vector transfer operation. |
| 520 | /// For example, given a linalg op such as: |
| 521 | /// |
| 522 | /// ``` |
| 523 | /// %0 = linalg.generic { |
| 524 | /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>, |
| 525 | /// indexing_maps = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3)> |
| 526 | /// } |
| 527 | /// ins(%0 : tensor<2x3x4xf32>) |
| 528 | /// outs(%1 : tensor<5x6xf32>) |
| 529 | /// ``` |
| 530 | /// |
| 531 | /// the iteration domain size of the linalg op is 3x5x4x6x2. The first affine |
| 532 | /// map is reindexed to `affine_map<(d0, d1, d2) -> (d2, d0, d1)>`, the second |
| 533 | /// affine map is reindexed to `affine_map<(d0, d1) -> (d0, d1)>`. |
| 534 | static AffineMap reindexIndexingMap(AffineMap map) { |
| 535 | assert(map.isProjectedPermutation(/*allowZeroInResults=*/true) && |
| 536 | "expected projected permutation" ); |
| 537 | auto res = compressUnusedDims(map); |
| 538 | assert(res.getNumDims() == |
| 539 | (res.getNumResults() - res.getNumOfZeroResults()) && |
| 540 | "expected reindexed map with same number of dims and results" ); |
| 541 | return res; |
| 542 | } |
| 543 | |
| 544 | /// Helper enum to represent conv1d input traversal order. |
| 545 | enum class Conv1DOpOrder { |
| 546 | W, // Corresponds to non-channeled 1D convolution operation. |
| 547 | Ncw, // Corresponds to operation that traverses the input in (n, c, w) order. |
| 548 | Nwc // Corresponds to operation that traverses the input in (n, w, c) order. |
| 549 | }; |
| 550 | |
| 551 | /// Helper data structure to represent the result of vectorization for a single |
| 552 | /// operation. In certain specific cases, like terminators, we do not want to |
| 553 | /// propagate. |
| 554 | enum VectorizationHookStatus { |
| 555 | /// Op failed to vectorize. |
| 556 | Failure = 0, |
| 557 | /// Op vectorized and custom function took care of replacement logic |
| 558 | NoReplace, |
| 559 | /// Op vectorized into a new Op whose results will replace original Op's |
| 560 | /// results. |
| 561 | NewOp |
| 562 | // TODO: support values if Op vectorized to Many-Ops whose results we need to |
| 563 | // aggregate for replacement. |
| 564 | }; |
| 565 | /// VectorizationHookResult contains the vectorized op returned from a |
| 566 | /// CustomVectorizationHook. This is an internal implementation detail of |
| 567 | /// linalg vectorization, not to be confused with VectorizationResult. |
| 568 | struct VectorizationHookResult { |
| 569 | /// Return status from vectorizing the current op. |
| 570 | enum VectorizationHookStatus status = VectorizationHookStatus::Failure; |
| 571 | /// New vectorized operation to replace the current op. |
| 572 | /// Replacement behavior is specified by `status`. |
| 573 | Operation *newOp; |
| 574 | }; |
| 575 | |
| 576 | std::optional<vector::CombiningKind> |
| 577 | mlir::linalg::getCombinerOpKind(Operation *combinerOp) { |
| 578 | using ::mlir::vector::CombiningKind; |
| 579 | |
| 580 | if (!combinerOp) |
| 581 | return std::nullopt; |
| 582 | return llvm::TypeSwitch<Operation *, std::optional<CombiningKind>>(combinerOp) |
| 583 | .Case<arith::AddIOp, arith::AddFOp>( |
| 584 | caseFn: [&](auto op) { return CombiningKind::ADD; }) |
| 585 | .Case<arith::AndIOp>(caseFn: [&](auto op) { return CombiningKind::AND; }) |
| 586 | .Case<arith::MaxSIOp>(caseFn: [&](auto op) { return CombiningKind::MAXSI; }) |
| 587 | .Case<arith::MaxUIOp>(caseFn: [&](auto op) { return CombiningKind::MAXUI; }) |
| 588 | .Case<arith::MaximumFOp>(caseFn: [&](auto op) { return CombiningKind::MAXIMUMF; }) |
| 589 | .Case<arith::MaxNumFOp>(caseFn: [&](auto op) { return CombiningKind::MAXNUMF; }) |
| 590 | .Case<arith::MinSIOp>(caseFn: [&](auto op) { return CombiningKind::MINSI; }) |
| 591 | .Case<arith::MinUIOp>(caseFn: [&](auto op) { return CombiningKind::MINUI; }) |
| 592 | .Case<arith::MinimumFOp>(caseFn: [&](auto op) { return CombiningKind::MINIMUMF; }) |
| 593 | .Case<arith::MinNumFOp>(caseFn: [&](auto op) { return CombiningKind::MINNUMF; }) |
| 594 | .Case<arith::MulIOp, arith::MulFOp>( |
| 595 | caseFn: [&](auto op) { return CombiningKind::MUL; }) |
| 596 | .Case<arith::OrIOp>(caseFn: [&](auto op) { return CombiningKind::OR; }) |
| 597 | .Case<arith::XOrIOp>(caseFn: [&](auto op) { return CombiningKind::XOR; }) |
| 598 | .Default(defaultFn: [&](auto op) { return std::nullopt; }); |
| 599 | } |
| 600 | |
| 601 | /// Check whether `outputOperand` is a reduction with a single combiner |
| 602 | /// operation. Return the combiner operation of the reduction. Return |
| 603 | /// nullptr otherwise. Multiple reduction operations would impose an |
| 604 | /// ordering between reduction dimensions and is currently unsupported in |
| 605 | /// Linalg. This limitation is motivated by the fact that e.g. min(max(X)) != |
| 606 | /// max(min(X)) |
| 607 | // TODO: use in LinalgOp verification, there is a circular dependency atm. |
| 608 | static Operation *matchLinalgReduction(OpOperand *outputOperand) { |
| 609 | auto linalgOp = cast<LinalgOp>(Val: outputOperand->getOwner()); |
| 610 | unsigned outputPos = |
| 611 | outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs(); |
| 612 | // Only single combiner operations are supported for now. |
| 613 | SmallVector<Operation *, 4> combinerOps; |
| 614 | if (!matchReduction(iterCarriedArgs: linalgOp.getRegionOutputArgs(), redPos: outputPos, combinerOps) || |
| 615 | combinerOps.size() != 1) |
| 616 | return nullptr; |
| 617 | |
| 618 | // Return the combiner operation. |
| 619 | return combinerOps[0]; |
| 620 | } |
| 621 | |
| 622 | /// Broadcast `value` to a vector of `shape` if possible. Return value |
| 623 | /// otherwise. |
| 624 | static Value broadcastIfNeeded(OpBuilder &b, Value value, Type dstType) { |
| 625 | auto dstVecType = dyn_cast<VectorType>(Val&: dstType); |
| 626 | // If no shape to broadcast to, just return `value`. |
| 627 | if (dstVecType.getRank() == 0) |
| 628 | return value; |
| 629 | if (vector::isBroadcastableTo(srcType: value.getType(), dstVectorType: dstVecType) != |
| 630 | vector::BroadcastableToResult::Success) |
| 631 | return value; |
| 632 | Location loc = b.getInsertionPoint()->getLoc(); |
| 633 | return b.createOrFold<vector::BroadcastOp>(location: loc, args&: dstVecType, args&: value); |
| 634 | } |
| 635 | |
| 636 | /// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This |
| 637 | /// assumes that `reductionOp` has two operands and one of them is the reduction |
| 638 | /// initial value.buildMultiDimReduce |
| 639 | // Note: this is a true builder that notifies the OpBuilder listener. |
| 640 | // TODO: Consider moving as a static helper on the ReduceOp. |
| 641 | static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, |
| 642 | Value valueToReduce, Value acc, |
| 643 | ArrayRef<bool> dimsToMask) { |
| 644 | auto maybeKind = getCombinerOpKind(combinerOp: reduceOp); |
| 645 | assert(maybeKind && "Failed precondition: could not get reduction kind" ); |
| 646 | return b.create<vector::MultiDimReductionOp>( |
| 647 | location: reduceOp->getLoc(), args&: valueToReduce, args&: acc, args&: dimsToMask, args&: *maybeKind); |
| 648 | } |
| 649 | |
| 650 | static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) { |
| 651 | return llvm::to_vector( |
| 652 | Range: llvm::map_range(C: linalgOp.getIteratorTypesArray(), F: isReductionIterator)); |
| 653 | } |
| 654 | |
| 655 | /// Check if `op` is a linalg.reduce or a linalg.generic that has at least one |
| 656 | /// reduction iterator. |
| 657 | static bool hasReductionIterator(LinalgOp &op) { |
| 658 | return isa<linalg::ReduceOp>(Val: op) || |
| 659 | (isa<linalg::GenericOp>(Val: op) && |
| 660 | llvm::any_of(Range: op.getIteratorTypesArray(), P: isReductionIterator)); |
| 661 | } |
| 662 | |
| 663 | /// Build a vector.transfer_write of `value` into `outputOperand` at indices set |
| 664 | /// to all `0`; where `outputOperand` is an output operand of the LinalgOp |
| 665 | /// currently being vectorized. If `dest` has null rank, build an memref.store. |
| 666 | /// Return the produced value or null if no value is produced. |
| 667 | // Note: this is a true builder that notifies the OpBuilder listener. |
| 668 | // TODO: Consider moving as a static helper on the ReduceOp. |
| 669 | static Value buildVectorWrite(RewriterBase &rewriter, Value value, |
| 670 | OpOperand *outputOperand, |
| 671 | VectorizationState &state) { |
| 672 | Location loc = value.getLoc(); |
| 673 | auto linalgOp = cast<LinalgOp>(Val: outputOperand->getOwner()); |
| 674 | AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(opOperand: outputOperand); |
| 675 | |
| 676 | // Compute the vector type of the value to store. This type should be an |
| 677 | // identity or projection of the canonical vector type without any permutation |
| 678 | // applied, given that any permutation in a transfer write happens as part of |
| 679 | // the write itself. |
| 680 | AffineMap vectorTypeMap = AffineMap::getFilteredIdentityMap( |
| 681 | ctx: opOperandMap.getContext(), numDims: opOperandMap.getNumInputs(), |
| 682 | keepDimFilter: [&](AffineDimExpr dimExpr) -> bool { |
| 683 | return llvm::is_contained(Range: opOperandMap.getResults(), Element: dimExpr); |
| 684 | }); |
| 685 | auto vectorType = state.getCanonicalVecType( |
| 686 | elementType: getElementTypeOrSelf(type: outputOperand->get().getType()), dimPermutation: vectorTypeMap); |
| 687 | |
| 688 | Operation *write; |
| 689 | if (vectorType.getRank() > 0) { |
| 690 | AffineMap writeMap = inversePermutation(map: reindexIndexingMap(map: opOperandMap)); |
| 691 | SmallVector<Value> indices(linalgOp.getRank(opOperand: outputOperand), |
| 692 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0)); |
| 693 | value = broadcastIfNeeded(b&: rewriter, value, dstType: vectorType); |
| 694 | assert(value.getType() == vectorType && "Incorrect type" ); |
| 695 | write = rewriter.create<vector::TransferWriteOp>( |
| 696 | location: loc, args&: value, args: outputOperand->get(), args&: indices, args&: writeMap); |
| 697 | } else { |
| 698 | // 0-d case is still special: do not invert the reindexing writeMap. |
| 699 | if (!isa<VectorType>(Val: value.getType())) |
| 700 | value = rewriter.create<vector::BroadcastOp>(location: loc, args&: vectorType, args&: value); |
| 701 | assert(value.getType() == vectorType && "Incorrect type" ); |
| 702 | write = rewriter.create<vector::TransferWriteOp>( |
| 703 | location: loc, args&: value, args: outputOperand->get(), args: ValueRange{}); |
| 704 | } |
| 705 | |
| 706 | write = state.maskOperation(rewriter, opToMask: write, linalgOp, maybeIndexingMap: opOperandMap); |
| 707 | |
| 708 | // If masked, set in-bounds to true. Masking guarantees that the access will |
| 709 | // be in-bounds. |
| 710 | if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(Val: write)) { |
| 711 | auto maskedWriteOp = cast<vector::TransferWriteOp>(Val: maskOp.getMaskableOp()); |
| 712 | SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true); |
| 713 | maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(values: inBounds)); |
| 714 | } |
| 715 | |
| 716 | LDBG("vectorized op: " << *write << "\n" ); |
| 717 | if (!write->getResults().empty()) |
| 718 | return write->getResult(idx: 0); |
| 719 | return Value(); |
| 720 | } |
| 721 | |
| 722 | // Custom vectorization precondition function type. This is intented to be used |
| 723 | // with CustomVectorizationHook. Returns success if the corresponding custom |
| 724 | // hook can vectorize the op. |
| 725 | using CustomVectorizationPrecondition = |
| 726 | std::function<LogicalResult(Operation *, bool)>; |
| 727 | |
| 728 | // Custom vectorization function type. Produce a vector form of Operation* |
| 729 | // assuming all its vectorized operands are already in the IRMapping. |
| 730 | // Return nullptr if the Operation cannot be vectorized. |
| 731 | using CustomVectorizationHook = |
| 732 | std::function<VectorizationHookResult(Operation *, const IRMapping &)>; |
| 733 | |
| 734 | /// Helper function to vectorize the terminator of a `linalgOp`. New result |
| 735 | /// vector values are appended to `newResults`. Return |
| 736 | /// VectorizationHookStatus::NoReplace to signal the vectorization algorithm |
| 737 | /// that it should not try to map produced operations and instead return the |
| 738 | /// results using the `newResults` vector making them available to the |
| 739 | /// vectorization algorithm for RAUW. This function is meant to be used as a |
| 740 | /// CustomVectorizationHook. |
| 741 | static VectorizationHookResult |
| 742 | vectorizeLinalgYield(RewriterBase &rewriter, Operation *op, |
| 743 | const IRMapping &bvm, VectorizationState &state, |
| 744 | LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) { |
| 745 | auto yieldOp = dyn_cast<linalg::YieldOp>(Val: op); |
| 746 | if (!yieldOp) |
| 747 | return VectorizationHookResult{.status: VectorizationHookStatus::Failure, .newOp: nullptr}; |
| 748 | for (const auto &output : llvm::enumerate(First: yieldOp.getValues())) { |
| 749 | // TODO: Scan for an opportunity for reuse. |
| 750 | // TODO: use a map. |
| 751 | Value vectorValue = bvm.lookup(from: output.value()); |
| 752 | Value newResult = |
| 753 | buildVectorWrite(rewriter, value: vectorValue, |
| 754 | outputOperand: linalgOp.getDpsInitOperand(i: output.index()), state); |
| 755 | if (newResult) |
| 756 | newResults.push_back(Elt: newResult); |
| 757 | } |
| 758 | |
| 759 | return VectorizationHookResult{.status: VectorizationHookStatus::NoReplace, .newOp: nullptr}; |
| 760 | } |
| 761 | |
| 762 | /// Helper function to vectorize the index operations of a `linalgOp`. Return |
| 763 | /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it |
| 764 | /// should map the produced operations. This function is meant to be used as a |
| 765 | /// CustomVectorizationHook. |
| 766 | static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, |
| 767 | VectorizationState &state, |
| 768 | Operation *op, |
| 769 | LinalgOp linalgOp) { |
| 770 | IndexOp indexOp = dyn_cast<linalg::IndexOp>(Val: op); |
| 771 | if (!indexOp) |
| 772 | return VectorizationHookResult{.status: VectorizationHookStatus::Failure, .newOp: nullptr}; |
| 773 | auto loc = indexOp.getLoc(); |
| 774 | // Compute the static loop sizes of the index op. |
| 775 | ArrayRef<int64_t> targetShape = state.getCanonicalVecShape(); |
| 776 | auto dim = indexOp.getDim(); |
| 777 | // Compute a one-dimensional index vector for the index op dimension. |
| 778 | auto indexVectorType = |
| 779 | VectorType::get(shape: {targetShape[dim]}, elementType: rewriter.getIndexType(), |
| 780 | scalableDims: state.getScalableVecDims()[dim]); |
| 781 | auto indexSteps = rewriter.create<vector::StepOp>(location: loc, args&: indexVectorType); |
| 782 | // Return the one-dimensional index vector if it lives in the trailing |
| 783 | // dimension of the iteration space since the vectorization algorithm in this |
| 784 | // case can handle the broadcast. |
| 785 | if (dim == targetShape.size() - 1) |
| 786 | return VectorizationHookResult{.status: VectorizationHookStatus::NewOp, .newOp: indexSteps}; |
| 787 | // Otherwise permute the targetShape to move the index dimension last, |
| 788 | // broadcast the one-dimensional index vector to the permuted shape, and |
| 789 | // finally transpose the broadcasted index vector to undo the permutation. |
| 790 | auto permPattern = |
| 791 | llvm::to_vector(Range: llvm::seq<unsigned>(Begin: 0, End: targetShape.size())); |
| 792 | std::swap(a&: permPattern[dim], b&: permPattern.back()); |
| 793 | auto permMap = |
| 794 | AffineMap::getPermutationMap(permutation: permPattern, context: linalgOp.getContext()); |
| 795 | |
| 796 | auto broadCastOp = rewriter.create<vector::BroadcastOp>( |
| 797 | location: loc, args: state.getCanonicalVecType(elementType: rewriter.getIndexType(), dimPermutation: permMap), |
| 798 | args&: indexSteps); |
| 799 | SmallVector<int64_t> transposition = |
| 800 | llvm::to_vector<16>(Range: llvm::seq<int64_t>(Begin: 0, End: linalgOp.getNumLoops())); |
| 801 | std::swap(a&: transposition.back(), b&: transposition[dim]); |
| 802 | auto transposeOp = |
| 803 | rewriter.create<vector::TransposeOp>(location: loc, args&: broadCastOp, args&: transposition); |
| 804 | return VectorizationHookResult{.status: VectorizationHookStatus::NewOp, .newOp: transposeOp}; |
| 805 | } |
| 806 | |
| 807 | /// Helper function to check if the tensor.extract can be vectorized by the |
| 808 | /// custom hook vectorizeTensorExtract. |
| 809 | static LogicalResult |
| 810 | (Operation *op, bool ) { |
| 811 | tensor::ExtractOp = dyn_cast<tensor::ExtractOp>(Val: op); |
| 812 | if (!extractOp) |
| 813 | return failure(); |
| 814 | |
| 815 | if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract) |
| 816 | return failure(); |
| 817 | |
| 818 | // Check the index type, but only for non 0-d tensors (for which we do need |
| 819 | // access indices). |
| 820 | if (not extractOp.getIndices().empty()) { |
| 821 | if (!VectorType::isValidElementType(t: extractOp.getIndices()[0].getType())) |
| 822 | return failure(); |
| 823 | } |
| 824 | |
| 825 | if (!llvm::all_of(Range: extractOp->getResultTypes(), |
| 826 | P: VectorType::isValidElementType)) { |
| 827 | return failure(); |
| 828 | } |
| 829 | |
| 830 | return success(); |
| 831 | } |
| 832 | |
| 833 | /// Calculates the offsets (`$index_vec`) for `vector.gather` operations |
| 834 | /// generated from `tensor.extract`. The offset is calculated as follows |
| 835 | /// (example using scalar values): |
| 836 | /// |
| 837 | /// offset = extractOp.indices[0] |
| 838 | /// for (i = 1; i < numIndices; i++) |
| 839 | /// offset = extractOp.dimSize[i] * offset + extractOp.indices[i]; |
| 840 | /// |
| 841 | /// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to: |
| 842 | /// offset = ( ( 1 ) * 80 + 2 ) * 15 + 3 |
| 843 | static Value (RewriterBase &rewriter, |
| 844 | VectorizationState &state, |
| 845 | tensor::ExtractOp , |
| 846 | const IRMapping &bvm) { |
| 847 | // The vector of indices for GatherOp should be shaped as the output vector. |
| 848 | auto indexVecType = state.getCanonicalVecType(elementType: rewriter.getIndexType()); |
| 849 | auto loc = extractOp.getLoc(); |
| 850 | |
| 851 | Value offset = broadcastIfNeeded( |
| 852 | b&: rewriter, value: bvm.lookup(from: extractOp.getIndices()[0]), dstType: indexVecType); |
| 853 | |
| 854 | const size_t numIndices = extractOp.getIndices().size(); |
| 855 | for (size_t i = 1; i < numIndices; i++) { |
| 856 | Value dimIdx = rewriter.create<arith::ConstantIndexOp>(location: loc, args&: i); |
| 857 | |
| 858 | auto dimSize = broadcastIfNeeded( |
| 859 | b&: rewriter, |
| 860 | value: rewriter.create<tensor::DimOp>(location: loc, args: extractOp.getTensor(), args&: dimIdx), |
| 861 | dstType: indexVecType); |
| 862 | |
| 863 | offset = rewriter.create<arith::MulIOp>(location: loc, args&: offset, args&: dimSize); |
| 864 | |
| 865 | auto = broadcastIfNeeded( |
| 866 | b&: rewriter, value: bvm.lookup(from: extractOp.getIndices()[i]), dstType: indexVecType); |
| 867 | |
| 868 | offset = rewriter.create<arith::AddIOp>(location: loc, args&: extractOpIndex, args&: offset); |
| 869 | } |
| 870 | |
| 871 | return offset; |
| 872 | } |
| 873 | |
| 874 | enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather }; |
| 875 | |
| 876 | /// Find the index of the trailing non-unit dim in linalgOp. This hook is used |
| 877 | /// when checking whether `tensor.extract` Op (within a `linalg.generic` Op) |
| 878 | /// represents a contiguous load operation. |
| 879 | /// |
| 880 | /// Note that when calling this hook, it is assumed that the output vector is |
| 881 | /// effectively 1D. Other cases (i.e. reading n-D vectors) should've been |
| 882 | /// labelled as a gather load before entering this method. |
| 883 | /// |
| 884 | /// Following on from the above, it is assumed that: |
| 885 | /// * for statically shaped loops, when no masks are used, only one dim is != |
| 886 | /// 1 (that's what the shape of the output vector is based on). |
| 887 | /// * for dynamically shaped loops, there might be more non-unit dims |
| 888 | /// as the output vector type is user-specified. |
| 889 | /// |
| 890 | /// TODO: Statically shaped loops + vector masking |
| 891 | static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) { |
| 892 | SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges(); |
| 893 | assert( |
| 894 | (linalgOp.hasDynamicShape() || |
| 895 | llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == 1) && |
| 896 | "For statically shaped Linalg Ops, only one " |
| 897 | "non-unit loop dim is expected" ); |
| 898 | assert(loopRanges.size() != 0 && "Empty loops, nothing to analyse." ); |
| 899 | |
| 900 | size_t idx = loopRanges.size() - 1; |
| 901 | for (; idx != 0; idx--) |
| 902 | if (loopRanges[idx] != 1) |
| 903 | break; |
| 904 | |
| 905 | return idx; |
| 906 | } |
| 907 | |
| 908 | /// Checks whether `val` can be used for calculating a loop invariant index. |
| 909 | static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, |
| 910 | VectorType resType) { |
| 911 | |
| 912 | assert(((llvm::count_if(resType.getShape(), |
| 913 | [](int64_t dimSize) { return dimSize > 1; }) == 1)) && |
| 914 | "n-D vectors are not yet supported" ); |
| 915 | |
| 916 | // Blocks outside _this_ linalg.generic are effectively loop invariant. |
| 917 | // However, analysing block arguments for _this_ linalg.generic Op is a bit |
| 918 | // tricky. Just bail out in the latter case. |
| 919 | // TODO: We could try analysing the corresponding affine map here. |
| 920 | auto *block = linalgOp.getBlock(); |
| 921 | if (isa<BlockArgument>(Val: val)) |
| 922 | return !llvm::is_contained(Range: block->getArguments(), Element: val); |
| 923 | |
| 924 | Operation *defOp = val.getDefiningOp(); |
| 925 | assert(defOp && "This is neither a block argument nor an operation result" ); |
| 926 | |
| 927 | // IndexOp is loop invariant as long as its result remains constant across |
| 928 | // iterations. Note that for dynamic shapes, the corresponding dim will also |
| 929 | // be conservatively treated as != 1. |
| 930 | if (auto indexOp = dyn_cast<linalg::IndexOp>(Val: defOp)) { |
| 931 | return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1; |
| 932 | } |
| 933 | |
| 934 | auto *ancestor = block->findAncestorOpInBlock(op&: *defOp); |
| 935 | |
| 936 | // Values define outside `linalgOp` are loop invariant. |
| 937 | if (!ancestor) |
| 938 | return true; |
| 939 | |
| 940 | // Values defined inside `linalgOp`, which are constant, are loop invariant. |
| 941 | if (isa<arith::ConstantOp>(Val: ancestor)) |
| 942 | return true; |
| 943 | |
| 944 | bool result = true; |
| 945 | for (auto op : ancestor->getOperands()) |
| 946 | result &= isLoopInvariantIdx(linalgOp, val&: op, resType); |
| 947 | |
| 948 | return result; |
| 949 | } |
| 950 | |
| 951 | /// Check whether `val` could be used for calculating the trailing index for a |
| 952 | /// contiguous load operation. |
| 953 | /// |
| 954 | /// There are currently 3 types of values that are allowed here: |
| 955 | /// 1. loop-invariant values, |
| 956 | /// 2. values that increment by 1 with every loop iteration, |
| 957 | /// 3. results of basic arithmetic operations (linear and continuous) |
| 958 | /// involving 1., 2. and 3. |
| 959 | /// This method returns True if indeed only such values are used in calculating |
| 960 | /// `val.` |
| 961 | /// |
| 962 | /// Additionally, the trailing index for a contiguous load operation should |
| 963 | /// increment by 1 with every loop iteration, i.e. be based on: |
| 964 | /// * `linalg.index <dim>` , |
| 965 | /// where <dim> is the trailing non-unit dim of the iteration space (this way, |
| 966 | /// `linalg.index <dim>` increments by 1 with every loop iteration). |
| 967 | /// `foundIndexOp` is updated to `true` when such Op is found. |
| 968 | static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, |
| 969 | bool &foundIndexOp, VectorType resType) { |
| 970 | |
| 971 | assert(((llvm::count_if(resType.getShape(), |
| 972 | [](int64_t dimSize) { return dimSize > 1; }) == 1)) && |
| 973 | "n-D vectors are not yet supported" ); |
| 974 | |
| 975 | // Blocks outside _this_ linalg.generic are effectively loop invariant. |
| 976 | // However, analysing block arguments for _this_ linalg.generic Op is a bit |
| 977 | // tricky. Just bail out in the latter case. |
| 978 | // TODO: We could try analysing the corresponding affine map here. |
| 979 | auto *block = linalgOp.getBlock(); |
| 980 | if (isa<BlockArgument>(Val: val)) |
| 981 | return !llvm::is_contained(Range: block->getArguments(), Element: val); |
| 982 | |
| 983 | Operation *defOp = val.getDefiningOp(); |
| 984 | assert(defOp && "This is neither a block argument nor an operation result" ); |
| 985 | |
| 986 | if (auto indexOp = dyn_cast<linalg::IndexOp>(Val: defOp)) { |
| 987 | auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp); |
| 988 | |
| 989 | foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne); |
| 990 | return true; |
| 991 | } |
| 992 | |
| 993 | auto *ancestor = block->findAncestorOpInBlock(op&: *defOp); |
| 994 | |
| 995 | if (!ancestor) |
| 996 | return false; |
| 997 | |
| 998 | // Conservatively reject Ops that could lead to indices with stride other |
| 999 | // than 1. |
| 1000 | if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(Val: ancestor)) |
| 1001 | return false; |
| 1002 | |
| 1003 | bool result = false; |
| 1004 | for (auto op : ancestor->getOperands()) |
| 1005 | result |= isContiguousLoadIdx(linalgOp, val&: op, foundIndexOp, resType); |
| 1006 | |
| 1007 | return result; |
| 1008 | } |
| 1009 | |
| 1010 | /// Infer the memory access pattern for the input ExtractOp |
| 1011 | /// |
| 1012 | /// Based on the ExtratOp result shape and the access indices, decides whether |
| 1013 | /// this Op corresponds to a contiguous load (including a broadcast of a scalar) |
| 1014 | /// or a gather load. When analysing the ExtractOp indices (to identify |
| 1015 | /// contiguous laods), this method looks for "loop" invariant indices (e.g. |
| 1016 | /// block arguments) and indices that change linearly (e.g. via `linalg.index` |
| 1017 | /// Op). |
| 1018 | /// |
| 1019 | /// Note that it is always safe to use gather load operations for contiguous |
| 1020 | /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume |
| 1021 | /// that `extractOp` is a gather load. |
| 1022 | static VectorMemoryAccessKind |
| 1023 | (tensor::ExtractOp , |
| 1024 | LinalgOp &linalgOp, VectorType resType) { |
| 1025 | |
| 1026 | auto inputShape = cast<ShapedType>(Val: extractOp.getTensor().getType()); |
| 1027 | |
| 1028 | // 0. Is this a 0-D vector? If yes then this is a scalar broadcast. |
| 1029 | if (inputShape.getShape().empty()) |
| 1030 | return VectorMemoryAccessKind::ScalarBroadcast; |
| 1031 | |
| 1032 | // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false |
| 1033 | // otherwise. |
| 1034 | bool isOutput1DVector = |
| 1035 | (llvm::count_if(Range: resType.getShape(), |
| 1036 | P: [](int64_t dimSize) { return dimSize > 1; }) == 1); |
| 1037 | // 1. Assume that it's a gather load when reading non-1D vector. |
| 1038 | if (!isOutput1DVector) |
| 1039 | return VectorMemoryAccessKind::Gather; |
| 1040 | |
| 1041 | bool leadingIdxsLoopInvariant = true; |
| 1042 | |
| 1043 | // 2. Analyze the leading indices of `extractOp`. |
| 1044 | // Look at the way each index is calculated and decide whether it is suitable |
| 1045 | // for a contiguous load, i.e. whether it's loop invariant. If not, it's a |
| 1046 | // gather load. |
| 1047 | auto indices = extractOp.getIndices(); |
| 1048 | auto leadIndices = indices.drop_back(n: 1); |
| 1049 | |
| 1050 | for (auto [i, indexVal] : llvm::enumerate(First&: leadIndices)) { |
| 1051 | if (inputShape.getShape()[i] == 1) |
| 1052 | continue; |
| 1053 | |
| 1054 | leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, val&: indexVal, resType); |
| 1055 | } |
| 1056 | |
| 1057 | if (!leadingIdxsLoopInvariant) { |
| 1058 | LDBG("Found gather load: " << extractOp); |
| 1059 | return VectorMemoryAccessKind::Gather; |
| 1060 | } |
| 1061 | |
| 1062 | // 3. Analyze the trailing index for `extractOp`. |
| 1063 | // At this point we know that the leading indices are loop invariant. This |
| 1064 | // means that is potentially a scalar or a contiguous load. We can decide |
| 1065 | // based on the trailing idx. |
| 1066 | auto = indices.back(); |
| 1067 | |
| 1068 | // 3a. Scalar broadcast load |
| 1069 | // If the trailing index is loop invariant then this is a scalar load. |
| 1070 | if (leadingIdxsLoopInvariant && |
| 1071 | isLoopInvariantIdx(linalgOp, val&: extractOpTrailingIdx, resType)) { |
| 1072 | LDBG("Found scalar broadcast load: " << extractOp); |
| 1073 | |
| 1074 | return VectorMemoryAccessKind::ScalarBroadcast; |
| 1075 | } |
| 1076 | |
| 1077 | // 3b. Contiguous loads |
| 1078 | // The trailing `extractOp` index should increment with every loop iteration. |
| 1079 | // This effectively means that it must be based on the trailing loop index. |
| 1080 | // This is what the following bool captures. |
| 1081 | bool foundIndexOp = false; |
| 1082 | bool isContiguousLoad = isContiguousLoadIdx(linalgOp, val&: extractOpTrailingIdx, |
| 1083 | foundIndexOp, resType); |
| 1084 | // TODO: Support generating contiguous loads for column vectors - that will |
| 1085 | // require adding a permutation map to tranfer_read Ops. |
| 1086 | bool isRowVector = resType.getShape().back() != 1; |
| 1087 | isContiguousLoad &= (foundIndexOp && isRowVector); |
| 1088 | |
| 1089 | if (isContiguousLoad) { |
| 1090 | LDBG("Found contigous load: " << extractOp); |
| 1091 | return VectorMemoryAccessKind::Contiguous; |
| 1092 | } |
| 1093 | |
| 1094 | // 4. Fallback case - gather load. |
| 1095 | LDBG("Found gather load: " << extractOp); |
| 1096 | return VectorMemoryAccessKind::Gather; |
| 1097 | } |
| 1098 | |
| 1099 | /// Helper function to vectorize the tensor.extract operations. Returns |
| 1100 | /// VectorizationHookStatus::NewOp to signal the vectorization algorithm that it |
| 1101 | /// should map the produced operations. This function is meant to be used as a |
| 1102 | /// CustomVectorizationHook. |
| 1103 | static VectorizationHookResult |
| 1104 | (RewriterBase &rewriter, VectorizationState &state, |
| 1105 | Operation *op, LinalgOp linalgOp, const IRMapping &bvm) { |
| 1106 | tensor::ExtractOp = dyn_cast<tensor::ExtractOp>(Val: op); |
| 1107 | if (!extractOp) |
| 1108 | return VectorizationHookResult{.status: VectorizationHookStatus::Failure, .newOp: nullptr}; |
| 1109 | auto loc = extractOp.getLoc(); |
| 1110 | |
| 1111 | // Compute the static loop sizes of the extract op. |
| 1112 | auto resultType = state.getCanonicalVecType(elementType: extractOp.getResult().getType()); |
| 1113 | auto maskConstantOp = rewriter.create<arith::ConstantOp>( |
| 1114 | location: loc, |
| 1115 | args: DenseIntElementsAttr::get(type: state.getCanonicalVecType(elementType: rewriter.getI1Type()), |
| 1116 | /*value=*/arg: true)); |
| 1117 | auto passThruConstantOp = |
| 1118 | rewriter.create<arith::ConstantOp>(location: loc, args: rewriter.getZeroAttr(type: resultType)); |
| 1119 | |
| 1120 | // Base indices are currently set to 0. We will need to re-visit if more |
| 1121 | // generic scenarios are to be supported. |
| 1122 | SmallVector<Value> baseIndices( |
| 1123 | extractOp.getIndices().size(), |
| 1124 | rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0)); |
| 1125 | |
| 1126 | VectorMemoryAccessKind memAccessKind = |
| 1127 | getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resType: resultType); |
| 1128 | |
| 1129 | // 1. Handle gather access |
| 1130 | if (memAccessKind == VectorMemoryAccessKind::Gather) { |
| 1131 | Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm); |
| 1132 | |
| 1133 | // Generate the gather load |
| 1134 | Operation *gatherOp = rewriter.create<vector::GatherOp>( |
| 1135 | location: loc, args&: resultType, args: extractOp.getTensor(), args&: baseIndices, args&: offset, |
| 1136 | args&: maskConstantOp, args&: passThruConstantOp); |
| 1137 | gatherOp = state.maskOperation(rewriter, opToMask: gatherOp, linalgOp); |
| 1138 | |
| 1139 | LDBG("Vectorised as gather load: " << extractOp << "\n" ); |
| 1140 | return VectorizationHookResult{.status: VectorizationHookStatus::NewOp, .newOp: gatherOp}; |
| 1141 | } |
| 1142 | |
| 1143 | // 2. Handle: |
| 1144 | // a. scalar loads + broadcast, |
| 1145 | // b. contiguous loads. |
| 1146 | // Both cases use vector.transfer_read. |
| 1147 | |
| 1148 | // Collect indices for `vector.transfer_read`. At this point, the indices will |
| 1149 | // either be scalars or would have been broadcast to vectors matching the |
| 1150 | // result type. For indices that are vectors, there are two options: |
| 1151 | // * for non-trailing indices, all elements are identical (contiguous |
| 1152 | // loads are identified by looking for non-trailing indices that are |
| 1153 | // invariant with respect to the corresponding linalg.generic), or |
| 1154 | // * for trailing indices, the index vector will contain values with stride |
| 1155 | // one, but for `vector.transfer_read` only the first (i.e. 0th) index is |
| 1156 | // needed. |
| 1157 | // This means that |
| 1158 | // * for scalar indices - just re-use it, |
| 1159 | // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom |
| 1160 | // (0th) element and use that. |
| 1161 | SmallVector<Value> transferReadIdxs; |
| 1162 | for (size_t i = 0; i < extractOp.getIndices().size(); i++) { |
| 1163 | Value idx = bvm.lookup(from: extractOp.getIndices()[i]); |
| 1164 | if (idx.getType().isIndex()) { |
| 1165 | transferReadIdxs.push_back(Elt: idx); |
| 1166 | continue; |
| 1167 | } |
| 1168 | |
| 1169 | auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>( |
| 1170 | location: loc, |
| 1171 | args: VectorType::get(shape: resultType.getShape().back(), elementType: rewriter.getIndexType(), |
| 1172 | scalableDims: resultType.getScalableDims().back()), |
| 1173 | args&: idx); |
| 1174 | transferReadIdxs.push_back( |
| 1175 | Elt: rewriter.create<vector::ExtractOp>(location: loc, args&: indexAs1dVector, args: 0)); |
| 1176 | } |
| 1177 | |
| 1178 | // `tensor.extract_element` is always in-bounds, hence the following holds. |
| 1179 | auto dstRank = resultType.getRank(); |
| 1180 | auto srcRank = extractOp.getTensor().getType().getRank(); |
| 1181 | SmallVector<bool> inBounds(dstRank, true); |
| 1182 | |
| 1183 | // 2a. Handle scalar broadcast access. |
| 1184 | if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) { |
| 1185 | MLIRContext *ctx = rewriter.getContext(); |
| 1186 | SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(constant: 0, context: ctx)); |
| 1187 | auto permutationMap = AffineMap::get(dimCount: srcRank, symbolCount: 0, results: exprs, context: ctx); |
| 1188 | |
| 1189 | auto transferReadOp = rewriter.create<vector::TransferReadOp>( |
| 1190 | location: loc, args&: resultType, args: extractOp.getTensor(), args&: transferReadIdxs, |
| 1191 | /*padding=*/args: std::nullopt, args&: permutationMap, args&: inBounds); |
| 1192 | |
| 1193 | // Mask this broadcasting xfer_read here rather than relying on the generic |
| 1194 | // path (the generic path assumes identity masking map, which wouldn't be |
| 1195 | // valid here). |
| 1196 | SmallVector<int64_t> readMaskShape = {1}; |
| 1197 | auto readMaskType = VectorType::get(shape: readMaskShape, elementType: rewriter.getI1Type()); |
| 1198 | auto allTrue = rewriter.create<vector::ConstantMaskOp>( |
| 1199 | location: loc, args&: readMaskType, args: vector::ConstantMaskKind::AllTrue); |
| 1200 | auto *maskedReadOp = |
| 1201 | mlir::vector::maskOperation(builder&: rewriter, maskableOp: transferReadOp, mask: allTrue); |
| 1202 | |
| 1203 | LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n" ); |
| 1204 | return VectorizationHookResult{.status: VectorizationHookStatus::NewOp, |
| 1205 | .newOp: maskedReadOp}; |
| 1206 | } |
| 1207 | |
| 1208 | // 2b. Handle contiguous access. |
| 1209 | auto permutationMap = AffineMap::getMinorIdentityMap( |
| 1210 | dims: srcRank, results: std::min(a: dstRank, b: srcRank), context: rewriter.getContext()); |
| 1211 | |
| 1212 | int32_t rankDiff = dstRank - srcRank; |
| 1213 | // When dstRank > srcRank, broadcast the source tensor to the unitary leading |
| 1214 | // dims so that the ranks match. This is done by extending the map with 0s. |
| 1215 | // For example, for dstRank = 3, srcRank = 2, the following map created |
| 1216 | // above: |
| 1217 | // (d0, d1) --> (d0, d1) |
| 1218 | // is extended as: |
| 1219 | // (d0, d1) --> (0, d0, d1) |
| 1220 | while (rankDiff > 0) { |
| 1221 | permutationMap = permutationMap.insertResult( |
| 1222 | expr: mlir::getAffineConstantExpr(constant: 0, context: rewriter.getContext()), pos: 0); |
| 1223 | rankDiff--; |
| 1224 | } |
| 1225 | |
| 1226 | auto transferReadOp = rewriter.create<vector::TransferReadOp>( |
| 1227 | location: loc, args&: resultType, args: extractOp.getTensor(), args&: transferReadIdxs, |
| 1228 | /*padding=*/args: std::nullopt, args&: permutationMap, args&: inBounds); |
| 1229 | |
| 1230 | LDBG("Vectorised as contiguous load: " << extractOp); |
| 1231 | return VectorizationHookResult{.status: VectorizationHookStatus::NewOp, |
| 1232 | .newOp: transferReadOp}; |
| 1233 | } |
| 1234 | |
| 1235 | /// Emit reduction operations if the shapes of the value to reduce is different |
| 1236 | /// that the result shape. |
| 1237 | // Note: this is a true builder that notifies the OpBuilder listener. |
| 1238 | // TODO: Consider moving as a static helper on the ReduceOp. |
| 1239 | static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, |
| 1240 | Value reduceValue, Value initialValue, |
| 1241 | const IRMapping &bvm) { |
| 1242 | Value reduceVec = bvm.lookup(from: reduceValue); |
| 1243 | Value outputVec = bvm.lookup(from: initialValue); |
| 1244 | auto reduceType = dyn_cast<VectorType>(Val: reduceVec.getType()); |
| 1245 | auto outputType = dyn_cast<VectorType>(Val: outputVec.getType()); |
| 1246 | // Reduce only if needed as the value may already have been reduce for |
| 1247 | // contraction vectorization. |
| 1248 | if (!reduceType || |
| 1249 | (outputType && reduceType.getShape() == outputType.getShape())) |
| 1250 | return nullptr; |
| 1251 | SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp); |
| 1252 | return buildMultiDimReduce(b, reduceOp: op, valueToReduce: reduceVec, acc: outputVec, dimsToMask); |
| 1253 | } |
| 1254 | |
| 1255 | /// Generic vectorization for a single operation `op`, given already vectorized |
| 1256 | /// operands carried by `bvm`. Vectorization occurs as follows: |
| 1257 | /// 1. Try to apply any of the `customVectorizationHooks` and return its |
| 1258 | /// result on success. |
| 1259 | /// 2. Clone any constant in the current scope without vectorization: each |
| 1260 | /// consumer of the constant will later determine the shape to which the |
| 1261 | /// constant needs to be broadcast to. |
| 1262 | /// 3. Fail on any remaining non `ElementwiseMappable` op. It is the purpose |
| 1263 | /// of the `customVectorizationHooks` to cover such cases. |
| 1264 | /// 4. Clone `op` in vector form to a vector of shape prescribed by the first |
| 1265 | /// operand of maximal rank. Other operands have smaller rank and are |
| 1266 | /// broadcast accordingly. It is assumed this broadcast is always legal, |
| 1267 | /// otherwise, it means one of the `customVectorizationHooks` is incorrect. |
| 1268 | /// |
| 1269 | /// This function assumes all operands of `op` have been vectorized and are in |
| 1270 | /// the `bvm` mapping. As a consequence, this function is meant to be called on |
| 1271 | /// a topologically-sorted list of ops. |
| 1272 | /// This function does not update `bvm` but returns a VectorizationHookStatus |
| 1273 | /// that instructs the caller what `bvm` update needs to occur. |
| 1274 | static VectorizationHookResult |
| 1275 | vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, |
| 1276 | LinalgOp linalgOp, Operation *op, const IRMapping &bvm, |
| 1277 | ArrayRef<CustomVectorizationHook> customVectorizationHooks) { |
| 1278 | LDBG("vectorize op " << *op << "\n" ); |
| 1279 | |
| 1280 | // 1. Try to apply any CustomVectorizationHook. |
| 1281 | if (!customVectorizationHooks.empty()) { |
| 1282 | for (auto &customFunc : customVectorizationHooks) { |
| 1283 | VectorizationHookResult result = customFunc(op, bvm); |
| 1284 | if (result.status == VectorizationHookStatus::Failure) |
| 1285 | continue; |
| 1286 | return result; |
| 1287 | } |
| 1288 | } |
| 1289 | |
| 1290 | // 2. Constant ops don't get vectorized but rather broadcasted at their users. |
| 1291 | // Clone so that the constant is not confined to the linalgOp block . |
| 1292 | if (isa<arith::ConstantOp, func::ConstantOp>(Val: op)) |
| 1293 | return VectorizationHookResult{.status: VectorizationHookStatus::NewOp, |
| 1294 | .newOp: rewriter.clone(op&: *op)}; |
| 1295 | |
| 1296 | // 3. Only ElementwiseMappable are allowed in the generic vectorization. |
| 1297 | if (!OpTrait::hasElementwiseMappableTraits(op)) |
| 1298 | return VectorizationHookResult{.status: VectorizationHookStatus::Failure, .newOp: nullptr}; |
| 1299 | |
| 1300 | // 4 . Check if the operation is a reduction. |
| 1301 | SmallVector<std::pair<Value, Value>> reductionOperands; |
| 1302 | for (Value operand : op->getOperands()) { |
| 1303 | auto blockArg = dyn_cast<BlockArgument>(Val&: operand); |
| 1304 | if (!blockArg || blockArg.getOwner() != linalgOp.getBlock() || |
| 1305 | blockArg.getArgNumber() < linalgOp.getNumDpsInputs()) |
| 1306 | continue; |
| 1307 | SmallVector<Operation *> reductionOps; |
| 1308 | Value reduceValue = matchReduction( |
| 1309 | iterCarriedArgs: linalgOp.getRegionOutputArgs(), |
| 1310 | redPos: blockArg.getArgNumber() - linalgOp.getNumDpsInputs(), combinerOps&: reductionOps); |
| 1311 | if (!reduceValue) |
| 1312 | continue; |
| 1313 | reductionOperands.push_back(Elt: std::make_pair(x&: reduceValue, y&: operand)); |
| 1314 | } |
| 1315 | if (!reductionOperands.empty()) { |
| 1316 | assert(reductionOperands.size() == 1); |
| 1317 | Operation *reduceOp = |
| 1318 | reduceIfNeeded(b&: rewriter, linalgOp, op, reduceValue: reductionOperands[0].first, |
| 1319 | initialValue: reductionOperands[0].second, bvm); |
| 1320 | if (reduceOp) |
| 1321 | return VectorizationHookResult{.status: VectorizationHookStatus::NewOp, .newOp: reduceOp}; |
| 1322 | } |
| 1323 | |
| 1324 | // 5. Generic vectorization path for ElementwiseMappable ops. |
| 1325 | // a. Get the first max ranked shape. |
| 1326 | VectorType firstMaxRankedType; |
| 1327 | for (Value operand : op->getOperands()) { |
| 1328 | auto vecOperand = bvm.lookup(from: operand); |
| 1329 | assert(vecOperand && "Vector operand couldn't be found" ); |
| 1330 | |
| 1331 | auto vecType = dyn_cast<VectorType>(Val: vecOperand.getType()); |
| 1332 | if (vecType && (!firstMaxRankedType || |
| 1333 | firstMaxRankedType.getRank() < vecType.getRank())) |
| 1334 | firstMaxRankedType = vecType; |
| 1335 | } |
| 1336 | // b. Broadcast each op if needed. |
| 1337 | SmallVector<Value> vecOperands; |
| 1338 | for (Value scalarOperand : op->getOperands()) { |
| 1339 | Value vecOperand = bvm.lookup(from: scalarOperand); |
| 1340 | assert(vecOperand && "Vector operand couldn't be found" ); |
| 1341 | |
| 1342 | if (firstMaxRankedType) { |
| 1343 | auto vecType = VectorType::get(shape: firstMaxRankedType.getShape(), |
| 1344 | elementType: getElementTypeOrSelf(type: vecOperand.getType()), |
| 1345 | scalableDims: firstMaxRankedType.getScalableDims()); |
| 1346 | vecOperands.push_back(Elt: broadcastIfNeeded(b&: rewriter, value: vecOperand, dstType: vecType)); |
| 1347 | } else { |
| 1348 | vecOperands.push_back(Elt: vecOperand); |
| 1349 | } |
| 1350 | } |
| 1351 | // c. for elementwise, the result is the vector with the firstMaxRankedShape |
| 1352 | SmallVector<Type> resultTypes; |
| 1353 | for (Type resultType : op->getResultTypes()) { |
| 1354 | resultTypes.push_back( |
| 1355 | Elt: firstMaxRankedType |
| 1356 | ? VectorType::get(shape: firstMaxRankedType.getShape(), elementType: resultType, |
| 1357 | scalableDims: firstMaxRankedType.getScalableDims()) |
| 1358 | : resultType); |
| 1359 | } |
| 1360 | // d. Build and return the new op. |
| 1361 | return VectorizationHookResult{ |
| 1362 | .status: VectorizationHookStatus::NewOp, |
| 1363 | .newOp: rewriter.create(loc: op->getLoc(), opName: op->getName().getIdentifier(), operands: vecOperands, |
| 1364 | types: resultTypes, attributes: op->getAttrs())}; |
| 1365 | } |
| 1366 | |
| 1367 | /// Generic vectorization function that rewrites the body of a `linalgOp` into |
| 1368 | /// vector form. Generic vectorization proceeds as follows: |
| 1369 | /// 1. Verify the `linalgOp` has one non-empty region. |
| 1370 | /// 2. Values defined above the region are mapped to themselves and will be |
| 1371 | /// broadcasted on a per-need basis by their consumers. |
| 1372 | /// 3. Each region argument is vectorized into a vector.transfer_read (or 0-d |
| 1373 | /// load). |
| 1374 | /// TODO: Reuse opportunities for RAR dependencies. |
| 1375 | /// 4a. Register CustomVectorizationHook for YieldOp to capture the results. |
| 1376 | /// 4rewriter. Register CustomVectorizationHook for IndexOp to access the |
| 1377 | /// iteration indices. |
| 1378 | /// 5. Iteratively call vectorizeOneOp on the region operations. |
| 1379 | /// |
| 1380 | /// When `broadcastToMaximalCommonShape` is set to true, eager broadcasting is |
| 1381 | /// performed to the maximal common vector size implied by the `linalgOp` |
| 1382 | /// iteration space. This eager broadcasting is introduced in the |
| 1383 | /// permutation_map of the vector.transfer_read operations. The eager |
| 1384 | /// broadcasting makes it trivial to determine where broadcast, transposes and |
| 1385 | /// reductions should occur, without any bookkeeping. The tradeoff is that, in |
| 1386 | /// the absence of good canonicalizations, the amount of work increases. |
| 1387 | /// This is not deemed a problem as we expect canonicalizations and foldings to |
| 1388 | /// aggressively clean up the useless work. |
| 1389 | static LogicalResult |
| 1390 | vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, |
| 1391 | LinalgOp linalgOp, |
| 1392 | SmallVectorImpl<Value> &newResults) { |
| 1393 | LDBG("Vectorizing operation as linalg generic\n" ); |
| 1394 | Block *block = linalgOp.getBlock(); |
| 1395 | |
| 1396 | // 2. Values defined above the region can only be broadcast for now. Make them |
| 1397 | // map to themselves. |
| 1398 | IRMapping bvm; |
| 1399 | SetVector<Value> valuesSet; |
| 1400 | mlir::getUsedValuesDefinedAbove(regions: linalgOp->getRegion(index: 0), values&: valuesSet); |
| 1401 | bvm.map(from: valuesSet.getArrayRef(), to: valuesSet.getArrayRef()); |
| 1402 | |
| 1403 | if (linalgOp.getNumDpsInits() == 0) |
| 1404 | return failure(); |
| 1405 | |
| 1406 | // 3. Turn all BBArgs into vector.transfer_read / load. |
| 1407 | Location loc = linalgOp.getLoc(); |
| 1408 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 1409 | for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { |
| 1410 | BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand); |
| 1411 | if (linalgOp.isScalar(opOperand)) { |
| 1412 | bvm.map(from: bbarg, to: opOperand->get()); |
| 1413 | continue; |
| 1414 | } |
| 1415 | |
| 1416 | // 3.a. Convert the indexing map for this input/output to a transfer read |
| 1417 | // permutation map and masking map. |
| 1418 | AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); |
| 1419 | |
| 1420 | AffineMap readMap; |
| 1421 | VectorType readType; |
| 1422 | Type elemType = getElementTypeOrSelf(val: opOperand->get()); |
| 1423 | if (linalgOp.isDpsInput(opOperand)) { |
| 1424 | // 3.a.i. For input reads we use the canonical vector shape. |
| 1425 | readMap = inverseAndBroadcastProjectedPermutation(map: indexingMap); |
| 1426 | readType = state.getCanonicalVecType(elementType: elemType); |
| 1427 | } else { |
| 1428 | // 3.a.ii. For output reads (iteration-carried dependence, e.g., |
| 1429 | // reductions), the vector shape is computed by mapping the canonical |
| 1430 | // vector shape to the output domain and back to the canonical domain. |
| 1431 | readMap = inversePermutation(map: reindexIndexingMap(map: indexingMap)); |
| 1432 | readType = |
| 1433 | state.getCanonicalVecType(elementType: elemType, dimPermutation: readMap.compose(map: indexingMap)); |
| 1434 | } |
| 1435 | |
| 1436 | SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero); |
| 1437 | |
| 1438 | Operation *read = rewriter.create<vector::TransferReadOp>( |
| 1439 | location: loc, args&: readType, args: opOperand->get(), args&: indices, |
| 1440 | /*padding=*/args: std::nullopt, args&: readMap); |
| 1441 | read = state.maskOperation(rewriter, opToMask: read, linalgOp, maybeIndexingMap: indexingMap); |
| 1442 | Value readValue = read->getResult(idx: 0); |
| 1443 | |
| 1444 | // 3.b. If masked, set in-bounds to true. Masking guarantees that the access |
| 1445 | // will be in-bounds. |
| 1446 | if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(Val: read)) { |
| 1447 | SmallVector<bool> inBounds(readType.getRank(), true); |
| 1448 | cast<vector::TransferReadOp>(Val: maskOp.getMaskableOp()) |
| 1449 | .setInBoundsAttr(rewriter.getBoolArrayAttr(values: inBounds)); |
| 1450 | } |
| 1451 | |
| 1452 | // 3.c. Not all ops support 0-d vectors, extract the scalar for now. |
| 1453 | // TODO: remove this. |
| 1454 | if (readType.getRank() == 0) |
| 1455 | readValue = rewriter.create<vector::ExtractOp>(location: loc, args&: readValue, |
| 1456 | args: ArrayRef<int64_t>()); |
| 1457 | |
| 1458 | LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue |
| 1459 | << "\n" ); |
| 1460 | bvm.map(from: bbarg, to: readValue); |
| 1461 | bvm.map(from: opOperand->get(), to: readValue); |
| 1462 | } |
| 1463 | |
| 1464 | SmallVector<CustomVectorizationHook> hooks; |
| 1465 | // 4a. Register CustomVectorizationHook for yieldOp. |
| 1466 | CustomVectorizationHook vectorizeYield = |
| 1467 | [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult { |
| 1468 | return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults); |
| 1469 | }; |
| 1470 | hooks.push_back(Elt: vectorizeYield); |
| 1471 | |
| 1472 | // 4b. Register CustomVectorizationHook for indexOp. |
| 1473 | CustomVectorizationHook vectorizeIndex = |
| 1474 | [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult { |
| 1475 | return vectorizeLinalgIndex(rewriter, state, op, linalgOp); |
| 1476 | }; |
| 1477 | hooks.push_back(Elt: vectorizeIndex); |
| 1478 | |
| 1479 | // 4c. Register CustomVectorizationHook for extractOp. |
| 1480 | CustomVectorizationHook = |
| 1481 | [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult { |
| 1482 | return vectorizeTensorExtract(rewriter, state, op, linalgOp, bvm); |
| 1483 | }; |
| 1484 | hooks.push_back(Elt: vectorizeExtract); |
| 1485 | |
| 1486 | // 5. Iteratively call `vectorizeOneOp` to each op in the slice. |
| 1487 | for (Operation &op : block->getOperations()) { |
| 1488 | VectorizationHookResult result = |
| 1489 | vectorizeOneOp(rewriter, state, linalgOp, op: &op, bvm, customVectorizationHooks: hooks); |
| 1490 | if (result.status == VectorizationHookStatus::Failure) { |
| 1491 | LDBG("failed to vectorize: " << op << "\n" ); |
| 1492 | return failure(); |
| 1493 | } |
| 1494 | if (result.status == VectorizationHookStatus::NewOp) { |
| 1495 | Operation *maybeMaskedOp = |
| 1496 | state.maskOperation(rewriter, opToMask: result.newOp, linalgOp); |
| 1497 | LDBG("New vector op: " << *maybeMaskedOp << "\n" ); |
| 1498 | bvm.map(from: op.getResults(), to: maybeMaskedOp->getResults()); |
| 1499 | } |
| 1500 | } |
| 1501 | |
| 1502 | return success(); |
| 1503 | } |
| 1504 | |
| 1505 | /// Given a linalg::PackOp, return the `dest` shape before any packing |
| 1506 | /// permutations. |
| 1507 | static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp, |
| 1508 | ArrayRef<int64_t> destShape) { |
| 1509 | return applyPermutation(input: destShape, permutation: linalg::getPackInverseDestPerm(packOp)); |
| 1510 | } |
| 1511 | |
| 1512 | /// Determines whether a mask for xfer_write is trivially "all true" |
| 1513 | /// |
| 1514 | /// Given all the inputs required to generate a mask (mask sizes and shapes), |
| 1515 | /// and an xfer_write operation (write indices and the destination tensor |
| 1516 | /// shape), determines whether the corresponding mask would be trivially |
| 1517 | /// foldable (i.e., trivially "all true"). |
| 1518 | /// |
| 1519 | /// Use this method to avoid generating spurious masks and relaying on |
| 1520 | /// vectorization post-processing to remove them. |
| 1521 | /// |
| 1522 | /// Pre-conditions for a mask to be trivially foldable: |
| 1523 | /// * All involved shapes (mask + destination tensor) are static. |
| 1524 | /// * All write indices are constant. |
| 1525 | /// * All mask sizes are constant (including `arith.constant`). |
| 1526 | /// |
| 1527 | /// If the pre-conditions are met, the method checks for each destination |
| 1528 | /// dimension `d`: |
| 1529 | /// (1) destDimSize[rankDiff + d] <= maskShape[d] |
| 1530 | /// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d] |
| 1531 | /// |
| 1532 | /// rankDiff = rank(dest) - rank(mask). |
| 1533 | /// |
| 1534 | /// This method takes a conservative view: it may return false even if the mask |
| 1535 | /// is technically foldable. |
| 1536 | /// |
| 1537 | /// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape |
| 1538 | /// of the dest tensor): |
| 1539 | /// %c0 = arith.constant 0 : index |
| 1540 | /// %mask = vector.create_mask 5, 1 |
| 1541 | /// vector.mask %mask { |
| 1542 | /// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0] |
| 1543 | /// {in_bounds = [true, true]} |
| 1544 | /// : vector<5x1xi32>, tensor<5x1xi32> |
| 1545 | /// } |
| 1546 | /// |
| 1547 | /// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape, |
| 1548 | /// mask is required to avoid out-of-bounds write): |
| 1549 | /// %c0 = arith.constant 0 : index |
| 1550 | /// %mask = vector.create_mask 5, 1 |
| 1551 | /// vector.mask %mask { |
| 1552 | /// vector.transfer_write %vecToStore_2, %dest[%c0, %c0] |
| 1553 | /// {in_bounds = [true, true]} |
| 1554 | /// : vector<8x1xi32>, tensor<5x1xi32> |
| 1555 | /// } |
| 1556 | /// |
| 1557 | /// TODO: Re-use in createReadOrMaskedRead |
| 1558 | static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes, |
| 1559 | SmallVector<Value> &writeIdxs, |
| 1560 | ArrayRef<int64_t> destShape, |
| 1561 | ArrayRef<int64_t> maskShape) { |
| 1562 | // Masking is unavoidable in the case of dynamic tensors. |
| 1563 | if (ShapedType::isDynamicShape(dSizes: destShape)) |
| 1564 | return false; |
| 1565 | |
| 1566 | // Collect all constant mask sizes. |
| 1567 | SmallVector<int64_t, 4> cstMaskSizes; |
| 1568 | for (auto [i, dimSize] : llvm::enumerate(First&: maskSizes)) { |
| 1569 | if (auto intSize = getConstantIntValue(ofr: dimSize)) { |
| 1570 | cstMaskSizes.push_back(Elt: *intSize); |
| 1571 | } |
| 1572 | } |
| 1573 | |
| 1574 | // If any of the mask sizes is non-constant, bail out. |
| 1575 | if (cstMaskSizes.size() != maskShape.size()) |
| 1576 | return false; |
| 1577 | |
| 1578 | // Collect all constant write indices. |
| 1579 | SmallVector<int64_t, 4> cstWriteIdxs; |
| 1580 | for (auto [i, idx] : llvm::enumerate(First&: writeIdxs)) { |
| 1581 | APSInt intVal; |
| 1582 | if (matchPattern(value: idx, pattern: m_ConstantInt(bind_value: &intVal))) { |
| 1583 | cstWriteIdxs.push_back(Elt: intVal.getSExtValue()); |
| 1584 | } |
| 1585 | } |
| 1586 | |
| 1587 | // If any of the write indices is non-constant, bail out. |
| 1588 | if (cstWriteIdxs.size() != destShape.size()) |
| 1589 | return false; |
| 1590 | |
| 1591 | // Go over all destination dims and check (1) and (2). Take into account that: |
| 1592 | // * The number of mask sizes will match the rank of the vector to store. |
| 1593 | // This could be lower than the rank of the destination tensor. |
| 1594 | // * Mask sizes could be larger than the corresponding mask shape (hence |
| 1595 | // `clamp`). |
| 1596 | // TODO: The 2nd item should be rejected by the verifier. |
| 1597 | int64_t rankDiff = destShape.size() - cstMaskSizes.size(); |
| 1598 | for (auto [i, idx] : llvm::enumerate(First&: cstMaskSizes)) { |
| 1599 | if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] || |
| 1600 | /*(2)*/ destShape[rankDiff + i] < |
| 1601 | (std::clamp(val: cstMaskSizes[i], lo: int64_t(0), hi: maskShape[i]) + |
| 1602 | cstWriteIdxs[i])) |
| 1603 | return false; |
| 1604 | } |
| 1605 | |
| 1606 | return true; |
| 1607 | } |
| 1608 | |
| 1609 | /// Creates an optionally masked TransferWriteOp |
| 1610 | /// |
| 1611 | /// Generates the following operation: |
| 1612 | /// %res = vector.transfer_write %vecToStore into %dest |
| 1613 | /// |
| 1614 | /// If shape(vecToStore) != shape(dest), masking is used to ensure correctness: |
| 1615 | /// |
| 1616 | /// %mask = vector.create_mask(%destShape) : %vecToStoreShape |
| 1617 | /// %res = vector.mask %mask { |
| 1618 | /// vector.transfer_write %vecToStore into %dest |
| 1619 | /// } |
| 1620 | /// |
| 1621 | /// The mask shape is identical to `vecToStore` (with the element type == |
| 1622 | /// i1), and the mask values are based on the shape of the `dest` tensor. |
| 1623 | /// |
| 1624 | /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute |
| 1625 | /// is used instead of masking: |
| 1626 | /// |
| 1627 | /// %write = vector.transfer_write %vecToStore into %dest |
| 1628 | /// in_bounds_flags = (...) |
| 1629 | /// %res = vector.transfer_write %input into %dest |
| 1630 | /// {in_bounds = in_bounds_flags} |
| 1631 | /// |
| 1632 | /// Finally, `writeIndices` specifies the offsets to use. If empty, all indices |
| 1633 | /// are set to 0. |
| 1634 | static Operation * |
| 1635 | createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, |
| 1636 | Value dest, SmallVector<Value> writeIndices = {}, |
| 1637 | bool useInBoundsInsteadOfMasking = false) { |
| 1638 | |
| 1639 | ShapedType destType = cast<ShapedType>(Val: dest.getType()); |
| 1640 | int64_t destRank = destType.getRank(); |
| 1641 | auto destShape = destType.getShape(); |
| 1642 | |
| 1643 | VectorType vecToStoreType = cast<VectorType>(Val: vecToStore.getType()); |
| 1644 | int64_t vecToStoreRank = vecToStoreType.getRank(); |
| 1645 | auto vecToStoreShape = vecToStoreType.getShape(); |
| 1646 | |
| 1647 | // Compute the in_bounds attribute |
| 1648 | SmallVector<bool> inBoundsVal(vecToStoreRank, true); |
| 1649 | if (useInBoundsInsteadOfMasking) { |
| 1650 | // Update the inBounds attribute. |
| 1651 | // FIXME: This computation is too weak - it ignores the write indices. |
| 1652 | for (unsigned i = 0; i < vecToStoreRank; i++) |
| 1653 | inBoundsVal[i] = |
| 1654 | (destShape[destRank - vecToStoreRank + i] >= vecToStoreShape[i]) && |
| 1655 | ShapedType::isStatic(dValue: destShape[destRank - vecToStoreRank + i]); |
| 1656 | } |
| 1657 | |
| 1658 | // If missing, initialize the write indices to 0. |
| 1659 | assert((writeIndices.empty() || |
| 1660 | writeIndices.size() == static_cast<size_t>(destRank)) && |
| 1661 | "Invalid number of write indices!" ); |
| 1662 | if (writeIndices.empty()) { |
| 1663 | auto zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 1664 | writeIndices.assign(NumElts: destRank, Elt: zero); |
| 1665 | } |
| 1666 | |
| 1667 | // Generate the xfer_write Op |
| 1668 | Operation *write = |
| 1669 | builder.create<vector::TransferWriteOp>(location: loc, |
| 1670 | /*vector=*/args&: vecToStore, |
| 1671 | /*source=*/args&: dest, |
| 1672 | /*indices=*/args&: writeIndices, |
| 1673 | /*inBounds=*/args&: inBoundsVal); |
| 1674 | |
| 1675 | // If masking is disabled, exit. |
| 1676 | if (useInBoundsInsteadOfMasking) |
| 1677 | return write; |
| 1678 | |
| 1679 | // Check if masking is needed. If not, exit. |
| 1680 | if (llvm::equal(LRange&: vecToStoreShape, RRange: destShape.take_back(N: vecToStoreRank))) |
| 1681 | return write; |
| 1682 | |
| 1683 | // Compute the mask and mask the write Op. |
| 1684 | auto writeMaskType = VectorType::get(shape: vecToStoreShape, elementType: builder.getI1Type()); |
| 1685 | |
| 1686 | SmallVector<OpFoldResult> destSizes = |
| 1687 | tensor::getMixedSizes(builder, loc, value: dest); |
| 1688 | SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank, |
| 1689 | destSizes.end()); |
| 1690 | |
| 1691 | if (isMaskTriviallyFoldable(maskSizes, writeIdxs&: writeIndices, destShape, |
| 1692 | maskShape: vecToStoreShape)) |
| 1693 | return write; |
| 1694 | |
| 1695 | Value maskForWrite = |
| 1696 | builder.createOrFold<vector::CreateMaskOp>(location: loc, args&: writeMaskType, args&: maskSizes); |
| 1697 | return mlir::vector::maskOperation(builder, maskableOp: write, mask: maskForWrite); |
| 1698 | } |
| 1699 | |
| 1700 | /// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant |
| 1701 | /// padding value and (3) input vector sizes into: |
| 1702 | /// |
| 1703 | /// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds |
| 1704 | /// |
| 1705 | /// As in the following example: |
| 1706 | /// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2] |
| 1707 | /// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32> |
| 1708 | /// |
| 1709 | /// This pack would be vectorized to: |
| 1710 | /// |
| 1711 | /// %load = vector.mask %mask { |
| 1712 | /// vector.transfer_read %arg0[%c0, %c0, %c0], %cst |
| 1713 | /// {in_bounds = [true, true, true]} : |
| 1714 | /// tensor<32x7x16xf32>, vector<32x8x16xf32> |
| 1715 | /// } : vector<32x8x16xi1> -> vector<32x8x16xf32> |
| 1716 | /// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32> |
| 1717 | /// to vector<32x4x2x1x16xf32> |
| 1718 | /// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2] |
| 1719 | /// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> |
| 1720 | /// %write = vector.transfer_write %transpose, |
| 1721 | /// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0] |
| 1722 | /// {in_bounds = [true, true, true, true, true]} |
| 1723 | /// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> |
| 1724 | /// |
| 1725 | /// If the (3) input vector sizes are not provided, the vector sizes are |
| 1726 | /// determined by the result tensor shape and the `in_bounds` |
| 1727 | /// attribute is used instead of masking to mark out-of-bounds accesses. |
| 1728 | /// |
| 1729 | /// NOTE: The input vector sizes specify the dimensions corresponding to the |
| 1730 | /// outer dimensions of the output tensor. The remaining dimensions are |
| 1731 | /// computed based on, e.g., the static inner tiles. |
| 1732 | /// Supporting dynamic inner tiles will require the user to specify the |
| 1733 | /// missing vector sizes. This is left as a TODO. |
| 1734 | static LogicalResult |
| 1735 | vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, |
| 1736 | ArrayRef<int64_t> inputVectorSizes, |
| 1737 | SmallVectorImpl<Value> &newResults) { |
| 1738 | // TODO: Introduce a parent class that will handle the insertion point update. |
| 1739 | OpBuilder::InsertionGuard g(rewriter); |
| 1740 | rewriter.setInsertionPoint(packOp); |
| 1741 | |
| 1742 | Location loc = packOp.getLoc(); |
| 1743 | auto padValue = packOp.getPaddingValue(); |
| 1744 | if (!padValue) { |
| 1745 | padValue = rewriter.create<arith::ConstantOp>( |
| 1746 | location: loc, args: rewriter.getZeroAttr(type: packOp.getSourceType().getElementType())); |
| 1747 | } |
| 1748 | ReifiedRankedShapedTypeDims reifiedReturnShapes; |
| 1749 | LogicalResult status = |
| 1750 | cast<ReifyRankedShapedTypeOpInterface>(Val: packOp.getOperation()) |
| 1751 | .reifyResultShapes(builder&: rewriter, reifiedReturnShapes); |
| 1752 | (void)status; // prevent unused variable warning on non-assert builds. |
| 1753 | assert(succeeded(status) && "failed to reify result shapes" ); |
| 1754 | |
| 1755 | // If the input vector sizes are not provided, then the vector sizes are |
| 1756 | // determined by the result tensor shape. In case the vector sizes aren't |
| 1757 | // provided, we update the inBounds attribute instead of masking. |
| 1758 | bool useInBoundsInsteadOfMasking = false; |
| 1759 | if (inputVectorSizes.empty()) { |
| 1760 | ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); |
| 1761 | inputVectorSizes = resultTensorShape.take_front(N: packOp.getSourceRank()); |
| 1762 | useInBoundsInsteadOfMasking = true; |
| 1763 | } |
| 1764 | |
| 1765 | // Create masked TransferReadOp. |
| 1766 | SmallVector<int64_t> inputShape(inputVectorSizes); |
| 1767 | auto innerTiles = packOp.getStaticInnerTiles(); |
| 1768 | auto innerDimsPos = packOp.getInnerDimsPos(); |
| 1769 | auto outerDimsPerm = packOp.getOuterDimsPerm(); |
| 1770 | if (!outerDimsPerm.empty()) |
| 1771 | applyPermutationToVector(inVec&: inputShape, |
| 1772 | permutation: invertPermutationVector(permutation: outerDimsPerm)); |
| 1773 | for (auto [idx, size] : enumerate(First&: innerTiles)) |
| 1774 | inputShape[innerDimsPos[idx]] *= size; |
| 1775 | auto maskedRead = vector::createReadOrMaskedRead( |
| 1776 | builder&: rewriter, loc, source: packOp.getSource(), inputVectorSizes: inputShape, padValue, |
| 1777 | useInBoundsInsteadOfMasking); |
| 1778 | |
| 1779 | // Create ShapeCastOp. |
| 1780 | SmallVector<int64_t> destShape(inputVectorSizes); |
| 1781 | destShape.append(in_start: innerTiles.begin(), in_end: innerTiles.end()); |
| 1782 | auto tiledPackType = VectorType::get(shape: getTiledPackShape(packOp, destShape), |
| 1783 | elementType: packOp.getDestType().getElementType()); |
| 1784 | auto shapeCastOp = |
| 1785 | rewriter.create<vector::ShapeCastOp>(location: loc, args&: tiledPackType, args&: maskedRead); |
| 1786 | |
| 1787 | // Create TransposeOp. |
| 1788 | auto destPermutation = |
| 1789 | invertPermutationVector(permutation: getPackInverseDestPerm(packOp)); |
| 1790 | auto transposeOp = rewriter.create<vector::TransposeOp>( |
| 1791 | location: loc, args: shapeCastOp.getResult(), args&: destPermutation); |
| 1792 | |
| 1793 | // Create TransferWriteOp. |
| 1794 | Value dest = rewriter.create<tensor::EmptyOp>( |
| 1795 | location: loc, args&: reifiedReturnShapes[0], |
| 1796 | args: transposeOp.getResult().getType().getElementType()); |
| 1797 | Operation *write = |
| 1798 | createWriteOrMaskedWrite(builder&: rewriter, loc, vecToStore: transposeOp.getResult(), dest); |
| 1799 | newResults.push_back(Elt: write->getResult(idx: 0)); |
| 1800 | return success(); |
| 1801 | } |
| 1802 | |
| 1803 | /// Vectorize a `linalg::UnPackOp` to these 4 Ops: |
| 1804 | /// Vector::TransferReadOp - Reads a vector from the source tensor |
| 1805 | /// vector::TransposeOp - Transpose the Source tensor |
| 1806 | /// ShapeCastOp - Reshape the data based on the target. |
| 1807 | /// vector::TransferWriteOp. - Write the result vector back to the destination |
| 1808 | /// tensor. |
| 1809 | /// If the vector sizes are not provided: |
| 1810 | /// * the vector sizes are determined by the input operand and attributes, |
| 1811 | /// * update the inBounds attribute instead of masking. |
| 1812 | static LogicalResult |
| 1813 | vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, |
| 1814 | ArrayRef<int64_t> inputVectorSizes, |
| 1815 | SmallVectorImpl<Value> &newResults) { |
| 1816 | |
| 1817 | // TODO: Introduce a parent class that will handle the insertion point update. |
| 1818 | OpBuilder::InsertionGuard g(rewriter); |
| 1819 | rewriter.setInsertionPoint(unpackOp); |
| 1820 | |
| 1821 | RankedTensorType unpackTensorType = unpackOp.getSourceType(); |
| 1822 | |
| 1823 | ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos(); |
| 1824 | ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles(); |
| 1825 | ArrayRef<int64_t> sourceShape = unpackTensorType.getShape(); |
| 1826 | bool useInBoundsInsteadOfMasking = false; |
| 1827 | ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); |
| 1828 | |
| 1829 | auto destSize = unpackOp.getDestRank(); |
| 1830 | |
| 1831 | if (!inputVectorSizes.empty()) |
| 1832 | assert(inputVectorSizes.size() == destSize && |
| 1833 | "Incorrect number of input vector sizes" ); |
| 1834 | |
| 1835 | // vectorSizes is the shape of the vector that will be used to do final |
| 1836 | // write on the destination tensor. It is set like this: Let's say the |
| 1837 | // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M. |
| 1838 | // Thus: |
| 1839 | // 1. vectorSizes = sourceShape.take_front(N) |
| 1840 | // 2. if outer_dims_perms is present: do that permutation on vectorSizes. |
| 1841 | // 3. multiply all the locations in vectorSize pointed by innerDimPos by the |
| 1842 | // innerTiles attribute value. |
| 1843 | SmallVector<int64_t> vectorSizes(inputVectorSizes); |
| 1844 | if (vectorSizes.empty()) { |
| 1845 | llvm::append_range(C&: vectorSizes, R: sourceShape.take_front(N: destSize)); |
| 1846 | if (!outerDimsPerm.empty()) |
| 1847 | applyPermutationToVector(inVec&: vectorSizes, permutation: outerDimsPerm); |
| 1848 | for (auto [i, pos] : llvm::enumerate(First&: innerDimPos)) |
| 1849 | vectorSizes[pos] *= innerTiles[i]; |
| 1850 | |
| 1851 | useInBoundsInsteadOfMasking = true; |
| 1852 | } |
| 1853 | |
| 1854 | // readVectorSizes is the size of tensor used to read and apply mask. It is |
| 1855 | // set like this: Let's say the vectorSize (VS) array is size 'N' and |
| 1856 | // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of |
| 1857 | // size M-N |
| 1858 | // Thus: |
| 1859 | // - initially: readVectorSizes = vectorInputSizes |
| 1860 | // - Divide all the readMaskShape locations pointed by innerDimPos |
| 1861 | // by the innerTileSize attribute value. |
| 1862 | // - if outer_dims_perms is present: do that permutation on readVectorSizes. |
| 1863 | // - Append the remaining shape from SS |
| 1864 | // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16> |
| 1865 | // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512, |
| 1866 | // 128] and outer_dims_perm is [1, 0] then read shape is: |
| 1867 | // ReadVectorSizes(initial): [512, 128] |
| 1868 | // Final Value(after innerDim Adjustment): [512/32, 128/16] |
| 1869 | // = [16, 8] |
| 1870 | // After applying outer_dims_perm: [8, 16] |
| 1871 | // After appending the rest of the sourceShape: [8, 16, 32, 16] |
| 1872 | |
| 1873 | SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end()); |
| 1874 | |
| 1875 | for (auto [index, size] : enumerate(First&: innerTiles)) { |
| 1876 | readVectorSizes[innerDimPos[index]] = |
| 1877 | llvm::divideCeil(Numerator: readVectorSizes[innerDimPos[index]], Denominator: size); |
| 1878 | } |
| 1879 | if (!outerDimsPerm.empty()) { |
| 1880 | applyPermutationToVector(inVec&: readVectorSizes, permutation: outerDimsPerm); |
| 1881 | } |
| 1882 | readVectorSizes.append(in_start: sourceShape.begin() + vectorSizes.size(), |
| 1883 | in_end: sourceShape.end()); |
| 1884 | |
| 1885 | ReifiedRankedShapedTypeDims reifiedRetShapes; |
| 1886 | LogicalResult status = |
| 1887 | cast<ReifyRankedShapedTypeOpInterface>(Val: unpackOp.getOperation()) |
| 1888 | .reifyResultShapes(builder&: rewriter, reifiedReturnShapes&: reifiedRetShapes); |
| 1889 | if (status.failed()) { |
| 1890 | LDBG("Unable to reify result shapes of " << unpackOp); |
| 1891 | return failure(); |
| 1892 | } |
| 1893 | Location loc = unpackOp->getLoc(); |
| 1894 | |
| 1895 | auto padValue = rewriter.create<arith::ConstantOp>( |
| 1896 | location: loc, args: rewriter.getZeroAttr(type: unpackOp.getSourceType().getElementType())); |
| 1897 | |
| 1898 | // Read result, mask if necessary. If transferReadOp shape is not equal |
| 1899 | // to shape of source, then a mask is necessary. |
| 1900 | Value readResult = vector::createReadOrMaskedRead( |
| 1901 | builder&: rewriter, loc, source: unpackOp.getSource(), inputVectorSizes: readVectorSizes, padValue, |
| 1902 | /*useInBoundsInsteadOfMasking=*/false); |
| 1903 | |
| 1904 | PackingMetadata packMetadata; |
| 1905 | SmallVector<int64_t> lastDimToInsertPosPerm = |
| 1906 | getUnPackInverseSrcPerm(unpackOp, metadata&: packMetadata); |
| 1907 | ShapedType maskedOpShapedType = cast<ShapedType>(Val: readResult.getType()); |
| 1908 | SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape()); |
| 1909 | mlir::Type stripMineElemType = maskedOpShapedType.getElementType(); |
| 1910 | applyPermutationToVector(inVec&: stripMineShape, permutation: lastDimToInsertPosPerm); |
| 1911 | RankedTensorType stripMineTensorType = |
| 1912 | RankedTensorType::get(shape: stripMineShape, elementType: stripMineElemType); |
| 1913 | // Transpose the appropriate rows to match output. |
| 1914 | vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>( |
| 1915 | location: loc, args&: readResult, args&: lastDimToInsertPosPerm); |
| 1916 | |
| 1917 | // Collapse the vector to the size required by result. |
| 1918 | RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( |
| 1919 | type: stripMineTensorType, reassociation: packMetadata.reassociations); |
| 1920 | mlir::VectorType vecCollapsedType = |
| 1921 | VectorType::get(shape: collapsedType.getShape(), elementType: collapsedType.getElementType()); |
| 1922 | vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>( |
| 1923 | location: loc, args&: vecCollapsedType, args: transposeOp->getResult(idx: 0)); |
| 1924 | |
| 1925 | // writeVectorSizes had to match the shapecast shape for dynamic sizes, |
| 1926 | // otherwise the validator complains that the mask size is invalid. |
| 1927 | SmallVector<int64_t> writeVectorSizes( |
| 1928 | unpackOp.getDestType().hasStaticShape() |
| 1929 | ? vectorSizes |
| 1930 | : shapeCastOp.getResultVectorType().getShape()); |
| 1931 | Value dest = rewriter.create<tensor::EmptyOp>( |
| 1932 | location: loc, args&: reifiedRetShapes[0], |
| 1933 | args: shapeCastOp.getResult().getType().getElementType()); |
| 1934 | Operation *write = createWriteOrMaskedWrite( |
| 1935 | builder&: rewriter, loc, vecToStore: shapeCastOp.getResult(), dest, |
| 1936 | /*writeIndices=*/{}, useInBoundsInsteadOfMasking); |
| 1937 | newResults.push_back(Elt: write->getResult(idx: 0)); |
| 1938 | return success(); |
| 1939 | } |
| 1940 | |
| 1941 | /// Vectorize a `padOp` with (1) static result type, (2) constant padding value |
| 1942 | /// and (3) all-zero lowPad to |
| 1943 | /// `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`. |
| 1944 | static LogicalResult |
| 1945 | vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, |
| 1946 | ArrayRef<int64_t> inputVectorSizes, |
| 1947 | SmallVectorImpl<Value> &newResults) { |
| 1948 | auto padValue = padOp.getConstantPaddingValue(); |
| 1949 | Location loc = padOp.getLoc(); |
| 1950 | |
| 1951 | // TODO: Introduce a parent class that will handle the insertion point update. |
| 1952 | OpBuilder::InsertionGuard g(rewriter); |
| 1953 | rewriter.setInsertionPoint(padOp); |
| 1954 | |
| 1955 | ReifiedRankedShapedTypeDims reifiedReturnShapes; |
| 1956 | LogicalResult status = |
| 1957 | cast<ReifyRankedShapedTypeOpInterface>(Val: padOp.getOperation()) |
| 1958 | .reifyResultShapes(builder&: rewriter, reifiedReturnShapes); |
| 1959 | (void)status; // prevent unused variable warning on non-assert builds |
| 1960 | assert(succeeded(status) && "failed to reify result shapes" ); |
| 1961 | auto maskedRead = vector::createReadOrMaskedRead( |
| 1962 | builder&: rewriter, loc, source: padOp.getSource(), inputVectorSizes, padValue, |
| 1963 | /*useInBoundsInsteadOfMasking=*/false); |
| 1964 | |
| 1965 | // Create Xfer write Op |
| 1966 | Value dest = rewriter.create<tensor::EmptyOp>( |
| 1967 | location: loc, args&: reifiedReturnShapes[0], args: padOp.getResultType().getElementType()); |
| 1968 | Operation *write = createWriteOrMaskedWrite(builder&: rewriter, loc, vecToStore: maskedRead, dest); |
| 1969 | newResults.push_back(Elt: write->getResult(idx: 0)); |
| 1970 | return success(); |
| 1971 | } |
| 1972 | |
| 1973 | // TODO: probably need some extra checks for reduction followed by consumer |
| 1974 | // ops that may not commute (e.g. linear reduction + non-linear instructions). |
| 1975 | static LogicalResult reductionPreconditions(LinalgOp op) { |
| 1976 | if (llvm::none_of(Range: op.getIteratorTypesArray(), P: isReductionIterator)) { |
| 1977 | LDBG("reduction precondition failed: no reduction iterator\n" ); |
| 1978 | return failure(); |
| 1979 | } |
| 1980 | for (OpOperand &opOperand : op.getDpsInitsMutable()) { |
| 1981 | AffineMap indexingMap = op.getMatchingIndexingMap(opOperand: &opOperand); |
| 1982 | if (indexingMap.isPermutation()) |
| 1983 | continue; |
| 1984 | |
| 1985 | Operation *reduceOp = matchLinalgReduction(outputOperand: &opOperand); |
| 1986 | if (!reduceOp || !getCombinerOpKind(combinerOp: reduceOp)) { |
| 1987 | LDBG("reduction precondition failed: reduction detection failed\n" ); |
| 1988 | return failure(); |
| 1989 | } |
| 1990 | } |
| 1991 | return success(); |
| 1992 | } |
| 1993 | |
| 1994 | static LogicalResult |
| 1995 | vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, |
| 1996 | bool flatten1DDepthwiseConv) { |
| 1997 | if (flatten1DDepthwiseConv) { |
| 1998 | LDBG("Vectorization of flattened convs with dynamic shapes is not " |
| 1999 | "supported\n" ); |
| 2000 | return failure(); |
| 2001 | } |
| 2002 | |
| 2003 | if (!isa<linalg::DepthwiseConv1DNwcWcOp>(Val: conv)) { |
| 2004 | LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n" ); |
| 2005 | return failure(); |
| 2006 | } |
| 2007 | |
| 2008 | // Support dynamic shapes in 1D depthwise convolution, but only in the |
| 2009 | // _channel_ dimension. |
| 2010 | Value lhs = conv.getDpsInputOperand(i: 0)->get(); |
| 2011 | ArrayRef<int64_t> lhsShape = cast<ShapedType>(Val: lhs.getType()).getShape(); |
| 2012 | auto shapeWithoutCh = lhsShape.drop_back(N: 1); |
| 2013 | if (ShapedType::isDynamicShape(dSizes: shapeWithoutCh)) { |
| 2014 | LDBG("Dynamically-shaped op vectorization precondition failed: only " |
| 2015 | "channel dim can be dynamic\n" ); |
| 2016 | return failure(); |
| 2017 | } |
| 2018 | |
| 2019 | return success(); |
| 2020 | } |
| 2021 | |
| 2022 | static LogicalResult |
| 2023 | vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, |
| 2024 | bool flatten1DDepthwiseConv) { |
| 2025 | if (isa<ConvolutionOpInterface>(Val: op.getOperation())) |
| 2026 | return vectorizeDynamicConvOpPrecondition(conv: op, flatten1DDepthwiseConv); |
| 2027 | |
| 2028 | if (hasReductionIterator(op)) |
| 2029 | return reductionPreconditions(op); |
| 2030 | |
| 2031 | // TODO: Masking only supports dynamic element-wise ops, linalg.generic ops, |
| 2032 | // linalg.copy ops and ops that implement ContractionOpInterface for now. |
| 2033 | if (!isElementwise(op) && |
| 2034 | !isa<linalg::GenericOp, linalg::CopyOp, linalg::ContractionOpInterface>( |
| 2035 | Val: op.getOperation())) |
| 2036 | return failure(); |
| 2037 | |
| 2038 | LDBG("Dynamically-shaped op meets vectorization pre-conditions\n" ); |
| 2039 | return success(); |
| 2040 | } |
| 2041 | |
| 2042 | /// Need to check if the inner-tiles are static/constant. |
| 2043 | static LogicalResult |
| 2044 | vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, |
| 2045 | ArrayRef<int64_t> inputVectorSizes) { |
| 2046 | |
| 2047 | if (llvm::any_of(Range: unpackOp.getInnerTiles(), P: [](OpFoldResult res) { |
| 2048 | return !getConstantIntValue(ofr: res).has_value(); |
| 2049 | })) { |
| 2050 | LDBG("Inner-tiles must be constant: " << unpackOp << "\n" ); |
| 2051 | return failure(); |
| 2052 | } |
| 2053 | ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape(); |
| 2054 | bool satisfyEmptyCond = inputVectorSizes.empty() && |
| 2055 | unpackOp.getDestType().hasStaticShape() && |
| 2056 | unpackOp.getSourceType().hasStaticShape(); |
| 2057 | if (!satisfyEmptyCond && |
| 2058 | failed(Result: vector::isValidMaskedInputVector(shape: resultShape, inputVectorSizes))) |
| 2059 | return failure(); |
| 2060 | |
| 2061 | return success(); |
| 2062 | } |
| 2063 | |
| 2064 | static LogicalResult |
| 2065 | vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, |
| 2066 | ArrayRef<int64_t> inputVectorSizes) { |
| 2067 | |
| 2068 | TypedValue<RankedTensorType> source = sliceOp.getSource(); |
| 2069 | auto sourceType = source.getType(); |
| 2070 | if (!VectorType::isValidElementType(t: sourceType.getElementType())) |
| 2071 | return failure(); |
| 2072 | |
| 2073 | // Get the pad value. |
| 2074 | // TransferReadOp (which is used to vectorize InsertSliceOp), requires a |
| 2075 | // scalar padding value. Note that: |
| 2076 | // * for in-bounds accesses, |
| 2077 | // the value is actually irrelevant. There are 2 cases in which xfer.read |
| 2078 | // accesses are known to be in-bounds: |
| 2079 | // 1. The source shape is static (output vector sizes would be based on |
| 2080 | // the source shape and hence all memory accesses would be in-bounds), |
| 2081 | // 2. Masking is used, i.e. the output vector sizes are user-provided. In |
| 2082 | // this case it is safe to assume that all memory accesses are in-bounds. |
| 2083 | // |
| 2084 | // When the value is not known and not needed, use 0. Otherwise, bail out. |
| 2085 | Value padValue = getStaticPadVal(op: sliceOp); |
| 2086 | bool isOutOfBoundsRead = |
| 2087 | !sourceType.hasStaticShape() && inputVectorSizes.empty(); |
| 2088 | |
| 2089 | if (!padValue && isOutOfBoundsRead) { |
| 2090 | LDBG("Failed to get a pad value for out-of-bounds read access\n" ); |
| 2091 | return failure(); |
| 2092 | } |
| 2093 | return success(); |
| 2094 | } |
| 2095 | |
| 2096 | namespace { |
| 2097 | enum class ConvOperationKind { Conv, Pool }; |
| 2098 | } // namespace |
| 2099 | |
| 2100 | static bool isCastOfBlockArgument(Operation *op) { |
| 2101 | return isa<CastOpInterface>(Val: op) && op->getNumOperands() == 1 && |
| 2102 | isa<BlockArgument>(Val: op->getOperand(idx: 0)); |
| 2103 | } |
| 2104 | |
| 2105 | // Returns the ConvOperationKind of the op using reduceOp of the generic |
| 2106 | // payload. If it is neither a convolution nor a pooling, it returns |
| 2107 | // std::nullopt. |
| 2108 | // |
| 2109 | // If (region has 2 ops (reduction + yield) or 3 ops (extension + reduction |
| 2110 | // + yield) and rhs is not used) then it is the body of a pooling |
| 2111 | // If conv, check for single `mul` predecessor. The `mul` operands must be |
| 2112 | // block arguments or extension of block arguments. |
| 2113 | // Otherwise, check for one or zero `ext` predecessor. The `ext` operands |
| 2114 | // must be block arguments or extension of block arguments. |
| 2115 | static std::optional<ConvOperationKind> |
| 2116 | getConvOperationKind(Operation *reduceOp) { |
| 2117 | int numBlockArguments = |
| 2118 | llvm::count_if(Range: reduceOp->getOperands(), P: llvm::IsaPred<BlockArgument>); |
| 2119 | |
| 2120 | switch (numBlockArguments) { |
| 2121 | case 1: { |
| 2122 | // Will be convolution if feeder is a MulOp. |
| 2123 | // A strength reduced version of MulOp for i1 type is AndOp which is also |
| 2124 | // supported. Otherwise, it can be pooling. This strength reduction logic |
| 2125 | // is in `buildBinaryFn` helper in the Linalg dialect. |
| 2126 | auto feedValIt = llvm::find_if_not(Range: reduceOp->getOperands(), |
| 2127 | P: llvm::IsaPred<BlockArgument>); |
| 2128 | assert(feedValIt != reduceOp->operand_end() && |
| 2129 | "Expected a non-block argument operand" ); |
| 2130 | Operation *feedOp = (*feedValIt).getDefiningOp(); |
| 2131 | if (isCastOfBlockArgument(op: feedOp)) { |
| 2132 | return ConvOperationKind::Pool; |
| 2133 | } |
| 2134 | |
| 2135 | if (!((isa<arith::MulIOp, arith::MulFOp>(Val: feedOp) || |
| 2136 | (isa<arith::AndIOp>(Val: feedOp) && |
| 2137 | feedOp->getResultTypes()[0].isInteger(width: 1))) && |
| 2138 | llvm::all_of(Range: feedOp->getOperands(), P: [](Value v) { |
| 2139 | if (isa<BlockArgument>(Val: v)) |
| 2140 | return true; |
| 2141 | if (Operation *op = v.getDefiningOp()) |
| 2142 | return isCastOfBlockArgument(op); |
| 2143 | return false; |
| 2144 | }))) { |
| 2145 | return std::nullopt; |
| 2146 | } |
| 2147 | |
| 2148 | return ConvOperationKind::Conv; |
| 2149 | } |
| 2150 | case 2: |
| 2151 | // Must be pooling |
| 2152 | return ConvOperationKind::Pool; |
| 2153 | default: |
| 2154 | return std::nullopt; |
| 2155 | } |
| 2156 | } |
| 2157 | |
| 2158 | static bool isSupportedPoolKind(vector::CombiningKind kind) { |
| 2159 | switch (kind) { |
| 2160 | case vector::CombiningKind::ADD: |
| 2161 | case vector::CombiningKind::MAXNUMF: |
| 2162 | case vector::CombiningKind::MAXIMUMF: |
| 2163 | case vector::CombiningKind::MAXSI: |
| 2164 | case vector::CombiningKind::MAXUI: |
| 2165 | case vector::CombiningKind::MINNUMF: |
| 2166 | case vector::CombiningKind::MINIMUMF: |
| 2167 | case vector::CombiningKind::MINSI: |
| 2168 | case vector::CombiningKind::MINUI: |
| 2169 | return true; |
| 2170 | default: |
| 2171 | return false; |
| 2172 | } |
| 2173 | } |
| 2174 | |
| 2175 | static LogicalResult vectorizeConvOpPrecondition(linalg::LinalgOp convOp) { |
| 2176 | auto getOperandType = [&](auto operand) { |
| 2177 | return dyn_cast<ShapedType>((operand->get()).getType()); |
| 2178 | }; |
| 2179 | ShapedType lhsShapedType = getOperandType(convOp.getDpsInputOperand(i: 0)); |
| 2180 | ShapedType rhsShapedType = getOperandType(convOp.getDpsInputOperand(i: 1)); |
| 2181 | ShapedType resShapedType = getOperandType(convOp.getDpsInitOperand(i: 0)); |
| 2182 | // (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR |
| 2183 | // (non-channeled convolution -> LHS and RHS both have single dimensions). |
| 2184 | // Note that this also ensures 2D and 3D convolutions are rejected. |
| 2185 | if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) && |
| 2186 | (lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1)) |
| 2187 | return failure(); |
| 2188 | |
| 2189 | Operation *reduceOp = matchLinalgReduction(outputOperand: convOp.getDpsInitOperand(i: 0)); |
| 2190 | if (!reduceOp) |
| 2191 | return failure(); |
| 2192 | |
| 2193 | auto maybeOper = getConvOperationKind(reduceOp); |
| 2194 | if (!maybeOper.has_value()) |
| 2195 | return failure(); |
| 2196 | |
| 2197 | auto maybeKind = getCombinerOpKind(combinerOp: reduceOp); |
| 2198 | // Typically convolution will have a `Add` CombiningKind but for i1 type it |
| 2199 | // can get strength reduced to `OR` which is also supported. This strength |
| 2200 | // reduction logic is in `buildBinaryFn` helper in the Linalg dialect. |
| 2201 | if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD && |
| 2202 | *maybeKind != vector::CombiningKind::OR) && |
| 2203 | (*maybeOper != ConvOperationKind::Pool || |
| 2204 | !isSupportedPoolKind(kind: *maybeKind)))) { |
| 2205 | return failure(); |
| 2206 | } |
| 2207 | |
| 2208 | auto rhsRank = rhsShapedType.getRank(); |
| 2209 | if (*maybeOper == ConvOperationKind::Pool) { |
| 2210 | if (rhsRank != 1) |
| 2211 | return failure(); |
| 2212 | } else { |
| 2213 | if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3) |
| 2214 | return failure(); |
| 2215 | } |
| 2216 | |
| 2217 | return success(); |
| 2218 | } |
| 2219 | |
| 2220 | static LogicalResult vectorizeLinalgOpPrecondition( |
| 2221 | LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes, |
| 2222 | bool , bool flatten1DDepthwiseConv) { |
| 2223 | // tensor with dimension of 0 cannot be vectorized. |
| 2224 | if (llvm::any_of(Range: linalgOp->getOpOperands(), P: [&](OpOperand &operand) { |
| 2225 | return llvm::is_contained(Range: linalgOp.getShape(opOperand: &operand), Element: 0); |
| 2226 | })) |
| 2227 | return failure(); |
| 2228 | // Check API contract for input vector sizes. |
| 2229 | if (!inputVectorSizes.empty() && |
| 2230 | failed(Result: vector::isValidMaskedInputVector(shape: linalgOp.getStaticLoopRanges(), |
| 2231 | inputVectorSizes))) |
| 2232 | return failure(); |
| 2233 | |
| 2234 | if (linalgOp.hasDynamicShape() && failed(Result: vectorizeDynamicLinalgOpPrecondition( |
| 2235 | op: linalgOp, flatten1DDepthwiseConv))) { |
| 2236 | LDBG("Dynamically-shaped op failed vectorization pre-conditions\n" ); |
| 2237 | return failure(); |
| 2238 | } |
| 2239 | |
| 2240 | SmallVector<CustomVectorizationPrecondition> customPreconditions; |
| 2241 | |
| 2242 | // Register CustomVectorizationPrecondition for extractOp. |
| 2243 | customPreconditions.push_back(Elt: tensorExtractVectorizationPrecondition); |
| 2244 | |
| 2245 | // All types in the body should be a supported element type for VectorType. |
| 2246 | for (Operation &innerOp : linalgOp->getRegion(index: 0).front()) { |
| 2247 | // Check if any custom hook can vectorize the inner op. |
| 2248 | if (llvm::any_of( |
| 2249 | Range&: customPreconditions, |
| 2250 | P: [&](const CustomVectorizationPrecondition &customPrecondition) { |
| 2251 | return succeeded( |
| 2252 | Result: customPrecondition(&innerOp, vectorizeNDExtract)); |
| 2253 | })) { |
| 2254 | continue; |
| 2255 | } |
| 2256 | if (!llvm::all_of(Range: innerOp.getOperandTypes(), |
| 2257 | P: VectorType::isValidElementType)) { |
| 2258 | return failure(); |
| 2259 | } |
| 2260 | if (!llvm::all_of(Range: innerOp.getResultTypes(), |
| 2261 | P: VectorType::isValidElementType)) { |
| 2262 | return failure(); |
| 2263 | } |
| 2264 | } |
| 2265 | if (isElementwise(op: linalgOp)) |
| 2266 | return success(); |
| 2267 | |
| 2268 | // TODO: isaConvolutionOpInterface that can also infer from generic |
| 2269 | // features. But we will still need stride/dilation attributes that will be |
| 2270 | // annoying to reverse-engineer... |
| 2271 | if (isa<ConvolutionOpInterface>(Val: linalgOp.getOperation())) |
| 2272 | return vectorizeConvOpPrecondition(convOp: linalgOp); |
| 2273 | |
| 2274 | // TODO: the common vector shape is equal to the static loop sizes only when |
| 2275 | // all indexing maps are projected permutations. For convs and stencils the |
| 2276 | // logic will need to evolve. |
| 2277 | if (!allIndexingsAreProjectedPermutation(op: linalgOp)) { |
| 2278 | LDBG("precondition failed: not projected permutations\n" ); |
| 2279 | return failure(); |
| 2280 | } |
| 2281 | if (failed(Result: reductionPreconditions(op: linalgOp))) { |
| 2282 | LDBG("precondition failed: reduction preconditions\n" ); |
| 2283 | return failure(); |
| 2284 | } |
| 2285 | return success(); |
| 2286 | } |
| 2287 | |
| 2288 | static LogicalResult |
| 2289 | vectorizePackOpPrecondition(linalg::PackOp packOp, |
| 2290 | ArrayRef<int64_t> inputVectorSizes) { |
| 2291 | auto padValue = packOp.getPaddingValue(); |
| 2292 | Attribute cstAttr; |
| 2293 | if (padValue && !matchPattern(value: padValue, pattern: m_Constant(bind_value: &cstAttr))) { |
| 2294 | LDBG("pad value is not constant: " << packOp << "\n" ); |
| 2295 | return failure(); |
| 2296 | } |
| 2297 | ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); |
| 2298 | bool satisfyEmptyCond = true; |
| 2299 | if (inputVectorSizes.empty()) { |
| 2300 | if (!packOp.getDestType().hasStaticShape() || |
| 2301 | !packOp.getSourceType().hasStaticShape()) |
| 2302 | satisfyEmptyCond = false; |
| 2303 | } |
| 2304 | |
| 2305 | if (!satisfyEmptyCond && |
| 2306 | failed(Result: vector::isValidMaskedInputVector( |
| 2307 | shape: resultTensorShape.take_front(N: packOp.getSourceRank()), |
| 2308 | inputVectorSizes))) |
| 2309 | return failure(); |
| 2310 | |
| 2311 | if (llvm::any_of(Range: packOp.getInnerTiles(), P: [](OpFoldResult v) { |
| 2312 | return !getConstantIntValue(ofr: v).has_value(); |
| 2313 | })) { |
| 2314 | LDBG("inner_tiles must be constant: " << packOp << "\n" ); |
| 2315 | return failure(); |
| 2316 | } |
| 2317 | |
| 2318 | return success(); |
| 2319 | } |
| 2320 | |
| 2321 | static LogicalResult |
| 2322 | vectorizePadOpPrecondition(tensor::PadOp padOp, |
| 2323 | ArrayRef<int64_t> inputVectorSizes) { |
| 2324 | auto padValue = padOp.getConstantPaddingValue(); |
| 2325 | if (!padValue) { |
| 2326 | LDBG("pad value is not constant: " << padOp << "\n" ); |
| 2327 | return failure(); |
| 2328 | } |
| 2329 | |
| 2330 | ArrayRef<int64_t> resultTensorShape = padOp.getResultType().getShape(); |
| 2331 | if (failed(Result: vector::isValidMaskedInputVector(shape: resultTensorShape, |
| 2332 | inputVectorSizes))) |
| 2333 | return failure(); |
| 2334 | |
| 2335 | // Padding with non-zero low pad values is not supported, unless the |
| 2336 | // corresponding result dim is 1 as this would require shifting the results to |
| 2337 | // the right for the low padded dims by the required amount of low padding. |
| 2338 | // However, we do support low padding if the dims being low padded have result |
| 2339 | // sizes of 1. The reason is when we have a low pad on a unit result dim, the |
| 2340 | // input size of that dimension will be dynamically zero (as the sum of the |
| 2341 | // low pad and input dim size has to be one) and hence we will create a zero |
| 2342 | // mask as the lowering logic just makes the mask one for the input dim size - |
| 2343 | // which is zero here. Hence we will load the pad value which is what we want |
| 2344 | // in this case. If the low pad is dynamically zero then the lowering is |
| 2345 | // correct as well as no shifts are necessary. |
| 2346 | if (llvm::any_of(Range: llvm::enumerate(First: padOp.getLow()), P: [&](const auto &en) { |
| 2347 | Value padValue = en.value(); |
| 2348 | unsigned pos = en.index(); |
| 2349 | std::optional<int64_t> pad = getConstantIntValue(ofr: padValue); |
| 2350 | return (!pad.has_value() || pad.value() != 0) && |
| 2351 | resultTensorShape[pos] != 1; |
| 2352 | })) { |
| 2353 | LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n" ); |
| 2354 | return failure(); |
| 2355 | } |
| 2356 | |
| 2357 | return success(); |
| 2358 | } |
| 2359 | |
| 2360 | /// Preconditions for scalable vectors. This is quite restrictive - it models |
| 2361 | /// the fact that in practice we would only make selected dimensions scalable. |
| 2362 | static LogicalResult |
| 2363 | vectorizeScalableVectorPrecondition(Operation *op, |
| 2364 | ArrayRef<int64_t> inputVectorSizes, |
| 2365 | ArrayRef<bool> inputScalableVecDims) { |
| 2366 | assert(inputVectorSizes.size() == inputScalableVecDims.size() && |
| 2367 | "Number of input vector sizes and scalable dims doesn't match" ); |
| 2368 | |
| 2369 | size_t numOfScalableDims = |
| 2370 | llvm::count_if(Range&: inputScalableVecDims, P: [](bool flag) { return flag; }); |
| 2371 | |
| 2372 | if (numOfScalableDims == 0) |
| 2373 | return success(); |
| 2374 | |
| 2375 | auto linalgOp = dyn_cast<LinalgOp>(Val: op); |
| 2376 | |
| 2377 | // Cond 1: There's been no need for scalable vectorisation of |
| 2378 | // non-linalg Ops so far |
| 2379 | if (!linalgOp) |
| 2380 | return failure(); |
| 2381 | |
| 2382 | // Cond 2: There's been no need for more than 2 scalable dims so far |
| 2383 | if (numOfScalableDims > 2) |
| 2384 | return failure(); |
| 2385 | |
| 2386 | // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that |
| 2387 | // it matches one of the supported cases: |
| 2388 | // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim |
| 2389 | // (*). |
| 2390 | // 2. Exactly 2 dims are scalable and those are the _last two adjacent_ |
| 2391 | // parallel dims. |
| 2392 | // 3. Exactly 1 reduction dim is scalable and that's the last (innermost) |
| 2393 | // dim. |
| 2394 | // The 2nd restriction above means that only Matmul-like Ops are supported |
| 2395 | // when 2 dims are scalable, e.g. : |
| 2396 | // * iterators = [parallel, parallel, reduction] |
| 2397 | // * scalable flags = [true, true, false] |
| 2398 | // |
| 2399 | // (*) Non-unit dims get folded away in practice. |
| 2400 | // TODO: Relax these conditions as good motivating examples are identified. |
| 2401 | |
| 2402 | // Find the first scalable flag. |
| 2403 | bool seenNonUnitParallel = false; |
| 2404 | auto iterators = linalgOp.getIteratorTypesArray(); |
| 2405 | SmallVector<bool> scalableFlags(inputScalableVecDims); |
| 2406 | int64_t idx = scalableFlags.size() - 1; |
| 2407 | while (!scalableFlags[idx]) { |
| 2408 | bool isNonUnitDim = (inputVectorSizes[idx] != 1); |
| 2409 | seenNonUnitParallel |= |
| 2410 | (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim); |
| 2411 | |
| 2412 | iterators.pop_back(); |
| 2413 | scalableFlags.pop_back(); |
| 2414 | --idx; |
| 2415 | } |
| 2416 | |
| 2417 | // Analyze the iterator corresponding to the first scalable dim. |
| 2418 | switch (iterators.back()) { |
| 2419 | case utils::IteratorType::reduction: { |
| 2420 | // Check 3. above is met. |
| 2421 | if (iterators.size() != inputVectorSizes.size()) { |
| 2422 | LDBG("Non-trailing reduction dim requested for scalable " |
| 2423 | "vectorization\n" ); |
| 2424 | return failure(); |
| 2425 | } |
| 2426 | if (isa<linalg::MatmulOp>(Val: op) || isa<linalg::MatmulTransposeAOp>(Val: op)) { |
| 2427 | LDBG("Scalable vectorization of the reduction dim in Matmul-like ops " |
| 2428 | "is not supported\n" ); |
| 2429 | return failure(); |
| 2430 | } |
| 2431 | break; |
| 2432 | } |
| 2433 | case utils::IteratorType::parallel: { |
| 2434 | // Check 1. and 2. above are met. |
| 2435 | if (seenNonUnitParallel) { |
| 2436 | LDBG("Inner parallel dim not requested for scalable " |
| 2437 | "vectorization\n" ); |
| 2438 | return failure(); |
| 2439 | } |
| 2440 | break; |
| 2441 | } |
| 2442 | } |
| 2443 | |
| 2444 | // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are |
| 2445 | // supported for which expect the folowing config: |
| 2446 | // * iterators = [parallel, parallel, reduction] |
| 2447 | // * scalable flags = [true, true, false] |
| 2448 | if (numOfScalableDims == 2) { |
| 2449 | // Disallow below case which breaks 3. above: |
| 2450 | // * iterators = [..., parallel, reduction] |
| 2451 | // * scalable flags = [..., true, true] |
| 2452 | if (iterators.back() == utils::IteratorType::reduction) { |
| 2453 | LDBG("Higher dim than the trailing reduction dim requested for scalable " |
| 2454 | "vectorization\n" ); |
| 2455 | return failure(); |
| 2456 | } |
| 2457 | scalableFlags.pop_back(); |
| 2458 | iterators.pop_back(); |
| 2459 | |
| 2460 | if (!scalableFlags.back() || |
| 2461 | (iterators.back() != utils::IteratorType::parallel)) |
| 2462 | return failure(); |
| 2463 | } |
| 2464 | |
| 2465 | // Check to not let go the matmul with extended semantic, through this |
| 2466 | // transform. |
| 2467 | if (linalgOp.hasUserDefinedMaps()) |
| 2468 | return failure(); |
| 2469 | |
| 2470 | // Cond 4: Only the following ops are supported in the |
| 2471 | // presence of scalable vectors |
| 2472 | return success(IsSuccess: isElementwise(op: linalgOp) || isa<linalg::MatmulOp>(Val: op) || |
| 2473 | isa<linalg::MatmulTransposeAOp>(Val: op) || |
| 2474 | isa<linalg::DepthwiseConv1DNwcWcOp>(Val: op) || |
| 2475 | isa<linalg::MatvecOp>(Val: op) || hasReductionIterator(op&: linalgOp)); |
| 2476 | } |
| 2477 | |
| 2478 | LogicalResult mlir::linalg::vectorizeOpPrecondition( |
| 2479 | Operation *op, ArrayRef<int64_t> inputVectorSizes, |
| 2480 | ArrayRef<bool> inputScalableVecDims, bool , |
| 2481 | bool flatten1DDepthwiseConv) { |
| 2482 | |
| 2483 | if (!hasVectorizationImpl(op)) |
| 2484 | return failure(); |
| 2485 | |
| 2486 | if (failed(Result: vectorizeScalableVectorPrecondition(op, inputVectorSizes, |
| 2487 | inputScalableVecDims))) |
| 2488 | return failure(); |
| 2489 | |
| 2490 | return TypeSwitch<Operation *, LogicalResult>(op) |
| 2491 | .Case<linalg::LinalgOp>(caseFn: [&](auto linalgOp) { |
| 2492 | return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes, |
| 2493 | vectorizeNDExtract, |
| 2494 | flatten1DDepthwiseConv); |
| 2495 | }) |
| 2496 | .Case<tensor::PadOp>(caseFn: [&](auto padOp) { |
| 2497 | return vectorizePadOpPrecondition(padOp, inputVectorSizes); |
| 2498 | }) |
| 2499 | .Case<linalg::PackOp>(caseFn: [&](auto packOp) { |
| 2500 | return vectorizePackOpPrecondition(packOp, inputVectorSizes); |
| 2501 | }) |
| 2502 | .Case<linalg::UnPackOp>(caseFn: [&](auto unpackOp) { |
| 2503 | return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes); |
| 2504 | }) |
| 2505 | .Case<tensor::InsertSliceOp>(caseFn: [&](auto sliceOp) { |
| 2506 | return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes); |
| 2507 | }) |
| 2508 | .Default(defaultFn: [](auto) { return failure(); }); |
| 2509 | } |
| 2510 | |
| 2511 | /// Converts affine.apply Ops to arithmetic operations. |
| 2512 | static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) { |
| 2513 | OpBuilder::InsertionGuard g(rewriter); |
| 2514 | auto toReplace = linalgOp.getBlock()->getOps<affine::AffineApplyOp>(); |
| 2515 | |
| 2516 | for (auto op : make_early_inc_range(Range&: toReplace)) { |
| 2517 | rewriter.setInsertionPoint(op); |
| 2518 | auto expanded = affine::expandAffineExpr( |
| 2519 | builder&: rewriter, loc: op->getLoc(), expr: op.getAffineMap().getResult(idx: 0), |
| 2520 | dimValues: op.getOperands().take_front(n: op.getAffineMap().getNumDims()), |
| 2521 | symbolValues: op.getOperands().take_back(n: op.getAffineMap().getNumSymbols())); |
| 2522 | rewriter.replaceOp(op, newValues: expanded); |
| 2523 | } |
| 2524 | } |
| 2525 | |
| 2526 | bool mlir::linalg::hasVectorizationImpl(Operation *op) { |
| 2527 | return isa<linalg::LinalgOp, tensor::PadOp, linalg::PackOp, linalg::UnPackOp, |
| 2528 | tensor::InsertSliceOp>(Val: op); |
| 2529 | } |
| 2530 | |
| 2531 | FailureOr<VectorizationResult> |
| 2532 | mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, |
| 2533 | ArrayRef<int64_t> inputVectorSizes, |
| 2534 | ArrayRef<bool> inputScalableVecDims, |
| 2535 | bool , bool flatten1DDepthwiseConv) { |
| 2536 | LDBG("Attempting to vectorize:\n" << *op << "\n" ); |
| 2537 | LDBG("Input vector sizes: " ); |
| 2538 | LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); |
| 2539 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 2540 | LDBG("Input scalable vector dims: " ); |
| 2541 | LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs())); |
| 2542 | LLVM_DEBUG(llvm::dbgs() << "\n" ); |
| 2543 | |
| 2544 | if (failed(Result: vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims, |
| 2545 | vectorizeNDExtract, |
| 2546 | flatten1DDepthwiseConv))) { |
| 2547 | LDBG("Vectorization pre-conditions failed\n" ); |
| 2548 | return failure(); |
| 2549 | } |
| 2550 | |
| 2551 | // Initialize vectorization state. |
| 2552 | VectorizationState state(rewriter); |
| 2553 | if (auto linalgOp = dyn_cast<linalg::LinalgOp>(Val: op)) { |
| 2554 | if (failed(Result: state.initState(rewriter, linalgOp, inputVectorSizes, |
| 2555 | inputScalableVecDims))) { |
| 2556 | LDBG("Vectorization state couldn't be initialized\n" ); |
| 2557 | return failure(); |
| 2558 | } |
| 2559 | } |
| 2560 | |
| 2561 | SmallVector<Value> results; |
| 2562 | auto vectorizeResult = |
| 2563 | TypeSwitch<Operation *, LogicalResult>(op) |
| 2564 | .Case<linalg::LinalgOp>(caseFn: [&](auto linalgOp) { |
| 2565 | // TODO: isaConvolutionOpInterface that can also infer from |
| 2566 | // generic features. Will require stride/dilation attributes |
| 2567 | // inference. |
| 2568 | if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) { |
| 2569 | FailureOr<Operation *> convOr = vectorizeConvolution( |
| 2570 | rewriter, linalgOp, inputVectorSizes, inputScalableVecDims, |
| 2571 | flatten1DDepthwiseConv); |
| 2572 | if (succeeded(Result: convOr)) { |
| 2573 | llvm::append_range(C&: results, R: (*convOr)->getResults()); |
| 2574 | return success(); |
| 2575 | } |
| 2576 | |
| 2577 | LDBG("Unsupported convolution can't be vectorized.\n" ); |
| 2578 | return failure(); |
| 2579 | } |
| 2580 | |
| 2581 | LDBG("Vectorize generic by broadcasting to the canonical vector " |
| 2582 | "shape\n" ); |
| 2583 | |
| 2584 | // Pre-process before proceeding. |
| 2585 | convertAffineApply(rewriter, linalgOp); |
| 2586 | |
| 2587 | // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted |
| 2588 | // to 'OpBuilder' when it is passed over to some methods like |
| 2589 | // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we |
| 2590 | // erase an op within these methods, the actual rewriter won't be |
| 2591 | // notified and we will end up with read-after-free issues! |
| 2592 | return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results); |
| 2593 | }) |
| 2594 | .Case<tensor::PadOp>(caseFn: [&](auto padOp) { |
| 2595 | return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, |
| 2596 | results); |
| 2597 | }) |
| 2598 | .Case<linalg::PackOp>(caseFn: [&](auto packOp) { |
| 2599 | return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, |
| 2600 | results); |
| 2601 | }) |
| 2602 | .Case<linalg::UnPackOp>(caseFn: [&](auto unpackOp) { |
| 2603 | return vectorizeAsTensorUnpackOp(rewriter, unpackOp, |
| 2604 | inputVectorSizes, results); |
| 2605 | }) |
| 2606 | .Case<tensor::InsertSliceOp>(caseFn: [&](auto sliceOp) { |
| 2607 | return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes, |
| 2608 | results); |
| 2609 | }) |
| 2610 | .Default(defaultFn: [](auto) { return failure(); }); |
| 2611 | |
| 2612 | if (failed(Result: vectorizeResult)) { |
| 2613 | LDBG("Vectorization failed\n" ); |
| 2614 | return failure(); |
| 2615 | } |
| 2616 | |
| 2617 | return VectorizationResult{.replacements: results}; |
| 2618 | } |
| 2619 | |
| 2620 | LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, |
| 2621 | memref::CopyOp copyOp) { |
| 2622 | auto srcType = cast<MemRefType>(Val: copyOp.getSource().getType()); |
| 2623 | auto dstType = cast<MemRefType>(Val: copyOp.getTarget().getType()); |
| 2624 | if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) |
| 2625 | return failure(); |
| 2626 | |
| 2627 | auto srcElementType = getElementTypeOrSelf(type: srcType); |
| 2628 | auto dstElementType = getElementTypeOrSelf(type: dstType); |
| 2629 | if (!VectorType::isValidElementType(t: srcElementType) || |
| 2630 | !VectorType::isValidElementType(t: dstElementType)) |
| 2631 | return failure(); |
| 2632 | |
| 2633 | auto readType = VectorType::get(shape: srcType.getShape(), elementType: srcElementType); |
| 2634 | auto writeType = VectorType::get(shape: dstType.getShape(), elementType: dstElementType); |
| 2635 | |
| 2636 | Location loc = copyOp->getLoc(); |
| 2637 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 2638 | SmallVector<Value> indices(srcType.getRank(), zero); |
| 2639 | |
| 2640 | Value readValue = rewriter.create<vector::TransferReadOp>( |
| 2641 | location: loc, args&: readType, args: copyOp.getSource(), args&: indices, |
| 2642 | /*padding=*/args: std::nullopt, |
| 2643 | args: rewriter.getMultiDimIdentityMap(rank: srcType.getRank())); |
| 2644 | if (cast<VectorType>(Val: readValue.getType()).getRank() == 0) { |
| 2645 | readValue = |
| 2646 | rewriter.create<vector::ExtractOp>(location: loc, args&: readValue, args: ArrayRef<int64_t>()); |
| 2647 | readValue = rewriter.create<vector::BroadcastOp>(location: loc, args&: writeType, args&: readValue); |
| 2648 | } |
| 2649 | Operation *writeValue = rewriter.create<vector::TransferWriteOp>( |
| 2650 | location: loc, args&: readValue, args: copyOp.getTarget(), args&: indices, |
| 2651 | args: rewriter.getMultiDimIdentityMap(rank: srcType.getRank())); |
| 2652 | rewriter.replaceOp(op: copyOp, newValues: writeValue->getResults()); |
| 2653 | return success(); |
| 2654 | } |
| 2655 | |
| 2656 | //----------------------------------------------------------------------------// |
| 2657 | // Misc. vectorization patterns. |
| 2658 | //----------------------------------------------------------------------------// |
| 2659 | /// Base pattern for rewriting tensor::PadOps whose result is consumed by a |
| 2660 | /// given operation type OpTy. |
| 2661 | template <typename OpTy> |
| 2662 | struct VectorizePadOpUserPattern : public OpRewritePattern<tensor::PadOp> { |
| 2663 | using OpRewritePattern<tensor::PadOp>::OpRewritePattern; |
| 2664 | |
| 2665 | LogicalResult matchAndRewrite(tensor::PadOp padOp, |
| 2666 | PatternRewriter &rewriter) const final { |
| 2667 | bool changed = false; |
| 2668 | // Insert users in vector, because some users may be replaced/removed. |
| 2669 | for (auto *user : llvm::to_vector<4>(Range: padOp->getUsers())) |
| 2670 | if (auto op = dyn_cast<OpTy>(user)) |
| 2671 | changed |= rewriteUser(rewriter, padOp, op).succeeded(); |
| 2672 | return success(IsSuccess: changed); |
| 2673 | } |
| 2674 | |
| 2675 | protected: |
| 2676 | virtual LogicalResult rewriteUser(PatternRewriter &rewriter, |
| 2677 | tensor::PadOp padOp, OpTy op) const = 0; |
| 2678 | }; |
| 2679 | |
| 2680 | /// Rewrite use of tensor::PadOp result in TransferReadOp. E.g.: |
| 2681 | /// ``` |
| 2682 | /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32> |
| 2683 | /// %r = vector.transfer_read %0[%c0, %c0], %cst |
| 2684 | /// {in_bounds = [true, true]} : tensor<17x5xf32>, vector<17x5xf32> |
| 2685 | /// ``` |
| 2686 | /// is rewritten to: |
| 2687 | /// ``` |
| 2688 | /// %r = vector.transfer_read %src[%c0, %c0], %padding |
| 2689 | /// {in_bounds = [true, true]} |
| 2690 | /// : tensor<?x?xf32>, vector<17x5xf32> |
| 2691 | /// ``` |
| 2692 | /// Note: By restricting this pattern to in-bounds TransferReadOps, we can be |
| 2693 | /// sure that the original padding value %cst was never used. |
| 2694 | /// |
| 2695 | /// This rewrite is possible if: |
| 2696 | /// - `xferOp` has no out-of-bounds dims or mask. |
| 2697 | /// - Low padding is static 0. |
| 2698 | /// - Single, scalar padding value. |
| 2699 | struct PadOpVectorizationWithTransferReadPattern |
| 2700 | : public VectorizePadOpUserPattern<vector::TransferReadOp> { |
| 2701 | using VectorizePadOpUserPattern< |
| 2702 | vector::TransferReadOp>::VectorizePadOpUserPattern; |
| 2703 | |
| 2704 | LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, |
| 2705 | vector::TransferReadOp xferOp) const override { |
| 2706 | // Low padding must be static 0. |
| 2707 | if (!padOp.hasZeroLowPad()) |
| 2708 | return failure(); |
| 2709 | // Pad value must be a constant. |
| 2710 | auto padValue = padOp.getConstantPaddingValue(); |
| 2711 | if (!padValue) |
| 2712 | return failure(); |
| 2713 | // Padding value of existing `xferOp` is unused. |
| 2714 | if (xferOp.hasOutOfBoundsDim() || xferOp.getMask()) |
| 2715 | return failure(); |
| 2716 | |
| 2717 | rewriter.modifyOpInPlace(root: xferOp, callable: [&]() { |
| 2718 | SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false); |
| 2719 | xferOp->setAttr(name: xferOp.getInBoundsAttrName(), |
| 2720 | value: rewriter.getBoolArrayAttr(values: inBounds)); |
| 2721 | xferOp.getBaseMutable().assign(value: padOp.getSource()); |
| 2722 | xferOp.getPaddingMutable().assign(value: padValue); |
| 2723 | }); |
| 2724 | |
| 2725 | return success(); |
| 2726 | } |
| 2727 | }; |
| 2728 | |
| 2729 | /// Rewrite use of tensor::PadOp result in TransferWriteOp. |
| 2730 | /// This pattern rewrites TransferWriteOps that write to a padded tensor |
| 2731 | /// value, where the same amount of padding is immediately removed again after |
| 2732 | /// the write. In such cases, the TransferWriteOp can write to the non-padded |
| 2733 | /// tensor value and apply out-of-bounds masking. E.g.: |
| 2734 | /// ``` |
| 2735 | /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1] |
| 2736 | /// : tensor<...> to tensor<?x?xf32> |
| 2737 | /// %1 = tensor.pad %0 ... : tensor<?x?xf32> to tensor<17x5xf32> |
| 2738 | /// %2 = vector.transfer_write %vec, %1[...] |
| 2739 | /// : vector<17x5xf32>, tensor<17x5xf32> |
| 2740 | /// %r = tensor.extract_slice %2[0, 0] [%s0, %s1] [1, 1] |
| 2741 | /// : tensor<17x5xf32> to tensor<?x?xf32> |
| 2742 | /// ``` |
| 2743 | /// is rewritten to: |
| 2744 | /// ``` |
| 2745 | /// %0 = tensor.extract_slice ...[...] [%s0, %s1] [1, 1] |
| 2746 | /// : tensor<...> to tensor<?x?xf32> |
| 2747 | /// %r = vector.transfer_write %vec, %0[...] : vector<17x5xf32>, |
| 2748 | /// tensor<?x?xf32> |
| 2749 | /// ``` |
| 2750 | /// Note: It is important that the ExtractSliceOp %r resizes the result of the |
| 2751 | /// TransferWriteOp to the same size as the input of the TensorPadOp (or an |
| 2752 | /// even smaller size). Otherwise, %r's new (dynamic) dimensions would differ |
| 2753 | /// from %r's old dimensions. |
| 2754 | /// |
| 2755 | /// This rewrite is possible if: |
| 2756 | /// - Low padding is static 0. |
| 2757 | /// - `xferOp` has exactly one use, which is an ExtractSliceOp. This |
| 2758 | /// ExtractSliceOp trims the same amount of padding that was added |
| 2759 | /// beforehand. |
| 2760 | /// - Single, scalar padding value. |
| 2761 | struct PadOpVectorizationWithTransferWritePattern |
| 2762 | : public VectorizePadOpUserPattern<vector::TransferWriteOp> { |
| 2763 | using VectorizePadOpUserPattern< |
| 2764 | vector::TransferWriteOp>::VectorizePadOpUserPattern; |
| 2765 | |
| 2766 | LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, |
| 2767 | vector::TransferWriteOp xferOp) const override { |
| 2768 | // TODO: support 0-d corner case. |
| 2769 | if (xferOp.getTransferRank() == 0) |
| 2770 | return failure(); |
| 2771 | |
| 2772 | // Low padding must be static 0. |
| 2773 | if (!padOp.hasZeroLowPad()) |
| 2774 | return failure(); |
| 2775 | // Pad value must be a constant. |
| 2776 | auto padValue = padOp.getConstantPaddingValue(); |
| 2777 | if (!padValue) |
| 2778 | return failure(); |
| 2779 | // TransferWriteOp result must be directly consumed by an ExtractSliceOp. |
| 2780 | if (!xferOp->hasOneUse()) |
| 2781 | return failure(); |
| 2782 | auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(Val: *xferOp->user_begin()); |
| 2783 | if (!trimPadding) |
| 2784 | return failure(); |
| 2785 | // Only static zero offsets supported when trimming padding. |
| 2786 | if (!trimPadding.hasZeroOffset()) |
| 2787 | return failure(); |
| 2788 | // trimPadding must remove the amount of padding that was added earlier. |
| 2789 | if (!hasSameTensorSize(beforePadding: padOp.getSource(), afterTrimming: trimPadding)) |
| 2790 | return failure(); |
| 2791 | |
| 2792 | // Insert the new TransferWriteOp at position of the old TransferWriteOp. |
| 2793 | rewriter.setInsertionPoint(xferOp); |
| 2794 | |
| 2795 | SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false); |
| 2796 | auto newXferOp = rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
| 2797 | op: xferOp, args: padOp.getSource().getType(), args: xferOp.getVector(), |
| 2798 | args: padOp.getSource(), args: xferOp.getIndices(), args: xferOp.getPermutationMapAttr(), |
| 2799 | args: xferOp.getMask(), args: rewriter.getBoolArrayAttr(values: inBounds)); |
| 2800 | rewriter.replaceOp(op: trimPadding, newValues: newXferOp->getResult(idx: 0)); |
| 2801 | |
| 2802 | return success(); |
| 2803 | } |
| 2804 | |
| 2805 | /// Check if `beforePadding` and `afterTrimming` have the same tensor size, |
| 2806 | /// i.e., same dimensions. |
| 2807 | /// |
| 2808 | /// Dimensions may be static, dynamic or mix of both. In case of dynamic |
| 2809 | /// dimensions, this function tries to infer the (static) tensor size by |
| 2810 | /// looking at the defining op and utilizing op-specific knowledge. |
| 2811 | /// |
| 2812 | /// This is a conservative analysis. In case equal tensor sizes cannot be |
| 2813 | /// proven statically, this analysis returns `false` even though the tensor |
| 2814 | /// sizes may turn out to be equal at runtime. |
| 2815 | bool (Value beforePadding, |
| 2816 | tensor::ExtractSliceOp afterTrimming) const { |
| 2817 | // If the input to tensor::PadOp is a CastOp, try with both CastOp |
| 2818 | // result and CastOp operand. |
| 2819 | if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>()) |
| 2820 | if (hasSameTensorSize(beforePadding: castOp.getSource(), afterTrimming)) |
| 2821 | return true; |
| 2822 | |
| 2823 | auto t1 = dyn_cast<RankedTensorType>(Val: beforePadding.getType()); |
| 2824 | auto t2 = dyn_cast<RankedTensorType>(Val: afterTrimming.getType()); |
| 2825 | // Only RankedTensorType supported. |
| 2826 | if (!t1 || !t2) |
| 2827 | return false; |
| 2828 | // Rank of both values must be the same. |
| 2829 | if (t1.getRank() != t2.getRank()) |
| 2830 | return false; |
| 2831 | |
| 2832 | // All static dimensions must be the same. Mixed cases (e.g., dimension |
| 2833 | // static in `t1` but dynamic in `t2`) are not supported. |
| 2834 | for (unsigned i = 0; i < t1.getRank(); ++i) { |
| 2835 | if (t1.isDynamicDim(idx: i) != t2.isDynamicDim(idx: i)) |
| 2836 | return false; |
| 2837 | if (!t1.isDynamicDim(idx: i) && t1.getDimSize(idx: i) != t2.getDimSize(idx: i)) |
| 2838 | return false; |
| 2839 | } |
| 2840 | |
| 2841 | // Nothing more to check if all dimensions are static. |
| 2842 | if (t1.getNumDynamicDims() == 0) |
| 2843 | return true; |
| 2844 | |
| 2845 | // All dynamic sizes must be the same. The only supported case at the |
| 2846 | // moment is when `beforePadding` is an ExtractSliceOp (or a cast |
| 2847 | // thereof). |
| 2848 | |
| 2849 | // Apart from CastOp, only ExtractSliceOp is supported. |
| 2850 | auto beforeSlice = beforePadding.getDefiningOp<tensor::ExtractSliceOp>(); |
| 2851 | if (!beforeSlice) |
| 2852 | return false; |
| 2853 | |
| 2854 | assert(static_cast<size_t>(t1.getRank()) == |
| 2855 | beforeSlice.getMixedSizes().size()); |
| 2856 | assert(static_cast<size_t>(t2.getRank()) == |
| 2857 | afterTrimming.getMixedSizes().size()); |
| 2858 | |
| 2859 | for (unsigned i = 0; i < t1.getRank(); ++i) { |
| 2860 | // Skip static dimensions. |
| 2861 | if (!t1.isDynamicDim(idx: i)) |
| 2862 | continue; |
| 2863 | auto size1 = beforeSlice.getMixedSizes()[i]; |
| 2864 | auto size2 = afterTrimming.getMixedSizes()[i]; |
| 2865 | |
| 2866 | // Case 1: Same value or same constant int. |
| 2867 | if (isEqualConstantIntOrValue(ofr1: size1, ofr2: size2)) |
| 2868 | continue; |
| 2869 | |
| 2870 | // Other cases: Take a deeper look at defining ops of values. |
| 2871 | auto v1 = llvm::dyn_cast_if_present<Value>(Val&: size1); |
| 2872 | auto v2 = llvm::dyn_cast_if_present<Value>(Val&: size2); |
| 2873 | if (!v1 || !v2) |
| 2874 | return false; |
| 2875 | |
| 2876 | // Case 2: Both values are identical AffineMinOps. (Should not happen if |
| 2877 | // CSE is run.) |
| 2878 | auto minOp1 = v1.getDefiningOp<affine::AffineMinOp>(); |
| 2879 | auto minOp2 = v2.getDefiningOp<affine::AffineMinOp>(); |
| 2880 | if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() && |
| 2881 | minOp1.getOperands() == minOp2.getOperands()) |
| 2882 | continue; |
| 2883 | |
| 2884 | // Add additional cases as needed. |
| 2885 | } |
| 2886 | |
| 2887 | // All tests passed. |
| 2888 | return true; |
| 2889 | } |
| 2890 | }; |
| 2891 | |
| 2892 | /// Returns the effective Pad value for the input op, provided it's a scalar. |
| 2893 | /// |
| 2894 | /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If |
| 2895 | /// this Op performs padding, retrieve the padding value provided that it's |
| 2896 | /// a scalar and static/fixed for all the padded values. Returns an empty value |
| 2897 | /// otherwise. |
| 2898 | /// |
| 2899 | /// TODO: This is used twice (when checking vectorization pre-conditions and |
| 2900 | /// when vectorizing). Cache results instead of re-running. |
| 2901 | static Value getStaticPadVal(Operation *op) { |
| 2902 | if (!op) |
| 2903 | return {}; |
| 2904 | |
| 2905 | // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's |
| 2906 | // being broadcast, provided that it's a scalar. |
| 2907 | if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(Val: op)) { |
| 2908 | auto source = bcast.getSource(); |
| 2909 | if (llvm::dyn_cast<VectorType>(Val: source.getType())) |
| 2910 | return {}; |
| 2911 | |
| 2912 | return source; |
| 2913 | } |
| 2914 | |
| 2915 | // 2. linalg.fill - use the scalar input value that used to fill the output |
| 2916 | // tensor. |
| 2917 | if (auto fill = llvm::dyn_cast<linalg::FillOp>(Val: op)) { |
| 2918 | return fill.getInputs()[0]; |
| 2919 | } |
| 2920 | |
| 2921 | // 3. tensor.generateOp - can't guarantee the value is fixed without |
| 2922 | // analysing, bail out. |
| 2923 | if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(Val: op)) { |
| 2924 | return {}; |
| 2925 | } |
| 2926 | |
| 2927 | // 4. vector.transfer_write - inspect the input vector that's written from. If |
| 2928 | // if contains a single value that has been broadcast (e.g. via |
| 2929 | // vector.broadcast), extract it, fail otherwise. |
| 2930 | if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(Val: op)) |
| 2931 | return getStaticPadVal(op: xferWrite.getVector().getDefiningOp()); |
| 2932 | |
| 2933 | // 5. tensor.insert_slice - inspect the destination tensor. If it's larger |
| 2934 | // than the input tensor, then, provided it's constant, we'll extract the |
| 2935 | // value that was used to generate it (via e.g. linalg.fill), fail otherwise. |
| 2936 | // TODO: Clarify the semantics when the input tensor is larger than the |
| 2937 | // destination. |
| 2938 | if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(Val: op)) |
| 2939 | return getStaticPadVal(op: slice.getDest().getDefiningOp()); |
| 2940 | |
| 2941 | return {}; |
| 2942 | } |
| 2943 | |
| 2944 | static LogicalResult |
| 2945 | vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, |
| 2946 | ArrayRef<int64_t> inputVectorSizes, |
| 2947 | SmallVectorImpl<Value> &newResults) { |
| 2948 | // TODO: Introduce a parent class that will handle the insertion point update. |
| 2949 | OpBuilder::InsertionGuard g(rewriter); |
| 2950 | rewriter.setInsertionPoint(sliceOp); |
| 2951 | |
| 2952 | TypedValue<RankedTensorType> source = sliceOp.getSource(); |
| 2953 | auto sourceType = source.getType(); |
| 2954 | auto resultType = sliceOp.getResultType(); |
| 2955 | |
| 2956 | Value padValue = getStaticPadVal(op: sliceOp); |
| 2957 | |
| 2958 | if (!padValue) { |
| 2959 | auto elemType = sourceType.getElementType(); |
| 2960 | padValue = rewriter.create<arith::ConstantOp>( |
| 2961 | location: sliceOp.getLoc(), args&: elemType, args: rewriter.getZeroAttr(type: elemType)); |
| 2962 | } |
| 2963 | |
| 2964 | // 2. Get the vector shape |
| 2965 | SmallVector<int64_t> vecShape; |
| 2966 | size_t rankDiff = resultType.getRank() - sourceType.getRank(); |
| 2967 | for (int64_t i = 0, end = sourceType.getRank(); i < end; ++i) { |
| 2968 | if (!inputVectorSizes.empty()) { |
| 2969 | vecShape.push_back(Elt: inputVectorSizes[i]); |
| 2970 | } else if (!sourceType.isDynamicDim(idx: i)) { |
| 2971 | vecShape.push_back(Elt: sourceType.getDimSize(idx: i)); |
| 2972 | } else if (!resultType.isDynamicDim(idx: i)) { |
| 2973 | // Source shape is not statically known, but result shape is. |
| 2974 | // Vectorize with size of result shape. This may be larger than the |
| 2975 | // source size. |
| 2976 | // FIXME: Using rankDiff implies that the source tensor is inserted at |
| 2977 | // the end of the destination tensor. However, that's not required. |
| 2978 | vecShape.push_back(Elt: resultType.getDimSize(idx: rankDiff + i)); |
| 2979 | } else { |
| 2980 | // Neither source nor result dim of padOp is static. Cannot vectorize |
| 2981 | // the copy. |
| 2982 | return failure(); |
| 2983 | } |
| 2984 | } |
| 2985 | auto vecType = VectorType::get(shape: vecShape, elementType: sourceType.getElementType()); |
| 2986 | |
| 2987 | // 3. Generate TransferReadOp + TransferWriteOp |
| 2988 | auto loc = sliceOp.getLoc(); |
| 2989 | |
| 2990 | // Create read |
| 2991 | SmallVector<Value> readIndices( |
| 2992 | vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0)); |
| 2993 | Value read = mlir::vector::createReadOrMaskedRead( |
| 2994 | builder&: rewriter, loc, source, inputVectorSizes: vecType.getShape(), padValue, |
| 2995 | /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty()); |
| 2996 | |
| 2997 | // Create write |
| 2998 | auto writeIndices = |
| 2999 | getValueOrCreateConstantIndexOp(b&: rewriter, loc, valueOrAttrVec: sliceOp.getMixedOffsets()); |
| 3000 | Operation *write = |
| 3001 | createWriteOrMaskedWrite(builder&: rewriter, loc, vecToStore: read, dest: sliceOp.getDest(), |
| 3002 | writeIndices, useInBoundsInsteadOfMasking: inputVectorSizes.empty()); |
| 3003 | |
| 3004 | // 4. Finalize |
| 3005 | newResults.push_back(Elt: write->getResult(idx: 0)); |
| 3006 | |
| 3007 | return success(); |
| 3008 | } |
| 3009 | |
| 3010 | /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.: |
| 3011 | /// ``` |
| 3012 | /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32> |
| 3013 | /// %r = tensor.insert_slice %0 |
| 3014 | /// into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1] |
| 3015 | /// : tensor<17x5xf32> into tensor<?x?x17x5xf32> |
| 3016 | /// ``` |
| 3017 | /// is rewritten to: |
| 3018 | /// ``` |
| 3019 | /// %0 = vector.transfer_read %src[%c0, %c0], %padding |
| 3020 | /// : tensor<?x?xf32>, vector<17x5xf32> |
| 3021 | /// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0] |
| 3022 | /// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32> |
| 3023 | /// ``` |
| 3024 | /// |
| 3025 | /// This rewrite is possible if: |
| 3026 | /// - Low padding is static 0. |
| 3027 | /// - `padOp` result shape is static. |
| 3028 | /// - The entire padded tensor is inserted. |
| 3029 | /// (Implies that sizes of `insertOp` are all static.) |
| 3030 | /// - Only unit strides in `insertOp`. |
| 3031 | /// - Single, scalar padding value. |
| 3032 | /// - `padOp` result not used as destination. |
| 3033 | struct PadOpVectorizationWithInsertSlicePattern |
| 3034 | : public VectorizePadOpUserPattern<tensor::InsertSliceOp> { |
| 3035 | using VectorizePadOpUserPattern< |
| 3036 | tensor::InsertSliceOp>::VectorizePadOpUserPattern; |
| 3037 | |
| 3038 | LogicalResult rewriteUser(PatternRewriter &rewriter, tensor::PadOp padOp, |
| 3039 | tensor::InsertSliceOp insertOp) const override { |
| 3040 | // Low padding must be static 0. |
| 3041 | if (!padOp.hasZeroLowPad()) |
| 3042 | return failure(); |
| 3043 | // Only unit stride supported. |
| 3044 | if (!insertOp.hasUnitStride()) |
| 3045 | return failure(); |
| 3046 | // Pad value must be a constant. |
| 3047 | auto padValue = padOp.getConstantPaddingValue(); |
| 3048 | if (!padValue) |
| 3049 | return failure(); |
| 3050 | // Dynamic shapes not supported. |
| 3051 | if (!cast<ShapedType>(Val: padOp.getResult().getType()).hasStaticShape()) |
| 3052 | return failure(); |
| 3053 | // Pad result not used as destination. |
| 3054 | if (insertOp.getDest() == padOp.getResult()) |
| 3055 | return failure(); |
| 3056 | |
| 3057 | auto vecType = VectorType::get(shape: padOp.getType().getShape(), |
| 3058 | elementType: padOp.getType().getElementType()); |
| 3059 | unsigned vecRank = vecType.getRank(); |
| 3060 | unsigned tensorRank = insertOp.getType().getRank(); |
| 3061 | |
| 3062 | // Check if sizes match: Insert the entire tensor into most minor dims. |
| 3063 | // (No permutations allowed.) |
| 3064 | SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1); |
| 3065 | expectedSizes.append(in_start: vecType.getShape().begin(), in_end: vecType.getShape().end()); |
| 3066 | if (!llvm::all_of( |
| 3067 | Range: llvm::zip(t: insertOp.getMixedSizes(), u&: expectedSizes), P: [](auto it) { |
| 3068 | return getConstantIntValue(std::get<0>(it)) == std::get<1>(it); |
| 3069 | })) |
| 3070 | return failure(); |
| 3071 | |
| 3072 | // Insert the TransferReadOp and TransferWriteOp at the position of the |
| 3073 | // InsertSliceOp. |
| 3074 | rewriter.setInsertionPoint(insertOp); |
| 3075 | |
| 3076 | // Generate TransferReadOp: Read entire source tensor and add high |
| 3077 | // padding. |
| 3078 | SmallVector<Value> readIndices( |
| 3079 | vecRank, rewriter.create<arith::ConstantIndexOp>(location: padOp.getLoc(), args: 0)); |
| 3080 | auto read = rewriter.create<vector::TransferReadOp>( |
| 3081 | location: padOp.getLoc(), args&: vecType, args: padOp.getSource(), args&: readIndices, args&: padValue); |
| 3082 | |
| 3083 | // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at |
| 3084 | // specified offsets. Write is fully in-bounds because a InsertSliceOp's |
| 3085 | // source must fit into the destination at the specified offsets. |
| 3086 | auto writeIndices = getValueOrCreateConstantIndexOp( |
| 3087 | b&: rewriter, loc: padOp.getLoc(), valueOrAttrVec: insertOp.getMixedOffsets()); |
| 3088 | SmallVector<bool> inBounds(vecRank, true); |
| 3089 | rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
| 3090 | op: insertOp, args&: read, args: insertOp.getDest(), args&: writeIndices, |
| 3091 | args: ArrayRef<bool>{inBounds}); |
| 3092 | |
| 3093 | return success(); |
| 3094 | } |
| 3095 | }; |
| 3096 | |
| 3097 | void mlir::linalg::populatePadOpVectorizationPatterns( |
| 3098 | RewritePatternSet &patterns, PatternBenefit baseBenefit) { |
| 3099 | patterns.add<PadOpVectorizationWithTransferReadPattern, |
| 3100 | PadOpVectorizationWithTransferWritePattern, |
| 3101 | PadOpVectorizationWithInsertSlicePattern>( |
| 3102 | arg: patterns.getContext(), args: baseBenefit.getBenefit() + 1); |
| 3103 | } |
| 3104 | |
| 3105 | //----------------------------------------------------------------------------// |
| 3106 | // Forwarding patterns |
| 3107 | //----------------------------------------------------------------------------// |
| 3108 | |
| 3109 | /// Check whether there is any interleaved use of any `values` between |
| 3110 | /// `firstOp` and `secondOp`. Conservatively return `true` if any op or value |
| 3111 | /// is in a different block. |
| 3112 | static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, |
| 3113 | ValueRange values) { |
| 3114 | if (firstOp->getBlock() != secondOp->getBlock() || |
| 3115 | !firstOp->isBeforeInBlock(other: secondOp)) { |
| 3116 | LDBG("interleavedUses precondition failed, firstOp: " |
| 3117 | << *firstOp << ", second op: " << *secondOp << "\n" ); |
| 3118 | return true; |
| 3119 | } |
| 3120 | for (auto v : values) { |
| 3121 | for (auto &u : v.getUses()) { |
| 3122 | Operation *owner = u.getOwner(); |
| 3123 | if (owner == firstOp || owner == secondOp) |
| 3124 | continue; |
| 3125 | // TODO: this is too conservative, use dominance info in the future. |
| 3126 | if (owner->getBlock() == firstOp->getBlock() && |
| 3127 | (owner->isBeforeInBlock(other: firstOp) || secondOp->isBeforeInBlock(other: owner))) |
| 3128 | continue; |
| 3129 | LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp |
| 3130 | << ", second op: " << *secondOp << "\n" ); |
| 3131 | return true; |
| 3132 | } |
| 3133 | } |
| 3134 | return false; |
| 3135 | } |
| 3136 | |
| 3137 | /// Return the unique subview use of `v` if it is indeed unique, null |
| 3138 | /// otherwise. |
| 3139 | static memref::SubViewOp getSubViewUseIfUnique(Value v) { |
| 3140 | memref::SubViewOp subViewOp; |
| 3141 | for (auto &u : v.getUses()) { |
| 3142 | if (auto newSubViewOp = dyn_cast<memref::SubViewOp>(Val: u.getOwner())) { |
| 3143 | if (subViewOp) |
| 3144 | return memref::SubViewOp(); |
| 3145 | subViewOp = newSubViewOp; |
| 3146 | } |
| 3147 | } |
| 3148 | return subViewOp; |
| 3149 | } |
| 3150 | |
| 3151 | /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, |
| 3152 | /// when available. |
| 3153 | LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( |
| 3154 | vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { |
| 3155 | |
| 3156 | // TODO: support mask. |
| 3157 | if (xferOp.getMask()) |
| 3158 | return rewriter.notifyMatchFailure(arg&: xferOp, msg: "unsupported mask" ); |
| 3159 | |
| 3160 | // Transfer into `view`. |
| 3161 | Value viewOrAlloc = xferOp.getBase(); |
| 3162 | if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() && |
| 3163 | !viewOrAlloc.getDefiningOp<memref::AllocOp>()) |
| 3164 | return rewriter.notifyMatchFailure(arg&: xferOp, msg: "source not a view or alloc" ); |
| 3165 | |
| 3166 | // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. |
| 3167 | memref::SubViewOp subViewOp = getSubViewUseIfUnique(v: viewOrAlloc); |
| 3168 | if (!subViewOp) |
| 3169 | return rewriter.notifyMatchFailure(arg&: xferOp, msg: "no subview found" ); |
| 3170 | Value subView = subViewOp.getResult(); |
| 3171 | |
| 3172 | // Find the copy into `subView` without interleaved uses. |
| 3173 | memref::CopyOp copyOp; |
| 3174 | for (auto &u : subView.getUses()) { |
| 3175 | if (auto newCopyOp = dyn_cast<memref::CopyOp>(Val: u.getOwner())) { |
| 3176 | assert(isa<MemRefType>(newCopyOp.getTarget().getType())); |
| 3177 | if (newCopyOp.getTarget() != subView) |
| 3178 | continue; |
| 3179 | if (mayExistInterleavedUses(firstOp: newCopyOp, secondOp: xferOp, values: {viewOrAlloc, subView})) |
| 3180 | continue; |
| 3181 | copyOp = newCopyOp; |
| 3182 | break; |
| 3183 | } |
| 3184 | } |
| 3185 | if (!copyOp) |
| 3186 | return rewriter.notifyMatchFailure(arg&: xferOp, msg: "no copy found" ); |
| 3187 | |
| 3188 | // Find the fill into `viewOrAlloc` without interleaved uses before the |
| 3189 | // copy. |
| 3190 | FillOp maybeFillOp; |
| 3191 | for (auto &u : viewOrAlloc.getUses()) { |
| 3192 | if (auto newFillOp = dyn_cast<FillOp>(Val: u.getOwner())) { |
| 3193 | assert(isa<MemRefType>(newFillOp.output().getType())); |
| 3194 | if (newFillOp.output() != viewOrAlloc) |
| 3195 | continue; |
| 3196 | if (mayExistInterleavedUses(firstOp: newFillOp, secondOp: copyOp, values: {viewOrAlloc, subView})) |
| 3197 | continue; |
| 3198 | maybeFillOp = newFillOp; |
| 3199 | break; |
| 3200 | } |
| 3201 | } |
| 3202 | // Ensure padding matches. |
| 3203 | if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value()) |
| 3204 | return rewriter.notifyMatchFailure(arg&: xferOp, |
| 3205 | msg: "padding value does not match fill" ); |
| 3206 | |
| 3207 | // `in` is the subview that memref.copy reads. Replace it. |
| 3208 | Value in = copyOp.getSource(); |
| 3209 | |
| 3210 | // memref.copy + linalg.fill can be used to create a padded local buffer. |
| 3211 | // The `masked` attribute is only valid on this padded buffer. |
| 3212 | // When forwarding to vector.transfer_read, the attribute must be reset |
| 3213 | // conservatively. |
| 3214 | auto vectorType = xferOp.getVectorType(); |
| 3215 | Value res = rewriter.create<vector::TransferReadOp>( |
| 3216 | location: xferOp.getLoc(), args&: vectorType, args&: in, args: xferOp.getIndices(), |
| 3217 | args: xferOp.getPermutationMapAttr(), args: xferOp.getPadding(), args: xferOp.getMask(), |
| 3218 | args: rewriter.getBoolArrayAttr( |
| 3219 | values: SmallVector<bool>(vectorType.getRank(), false))); |
| 3220 | |
| 3221 | if (maybeFillOp) |
| 3222 | rewriter.eraseOp(op: maybeFillOp); |
| 3223 | rewriter.eraseOp(op: copyOp); |
| 3224 | rewriter.replaceOp(op: xferOp, newValues: res); |
| 3225 | |
| 3226 | return success(); |
| 3227 | } |
| 3228 | |
| 3229 | /// TODO: use interfaces, side-effects and aliasing analysis as appropriate, |
| 3230 | /// when available. |
| 3231 | LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( |
| 3232 | vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { |
| 3233 | // TODO: support mask. |
| 3234 | if (xferOp.getMask()) |
| 3235 | return rewriter.notifyMatchFailure(arg&: xferOp, msg: "unsupported mask" ); |
| 3236 | |
| 3237 | // Transfer into `viewOrAlloc`. |
| 3238 | Value viewOrAlloc = xferOp.getBase(); |
| 3239 | if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() && |
| 3240 | !viewOrAlloc.getDefiningOp<memref::AllocOp>()) |
| 3241 | return rewriter.notifyMatchFailure(arg&: xferOp, msg: "source not a view or alloc" ); |
| 3242 | |
| 3243 | // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. |
| 3244 | memref::SubViewOp subViewOp = getSubViewUseIfUnique(v: viewOrAlloc); |
| 3245 | if (!subViewOp) |
| 3246 | return rewriter.notifyMatchFailure(arg&: xferOp, msg: "no subview found" ); |
| 3247 | Value subView = subViewOp.getResult(); |
| 3248 | |
| 3249 | // Find the copy from `subView` without interleaved uses. |
| 3250 | memref::CopyOp copyOp; |
| 3251 | for (auto &u : subViewOp.getResult().getUses()) { |
| 3252 | if (auto newCopyOp = dyn_cast<memref::CopyOp>(Val: u.getOwner())) { |
| 3253 | if (newCopyOp.getSource() != subView) |
| 3254 | continue; |
| 3255 | if (mayExistInterleavedUses(firstOp: xferOp, secondOp: newCopyOp, values: {viewOrAlloc, subView})) |
| 3256 | continue; |
| 3257 | copyOp = newCopyOp; |
| 3258 | break; |
| 3259 | } |
| 3260 | } |
| 3261 | if (!copyOp) |
| 3262 | return rewriter.notifyMatchFailure(arg&: xferOp, msg: "no copy found" ); |
| 3263 | |
| 3264 | // `out` is the subview copied into that we replace. |
| 3265 | assert(isa<MemRefType>(copyOp.getTarget().getType())); |
| 3266 | Value out = copyOp.getTarget(); |
| 3267 | |
| 3268 | // Forward vector.transfer into copy. |
| 3269 | // memref.copy + linalg.fill can be used to create a padded local buffer. |
| 3270 | // The `masked` attribute is only valid on this padded buffer. |
| 3271 | // When forwarding to vector.transfer_write, the attribute must be reset |
| 3272 | // conservatively. |
| 3273 | auto vector = xferOp.getVector(); |
| 3274 | rewriter.create<vector::TransferWriteOp>( |
| 3275 | location: xferOp.getLoc(), args&: vector, args&: out, args: xferOp.getIndices(), |
| 3276 | args: xferOp.getPermutationMapAttr(), args: xferOp.getMask(), |
| 3277 | args: rewriter.getBoolArrayAttr(values: SmallVector<bool>( |
| 3278 | dyn_cast<VectorType>(Val: vector.getType()).getRank(), false))); |
| 3279 | |
| 3280 | rewriter.eraseOp(op: copyOp); |
| 3281 | rewriter.eraseOp(op: xferOp); |
| 3282 | |
| 3283 | return success(); |
| 3284 | } |
| 3285 | |
| 3286 | //===----------------------------------------------------------------------===// |
| 3287 | // Convolution vectorization patterns |
| 3288 | //===----------------------------------------------------------------------===// |
| 3289 | |
| 3290 | template <int N> |
| 3291 | static void bindShapeDims(ShapedType shapedType) {} |
| 3292 | |
| 3293 | template <int N, typename IntTy, typename... IntTy2> |
| 3294 | static void bindShapeDims(ShapedType shapedType, IntTy &val, IntTy2 &...vals) { |
| 3295 | val = shapedType.getShape()[N]; |
| 3296 | bindShapeDims<N + 1, IntTy2 &...>(shapedType, vals...); |
| 3297 | } |
| 3298 | |
| 3299 | /// Bind a pack of int& to the leading dimensions of shapedType.getShape(). |
| 3300 | template <typename... IntTy> |
| 3301 | static void bindShapeDims(ShapedType shapedType, IntTy &...vals) { |
| 3302 | bindShapeDims<0>(shapedType, vals...); |
| 3303 | } |
| 3304 | |
| 3305 | namespace { |
| 3306 | /// Generate a vector implementation for either: |
| 3307 | /// ``` |
| 3308 | /// Op def: ( w, kw ) |
| 3309 | /// Iters: ({Par(), Red()}) |
| 3310 | /// Layout: {{w + kw}, {kw}, {w}} |
| 3311 | /// ``` |
| 3312 | /// kw is unrolled. |
| 3313 | /// |
| 3314 | /// or |
| 3315 | /// |
| 3316 | /// ``` |
| 3317 | /// Op def: ( n, w, c, kw, f ) |
| 3318 | /// Iters: ({Par(), Par(), Par(), Red(), Red()}) |
| 3319 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} |
| 3320 | /// ``` |
| 3321 | /// kw is unrolled, w is unrolled iff dilationW > 1. |
| 3322 | /// |
| 3323 | /// or |
| 3324 | /// |
| 3325 | /// ``` |
| 3326 | /// Op def: ( n, c, w, f, kw ) |
| 3327 | /// Iters: ({Par(), Par(), Par(), Red(), Red()}) |
| 3328 | /// Layout: {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}} |
| 3329 | /// ``` |
| 3330 | /// kw is unrolled, w is unrolled iff dilationW > 1. |
| 3331 | /// |
| 3332 | /// or |
| 3333 | /// |
| 3334 | /// ``` |
| 3335 | /// Op def: ( n, w, c, kw ) |
| 3336 | /// Iters: ({Par(), Par(), Par(), Red()}) |
| 3337 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} |
| 3338 | /// ``` |
| 3339 | /// kw is unrolled, w is unrolled iff dilationW > 1. |
| 3340 | struct Conv1DGenerator |
| 3341 | : public StructuredGenerator<LinalgOp, utils::IteratorType> { |
| 3342 | Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) |
| 3343 | : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) { |
| 3344 | |
| 3345 | lhsShaped = linalgOp.getDpsInputOperand(i: 0)->get(); |
| 3346 | rhsShaped = linalgOp.getDpsInputOperand(i: 1)->get(); |
| 3347 | resShaped = linalgOp.getDpsInitOperand(i: 0)->get(); |
| 3348 | lhsShapedType = dyn_cast<ShapedType>(Val: lhsShaped.getType()); |
| 3349 | rhsShapedType = dyn_cast<ShapedType>(Val: rhsShaped.getType()); |
| 3350 | resShapedType = dyn_cast<ShapedType>(Val: resShaped.getType()); |
| 3351 | |
| 3352 | Operation *reduceOp = matchLinalgReduction(outputOperand: linalgOp.getDpsInitOperand(i: 0)); |
| 3353 | redOp = reduceOp->getName().getIdentifier(); |
| 3354 | |
| 3355 | setConvOperationKind(reduceOp); |
| 3356 | |
| 3357 | auto maybeKind = getCombinerOpKind(combinerOp: reduceOp); |
| 3358 | reductionKind = maybeKind.value(); |
| 3359 | |
| 3360 | // The ConvolutionOpInterface gives us guarantees of existence for |
| 3361 | // strides/dilations. However, we do not need to rely on those, we can |
| 3362 | // simply use them if present, otherwise use the default and let the generic |
| 3363 | // conv. matcher in the ConvGenerator succeed or fail. |
| 3364 | auto strides = linalgOp->getAttrOfType<DenseIntElementsAttr>(name: "strides" ); |
| 3365 | auto dilations = linalgOp->getAttrOfType<DenseIntElementsAttr>(name: "dilations" ); |
| 3366 | strideW = strides ? *strides.getValues<uint64_t>().begin() : 1; |
| 3367 | dilationW = dilations ? *dilations.getValues<uint64_t>().begin() : 1; |
| 3368 | } |
| 3369 | |
| 3370 | /// Generate a vector implementation for: |
| 3371 | /// ``` |
| 3372 | /// Op def: ( w, kw ) |
| 3373 | /// Iters: ({Par(), Red()}) |
| 3374 | /// Layout: {{w + kw}, {kw}, {w}} |
| 3375 | /// ``` |
| 3376 | /// kw is always unrolled. |
| 3377 | /// |
| 3378 | /// or |
| 3379 | /// |
| 3380 | /// ``` |
| 3381 | /// Op def: ( n, w, c, kw, f ) |
| 3382 | /// Iters: ({Par(), Par(), Par(), Red(), Red()}) |
| 3383 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} |
| 3384 | /// ``` |
| 3385 | /// kw is always unrolled. |
| 3386 | /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is |
| 3387 | /// > 1. |
| 3388 | FailureOr<Operation *> conv(Conv1DOpOrder conv1DOpOrder) { |
| 3389 | int64_t nSize, wSize, cSize, kwSize, fSize; |
| 3390 | SmallVector<int64_t, 3> lhsShape, rhsShape, resShape; |
| 3391 | bool isSingleChanneled = (conv1DOpOrder == Conv1DOpOrder::W); |
| 3392 | switch (conv1DOpOrder) { |
| 3393 | case Conv1DOpOrder::W: |
| 3394 | // Initialize unused dimensions |
| 3395 | nSize = fSize = cSize = 0; |
| 3396 | // out{W} |
| 3397 | bindShapeDims(shapedType: resShapedType, vals&: wSize); |
| 3398 | // kernel{kw} |
| 3399 | bindShapeDims(shapedType: rhsShapedType, vals&: kwSize); |
| 3400 | lhsShape = {// iw = ow + kw - 1 |
| 3401 | // (i.e. 16 convolved with 3 -> 14) |
| 3402 | (wSize + kwSize - 1)}; |
| 3403 | rhsShape = {kwSize}; |
| 3404 | resShape = {wSize}; |
| 3405 | break; |
| 3406 | case Conv1DOpOrder::Nwc: |
| 3407 | // out{n, w, f} |
| 3408 | bindShapeDims(shapedType: resShapedType, vals&: nSize, vals&: wSize, vals&: fSize); |
| 3409 | switch (oper) { |
| 3410 | case ConvOperationKind::Conv: |
| 3411 | // kernel{kw, c, f} |
| 3412 | bindShapeDims(shapedType: rhsShapedType, vals&: kwSize, vals&: cSize); |
| 3413 | break; |
| 3414 | case ConvOperationKind::Pool: |
| 3415 | // kernel{kw} |
| 3416 | bindShapeDims(shapedType: rhsShapedType, vals&: kwSize); |
| 3417 | cSize = fSize; |
| 3418 | break; |
| 3419 | } |
| 3420 | lhsShape = {nSize, |
| 3421 | // iw = ow * sw + kw * dw - 1 |
| 3422 | // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) |
| 3423 | // Perform the proper inclusive -> exclusive -> inclusive. |
| 3424 | ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - |
| 3425 | 1, |
| 3426 | cSize}; |
| 3427 | switch (oper) { |
| 3428 | case ConvOperationKind::Conv: |
| 3429 | rhsShape = {kwSize, cSize, fSize}; |
| 3430 | break; |
| 3431 | case ConvOperationKind::Pool: |
| 3432 | rhsShape = {kwSize}; |
| 3433 | break; |
| 3434 | } |
| 3435 | resShape = {nSize, wSize, fSize}; |
| 3436 | break; |
| 3437 | case Conv1DOpOrder::Ncw: |
| 3438 | // out{n, f, w} |
| 3439 | bindShapeDims(shapedType: resShapedType, vals&: nSize, vals&: fSize, vals&: wSize); |
| 3440 | switch (oper) { |
| 3441 | case ConvOperationKind::Conv: |
| 3442 | // kernel{f, c, kw} |
| 3443 | bindShapeDims(shapedType: rhsShapedType, vals&: fSize, vals&: cSize, vals&: kwSize); |
| 3444 | break; |
| 3445 | case ConvOperationKind::Pool: |
| 3446 | // kernel{kw} |
| 3447 | bindShapeDims(shapedType: rhsShapedType, vals&: kwSize); |
| 3448 | cSize = fSize; |
| 3449 | break; |
| 3450 | } |
| 3451 | lhsShape = {nSize, cSize, |
| 3452 | // iw = ow * sw + kw * dw - 1 |
| 3453 | // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) |
| 3454 | // Perform the proper inclusive -> exclusive -> inclusive. |
| 3455 | ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - |
| 3456 | 1}; |
| 3457 | switch (oper) { |
| 3458 | case ConvOperationKind::Conv: |
| 3459 | rhsShape = {fSize, cSize, kwSize}; |
| 3460 | break; |
| 3461 | case ConvOperationKind::Pool: |
| 3462 | rhsShape = {kwSize}; |
| 3463 | break; |
| 3464 | } |
| 3465 | resShape = {nSize, fSize, wSize}; |
| 3466 | break; |
| 3467 | } |
| 3468 | |
| 3469 | vector::TransferWriteOp write; |
| 3470 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 3471 | |
| 3472 | // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. |
| 3473 | // When strideW == 1, we can batch the contiguous loads and avoid |
| 3474 | // unrolling |
| 3475 | int64_t wSizeStep = strideW == 1 ? wSize : 1; |
| 3476 | |
| 3477 | Type lhsEltType = lhsShapedType.getElementType(); |
| 3478 | Type rhsEltType = rhsShapedType.getElementType(); |
| 3479 | Type resEltType = resShapedType.getElementType(); |
| 3480 | auto lhsType = VectorType::get(shape: lhsShape, elementType: lhsEltType); |
| 3481 | auto rhsType = VectorType::get(shape: rhsShape, elementType: rhsEltType); |
| 3482 | auto resType = VectorType::get(shape: resShape, elementType: resEltType); |
| 3483 | // Zero padding with the corresponding dimensions for lhs, rhs and res. |
| 3484 | SmallVector<Value> lhsPadding(lhsShape.size(), zero); |
| 3485 | SmallVector<Value> rhsPadding(rhsShape.size(), zero); |
| 3486 | SmallVector<Value> resPadding(resShape.size(), zero); |
| 3487 | |
| 3488 | // Read the whole lhs, rhs and res in one shot (with zero padding). |
| 3489 | Value lhs = rewriter.create<vector::TransferReadOp>( |
| 3490 | location: loc, args&: lhsType, args&: lhsShaped, args&: lhsPadding, |
| 3491 | /*padding=*/args: arith::getZeroConstant(builder&: rewriter, loc, type: lhsEltType)); |
| 3492 | // This is needed only for Conv. |
| 3493 | Value rhs = nullptr; |
| 3494 | if (oper == ConvOperationKind::Conv) |
| 3495 | rhs = rewriter.create<vector::TransferReadOp>( |
| 3496 | location: loc, args&: rhsType, args&: rhsShaped, args&: rhsPadding, |
| 3497 | /*padding=*/args: arith::getZeroConstant(builder&: rewriter, loc, type: rhsEltType)); |
| 3498 | Value res = rewriter.create<vector::TransferReadOp>( |
| 3499 | location: loc, args&: resType, args&: resShaped, args&: resPadding, |
| 3500 | /*padding=*/args: arith::getZeroConstant(builder&: rewriter, loc, type: resEltType)); |
| 3501 | |
| 3502 | // The base vectorization case for channeled convolution is input: |
| 3503 | // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern |
| 3504 | // vectorization case, we do pre transpose on input, weight, and output. |
| 3505 | switch (conv1DOpOrder) { |
| 3506 | case Conv1DOpOrder::W: |
| 3507 | case Conv1DOpOrder::Nwc: |
| 3508 | // Base case, so no transposes necessary. |
| 3509 | break; |
| 3510 | case Conv1DOpOrder::Ncw: { |
| 3511 | // To match base vectorization case, we pre-transpose current case. |
| 3512 | // ncw -> nwc |
| 3513 | static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1}; |
| 3514 | lhs = rewriter.create<vector::TransposeOp>(location: loc, args&: lhs, args: permLhs); |
| 3515 | // fcw -> wcf |
| 3516 | static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0}; |
| 3517 | |
| 3518 | // This is needed only for Conv. |
| 3519 | if (oper == ConvOperationKind::Conv) |
| 3520 | rhs = rewriter.create<vector::TransposeOp>(location: loc, args&: rhs, args: permRhs); |
| 3521 | // nfw -> nwf |
| 3522 | static constexpr std::array<int64_t, 3> permRes = {0, 2, 1}; |
| 3523 | res = rewriter.create<vector::TransposeOp>(location: loc, args&: res, args: permRes); |
| 3524 | break; |
| 3525 | } |
| 3526 | } |
| 3527 | |
| 3528 | //===------------------------------------------------------------------===// |
| 3529 | // Begin vector-only rewrite part |
| 3530 | //===------------------------------------------------------------------===// |
| 3531 | // Unroll along kw and read slices of lhs and rhs. |
| 3532 | SmallVector<Value> lhsVals, rhsVals, resVals; |
| 3533 | lhsVals = extractConvInputSlices(rewriter, loc, input: lhs, nSize, wSize, cSize, |
| 3534 | kwSize, strideW, dilationW, wSizeStep, |
| 3535 | isSingleChanneled); |
| 3536 | // Do not do for pooling. |
| 3537 | if (oper == ConvOperationKind::Conv) |
| 3538 | rhsVals = extractConvFilterSlices(rewriter, loc, filter: rhs, kwSize); |
| 3539 | resVals = extractConvResultSlices(rewriter, loc, res, nSize, wSize, fSize, |
| 3540 | wSizeStep, isSingleChanneled); |
| 3541 | |
| 3542 | auto linearIndex = [&](int64_t kw, int64_t w) { |
| 3543 | return kw * (wSize / wSizeStep) + w; |
| 3544 | }; |
| 3545 | |
| 3546 | // Compute contraction: O{n, w, f} += I{n, sw * w + dw * kw, c} * F{c, f} |
| 3547 | // or perform outerproduct for non-channeled convolution or perform simple |
| 3548 | // arith operation for pooling |
| 3549 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 3550 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 3551 | switch (oper) { |
| 3552 | case ConvOperationKind::Conv: |
| 3553 | if (isSingleChanneled) { |
| 3554 | resVals[w] = conv1dSliceAsOuterProduct(rewriter, loc, |
| 3555 | lhs: lhsVals[linearIndex(kw, w)], |
| 3556 | rhs: rhsVals[kw], res: resVals[w]); |
| 3557 | } else { |
| 3558 | resVals[w] = conv1dSliceAsContraction(rewriter, loc, |
| 3559 | lhs: lhsVals[linearIndex(kw, w)], |
| 3560 | rhs: rhsVals[kw], res: resVals[w]); |
| 3561 | } |
| 3562 | break; |
| 3563 | case ConvOperationKind::Pool: |
| 3564 | resVals[w] = pool1dSlice(rewriter, loc, lhs: lhsVals[linearIndex(kw, w)], |
| 3565 | res: resVals[w]); |
| 3566 | break; |
| 3567 | } |
| 3568 | } |
| 3569 | } |
| 3570 | |
| 3571 | res = insertConvResultSlices(rewriter, loc, res, wSize, wSizeStep, resVals, |
| 3572 | isSingleChanneled); |
| 3573 | //===------------------------------------------------------------------===// |
| 3574 | // End vector-only rewrite part |
| 3575 | //===------------------------------------------------------------------===// |
| 3576 | |
| 3577 | // The base vectorization case for channeled convolution is output: |
| 3578 | // {n,w,f} To reuse the result from base pattern vectorization case, we |
| 3579 | // post transpose the base case result. |
| 3580 | switch (conv1DOpOrder) { |
| 3581 | case Conv1DOpOrder::W: |
| 3582 | case Conv1DOpOrder::Nwc: |
| 3583 | // Base case, so no transposes necessary. |
| 3584 | break; |
| 3585 | case Conv1DOpOrder::Ncw: { |
| 3586 | // nwf -> nfw |
| 3587 | static constexpr std::array<int64_t, 3> perm = {0, 2, 1}; |
| 3588 | res = rewriter.create<vector::TransposeOp>(location: loc, args&: res, args: perm); |
| 3589 | break; |
| 3590 | } |
| 3591 | } |
| 3592 | |
| 3593 | return rewriter |
| 3594 | .create<vector::TransferWriteOp>(location: loc, args&: res, args&: resShaped, args&: resPadding) |
| 3595 | .getOperation(); |
| 3596 | } |
| 3597 | |
| 3598 | // Take a value and widen to have the same element type as `ty`. |
| 3599 | Value promote(RewriterBase &rewriter, Location loc, Value val, Type ty) { |
| 3600 | const Type srcElementType = getElementTypeOrSelf(type: val.getType()); |
| 3601 | const Type dstElementType = getElementTypeOrSelf(type: ty); |
| 3602 | assert(isa<IntegerType>(dstElementType) || isa<FloatType>(dstElementType)); |
| 3603 | if (srcElementType == dstElementType) |
| 3604 | return val; |
| 3605 | |
| 3606 | const int64_t srcWidth = srcElementType.getIntOrFloatBitWidth(); |
| 3607 | const int64_t dstWidth = dstElementType.getIntOrFloatBitWidth(); |
| 3608 | const Type dstType = |
| 3609 | cast<ShapedType>(Val: val.getType()).cloneWith(shape: std::nullopt, elementType: dstElementType); |
| 3610 | |
| 3611 | if (isa<IntegerType>(Val: srcElementType) && isa<FloatType>(Val: dstElementType)) { |
| 3612 | return rewriter.create<arith::SIToFPOp>(location: loc, args: dstType, args&: val); |
| 3613 | } |
| 3614 | |
| 3615 | if (isa<FloatType>(Val: srcElementType) && isa<FloatType>(Val: dstElementType) && |
| 3616 | srcWidth < dstWidth) |
| 3617 | return rewriter.create<arith::ExtFOp>(location: loc, args: dstType, args&: val); |
| 3618 | |
| 3619 | if (isa<IntegerType>(Val: srcElementType) && isa<IntegerType>(Val: dstElementType) && |
| 3620 | srcWidth < dstWidth) |
| 3621 | return rewriter.create<arith::ExtSIOp>(location: loc, args: dstType, args&: val); |
| 3622 | |
| 3623 | assert(false && "unhandled promotion case" ); |
| 3624 | return nullptr; |
| 3625 | } |
| 3626 | |
| 3627 | // Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f} |
| 3628 | Value conv1dSliceAsContraction(RewriterBase &rewriter, Location loc, |
| 3629 | Value lhs, Value rhs, Value res) { |
| 3630 | vector::IteratorType par = vector::IteratorType::parallel; |
| 3631 | vector::IteratorType red = vector::IteratorType::reduction; |
| 3632 | AffineExpr n, w, f, c; |
| 3633 | bindDims(ctx, exprs&: n, exprs&: w, exprs&: f, exprs&: c); |
| 3634 | lhs = promote(rewriter, loc, val: lhs, ty: res.getType()); |
| 3635 | rhs = promote(rewriter, loc, val: rhs, ty: res.getType()); |
| 3636 | auto contrationOp = rewriter.create<vector::ContractionOp>( |
| 3637 | location: loc, args&: lhs, args&: rhs, args&: res, |
| 3638 | /*indexingMaps=*/args: MapList{{n, w, c}, {c, f}, {n, w, f}}, |
| 3639 | /*iteratorTypes=*/args: ArrayRef<vector::IteratorType>{par, par, par, red}); |
| 3640 | contrationOp.setKind(reductionKind); |
| 3641 | return contrationOp; |
| 3642 | } |
| 3643 | |
| 3644 | // Create an outerproduct: lhs{w} * rhs{1} -> res{w} for single channel |
| 3645 | // convolution. |
| 3646 | Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc, |
| 3647 | Value lhs, Value rhs, Value res) { |
| 3648 | return rewriter.create<vector::OuterProductOp>( |
| 3649 | location: loc, args: res.getType(), args&: lhs, args&: rhs, args&: res, args: vector::CombiningKind::ADD); |
| 3650 | } |
| 3651 | |
| 3652 | // Create a reduction: lhs{n, w, c} -> res{n, w, c} |
| 3653 | Value pool1dSlice(RewriterBase &rewriter, Location loc, Value lhs, |
| 3654 | Value res) { |
| 3655 | if (isPoolExt) |
| 3656 | lhs = rewriter.create(loc, opName: poolExtOp, operands: lhs, types: res.getType())->getResult(idx: 0); |
| 3657 | return rewriter |
| 3658 | .create(loc, opName: redOp, operands: ArrayRef<Value>{lhs, res}, types: res.getType()) |
| 3659 | ->getResult(idx: 0); |
| 3660 | } |
| 3661 | |
| 3662 | /// Generate a vector implementation for: |
| 3663 | /// ``` |
| 3664 | /// Op def: ( n, w, c, kw) |
| 3665 | /// Iters: ({Par(), Par(), Par(), Red()}) |
| 3666 | /// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} |
| 3667 | /// ``` |
| 3668 | /// kw is always unrolled. |
| 3669 | /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is |
| 3670 | /// > 1. |
| 3671 | FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize, |
| 3672 | bool channelDimScalableFlag, |
| 3673 | bool flatten) { |
| 3674 | bool scalableChDim = false; |
| 3675 | bool useMasking = false; |
| 3676 | int64_t nSize, wSize, cSize, kwSize; |
| 3677 | // kernel{kw, c} |
| 3678 | bindShapeDims(shapedType: rhsShapedType, vals&: kwSize, vals&: cSize); |
| 3679 | if (ShapedType::isDynamic(dValue: cSize)) { |
| 3680 | assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0" ); |
| 3681 | cSize = channelDimVecSize; |
| 3682 | // Scalable vectors are only used when both conditions are met: |
| 3683 | // 1. channel dim is dynamic |
| 3684 | // 2. channelDimScalableFlag is set |
| 3685 | scalableChDim = channelDimScalableFlag; |
| 3686 | useMasking = true; |
| 3687 | } |
| 3688 | |
| 3689 | assert(!(useMasking && flatten) && |
| 3690 | "Unsupported flattened conv with dynamic shapes" ); |
| 3691 | |
| 3692 | // out{n, w, c} |
| 3693 | bindShapeDims(shapedType: resShapedType, vals&: nSize, vals&: wSize); |
| 3694 | |
| 3695 | vector::TransferWriteOp write; |
| 3696 | Value zero = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 0); |
| 3697 | |
| 3698 | // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. |
| 3699 | // When strideW == 1, we can batch the contiguous loads and avoid |
| 3700 | // unrolling |
| 3701 | int64_t wSizeStep = strideW == 1 ? wSize : 1; |
| 3702 | |
| 3703 | Type lhsEltType = lhsShapedType.getElementType(); |
| 3704 | Type rhsEltType = rhsShapedType.getElementType(); |
| 3705 | Type resEltType = resShapedType.getElementType(); |
| 3706 | VectorType lhsType = VectorType::get( |
| 3707 | shape: {nSize, |
| 3708 | // iw = ow * sw + kw * dw - 1 |
| 3709 | // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14) |
| 3710 | ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1, |
| 3711 | cSize}, |
| 3712 | elementType: lhsEltType, /*scalableDims=*/{false, false, scalableChDim}); |
| 3713 | VectorType rhsType = |
| 3714 | VectorType::get(shape: {kwSize, cSize}, elementType: rhsEltType, |
| 3715 | /*scalableDims=*/{false, scalableChDim}); |
| 3716 | VectorType resType = |
| 3717 | VectorType::get(shape: {nSize, wSize, cSize}, elementType: resEltType, |
| 3718 | /*scalableDims=*/{false, false, scalableChDim}); |
| 3719 | |
| 3720 | // Masks the input xfer Op along the channel dim, iff the corresponding |
| 3721 | // scalable flag is set. |
| 3722 | auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape, |
| 3723 | ArrayRef<bool> scalableDims, |
| 3724 | Operation *opToMask) { |
| 3725 | if (!useMasking) |
| 3726 | return opToMask; |
| 3727 | auto maskType = |
| 3728 | VectorType::get(shape: maskShape, elementType: rewriter.getI1Type(), scalableDims); |
| 3729 | |
| 3730 | SmallVector<bool> inBounds(maskShape.size(), true); |
| 3731 | auto xferOp = cast<VectorTransferOpInterface>(Val: opToMask); |
| 3732 | xferOp->setAttr(name: xferOp.getInBoundsAttrName(), |
| 3733 | value: rewriter.getBoolArrayAttr(values: inBounds)); |
| 3734 | |
| 3735 | SmallVector<OpFoldResult> mixedDims = vector::getMixedSizesXfer( |
| 3736 | hasTensorSemantics: cast<LinalgOp>(Val: op).hasPureTensorSemantics(), xfer: opToMask, rewriter); |
| 3737 | |
| 3738 | Value maskOp = |
| 3739 | rewriter.create<vector::CreateMaskOp>(location: loc, args&: maskType, args&: mixedDims); |
| 3740 | |
| 3741 | return mlir::vector::maskOperation(builder&: rewriter, maskableOp: opToMask, mask: maskOp); |
| 3742 | }; |
| 3743 | |
| 3744 | // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, |
| 3745 | // 0]. |
| 3746 | Value lhs = rewriter.create<vector::TransferReadOp>( |
| 3747 | location: loc, args&: lhsType, args&: lhsShaped, args: ValueRange{zero, zero, zero}, |
| 3748 | /*padding=*/args: arith::getZeroConstant(builder&: rewriter, loc, type: lhsEltType)); |
| 3749 | auto maybeMaskedLhs = maybeMaskXferOp( |
| 3750 | lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp()); |
| 3751 | |
| 3752 | // Read rhs slice of size {kw, c} @ [0, 0]. |
| 3753 | Value rhs = rewriter.create<vector::TransferReadOp>( |
| 3754 | location: loc, args&: rhsType, args&: rhsShaped, args: ValueRange{zero, zero}, |
| 3755 | /*padding=*/args: arith::getZeroConstant(builder&: rewriter, loc, type: rhsEltType)); |
| 3756 | auto maybeMaskedRhs = maybeMaskXferOp( |
| 3757 | rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp()); |
| 3758 | |
| 3759 | // Read res slice of size {n, w, c} @ [0, 0, 0]. |
| 3760 | Value res = rewriter.create<vector::TransferReadOp>( |
| 3761 | location: loc, args&: resType, args&: resShaped, args: ValueRange{zero, zero, zero}, |
| 3762 | /*padding=*/args: arith::getZeroConstant(builder&: rewriter, loc, type: resEltType)); |
| 3763 | auto maybeMaskedRes = maybeMaskXferOp( |
| 3764 | resType.getShape(), resType.getScalableDims(), res.getDefiningOp()); |
| 3765 | |
| 3766 | //===------------------------------------------------------------------===// |
| 3767 | // Begin vector-only rewrite part |
| 3768 | //===------------------------------------------------------------------===// |
| 3769 | // Unroll along kw and read slices of lhs and rhs. |
| 3770 | SmallVector<Value> lhsVals, rhsVals, resVals; |
| 3771 | SmallVector<int64_t> inOutSliceSizes = {nSize, wSizeStep, cSize}; |
| 3772 | SmallVector<int64_t> inOutStrides = {1, 1, 1}; |
| 3773 | |
| 3774 | // Extract lhs slice of size {n, wSizeStep, c} |
| 3775 | // @ [0, sw * w + dw * kw, 0]. |
| 3776 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 3777 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 3778 | lhsVals.push_back(Elt: rewriter.create<vector::ExtractStridedSliceOp>( |
| 3779 | location: loc, args: maybeMaskedLhs->getResult(idx: 0), |
| 3780 | /*offsets=*/args: ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0}, |
| 3781 | args&: inOutSliceSizes, args&: inOutStrides)); |
| 3782 | } |
| 3783 | } |
| 3784 | // Extract rhs slice of size {c} @ [kw]. |
| 3785 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 3786 | rhsVals.push_back(Elt: rewriter.create<vector::ExtractOp>( |
| 3787 | location: loc, args: maybeMaskedRhs->getResult(idx: 0), |
| 3788 | /*offsets=*/args: ArrayRef<int64_t>{kw})); |
| 3789 | } |
| 3790 | // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. |
| 3791 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 3792 | resVals.push_back(Elt: rewriter.create<vector::ExtractStridedSliceOp>( |
| 3793 | location: loc, args: maybeMaskedRes->getResult(idx: 0), |
| 3794 | /*offsets=*/args: ArrayRef<int64_t>{0, w, 0}, args&: inOutSliceSizes, |
| 3795 | args&: inOutStrides)); |
| 3796 | } |
| 3797 | |
| 3798 | auto linearIndex = [&](int64_t kw, int64_t w) { |
| 3799 | return kw * (wSize / wSizeStep) + w; |
| 3800 | }; |
| 3801 | |
| 3802 | // Note - the scalable flags are ignored as flattening combined with |
| 3803 | // scalable vectorization is not supported. |
| 3804 | SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize}; |
| 3805 | auto lhsTypeAfterFlattening = |
| 3806 | VectorType::get(shape: inOutFlattenSliceSizes, elementType: lhsEltType); |
| 3807 | auto resTypeAfterFlattening = |
| 3808 | VectorType::get(shape: inOutFlattenSliceSizes, elementType: resEltType); |
| 3809 | |
| 3810 | // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c} |
| 3811 | for (int64_t kw = 0; kw < kwSize; ++kw) { |
| 3812 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 3813 | Value lhsVal = lhsVals[linearIndex(kw, w)]; |
| 3814 | Value resVal = resVals[w]; |
| 3815 | if (flatten) { |
| 3816 | // Flatten the input and output vectors (collapse the channel |
| 3817 | // dimension) |
| 3818 | lhsVal = rewriter.create<vector::ShapeCastOp>( |
| 3819 | location: loc, args&: lhsTypeAfterFlattening, args&: lhsVals[linearIndex(kw, w)]); |
| 3820 | resVal = rewriter.create<vector::ShapeCastOp>( |
| 3821 | location: loc, args&: resTypeAfterFlattening, args&: resVals[w]); |
| 3822 | } |
| 3823 | resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhs: lhsVal, |
| 3824 | rhs: rhsVals[kw], res: resVal, flatten); |
| 3825 | if (flatten) { |
| 3826 | // Un-flatten the output vector (restore the channel dimension) |
| 3827 | resVals[w] = rewriter.create<vector::ShapeCastOp>( |
| 3828 | location: loc, args: VectorType::get(shape: inOutSliceSizes, elementType: resEltType), args&: resVals[w]); |
| 3829 | } |
| 3830 | } |
| 3831 | } |
| 3832 | |
| 3833 | // Its possible we failed to create the Fma. |
| 3834 | if (!llvm::all_of(Range&: resVals, P: [](Value v) { return v; })) { |
| 3835 | // Manually revert (in reverse order) to avoid leaving a bad IR state. |
| 3836 | for (auto &collection : |
| 3837 | {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}}) |
| 3838 | for (Value v : collection) |
| 3839 | rewriter.eraseOp(op: v.getDefiningOp()); |
| 3840 | return rewriter.notifyMatchFailure(arg&: op, msg: "failed to create FMA" ); |
| 3841 | } |
| 3842 | |
| 3843 | // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. |
| 3844 | // This does not depend on kw. |
| 3845 | for (int64_t w = 0; w < wSize; w += wSizeStep) { |
| 3846 | maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>( |
| 3847 | location: loc, args&: resVals[w], args: maybeMaskedRes->getResult(idx: 0), |
| 3848 | /*offsets=*/args: ArrayRef<int64_t>{0, w, 0}, |
| 3849 | /*strides=*/args: ArrayRef<int64_t>{1, 1, 1}); |
| 3850 | } |
| 3851 | //===------------------------------------------------------------------===// |
| 3852 | // End vector-only rewrite part |
| 3853 | //===------------------------------------------------------------------===// |
| 3854 | |
| 3855 | // Write back res slice of size {n, w, c} @ [0, 0, 0]. |
| 3856 | Operation *resOut = rewriter.create<vector::TransferWriteOp>( |
| 3857 | location: loc, args: maybeMaskedRes->getResult(idx: 0), args&: resShaped, |
| 3858 | args: ValueRange{zero, zero, zero}); |
| 3859 | return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(), |
| 3860 | resOut); |
| 3861 | } |
| 3862 | |
| 3863 | /// Lower: |
| 3864 | /// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false) |
| 3865 | /// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true) |
| 3866 | /// to MulAcc. |
| 3867 | Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc, |
| 3868 | Value lhs, Value rhs, Value res, |
| 3869 | bool flatten) { |
| 3870 | auto rhsTy = cast<ShapedType>(Val: rhs.getType()); |
| 3871 | auto resTy = cast<ShapedType>(Val: res.getType()); |
| 3872 | |
| 3873 | // TODO(suderman): Change this to use a vector.ima intrinsic. |
| 3874 | lhs = promote(rewriter, loc, val: lhs, ty: resTy); |
| 3875 | |
| 3876 | if (flatten) { |
| 3877 | // NOTE: This following logic won't work for scalable vectors. For this |
| 3878 | // reason, "flattening" is not supported when shapes are dynamic (this |
| 3879 | // should be captured by one of the pre-conditions). |
| 3880 | |
| 3881 | // There are two options for handling the filter: |
| 3882 | // * shape_cast(broadcast(filter)) |
| 3883 | // * broadcast(shuffle(filter)) |
| 3884 | // Opt for the option without shape_cast to simplify the codegen. |
| 3885 | auto rhsSize = cast<VectorType>(Val: rhs.getType()).getShape()[0]; |
| 3886 | auto resSize = cast<VectorType>(Val: res.getType()).getShape()[1]; |
| 3887 | |
| 3888 | SmallVector<int64_t, 16> indices; |
| 3889 | for (int i = 0; i < resSize / rhsSize; ++i) { |
| 3890 | for (int j = 0; j < rhsSize; ++j) |
| 3891 | indices.push_back(Elt: j); |
| 3892 | } |
| 3893 | |
| 3894 | rhs = rewriter.create<vector::ShuffleOp>(location: loc, args&: rhs, args&: rhs, args&: indices); |
| 3895 | } |
| 3896 | // Broadcast the filter to match the output vector |
| 3897 | rhs = rewriter.create<vector::BroadcastOp>( |
| 3898 | location: loc, args: resTy.clone(elementType: rhsTy.getElementType()), args&: rhs); |
| 3899 | |
| 3900 | rhs = promote(rewriter, loc, val: rhs, ty: resTy); |
| 3901 | |
| 3902 | if (!lhs || !rhs) |
| 3903 | return nullptr; |
| 3904 | |
| 3905 | if (isa<FloatType>(Val: resTy.getElementType())) |
| 3906 | return rewriter.create<vector::FMAOp>(location: loc, args&: lhs, args&: rhs, args&: res); |
| 3907 | |
| 3908 | auto mul = rewriter.create<arith::MulIOp>(location: loc, args&: lhs, args&: rhs); |
| 3909 | return rewriter.create<arith::AddIOp>(location: loc, args&: mul, args&: res); |
| 3910 | } |
| 3911 | |
| 3912 | /// Entry point for non-channeled convolution: |
| 3913 | /// {{w + kw}, {kw}, {w}} |
| 3914 | FailureOr<Operation *> generateNonChanneledConv() { |
| 3915 | AffineExpr w, kw; |
| 3916 | bindDims(ctx, exprs&: w, exprs&: kw); |
| 3917 | if (!iters(its: {Par(), Red()})) |
| 3918 | return rewriter.notifyMatchFailure(arg&: op, |
| 3919 | msg: "failed to match conv::W 1-par 1-red" ); |
| 3920 | |
| 3921 | // No transposition needed. |
| 3922 | if (layout(l: {/*lhsIndex*/ {w + kw}, |
| 3923 | /*rhsIndex*/ {kw}, |
| 3924 | /*resIndex*/ {w}})) |
| 3925 | return conv(conv1DOpOrder: Conv1DOpOrder::W); |
| 3926 | |
| 3927 | return rewriter.notifyMatchFailure(arg&: op, msg: "not a conv::W layout" ); |
| 3928 | } |
| 3929 | |
| 3930 | /// Entry point that transposes into the common form: |
| 3931 | /// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}} |
| 3932 | FailureOr<Operation *> generateNwcConv() { |
| 3933 | AffineExpr n, w, f, kw, c; |
| 3934 | bindDims(ctx, exprs&: n, exprs&: w, exprs&: f, exprs&: kw, exprs&: c); |
| 3935 | if (!iters(its: {Par(), Par(), Par(), Red(), Red()})) |
| 3936 | return rewriter.notifyMatchFailure( |
| 3937 | arg&: op, msg: "failed to match conv::Nwc 3-par 2-red" ); |
| 3938 | |
| 3939 | // No transposition needed. |
| 3940 | if (layout(l: {/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, |
| 3941 | /*rhsIndex*/ {kw, c, f}, |
| 3942 | /*resIndex*/ {n, w, f}})) |
| 3943 | return conv(conv1DOpOrder: Conv1DOpOrder::Nwc); |
| 3944 | |
| 3945 | return rewriter.notifyMatchFailure(arg&: op, msg: "not a conv::Nwc layout" ); |
| 3946 | } |
| 3947 | |
| 3948 | /// Entry point that transposes into the common form: |
| 3949 | /// {{n, c, strideW * w + dilationW * kw}, {f, c, kw}, {n, f, w}} |
| 3950 | FailureOr<Operation *> generateNcwConv() { |
| 3951 | AffineExpr n, w, f, kw, c; |
| 3952 | bindDims(ctx, exprs&: n, exprs&: f, exprs&: w, exprs&: c, exprs&: kw); |
| 3953 | if (!iters(its: {Par(), Par(), Par(), Red(), Red()})) |
| 3954 | return rewriter.notifyMatchFailure( |
| 3955 | arg&: op, msg: "failed to match conv::Ncw 3-par 2-red" ); |
| 3956 | |
| 3957 | if (layout(l: {/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, |
| 3958 | /*rhsIndex*/ {f, c, kw}, |
| 3959 | /*resIndex*/ {n, f, w}})) |
| 3960 | return conv(conv1DOpOrder: Conv1DOpOrder::Ncw); |
| 3961 | |
| 3962 | return rewriter.notifyMatchFailure(arg&: op, msg: "not a conv::Ncw layout" ); |
| 3963 | } |
| 3964 | |
| 3965 | /// Entry point that transposes into the common form: |
| 3966 | /// {{n, strideW * w + dilationW * kw, c}, {kw}, {n, w, c}} for pooling |
| 3967 | FailureOr<Operation *> generateNwcPooling() { |
| 3968 | AffineExpr n, w, c, kw; |
| 3969 | bindDims(ctx, exprs&: n, exprs&: w, exprs&: c, exprs&: kw); |
| 3970 | if (!iters(its: {Par(), Par(), Par(), Red()})) |
| 3971 | return rewriter.notifyMatchFailure(arg&: op, |
| 3972 | msg: "failed to match pooling 3-par 1-red" ); |
| 3973 | |
| 3974 | // No transposition needed. |
| 3975 | if (layout(l: {/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, |
| 3976 | /*rhsIndex*/ {kw}, |
| 3977 | /*resIndex*/ {n, w, c}})) |
| 3978 | return conv(conv1DOpOrder: Conv1DOpOrder::Nwc); |
| 3979 | |
| 3980 | return rewriter.notifyMatchFailure(arg&: op, msg: "not a pooling::Nwc layout" ); |
| 3981 | } |
| 3982 | |
| 3983 | /// Entry point that transposes into the common form: |
| 3984 | /// {{n, c, strideW * w + dilationW * kw}, {kw}, {n, c, w}} for pooling |
| 3985 | FailureOr<Operation *> generateNcwPooling() { |
| 3986 | AffineExpr n, w, c, kw; |
| 3987 | bindDims(ctx, exprs&: n, exprs&: c, exprs&: w, exprs&: kw); |
| 3988 | if (!iters(its: {Par(), Par(), Par(), Red()})) |
| 3989 | return rewriter.notifyMatchFailure(arg&: op, |
| 3990 | msg: "failed to match pooling 3-par 1-red" ); |
| 3991 | |
| 3992 | if (layout(l: {/*lhsIndex*/ {n, c, strideW * w + dilationW * kw}, |
| 3993 | /*rhsIndex*/ {kw}, |
| 3994 | /*resIndex*/ {n, c, w}})) |
| 3995 | return conv(conv1DOpOrder: Conv1DOpOrder::Ncw); |
| 3996 | |
| 3997 | return rewriter.notifyMatchFailure(arg&: op, msg: "not a pooling::Ncw layout" ); |
| 3998 | } |
| 3999 | |
| 4000 | /// Entry point that transposes into the common form: |
| 4001 | /// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}} |
| 4002 | FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0, |
| 4003 | bool vecChDimScalableFlag = false, |
| 4004 | bool flatten = false) { |
| 4005 | AffineExpr n, w, c, kw; |
| 4006 | bindDims(ctx, exprs&: n, exprs&: w, exprs&: c, exprs&: kw); |
| 4007 | if (!iters(its: {Par(), Par(), Par(), Red()})) |
| 4008 | return rewriter.notifyMatchFailure( |
| 4009 | arg&: op, msg: "failed to match depthwise::Nwc conv 3-par 1-red" ); |
| 4010 | |
| 4011 | // No transposition needed. |
| 4012 | if (layout(l: {/*lhsIndex*/ {n, strideW * w + dilationW * kw, c}, |
| 4013 | /*rhsIndex*/ {kw, c}, |
| 4014 | /*resIndex*/ {n, w, c}})) |
| 4015 | return depthwiseConv(channelDimVecSize: vecChDimSize, channelDimScalableFlag: vecChDimScalableFlag, flatten); |
| 4016 | |
| 4017 | return rewriter.notifyMatchFailure(arg&: op, msg: "not a depthwise::Nwc layout" ); |
| 4018 | } |
| 4019 | |
| 4020 | private: |
| 4021 | ConvOperationKind oper = ConvOperationKind::Conv; |
| 4022 | StringAttr redOp; |
| 4023 | StringAttr poolExtOp; |
| 4024 | bool isPoolExt = false; |
| 4025 | int strideW, dilationW; |
| 4026 | Value lhsShaped, rhsShaped, resShaped; |
| 4027 | ShapedType lhsShapedType, rhsShapedType, resShapedType; |
| 4028 | vector::CombiningKind reductionKind; |
| 4029 | |
| 4030 | // Sets oper, poolExtOp and isPoolExt for valid conv/pooling ops. |
| 4031 | void setConvOperationKind(Operation *reduceOp) { |
| 4032 | int numBlockArguments = |
| 4033 | llvm::count_if(Range: reduceOp->getOperands(), P: llvm::IsaPred<BlockArgument>); |
| 4034 | if (numBlockArguments == 1) { |
| 4035 | // Will be convolution if feeder is a MulOp. |
| 4036 | // A strength reduced version of MulOp for i1 type is AndOp which is also |
| 4037 | // supported. Otherwise, it can be pooling. This strength reduction logic |
| 4038 | // is in `buildBinaryFn` helper in the Linalg dialect. |
| 4039 | auto feedValIt = llvm::find_if_not(Range: reduceOp->getOperands(), |
| 4040 | P: llvm::IsaPred<BlockArgument>); |
| 4041 | Operation *feedOp = (*feedValIt).getDefiningOp(); |
| 4042 | if (isCastOfBlockArgument(op: feedOp)) { |
| 4043 | oper = ConvOperationKind::Pool; |
| 4044 | isPoolExt = true; |
| 4045 | poolExtOp = feedOp->getName().getIdentifier(); |
| 4046 | return; |
| 4047 | } |
| 4048 | oper = ConvOperationKind::Conv; |
| 4049 | return; |
| 4050 | } |
| 4051 | // numBlockArugments == 2 and this is a pooling op. |
| 4052 | oper = ConvOperationKind::Pool; |
| 4053 | isPoolExt = false; |
| 4054 | } |
| 4055 | }; |
| 4056 | } // namespace |
| 4057 | |
| 4058 | /// Helper function to vectorize a LinalgOp with convolution semantics. |
| 4059 | // TODO: extend the generic vectorization to support windows and drop this. |
| 4060 | static FailureOr<Operation *> vectorizeConvolution( |
| 4061 | RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes, |
| 4062 | ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) { |
| 4063 | Conv1DGenerator conv1dGen(rewriter, op); |
| 4064 | auto res = conv1dGen.generateNonChanneledConv(); |
| 4065 | if (succeeded(Result: res)) |
| 4066 | return res; |
| 4067 | res = conv1dGen.generateNwcConv(); |
| 4068 | if (succeeded(Result: res)) |
| 4069 | return res; |
| 4070 | res = conv1dGen.generateNcwConv(); |
| 4071 | if (succeeded(Result: res)) |
| 4072 | return res; |
| 4073 | res = conv1dGen.generateNwcPooling(); |
| 4074 | if (succeeded(Result: res)) |
| 4075 | return res; |
| 4076 | res = conv1dGen.generateNcwPooling(); |
| 4077 | if (succeeded(Result: res)) |
| 4078 | return res; |
| 4079 | |
| 4080 | // Only depthwise 1D NWC convs are left - these can be vectorized using masks |
| 4081 | // and scalable vectors. Note that ATM the only dim that can be dynamic (i.e. |
| 4082 | // masked/scalable) is the channel dim (i.e. the trailing dim). |
| 4083 | uint64_t vecChDimSize = ShapedType::kDynamic; |
| 4084 | bool vecChDimScalableFlag = false; |
| 4085 | if (!inputVecSizes.empty()) { |
| 4086 | // Only use the input vector size corresponding to the channel dim. Other |
| 4087 | // vector dims will be inferred from the Ops. |
| 4088 | assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) || |
| 4089 | isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) && |
| 4090 | "Not a 1D depthwise conv!" ); |
| 4091 | size_t chDimIdx = |
| 4092 | TypeSwitch<Operation *, size_t>(op) |
| 4093 | .Case<linalg::DepthwiseConv1DNwcWcOp>(caseFn: [](auto conv) { return 2; }) |
| 4094 | .Case<linalg::DepthwiseConv1DNcwCwOp>(caseFn: [](auto conv) { return 1; }); |
| 4095 | |
| 4096 | vecChDimSize = inputVecSizes[chDimIdx]; |
| 4097 | vecChDimScalableFlag = inputScalableVecDims[chDimIdx]; |
| 4098 | } |
| 4099 | return conv1dGen.generateDilatedConv(vecChDimSize, vecChDimScalableFlag, |
| 4100 | flatten: flatten1DDepthwiseConv); |
| 4101 | } |
| 4102 | |
| 4103 | struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> { |
| 4104 | using OpInterfaceRewritePattern::OpInterfaceRewritePattern; |
| 4105 | |
| 4106 | LogicalResult matchAndRewrite(LinalgOp op, |
| 4107 | PatternRewriter &rewriter) const override { |
| 4108 | FailureOr<Operation *> resultOrFail = vectorizeConvolution(rewriter, op); |
| 4109 | if (failed(Result: resultOrFail)) |
| 4110 | return failure(); |
| 4111 | Operation *newOp = *resultOrFail; |
| 4112 | if (newOp->getNumResults() == 0) { |
| 4113 | rewriter.eraseOp(op: op.getOperation()); |
| 4114 | return success(); |
| 4115 | } |
| 4116 | assert(newOp->getNumResults() == 1 && "expected single result" ); |
| 4117 | rewriter.replaceOp(op: op.getOperation(), newValues: newOp->getResult(idx: 0)); |
| 4118 | return success(); |
| 4119 | } |
| 4120 | }; |
| 4121 | |
| 4122 | void mlir::linalg::populateConvolutionVectorizationPatterns( |
| 4123 | RewritePatternSet &patterns, PatternBenefit benefit) { |
| 4124 | patterns.add<VectorizeConvolution>(arg: patterns.getContext(), args&: benefit); |
| 4125 | } |
| 4126 | |