1 | //===- VectorToSCF.h - Convert vector to SCF dialect ------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_ |
10 | #define MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_ |
11 | |
12 | #include "mlir/IR/PatternMatch.h" |
13 | |
14 | namespace mlir { |
15 | class MLIRContext; |
16 | class Pass; |
17 | class RewritePatternSet; |
18 | |
19 | #define GEN_PASS_DECL_CONVERTVECTORTOSCF |
20 | #include "mlir/Conversion/Passes.h.inc" |
21 | |
22 | /// When lowering an N-d vector transfer op to an (N-1)-d vector transfer op, |
23 | /// a temporary buffer is created through which individual (N-1)-d vector are |
24 | /// staged. This pattern can be applied multiple time, until the transfer op |
25 | /// is 1-d. |
26 | /// This is consistent with the lack of an LLVM instruction to dynamically |
27 | /// index into an aggregate (see the Vector dialect lowering to LLVM deep dive). |
28 | /// |
29 | /// An instruction such as: |
30 | /// ``` |
31 | /// vector.transfer_write %vec, %A[%a, %b, %c] : |
32 | /// vector<9x17x15xf32>, memref<?x?x?xf32> |
33 | /// ``` |
34 | /// Lowers to pseudo-IR resembling (unpacking one dimension): |
35 | /// ``` |
36 | /// %0 = alloca() : memref<vector<9x17x15xf32>> |
37 | /// store %vec, %0[] : memref<vector<9x17x15xf32>> |
38 | /// %1 = vector.type_cast %0 : |
39 | /// memref<vector<9x17x15xf32>> to memref<9xvector<17x15xf32>> |
40 | /// affine.for %I = 0 to 9 { |
41 | /// %dim = dim %A, 0 : memref<?x?x?xf32> |
42 | /// %add = affine.apply %I + %a |
43 | /// %cmp = arith.cmpi "slt", %add, %dim : index |
44 | /// scf.if %cmp { |
45 | /// %vec_2d = load %1[%I] : memref<9xvector<17x15xf32>> |
46 | /// vector.transfer_write %vec_2d, %A[%add, %b, %c] : |
47 | /// vector<17x15xf32>, memref<?x?x?xf32> |
48 | /// ``` |
49 | /// |
50 | /// When applying the pattern a second time, the existing alloca() operation |
51 | /// is reused and only a second vector.type_cast is added. |
52 | struct VectorTransferToSCFOptions { |
53 | /// Minimal rank to which vector transfer are lowered. |
54 | unsigned targetRank = 1; |
55 | VectorTransferToSCFOptions &setTargetRank(unsigned r) { |
56 | targetRank = r; |
57 | return *this; |
58 | } |
59 | /// Allows vector transfers that operated on tensors to be lowered (this is an |
60 | /// uncommon alternative). |
61 | bool lowerTensors = false; |
62 | VectorTransferToSCFOptions &enableLowerTensors(bool l = true) { |
63 | lowerTensors = l; |
64 | return *this; |
65 | } |
66 | /// Triggers full unrolling (vs iterating with a loop) during transfer to scf. |
67 | bool unroll = false; |
68 | VectorTransferToSCFOptions &enableFullUnroll(bool u = true) { |
69 | unroll = u; |
70 | return *this; |
71 | } |
72 | }; |
73 | |
74 | /// Collect a set of patterns to convert from the Vector dialect to SCF + func. |
75 | void populateVectorToSCFConversionPatterns( |
76 | RewritePatternSet &patterns, |
77 | const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions()); |
78 | |
79 | /// Create a pass to convert a subset of vector ops to SCF. |
80 | std::unique_ptr<Pass> createConvertVectorToSCFPass( |
81 | const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions()); |
82 | |
83 | } // namespace mlir |
84 | |
85 | #endif // MLIR_CONVERSION_VECTORTOSCF_VECTORTOSCF_H_ |
86 | |