1 | //===- TestComposeSubView.cpp - Test composed subviews --------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
10 | #include "mlir/Dialect/MemRef/Transforms/Passes.h" |
11 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
12 | |
13 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | |
16 | using namespace mlir; |
17 | |
18 | namespace { |
19 | struct TestMultiBufferingPass |
20 | : public PassWrapper<TestMultiBufferingPass, OperationPass<>> { |
21 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiBufferingPass) |
22 | |
23 | TestMultiBufferingPass() = default; |
24 | TestMultiBufferingPass(const TestMultiBufferingPass &pass) |
25 | : PassWrapper(pass) {} |
26 | void getDependentDialects(DialectRegistry ®istry) const override { |
27 | registry.insert<affine::AffineDialect>(); |
28 | } |
29 | StringRef getArgument() const final { return "test-multi-buffering"; } |
30 | StringRef getDescription() const final { |
31 | return "Test multi buffering transformation"; |
32 | } |
33 | void runOnOperation() override; |
34 | Option<unsigned> multiplier{ |
35 | *this, "multiplier", |
36 | llvm::cl::desc( |
37 | "Decide how many versions of the buffer should be created,"), |
38 | llvm::cl::init(Val: 2)}; |
39 | }; |
40 | |
41 | void TestMultiBufferingPass::runOnOperation() { |
42 | SmallVector<memref::AllocOp> allocs; |
43 | getOperation()->walk( |
44 | [&allocs](memref::AllocOp alloc) { allocs.push_back(alloc); }); |
45 | for (memref::AllocOp alloc : allocs) |
46 | (void)multiBuffer(alloc, multiplier); |
47 | } |
48 | } // namespace |
49 | |
50 | namespace mlir { |
51 | namespace test { |
52 | void registerTestMultiBuffering() { |
53 | PassRegistration<TestMultiBufferingPass>(); |
54 | } |
55 | } // namespace test |
56 | } // namespace mlir |
57 |