1//===- TestAliasAnalysis.cpp - Test alias analysis results ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains test passes for constructing and testing alias analysis
10// results.
11//
12//===----------------------------------------------------------------------===//
13
14#include "TestAliasAnalysis.h"
15#include "mlir/Analysis/AliasAnalysis.h"
16#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
17#include "mlir/Interfaces/FunctionInterfaces.h"
18#include "mlir/Pass/Pass.h"
19
20using namespace mlir;
21
22/// Print a value that is used as an operand of an alias query.
23static void printAliasOperand(Operation *op) {
24 llvm::errs() << op->getAttrOfType<StringAttr>("test.ptr").getValue();
25}
26static void printAliasOperand(Value value) {
27 if (BlockArgument arg = dyn_cast<BlockArgument>(Val&: value)) {
28 Region *region = arg.getParentRegion();
29 unsigned parentBlockNumber =
30 std::distance(first: region->begin(), last: arg.getOwner()->getIterator());
31 llvm::errs() << region->getParentOp()
32 ->getAttrOfType<StringAttr>("test.ptr")
33 .getValue()
34 << ".region" << region->getRegionNumber();
35 if (parentBlockNumber != 0)
36 llvm::errs() << ".block" << parentBlockNumber;
37 llvm::errs() << "#" << arg.getArgNumber();
38 return;
39 }
40 OpResult result = cast<OpResult>(Val&: value);
41 printAliasOperand(op: result.getOwner());
42 llvm::errs() << "#" << result.getResultNumber();
43}
44
45namespace mlir {
46namespace test {
47void printAliasResult(AliasResult result, Value lhs, Value rhs) {
48 printAliasOperand(value: lhs);
49 llvm::errs() << " <-> ";
50 printAliasOperand(value: rhs);
51 llvm::errs() << ": " << result << "\n";
52}
53
54/// Print the result of an alias query.
55void printModRefResult(ModRefResult result, Operation *op, Value location) {
56 printAliasOperand(op);
57 llvm::errs() << " -> ";
58 printAliasOperand(value: location);
59 llvm::errs() << ": " << result << "\n";
60}
61
62void TestAliasAnalysisBase::runAliasAnalysisOnOperation(
63 Operation *op, AliasAnalysis &aliasAnalysis) {
64 llvm::errs() << "Testing : " << *op->getInherentAttr(name: "sym_name") << "\n";
65
66 // Collect all of the values to check for aliasing behavior.
67 SmallVector<Value, 32> valsToCheck;
68 op->walk(callback: [&](Operation *op) {
69 if (!op->getDiscardableAttr(name: "test.ptr"))
70 return;
71 valsToCheck.append(in_start: op->result_begin(), in_end: op->result_end());
72 for (Region &region : op->getRegions())
73 for (Block &block : region)
74 valsToCheck.append(in_start: block.args_begin(), in_end: block.args_end());
75 });
76
77 // Check for aliasing behavior between each of the values.
78 for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it)
79 for (auto *innerIt = valsToCheck.begin(); innerIt != it; ++innerIt)
80 printAliasResult(result: aliasAnalysis.alias(lhs: *innerIt, rhs: *it), lhs: *innerIt, rhs: *it);
81}
82
83void TestAliasAnalysisModRefBase::runAliasAnalysisOnOperation(
84 Operation *op, AliasAnalysis &aliasAnalysis) {
85 llvm::errs() << "Testing : " << *op->getInherentAttr(name: "sym_name") << "\n";
86
87 // Collect all of the values to check for aliasing behavior.
88 SmallVector<Value, 32> valsToCheck;
89 op->walk(callback: [&](Operation *op) {
90 if (!op->getDiscardableAttr(name: "test.ptr"))
91 return;
92 valsToCheck.append(in_start: op->result_begin(), in_end: op->result_end());
93 for (Region &region : op->getRegions())
94 for (Block &block : region)
95 valsToCheck.append(in_start: block.args_begin(), in_end: block.args_end());
96 });
97
98 // Check for aliasing behavior between each of the values.
99 for (auto &it : valsToCheck) {
100 op->walk(callback: [&](Operation *op) {
101 if (!op->getDiscardableAttr(name: "test.ptr"))
102 return;
103 printModRefResult(result: aliasAnalysis.getModRef(op, location: it), op, location: it);
104 });
105 }
106}
107
108} // namespace test
109} // namespace mlir
110
111//===----------------------------------------------------------------------===//
112// Testing AliasResult
113//===----------------------------------------------------------------------===//
114
115namespace {
116struct TestAliasAnalysisPass
117 : public test::TestAliasAnalysisBase,
118 PassWrapper<TestAliasAnalysisPass, OperationPass<>> {
119 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisPass)
120
121 StringRef getArgument() const final { return "test-alias-analysis"; }
122 StringRef getDescription() const final {
123 return "Test alias analysis results.";
124 }
125 void runOnOperation() override {
126 AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
127 runAliasAnalysisOnOperation(op: getOperation(), aliasAnalysis);
128 }
129};
130} // namespace
131
132//===----------------------------------------------------------------------===//
133// Testing ModRefResult
134//===----------------------------------------------------------------------===//
135
136namespace {
137struct TestAliasAnalysisModRefPass
138 : public test::TestAliasAnalysisModRefBase,
139 PassWrapper<TestAliasAnalysisModRefPass, OperationPass<>> {
140 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisModRefPass)
141
142 StringRef getArgument() const final { return "test-alias-analysis-modref"; }
143 StringRef getDescription() const final {
144 return "Test alias analysis ModRef results.";
145 }
146 void runOnOperation() override {
147 AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
148 runAliasAnalysisOnOperation(op: getOperation(), aliasAnalysis);
149 }
150};
151} // namespace
152
153//===----------------------------------------------------------------------===//
154// Testing LocalAliasAnalysis extending
155//===----------------------------------------------------------------------===//
156
157/// Check if value is function argument.
158static bool isFuncArg(Value val) {
159 auto blockArg = dyn_cast<BlockArgument>(Val&: val);
160 if (!blockArg)
161 return false;
162
163 return mlir::isa_and_nonnull<FunctionOpInterface>(
164 Val: blockArg.getOwner()->getParentOp());
165}
166
167/// Check if value has "restrict" attribute. Value must be a function argument.
168static bool isRestrict(Value val) {
169 auto blockArg = cast<BlockArgument>(Val&: val);
170 auto func =
171 mlir::cast<FunctionOpInterface>(blockArg.getOwner()->getParentOp());
172 return !!func.getArgAttr(blockArg.getArgNumber(),
173 "local_alias_analysis.restrict");
174}
175
176namespace {
177/// LocalAliasAnalysis extended to support "restrict" attreibute.
178class LocalAliasAnalysisRestrict : public LocalAliasAnalysis {
179protected:
180 AliasResult aliasImpl(Value lhs, Value rhs) override {
181 if (lhs == rhs)
182 return AliasResult::MustAlias;
183
184 // Assume no aliasing if both values are function arguments and any of them
185 // have restrict attr.
186 if (isFuncArg(val: lhs) && isFuncArg(val: rhs))
187 if (isRestrict(val: lhs) || isRestrict(val: rhs))
188 return AliasResult::NoAlias;
189
190 return LocalAliasAnalysis::aliasImpl(lhs, rhs);
191 }
192};
193
194/// This pass tests adding additional analysis impls to the AliasAnalysis.
195struct TestAliasAnalysisExtendingPass
196 : public test::TestAliasAnalysisBase,
197 PassWrapper<TestAliasAnalysisExtendingPass, OperationPass<>> {
198 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisExtendingPass)
199
200 StringRef getArgument() const final {
201 return "test-alias-analysis-extending";
202 }
203 StringRef getDescription() const final {
204 return "Test alias analysis extending.";
205 }
206 void runOnOperation() override {
207 AliasAnalysis aliasAnalysis(getOperation());
208 aliasAnalysis.addAnalysisImplementation(analysis: LocalAliasAnalysisRestrict());
209 runAliasAnalysisOnOperation(op: getOperation(), aliasAnalysis);
210 }
211};
212} // namespace
213
214//===----------------------------------------------------------------------===//
215// Pass Registration
216//===----------------------------------------------------------------------===//
217
218namespace mlir {
219namespace test {
220void registerTestAliasAnalysisPass() {
221 PassRegistration<TestAliasAnalysisExtendingPass>();
222 PassRegistration<TestAliasAnalysisModRefPass>();
223 PassRegistration<TestAliasAnalysisPass>();
224}
225} // namespace test
226} // namespace mlir
227

source code of mlir/test/lib/Analysis/TestAliasAnalysis.cpp