1//===- TestMemRefStrideCalculation.cpp - Pass to test strides computation--===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MemRef/IR/MemRef.h"
10#include "mlir/IR/BuiltinTypes.h"
11#include "mlir/Pass/Pass.h"
12
13using namespace mlir;
14
15namespace {
16struct TestMemRefStrideCalculation
17 : public PassWrapper<TestMemRefStrideCalculation,
18 InterfacePass<SymbolOpInterface>> {
19 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMemRefStrideCalculation)
20
21 StringRef getArgument() const final {
22 return "test-memref-stride-calculation";
23 }
24 StringRef getDescription() const final {
25 return "Test operation constant folding";
26 }
27 void runOnOperation() override;
28};
29} // namespace
30
31/// Traverse AllocOp and compute strides of each MemRefType independently.
32void TestMemRefStrideCalculation::runOnOperation() {
33 llvm::outs() << "Testing: " << getOperation().getName() << "\n";
34 getOperation().walk([&](memref::AllocOp allocOp) {
35 auto memrefType = cast<MemRefType>(allocOp.getResult().getType());
36 int64_t offset;
37 SmallVector<int64_t, 4> strides;
38 if (failed(getStridesAndOffset(memrefType, strides, offset))) {
39 llvm::outs() << "MemRefType " << memrefType << " cannot be converted to "
40 << "strided form\n";
41 return;
42 }
43 llvm::outs() << "MemRefType offset: ";
44 if (ShapedType::isDynamic(offset))
45 llvm::outs() << "?";
46 else
47 llvm::outs() << offset;
48 llvm::outs() << " strides: ";
49 llvm::interleaveComma(c: strides, os&: llvm::outs(), each_fn: [&](int64_t v) {
50 if (ShapedType::isDynamic(v))
51 llvm::outs() << "?";
52 else
53 llvm::outs() << v;
54 });
55 llvm::outs() << "\n";
56 });
57 llvm::outs().flush();
58}
59
60namespace mlir {
61namespace test {
62void registerTestMemRefStrideCalculation() {
63 PassRegistration<TestMemRefStrideCalculation>();
64}
65} // namespace test
66} // namespace mlir
67

source code of mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp