1 | //===- TestAliasAnalysis.cpp - Test alias analysis results ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains test passes for constructing and testing alias analysis |
10 | // results. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "TestAliasAnalysis.h" |
15 | #include "mlir/Analysis/AliasAnalysis.h" |
16 | #include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h" |
17 | #include "mlir/Interfaces/FunctionInterfaces.h" |
18 | #include "mlir/Pass/Pass.h" |
19 | |
20 | using namespace mlir; |
21 | |
22 | /// Print a value that is used as an operand of an alias query. |
23 | static void printAliasOperand(Operation *op) { |
24 | llvm::errs() << op->getAttrOfType<StringAttr>("test.ptr" ).getValue(); |
25 | } |
26 | static void printAliasOperand(Value value) { |
27 | if (BlockArgument arg = dyn_cast<BlockArgument>(Val&: value)) { |
28 | Region *region = arg.getParentRegion(); |
29 | unsigned parentBlockNumber = |
30 | std::distance(first: region->begin(), last: arg.getOwner()->getIterator()); |
31 | llvm::errs() << region->getParentOp() |
32 | ->getAttrOfType<StringAttr>("test.ptr" ) |
33 | .getValue() |
34 | << ".region" << region->getRegionNumber(); |
35 | if (parentBlockNumber != 0) |
36 | llvm::errs() << ".block" << parentBlockNumber; |
37 | llvm::errs() << "#" << arg.getArgNumber(); |
38 | return; |
39 | } |
40 | OpResult result = cast<OpResult>(Val&: value); |
41 | printAliasOperand(op: result.getOwner()); |
42 | llvm::errs() << "#" << result.getResultNumber(); |
43 | } |
44 | |
45 | namespace mlir { |
46 | namespace test { |
47 | void printAliasResult(AliasResult result, Value lhs, Value rhs) { |
48 | printAliasOperand(value: lhs); |
49 | llvm::errs() << " <-> " ; |
50 | printAliasOperand(value: rhs); |
51 | llvm::errs() << ": " << result << "\n" ; |
52 | } |
53 | |
54 | /// Print the result of an alias query. |
55 | void printModRefResult(ModRefResult result, Operation *op, Value location) { |
56 | printAliasOperand(op); |
57 | llvm::errs() << " -> " ; |
58 | printAliasOperand(value: location); |
59 | llvm::errs() << ": " << result << "\n" ; |
60 | } |
61 | |
62 | void TestAliasAnalysisBase::runAliasAnalysisOnOperation( |
63 | Operation *op, AliasAnalysis &aliasAnalysis) { |
64 | llvm::errs() << "Testing : " << *op->getInherentAttr(name: "sym_name" ) << "\n" ; |
65 | |
66 | // Collect all of the values to check for aliasing behavior. |
67 | SmallVector<Value, 32> valsToCheck; |
68 | op->walk(callback: [&](Operation *op) { |
69 | if (!op->getDiscardableAttr(name: "test.ptr" )) |
70 | return; |
71 | valsToCheck.append(in_start: op->result_begin(), in_end: op->result_end()); |
72 | for (Region ®ion : op->getRegions()) |
73 | for (Block &block : region) |
74 | valsToCheck.append(in_start: block.args_begin(), in_end: block.args_end()); |
75 | }); |
76 | |
77 | // Check for aliasing behavior between each of the values. |
78 | for (auto it = valsToCheck.begin(), e = valsToCheck.end(); it != e; ++it) |
79 | for (auto *innerIt = valsToCheck.begin(); innerIt != it; ++innerIt) |
80 | printAliasResult(result: aliasAnalysis.alias(lhs: *innerIt, rhs: *it), lhs: *innerIt, rhs: *it); |
81 | } |
82 | |
83 | void TestAliasAnalysisModRefBase::runAliasAnalysisOnOperation( |
84 | Operation *op, AliasAnalysis &aliasAnalysis) { |
85 | llvm::errs() << "Testing : " << *op->getInherentAttr(name: "sym_name" ) << "\n" ; |
86 | |
87 | // Collect all of the values to check for aliasing behavior. |
88 | SmallVector<Value, 32> valsToCheck; |
89 | op->walk(callback: [&](Operation *op) { |
90 | if (!op->getDiscardableAttr(name: "test.ptr" )) |
91 | return; |
92 | valsToCheck.append(in_start: op->result_begin(), in_end: op->result_end()); |
93 | for (Region ®ion : op->getRegions()) |
94 | for (Block &block : region) |
95 | valsToCheck.append(in_start: block.args_begin(), in_end: block.args_end()); |
96 | }); |
97 | |
98 | // Check for aliasing behavior between each of the values. |
99 | for (auto &it : valsToCheck) { |
100 | op->walk(callback: [&](Operation *op) { |
101 | if (!op->getDiscardableAttr(name: "test.ptr" )) |
102 | return; |
103 | printModRefResult(result: aliasAnalysis.getModRef(op, location: it), op, location: it); |
104 | }); |
105 | } |
106 | } |
107 | |
108 | } // namespace test |
109 | } // namespace mlir |
110 | |
111 | //===----------------------------------------------------------------------===// |
112 | // Testing AliasResult |
113 | //===----------------------------------------------------------------------===// |
114 | |
115 | namespace { |
116 | struct TestAliasAnalysisPass |
117 | : public test::TestAliasAnalysisBase, |
118 | PassWrapper<TestAliasAnalysisPass, OperationPass<>> { |
119 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisPass) |
120 | |
121 | StringRef getArgument() const final { return "test-alias-analysis" ; } |
122 | StringRef getDescription() const final { |
123 | return "Test alias analysis results." ; |
124 | } |
125 | void runOnOperation() override { |
126 | AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>(); |
127 | runAliasAnalysisOnOperation(op: getOperation(), aliasAnalysis); |
128 | } |
129 | }; |
130 | } // namespace |
131 | |
132 | //===----------------------------------------------------------------------===// |
133 | // Testing ModRefResult |
134 | //===----------------------------------------------------------------------===// |
135 | |
136 | namespace { |
137 | struct TestAliasAnalysisModRefPass |
138 | : public test::TestAliasAnalysisModRefBase, |
139 | PassWrapper<TestAliasAnalysisModRefPass, OperationPass<>> { |
140 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisModRefPass) |
141 | |
142 | StringRef getArgument() const final { return "test-alias-analysis-modref" ; } |
143 | StringRef getDescription() const final { |
144 | return "Test alias analysis ModRef results." ; |
145 | } |
146 | void runOnOperation() override { |
147 | AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>(); |
148 | runAliasAnalysisOnOperation(op: getOperation(), aliasAnalysis); |
149 | } |
150 | }; |
151 | } // namespace |
152 | |
153 | //===----------------------------------------------------------------------===// |
154 | // Testing LocalAliasAnalysis extending |
155 | //===----------------------------------------------------------------------===// |
156 | |
157 | /// Check if value is function argument. |
158 | static bool isFuncArg(Value val) { |
159 | auto blockArg = dyn_cast<BlockArgument>(Val&: val); |
160 | if (!blockArg) |
161 | return false; |
162 | |
163 | return mlir::isa_and_nonnull<FunctionOpInterface>( |
164 | Val: blockArg.getOwner()->getParentOp()); |
165 | } |
166 | |
167 | /// Check if value has "restrict" attribute. Value must be a function argument. |
168 | static bool isRestrict(Value val) { |
169 | auto blockArg = cast<BlockArgument>(Val&: val); |
170 | auto func = |
171 | mlir::cast<FunctionOpInterface>(blockArg.getOwner()->getParentOp()); |
172 | return !!func.getArgAttr(blockArg.getArgNumber(), |
173 | "local_alias_analysis.restrict" ); |
174 | } |
175 | |
176 | namespace { |
177 | /// LocalAliasAnalysis extended to support "restrict" attreibute. |
178 | class LocalAliasAnalysisRestrict : public LocalAliasAnalysis { |
179 | protected: |
180 | AliasResult aliasImpl(Value lhs, Value rhs) override { |
181 | if (lhs == rhs) |
182 | return AliasResult::MustAlias; |
183 | |
184 | // Assume no aliasing if both values are function arguments and any of them |
185 | // have restrict attr. |
186 | if (isFuncArg(val: lhs) && isFuncArg(val: rhs)) |
187 | if (isRestrict(val: lhs) || isRestrict(val: rhs)) |
188 | return AliasResult::NoAlias; |
189 | |
190 | return LocalAliasAnalysis::aliasImpl(lhs, rhs); |
191 | } |
192 | }; |
193 | |
194 | /// This pass tests adding additional analysis impls to the AliasAnalysis. |
195 | struct TestAliasAnalysisExtendingPass |
196 | : public test::TestAliasAnalysisBase, |
197 | PassWrapper<TestAliasAnalysisExtendingPass, OperationPass<>> { |
198 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasAnalysisExtendingPass) |
199 | |
200 | StringRef getArgument() const final { |
201 | return "test-alias-analysis-extending" ; |
202 | } |
203 | StringRef getDescription() const final { |
204 | return "Test alias analysis extending." ; |
205 | } |
206 | void runOnOperation() override { |
207 | AliasAnalysis aliasAnalysis(getOperation()); |
208 | aliasAnalysis.addAnalysisImplementation(analysis: LocalAliasAnalysisRestrict()); |
209 | runAliasAnalysisOnOperation(op: getOperation(), aliasAnalysis); |
210 | } |
211 | }; |
212 | } // namespace |
213 | |
214 | //===----------------------------------------------------------------------===// |
215 | // Pass Registration |
216 | //===----------------------------------------------------------------------===// |
217 | |
218 | namespace mlir { |
219 | namespace test { |
220 | void registerTestAliasAnalysisPass() { |
221 | PassRegistration<TestAliasAnalysisExtendingPass>(); |
222 | PassRegistration<TestAliasAnalysisModRefPass>(); |
223 | PassRegistration<TestAliasAnalysisPass>(); |
224 | } |
225 | } // namespace test |
226 | } // namespace mlir |
227 | |