1 | //= TestAffineLoopParametricTiling.cpp -- Parametric Affine loop tiling pass =// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a test pass to test parametric tiling of perfectly |
10 | // nested affine for loops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Affine/LoopUtils.h" |
16 | #include "mlir/Dialect/Affine/Passes.h" |
17 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
18 | |
19 | using namespace mlir; |
20 | using namespace mlir::affine; |
21 | |
22 | #define DEBUG_TYPE "test-affine-parametric-tile" |
23 | |
24 | namespace { |
25 | struct TestAffineLoopParametricTiling |
26 | : public PassWrapper<TestAffineLoopParametricTiling, |
27 | OperationPass<func::FuncOp>> { |
28 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAffineLoopParametricTiling) |
29 | |
30 | StringRef getArgument() const final { return "test-affine-parametric-tile" ; } |
31 | StringRef getDescription() const final { |
32 | return "Tile affine loops using SSA values as tile sizes" ; |
33 | } |
34 | void runOnOperation() override; |
35 | }; |
36 | } // namespace |
37 | |
38 | /// Checks if the function enclosing the loop nest has any arguments passed to |
39 | /// it, which can be used as tiling parameters. Assumes that atleast 'n' |
40 | /// arguments are passed, where 'n' is the number of loops in the loop nest. |
41 | static LogicalResult checkIfTilingParametersExist(ArrayRef<AffineForOp> band) { |
42 | assert(!band.empty() && "no loops in input band" ); |
43 | AffineForOp topLoop = band[0]; |
44 | |
45 | if (func::FuncOp funcOp = dyn_cast<func::FuncOp>(topLoop->getParentOp())) |
46 | if (funcOp.getNumArguments() < band.size()) |
47 | return topLoop->emitError( |
48 | "too few tile sizes provided in the argument list of the function " |
49 | "which contains the current band" ); |
50 | return success(); |
51 | } |
52 | |
53 | /// Captures tiling parameters, which are expected to be passed as arguments |
54 | /// to the function enclosing the loop nest. Also checks if the required |
55 | /// parameters are of index type. This approach is temporary for testing |
56 | /// purposes. |
57 | static LogicalResult |
58 | getTilingParameters(ArrayRef<AffineForOp> band, |
59 | SmallVectorImpl<Value> &tilingParameters) { |
60 | AffineForOp topLoop = band[0]; |
61 | Region *funcOpRegion = topLoop->getParentRegion(); |
62 | unsigned nestDepth = band.size(); |
63 | |
64 | for (BlockArgument blockArgument : |
65 | funcOpRegion->getArguments().take_front(nestDepth)) { |
66 | if (blockArgument.getArgNumber() < nestDepth) { |
67 | if (!blockArgument.getType().isIndex()) |
68 | return topLoop->emitError( |
69 | "expected tiling parameters to be of index type" ); |
70 | tilingParameters.push_back(blockArgument); |
71 | } |
72 | } |
73 | return success(); |
74 | } |
75 | |
76 | void TestAffineLoopParametricTiling::runOnOperation() { |
77 | // Bands of loops to tile. |
78 | std::vector<SmallVector<AffineForOp, 6>> bands; |
79 | getTileableBands(getOperation(), &bands); |
80 | |
81 | // Tile each band. |
82 | for (MutableArrayRef<AffineForOp> band : bands) { |
83 | // Capture the tiling parameters from the arguments to the function |
84 | // enclosing this loop nest. |
85 | SmallVector<AffineForOp, 6> tiledNest; |
86 | SmallVector<Value, 6> tilingParameters; |
87 | // Check if tiling parameters are present. |
88 | if (checkIfTilingParametersExist(band).failed()) |
89 | return; |
90 | |
91 | // Get function arguments as tiling parameters. |
92 | if (getTilingParameters(band, tilingParameters).failed()) |
93 | return; |
94 | |
95 | (void)tilePerfectlyNestedParametric(band, tilingParameters, &tiledNest); |
96 | } |
97 | } |
98 | |
99 | namespace mlir { |
100 | namespace test { |
101 | void registerTestAffineLoopParametricTilingPass() { |
102 | PassRegistration<TestAffineLoopParametricTiling>(); |
103 | } |
104 | } // namespace test |
105 | } // namespace mlir |
106 | |