1 | //===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements target-independent rewrites and utilities to lower the |
10 | // 'vector.interleave' operation. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
15 | #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
16 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | |
20 | #define DEBUG_TYPE "vector-interleave-lowering" |
21 | |
22 | using namespace mlir; |
23 | using namespace mlir::vector; |
24 | |
25 | namespace { |
26 | |
27 | /// A one-shot unrolling of vector.interleave to the `targetRank`. |
28 | /// |
29 | /// Example: |
30 | /// |
31 | /// ```mlir |
32 | /// vector.interleave %a, %b : vector<1x2x3x4xi64> |
33 | /// ``` |
34 | /// Would be unrolled to: |
35 | /// ```mlir |
36 | /// %result = arith.constant dense<0> : vector<1x2x3x8xi64> |
37 | /// %0 = vector.extract %a[0, 0, 0] ─┐ |
38 | /// : vector<4xi64> from vector<1x2x3x4xi64> | |
39 | /// %1 = vector.extract %b[0, 0, 0] | |
40 | /// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for |
41 | /// %2 = vector.interleave %0, %1 : vector<4xi64> | all leading positions |
42 | /// %3 = vector.insert %2, %result [0, 0, 0] | |
43 | /// : vector<8xi64> into vector<1x2x3x8xi64> ┘ |
44 | /// ``` |
45 | /// |
46 | /// Note: If any leading dimension before the `targetRank` is scalable the |
47 | /// unrolling will stop before the scalable dimension. |
48 | class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> { |
49 | public: |
50 | UnrollInterleaveOp(int64_t targetRank, MLIRContext *context, |
51 | PatternBenefit benefit = 1) |
52 | : OpRewritePattern(context, benefit), targetRank(targetRank){}; |
53 | |
54 | LogicalResult matchAndRewrite(vector::InterleaveOp op, |
55 | PatternRewriter &rewriter) const override { |
56 | VectorType resultType = op.getResultVectorType(); |
57 | auto unrollIterator = vector::createUnrollIterator(vType: resultType, targetRank); |
58 | if (!unrollIterator) |
59 | return failure(); |
60 | |
61 | auto loc = op.getLoc(); |
62 | Value result = rewriter.create<arith::ConstantOp>( |
63 | loc, resultType, rewriter.getZeroAttr(resultType)); |
64 | for (auto position : *unrollIterator) { |
65 | Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), position); |
66 | Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), position); |
67 | Value interleave = |
68 | rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs); |
69 | result = rewriter.create<InsertOp>(loc, interleave, result, position); |
70 | } |
71 | |
72 | rewriter.replaceOp(op, result); |
73 | return success(); |
74 | } |
75 | |
76 | private: |
77 | int64_t targetRank = 1; |
78 | }; |
79 | |
80 | } // namespace |
81 | |
82 | void mlir::vector::populateVectorInterleaveLoweringPatterns( |
83 | RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) { |
84 | patterns.add<UnrollInterleaveOp>(arg&: targetRank, args: patterns.getContext(), args&: benefit); |
85 | } |
86 | |