1 | //===- TestMemRefStrideCalculation.cpp - Pass to test strides computation--===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
10 | #include "mlir/IR/BuiltinTypes.h" |
11 | #include "mlir/Pass/Pass.h" |
12 | |
13 | using namespace mlir; |
14 | |
15 | namespace { |
16 | struct TestMemRefStrideCalculation |
17 | : public PassWrapper<TestMemRefStrideCalculation, |
18 | InterfacePass<SymbolOpInterface>> { |
19 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMemRefStrideCalculation) |
20 | |
21 | StringRef getArgument() const final { |
22 | return "test-memref-stride-calculation" ; |
23 | } |
24 | StringRef getDescription() const final { |
25 | return "Test operation constant folding" ; |
26 | } |
27 | void runOnOperation() override; |
28 | }; |
29 | } // namespace |
30 | |
31 | /// Traverse AllocOp and compute strides of each MemRefType independently. |
32 | void TestMemRefStrideCalculation::runOnOperation() { |
33 | llvm::outs() << "Testing: " << getOperation().getName() << "\n" ; |
34 | getOperation().walk([&](memref::AllocOp allocOp) { |
35 | auto memrefType = cast<MemRefType>(allocOp.getResult().getType()); |
36 | int64_t offset; |
37 | SmallVector<int64_t, 4> strides; |
38 | if (failed(getStridesAndOffset(memrefType, strides, offset))) { |
39 | llvm::outs() << "MemRefType " << memrefType << " cannot be converted to " |
40 | << "strided form\n" ; |
41 | return; |
42 | } |
43 | llvm::outs() << "MemRefType offset: " ; |
44 | if (ShapedType::isDynamic(offset)) |
45 | llvm::outs() << "?" ; |
46 | else |
47 | llvm::outs() << offset; |
48 | llvm::outs() << " strides: " ; |
49 | llvm::interleaveComma(c: strides, os&: llvm::outs(), each_fn: [&](int64_t v) { |
50 | if (ShapedType::isDynamic(v)) |
51 | llvm::outs() << "?" ; |
52 | else |
53 | llvm::outs() << v; |
54 | }); |
55 | llvm::outs() << "\n" ; |
56 | }); |
57 | llvm::outs().flush(); |
58 | } |
59 | |
60 | namespace mlir { |
61 | namespace test { |
62 | void registerTestMemRefStrideCalculation() { |
63 | PassRegistration<TestMemRefStrideCalculation>(); |
64 | } |
65 | } // namespace test |
66 | } // namespace mlir |
67 | |