| 1 | //===- VectorizerTestPass.cpp - VectorizerTestPass Pass Impl --------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements a simple testing pass for vectorization functionality. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Analysis/SliceAnalysis.h" |
| 14 | #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" |
| 15 | #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h" |
| 16 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 17 | #include "mlir/Dialect/Affine/LoopUtils.h" |
| 18 | #include "mlir/Dialect/Affine/Utils.h" |
| 19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 20 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
| 21 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| 22 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" |
| 23 | #include "mlir/IR/Builders.h" |
| 24 | #include "mlir/IR/BuiltinTypes.h" |
| 25 | #include "mlir/IR/Diagnostics.h" |
| 26 | #include "mlir/Pass/Pass.h" |
| 27 | #include "mlir/Transforms/Passes.h" |
| 28 | |
| 29 | #include "llvm/ADT/STLExtras.h" |
| 30 | #include "llvm/Support/CommandLine.h" |
| 31 | #include "llvm/Support/Debug.h" |
| 32 | |
| 33 | #define DEBUG_TYPE "affine-super-vectorizer-test" |
| 34 | |
| 35 | using namespace mlir; |
| 36 | using namespace mlir::affine; |
| 37 | |
| 38 | static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options" ); |
| 39 | |
| 40 | namespace { |
| 41 | struct VectorizerTestPass |
| 42 | : public PassWrapper<VectorizerTestPass, OperationPass<func::FuncOp>> { |
| 43 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(VectorizerTestPass) |
| 44 | |
| 45 | static constexpr auto kTestAffineMapOpName = "test_affine_map" ; |
| 46 | static constexpr auto kTestAffineMapAttrName = "affine_map" ; |
| 47 | void getDependentDialects(DialectRegistry ®istry) const override { |
| 48 | registry.insert<vector::VectorDialect>(); |
| 49 | } |
| 50 | StringRef getArgument() const final { return "affine-super-vectorizer-test" ; } |
| 51 | StringRef getDescription() const final { |
| 52 | return "Tests vectorizer standalone functionality." ; |
| 53 | } |
| 54 | |
| 55 | VectorizerTestPass() = default; |
| 56 | VectorizerTestPass(const VectorizerTestPass &pass) : PassWrapper(pass){}; |
| 57 | |
| 58 | ListOption<int> clTestVectorShapeRatio{ |
| 59 | *this, "vector-shape-ratio" , |
| 60 | llvm::cl::desc("Specify the HW vector size for vectorization" )}; |
| 61 | Option<bool> clTestForwardSlicingAnalysis{ |
| 62 | *this, "forward-slicing" , |
| 63 | llvm::cl::desc( |
| 64 | "Enable testing forward static slicing and topological sort " |
| 65 | "functionalities" )}; |
| 66 | Option<bool> clTestBackwardSlicingAnalysis{ |
| 67 | *this, "backward-slicing" , |
| 68 | llvm::cl::desc("Enable testing backward static slicing and " |
| 69 | "topological sort functionalities" )}; |
| 70 | Option<bool> clTestSlicingAnalysis{ |
| 71 | *this, "slicing" , |
| 72 | llvm::cl::desc("Enable testing static slicing and topological sort " |
| 73 | "functionalities" )}; |
| 74 | Option<bool> clTestComposeMaps{ |
| 75 | *this, "compose-maps" , |
| 76 | llvm::cl::desc("Enable testing the composition of AffineMap where each " |
| 77 | "AffineMap in the composition is specified as the " |
| 78 | "affine_map attribute " |
| 79 | "in a constant op." )}; |
| 80 | Option<bool> clTestVecAffineLoopNest{ |
| 81 | *this, "vectorize-affine-loop-nest" , |
| 82 | llvm::cl::desc( |
| 83 | "Enable testing for the 'vectorizeAffineLoopNest' utility by " |
| 84 | "vectorizing the outermost loops found" )}; |
| 85 | |
| 86 | void runOnOperation() override; |
| 87 | void testVectorShapeRatio(llvm::raw_ostream &outs); |
| 88 | void testForwardSlicing(llvm::raw_ostream &outs); |
| 89 | void testBackwardSlicing(llvm::raw_ostream &outs); |
| 90 | void testSlicing(llvm::raw_ostream &outs); |
| 91 | void testComposeMaps(llvm::raw_ostream &outs); |
| 92 | |
| 93 | /// Test for 'vectorizeAffineLoopNest' utility. |
| 94 | void testVecAffineLoopNest(llvm::raw_ostream &outs); |
| 95 | }; |
| 96 | |
| 97 | } // namespace |
| 98 | |
| 99 | void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) { |
| 100 | auto f = getOperation(); |
| 101 | using affine::matcher::Op; |
| 102 | SmallVector<int64_t, 8> shape(clTestVectorShapeRatio.begin(), |
| 103 | clTestVectorShapeRatio.end()); |
| 104 | auto subVectorType = VectorType::get(shape, Float32Type::get(f.getContext())); |
| 105 | // Only filter operations that operate on a strict super-vector and have one |
| 106 | // return. This makes testing easier. |
| 107 | auto filter = [&](Operation &op) { |
| 108 | assert(subVectorType.getElementType().isF32() && |
| 109 | "Only f32 supported for now" ); |
| 110 | if (!mlir::matcher::operatesOnSuperVectorsOf(op, subVectorType: subVectorType)) { |
| 111 | return false; |
| 112 | } |
| 113 | if (op.getNumResults() != 1) { |
| 114 | return false; |
| 115 | } |
| 116 | return true; |
| 117 | }; |
| 118 | auto pat = Op(filter); |
| 119 | SmallVector<NestedMatch, 8> matches; |
| 120 | pat.match(op: f, matches: &matches); |
| 121 | for (auto m : matches) { |
| 122 | auto *opInst = m.getMatchedOperation(); |
| 123 | // This is a unit test that only checks and prints shape ratio. |
| 124 | // As a consequence we write only Ops with a single return type for the |
| 125 | // purpose of this test. If we need to test more intricate behavior in the |
| 126 | // future we can always extend. |
| 127 | auto superVectorType = cast<VectorType>(opInst->getResult(idx: 0).getType()); |
| 128 | auto ratio = |
| 129 | computeShapeRatio(superVectorType.getShape(), subVectorType.getShape()); |
| 130 | if (!ratio) { |
| 131 | opInst->emitRemark(message: "NOT MATCHED" ); |
| 132 | } else { |
| 133 | outs << "\nmatched: " << *opInst << " with shape ratio: " ; |
| 134 | llvm::interleaveComma(c: MutableArrayRef<int64_t>(*ratio), os&: outs); |
| 135 | } |
| 136 | } |
| 137 | } |
| 138 | |
| 139 | static NestedPattern patternTestSlicingOps() { |
| 140 | using affine::matcher::Op; |
| 141 | // Match all operations with the kTestSlicingOpName name. |
| 142 | auto filter = [](Operation &op) { |
| 143 | // Just use a custom op name for this test, it makes life easier. |
| 144 | return op.getName().getStringRef() == "slicing-test-op" ; |
| 145 | }; |
| 146 | return Op(filter); |
| 147 | } |
| 148 | |
| 149 | void VectorizerTestPass::testBackwardSlicing(llvm::raw_ostream &outs) { |
| 150 | auto f = getOperation(); |
| 151 | outs << "\n" << f.getName(); |
| 152 | |
| 153 | SmallVector<NestedMatch, 8> matches; |
| 154 | patternTestSlicingOps().match(op: f, matches: &matches); |
| 155 | for (auto m : matches) { |
| 156 | SetVector<Operation *> backwardSlice; |
| 157 | LogicalResult result = |
| 158 | getBackwardSlice(op: m.getMatchedOperation(), backwardSlice: &backwardSlice); |
| 159 | assert(result.succeeded() && "expected a backward slice" ); |
| 160 | (void)result; |
| 161 | outs << "\nmatched: " << *m.getMatchedOperation() |
| 162 | << " backward static slice: " ; |
| 163 | for (auto *op : backwardSlice) |
| 164 | outs << "\n" << *op; |
| 165 | } |
| 166 | } |
| 167 | |
| 168 | void VectorizerTestPass::testForwardSlicing(llvm::raw_ostream &outs) { |
| 169 | auto f = getOperation(); |
| 170 | outs << "\n" << f.getName(); |
| 171 | |
| 172 | SmallVector<NestedMatch, 8> matches; |
| 173 | patternTestSlicingOps().match(op: f, matches: &matches); |
| 174 | for (auto m : matches) { |
| 175 | SetVector<Operation *> forwardSlice; |
| 176 | getForwardSlice(op: m.getMatchedOperation(), forwardSlice: &forwardSlice); |
| 177 | outs << "\nmatched: " << *m.getMatchedOperation() |
| 178 | << " forward static slice: " ; |
| 179 | for (auto *op : forwardSlice) |
| 180 | outs << "\n" << *op; |
| 181 | } |
| 182 | } |
| 183 | |
| 184 | void VectorizerTestPass::testSlicing(llvm::raw_ostream &outs) { |
| 185 | auto f = getOperation(); |
| 186 | outs << "\n" << f.getName(); |
| 187 | |
| 188 | SmallVector<NestedMatch, 8> matches; |
| 189 | patternTestSlicingOps().match(op: f, matches: &matches); |
| 190 | for (auto m : matches) { |
| 191 | SetVector<Operation *> staticSlice = getSlice(op: m.getMatchedOperation()); |
| 192 | outs << "\nmatched: " << *m.getMatchedOperation() << " static slice: " ; |
| 193 | for (auto *op : staticSlice) |
| 194 | outs << "\n" << *op; |
| 195 | } |
| 196 | } |
| 197 | |
| 198 | static bool customOpWithAffineMapAttribute(Operation &op) { |
| 199 | return op.getName().getStringRef() == |
| 200 | VectorizerTestPass::kTestAffineMapOpName; |
| 201 | } |
| 202 | |
| 203 | void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) { |
| 204 | auto f = getOperation(); |
| 205 | |
| 206 | using affine::matcher::Op; |
| 207 | auto pattern = Op(filter: customOpWithAffineMapAttribute); |
| 208 | SmallVector<NestedMatch, 8> matches; |
| 209 | pattern.match(op: f, matches: &matches); |
| 210 | SmallVector<AffineMap, 4> maps; |
| 211 | maps.reserve(N: matches.size()); |
| 212 | for (auto m : llvm::reverse(C&: matches)) { |
| 213 | auto *opInst = m.getMatchedOperation(); |
| 214 | auto map = |
| 215 | cast<AffineMapAttr>(opInst->getDiscardableAttr( |
| 216 | name: VectorizerTestPass::kTestAffineMapAttrName)) |
| 217 | .getValue(); |
| 218 | maps.push_back(Elt: map); |
| 219 | } |
| 220 | if (maps.empty()) |
| 221 | // Nothing to compose |
| 222 | return; |
| 223 | AffineMap res; |
| 224 | for (auto m : maps) { |
| 225 | res = res ? res.compose(map: m) : m; |
| 226 | } |
| 227 | simplifyAffineMap(map: res).print(os&: outs << "\nComposed map: " ); |
| 228 | } |
| 229 | |
| 230 | /// Test for 'vectorizeAffineLoopNest' utility. |
| 231 | void VectorizerTestPass::testVecAffineLoopNest(llvm::raw_ostream &outs) { |
| 232 | std::vector<SmallVector<AffineForOp, 2>> loops; |
| 233 | gatherLoops(getOperation(), loops); |
| 234 | |
| 235 | // Expected only one loop nest. |
| 236 | if (loops.empty() || loops[0].size() != 1) |
| 237 | return; |
| 238 | |
| 239 | // We vectorize the outermost loop found with VF=4. |
| 240 | AffineForOp outermostLoop = loops[0][0]; |
| 241 | VectorizationStrategy strategy; |
| 242 | strategy.vectorSizes.push_back(Elt: 4 /*vectorization factor*/); |
| 243 | strategy.loopToVectorDim[outermostLoop] = 0; |
| 244 | |
| 245 | SmallVector<LoopReduction, 2> reductions; |
| 246 | if (!isLoopParallel(outermostLoop, &reductions)) { |
| 247 | outs << "Outermost loop cannot be parallel\n" ; |
| 248 | return; |
| 249 | } |
| 250 | std::vector<SmallVector<AffineForOp, 2>> loopsToVectorize; |
| 251 | loopsToVectorize.push_back({outermostLoop}); |
| 252 | (void)vectorizeAffineLoopNest(loops&: loopsToVectorize, strategy); |
| 253 | } |
| 254 | |
| 255 | void VectorizerTestPass::runOnOperation() { |
| 256 | // Only support single block functions at this point. |
| 257 | func::FuncOp f = getOperation(); |
| 258 | if (!llvm::hasSingleElement(f)) |
| 259 | return; |
| 260 | |
| 261 | std::string str; |
| 262 | llvm::raw_string_ostream outs(str); |
| 263 | |
| 264 | { // Tests that expect a NestedPatternContext to be allocated externally. |
| 265 | NestedPatternContext mlContext; |
| 266 | |
| 267 | if (!clTestVectorShapeRatio.empty()) |
| 268 | testVectorShapeRatio(outs); |
| 269 | |
| 270 | if (clTestForwardSlicingAnalysis) |
| 271 | testForwardSlicing(outs); |
| 272 | |
| 273 | if (clTestBackwardSlicingAnalysis) |
| 274 | testBackwardSlicing(outs); |
| 275 | |
| 276 | if (clTestSlicingAnalysis) |
| 277 | testSlicing(outs); |
| 278 | |
| 279 | if (clTestComposeMaps) |
| 280 | testComposeMaps(outs); |
| 281 | } |
| 282 | |
| 283 | if (clTestVecAffineLoopNest) |
| 284 | testVecAffineLoopNest(outs); |
| 285 | |
| 286 | if (!outs.str().empty()) { |
| 287 | emitRemark(UnknownLoc::get(&getContext()), outs.str()); |
| 288 | } |
| 289 | } |
| 290 | |
| 291 | namespace mlir { |
| 292 | void registerVectorizerTestPass() { PassRegistration<VectorizerTestPass>(); } |
| 293 | } // namespace mlir |
| 294 | |