1//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.interleave' operation.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
16#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/PatternMatch.h"
19
20#define DEBUG_TYPE "vector-interleave-lowering"
21
22using namespace mlir;
23using namespace mlir::vector;
24
25namespace {
26
27/// A one-shot unrolling of vector.interleave to the `targetRank`.
28///
29/// Example:
30///
31/// ```mlir
32/// vector.interleave %a, %b : vector<1x2x3x4xi64>
33/// ```
34/// Would be unrolled to:
35/// ```mlir
36/// %result = arith.constant dense<0> : vector<1x2x3x8xi64>
37/// %0 = vector.extract %a[0, 0, 0] ─┐
38/// : vector<4xi64> from vector<1x2x3x4xi64> |
39/// %1 = vector.extract %b[0, 0, 0] |
40/// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for
41/// %2 = vector.interleave %0, %1 : vector<4xi64> | all leading positions
42/// %3 = vector.insert %2, %result [0, 0, 0] |
43/// : vector<8xi64> into vector<1x2x3x8xi64> ┘
44/// ```
45///
46/// Note: If any leading dimension before the `targetRank` is scalable the
47/// unrolling will stop before the scalable dimension.
48class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
49public:
50 UnrollInterleaveOp(int64_t targetRank, MLIRContext *context,
51 PatternBenefit benefit = 1)
52 : OpRewritePattern(context, benefit), targetRank(targetRank){};
53
54 LogicalResult matchAndRewrite(vector::InterleaveOp op,
55 PatternRewriter &rewriter) const override {
56 VectorType resultType = op.getResultVectorType();
57 auto unrollIterator = vector::createUnrollIterator(vType: resultType, targetRank);
58 if (!unrollIterator)
59 return failure();
60
61 auto loc = op.getLoc();
62 Value result = rewriter.create<arith::ConstantOp>(
63 loc, resultType, rewriter.getZeroAttr(resultType));
64 for (auto position : *unrollIterator) {
65 Value extractLhs = rewriter.create<ExtractOp>(loc, op.getLhs(), position);
66 Value extractRhs = rewriter.create<ExtractOp>(loc, op.getRhs(), position);
67 Value interleave =
68 rewriter.create<InterleaveOp>(loc, extractLhs, extractRhs);
69 result = rewriter.create<InsertOp>(loc, interleave, result, position);
70 }
71
72 rewriter.replaceOp(op, result);
73 return success();
74 }
75
76private:
77 int64_t targetRank = 1;
78};
79
80} // namespace
81
82void mlir::vector::populateVectorInterleaveLoweringPatterns(
83 RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
84 patterns.add<UnrollInterleaveOp>(arg&: targetRank, args: patterns.getContext(), args&: benefit);
85}
86

source code of mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp