| 1 | //===- LoopAnalysis.cpp - Misc loop analysis routines //-------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements miscellaneous loop analysis routines. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" |
| 14 | |
| 15 | #include "mlir/Analysis/SliceAnalysis.h" |
| 16 | #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" |
| 17 | #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" |
| 18 | #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h" |
| 19 | #include "mlir/Dialect/Affine/Analysis/Utils.h" |
| 20 | #include "mlir/Dialect/Affine/IR/AffineValueMap.h" |
| 21 | #include "llvm/Support/MathExtras.h" |
| 22 | |
| 23 | #include "llvm/ADT/DenseSet.h" |
| 24 | #include "llvm/ADT/SmallPtrSet.h" |
| 25 | #include "llvm/ADT/SmallString.h" |
| 26 | #include "llvm/Support/Debug.h" |
| 27 | #include <numeric> |
| 28 | #include <optional> |
| 29 | #include <type_traits> |
| 30 | |
| 31 | #define DEBUG_TYPE "affine-loop-analysis" |
| 32 | |
| 33 | using namespace mlir; |
| 34 | using namespace mlir::affine; |
| 35 | |
| 36 | namespace { |
| 37 | |
| 38 | /// A directed graph to model relationships between MLIR Operations. |
| 39 | class DirectedOpGraph { |
| 40 | public: |
| 41 | /// Add a node to the graph. |
| 42 | void addNode(Operation *op) { |
| 43 | assert(!hasNode(op) && "node already added" ); |
| 44 | nodes.emplace_back(Args&: op); |
| 45 | edges[op] = {}; |
| 46 | } |
| 47 | |
| 48 | /// Add an edge from `src` to `dest`. |
| 49 | void addEdge(Operation *src, Operation *dest) { |
| 50 | // This is a multi-graph. |
| 51 | assert(hasNode(src) && "src node does not exist in graph" ); |
| 52 | assert(hasNode(dest) && "dest node does not exist in graph" ); |
| 53 | edges[src].push_back(Elt: getNode(op: dest)); |
| 54 | } |
| 55 | |
| 56 | /// Returns true if there is a (directed) cycle in the graph. |
| 57 | bool hasCycle() { return dfs(/*cycleCheck=*/true); } |
| 58 | |
| 59 | void printEdges() { |
| 60 | for (auto &en : edges) { |
| 61 | llvm::dbgs() << *en.first << " (" << en.first << ")" |
| 62 | << " has " << en.second.size() << " edges:\n" ; |
| 63 | for (auto *node : en.second) { |
| 64 | llvm::dbgs() << '\t' << *node->op << '\n'; |
| 65 | } |
| 66 | } |
| 67 | } |
| 68 | |
| 69 | private: |
| 70 | /// A node of a directed graph between MLIR Operations to model various |
| 71 | /// relationships. This is meant to be used internally. |
| 72 | struct DGNode { |
| 73 | DGNode(Operation *op) : op(op) {}; |
| 74 | Operation *op; |
| 75 | |
| 76 | // Start and finish visit numbers are standard in DFS to implement things |
| 77 | // like finding strongly connected components. These numbers are modified |
| 78 | // during analyses on the graph and so seemingly const API methods will be |
| 79 | // non-const. |
| 80 | |
| 81 | /// Start visit number. |
| 82 | int vn = -1; |
| 83 | |
| 84 | /// Finish visit number. |
| 85 | int fn = -1; |
| 86 | }; |
| 87 | |
| 88 | /// Get internal node corresponding to `op`. |
| 89 | DGNode *getNode(Operation *op) { |
| 90 | auto *value = |
| 91 | llvm::find_if(Range&: nodes, P: [&](const DGNode &node) { return node.op == op; }); |
| 92 | assert(value != nodes.end() && "node doesn't exist in graph" ); |
| 93 | return &*value; |
| 94 | } |
| 95 | |
| 96 | /// Returns true if `key` is in the graph. |
| 97 | bool hasNode(Operation *key) const { |
| 98 | return llvm::find_if(Range: nodes, P: [&](const DGNode &node) { |
| 99 | return node.op == key; |
| 100 | }) != nodes.end(); |
| 101 | } |
| 102 | |
| 103 | /// Perform a depth-first traversal of the graph setting visited and finished |
| 104 | /// numbers. If `cycleCheck` is set, detects cycles and returns true as soon |
| 105 | /// as the first cycle is detected, and false if there are no cycles. If |
| 106 | /// `cycleCheck` is not set, completes the DFS and the `return` value doesn't |
| 107 | /// have a meaning. |
| 108 | bool dfs(bool cycleCheck = false) { |
| 109 | for (DGNode &node : nodes) { |
| 110 | node.vn = 0; |
| 111 | node.fn = -1; |
| 112 | } |
| 113 | |
| 114 | unsigned time = 0; |
| 115 | for (DGNode &node : nodes) { |
| 116 | if (node.vn == 0) { |
| 117 | bool ret = dfsNode(node, cycleCheck, time); |
| 118 | // Check if a cycle was already found. |
| 119 | if (cycleCheck && ret) |
| 120 | return true; |
| 121 | } else if (cycleCheck && node.fn == -1) { |
| 122 | // We have encountered a node whose visit has started but it's not |
| 123 | // finished. So we have a cycle. |
| 124 | return true; |
| 125 | } |
| 126 | } |
| 127 | return false; |
| 128 | } |
| 129 | |
| 130 | /// Perform depth-first traversal starting at `node`. Return true |
| 131 | /// as soon as a cycle is found if `cycleCheck` was set. Update `time`. |
| 132 | bool dfsNode(DGNode &node, bool cycleCheck, unsigned &time) const { |
| 133 | auto nodeEdges = edges.find(Val: node.op); |
| 134 | assert(nodeEdges != edges.end() && "missing node in graph" ); |
| 135 | node.vn = ++time; |
| 136 | |
| 137 | for (auto &neighbour : nodeEdges->second) { |
| 138 | if (neighbour->vn == 0) { |
| 139 | bool ret = dfsNode(node&: *neighbour, cycleCheck, time); |
| 140 | if (cycleCheck && ret) |
| 141 | return true; |
| 142 | } else if (cycleCheck && neighbour->fn == -1) { |
| 143 | // We have encountered a node whose visit has started but it's not |
| 144 | // finished. So we have a cycle. |
| 145 | return true; |
| 146 | } |
| 147 | } |
| 148 | |
| 149 | // Update finish time. |
| 150 | node.fn = ++time; |
| 151 | |
| 152 | return false; |
| 153 | } |
| 154 | |
| 155 | // The list of nodes. The storage is owned by this class. |
| 156 | SmallVector<DGNode> nodes; |
| 157 | |
| 158 | // Edges as an adjacency list. |
| 159 | DenseMap<Operation *, SmallVector<DGNode *>> edges; |
| 160 | }; |
| 161 | |
| 162 | } // namespace |
| 163 | |
| 164 | /// Returns the trip count of the loop as an affine expression if the latter is |
| 165 | /// expressible as an affine expression, and nullptr otherwise. The trip count |
| 166 | /// expression is simplified before returning. This method only utilizes map |
| 167 | /// composition to construct lower and upper bounds before computing the trip |
| 168 | /// count expressions. |
| 169 | void mlir::affine::getTripCountMapAndOperands( |
| 170 | AffineForOp forOp, AffineMap *tripCountMap, |
| 171 | SmallVectorImpl<Value> *tripCountOperands) { |
| 172 | MLIRContext *context = forOp.getContext(); |
| 173 | int64_t step = forOp.getStepAsInt(); |
| 174 | int64_t loopSpan; |
| 175 | if (forOp.hasConstantBounds()) { |
| 176 | int64_t lb = forOp.getConstantLowerBound(); |
| 177 | int64_t ub = forOp.getConstantUpperBound(); |
| 178 | loopSpan = ub - lb; |
| 179 | if (loopSpan < 0) |
| 180 | loopSpan = 0; |
| 181 | *tripCountMap = AffineMap::getConstantMap( |
| 182 | val: llvm::divideCeilSigned(Numerator: loopSpan, Denominator: step), context); |
| 183 | tripCountOperands->clear(); |
| 184 | return; |
| 185 | } |
| 186 | auto lbMap = forOp.getLowerBoundMap(); |
| 187 | auto ubMap = forOp.getUpperBoundMap(); |
| 188 | if (lbMap.getNumResults() != 1) { |
| 189 | *tripCountMap = AffineMap(); |
| 190 | return; |
| 191 | } |
| 192 | |
| 193 | // Difference of each upper bound expression from the single lower bound |
| 194 | // expression (divided by the step) provides the expressions for the trip |
| 195 | // count map. |
| 196 | AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands()); |
| 197 | |
| 198 | SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(), |
| 199 | lbMap.getResult(0)); |
| 200 | auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(), |
| 201 | lbSplatExpr, context); |
| 202 | AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands()); |
| 203 | |
| 204 | AffineValueMap tripCountValueMap; |
| 205 | AffineValueMap::difference(a: ubValueMap, b: lbSplatValueMap, res: &tripCountValueMap); |
| 206 | for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i) |
| 207 | tripCountValueMap.setResult(i, |
| 208 | e: tripCountValueMap.getResult(i).ceilDiv(v: step)); |
| 209 | |
| 210 | *tripCountMap = tripCountValueMap.getAffineMap(); |
| 211 | tripCountOperands->assign(in_start: tripCountValueMap.getOperands().begin(), |
| 212 | in_end: tripCountValueMap.getOperands().end()); |
| 213 | } |
| 214 | |
| 215 | /// Returns the trip count of the loop if it's a constant, std::nullopt |
| 216 | /// otherwise. This method uses affine expression analysis (in turn using |
| 217 | /// getTripCount) and is able to determine constant trip count in non-trivial |
| 218 | /// cases. |
| 219 | std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) { |
| 220 | SmallVector<Value, 4> operands; |
| 221 | AffineMap map; |
| 222 | getTripCountMapAndOperands(forOp, &map, &operands); |
| 223 | |
| 224 | if (!map) |
| 225 | return std::nullopt; |
| 226 | |
| 227 | // Take the min if all trip counts are constant. |
| 228 | std::optional<uint64_t> tripCount; |
| 229 | for (auto resultExpr : map.getResults()) { |
| 230 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: resultExpr)) { |
| 231 | if (tripCount.has_value()) |
| 232 | tripCount = |
| 233 | std::min(a: *tripCount, b: static_cast<uint64_t>(constExpr.getValue())); |
| 234 | else |
| 235 | tripCount = constExpr.getValue(); |
| 236 | } else { |
| 237 | return std::nullopt; |
| 238 | } |
| 239 | } |
| 240 | return tripCount; |
| 241 | } |
| 242 | |
| 243 | /// Returns the greatest known integral divisor of the trip count. Affine |
| 244 | /// expression analysis is used (indirectly through getTripCount), and |
| 245 | /// this method is thus able to determine non-trivial divisors. |
| 246 | uint64_t mlir::affine::getLargestDivisorOfTripCount(AffineForOp forOp) { |
| 247 | SmallVector<Value, 4> operands; |
| 248 | AffineMap map; |
| 249 | getTripCountMapAndOperands(forOp, &map, &operands); |
| 250 | |
| 251 | if (!map) |
| 252 | return 1; |
| 253 | |
| 254 | // The largest divisor of the trip count is the GCD of the individual largest |
| 255 | // divisors. |
| 256 | assert(map.getNumResults() >= 1 && "expected one or more results" ); |
| 257 | std::optional<uint64_t> gcd; |
| 258 | for (auto resultExpr : map.getResults()) { |
| 259 | uint64_t thisGcd; |
| 260 | if (auto constExpr = dyn_cast<AffineConstantExpr>(Val&: resultExpr)) { |
| 261 | uint64_t tripCount = constExpr.getValue(); |
| 262 | // 0 iteration loops (greatest divisor is 2^64 - 1). |
| 263 | if (tripCount == 0) |
| 264 | thisGcd = std::numeric_limits<uint64_t>::max(); |
| 265 | else |
| 266 | // The greatest divisor is the trip count. |
| 267 | thisGcd = tripCount; |
| 268 | } else { |
| 269 | // Trip count is not a known constant; return its largest known divisor. |
| 270 | thisGcd = resultExpr.getLargestKnownDivisor(); |
| 271 | } |
| 272 | if (gcd.has_value()) |
| 273 | gcd = std::gcd(m: *gcd, n: thisGcd); |
| 274 | else |
| 275 | gcd = thisGcd; |
| 276 | } |
| 277 | assert(gcd.has_value() && "value expected per above logic" ); |
| 278 | return *gcd; |
| 279 | } |
| 280 | |
| 281 | /// Given an affine.for `iv` and an access `index` of type index, returns `true` |
| 282 | /// if `index` is independent of `iv` and false otherwise. |
| 283 | /// |
| 284 | /// Prerequisites: `iv` and `index` of the proper type; |
| 285 | static bool isAccessIndexInvariant(Value iv, Value index) { |
| 286 | assert(isAffineForInductionVar(iv) && "iv must be an affine.for iv" ); |
| 287 | assert(isa<IndexType>(index.getType()) && "index must be of 'index' type" ); |
| 288 | auto map = AffineMap::getMultiDimIdentityMap(/*numDims=*/1, context: iv.getContext()); |
| 289 | SmallVector<Value> operands = {index}; |
| 290 | AffineValueMap avm(map, operands); |
| 291 | avm.composeSimplifyAndCanonicalize(); |
| 292 | return !avm.isFunctionOf(idx: 0, value: iv); |
| 293 | } |
| 294 | |
| 295 | // Pre-requisite: Loop bounds should be in canonical form. |
| 296 | template <typename LoadOrStoreOp> |
| 297 | bool mlir::affine::isInvariantAccess(LoadOrStoreOp memOp, AffineForOp forOp) { |
| 298 | AffineValueMap avm(memOp.getAffineMap(), memOp.getMapOperands()); |
| 299 | avm.composeSimplifyAndCanonicalize(); |
| 300 | return !llvm::is_contained(avm.getOperands(), forOp.getInductionVar()); |
| 301 | } |
| 302 | |
| 303 | // Explicitly instantiate the template so that the compiler knows we need them. |
| 304 | template bool mlir::affine::isInvariantAccess(AffineReadOpInterface, |
| 305 | AffineForOp); |
| 306 | template bool mlir::affine::isInvariantAccess(AffineWriteOpInterface, |
| 307 | AffineForOp); |
| 308 | template bool mlir::affine::isInvariantAccess(AffineLoadOp, AffineForOp); |
| 309 | template bool mlir::affine::isInvariantAccess(AffineStoreOp, AffineForOp); |
| 310 | |
| 311 | DenseSet<Value> mlir::affine::getInvariantAccesses(Value iv, |
| 312 | ArrayRef<Value> indices) { |
| 313 | DenseSet<Value> res; |
| 314 | for (Value index : indices) { |
| 315 | if (isAccessIndexInvariant(iv, index)) |
| 316 | res.insert(V: index); |
| 317 | } |
| 318 | return res; |
| 319 | } |
| 320 | |
| 321 | // TODO: check access stride. |
| 322 | template <typename LoadOrStoreOp> |
| 323 | bool mlir::affine::isContiguousAccess(Value iv, LoadOrStoreOp memoryOp, |
| 324 | int *memRefDim) { |
| 325 | static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface, |
| 326 | AffineWriteOpInterface>::value, |
| 327 | "Must be called on either an affine read or write op" ); |
| 328 | assert(memRefDim && "memRefDim == nullptr" ); |
| 329 | auto memRefType = memoryOp.getMemRefType(); |
| 330 | |
| 331 | if (!memRefType.getLayout().isIdentity()) |
| 332 | return memoryOp.emitError("NYI: non-trivial layout map" ), false; |
| 333 | |
| 334 | int uniqueVaryingIndexAlongIv = -1; |
| 335 | auto accessMap = memoryOp.getAffineMap(); |
| 336 | SmallVector<Value, 4> mapOperands(memoryOp.getMapOperands()); |
| 337 | unsigned numDims = accessMap.getNumDims(); |
| 338 | for (unsigned i = 0, e = memRefType.getRank(); i < e; ++i) { |
| 339 | // Gather map operands used in result expr 'i' in 'exprOperands'. |
| 340 | SmallVector<Value, 4> exprOperands; |
| 341 | auto resultExpr = accessMap.getResult(i); |
| 342 | resultExpr.walk([&](AffineExpr expr) { |
| 343 | if (auto dimExpr = dyn_cast<AffineDimExpr>(Val&: expr)) |
| 344 | exprOperands.push_back(Elt: mapOperands[dimExpr.getPosition()]); |
| 345 | else if (auto symExpr = dyn_cast<AffineSymbolExpr>(Val&: expr)) |
| 346 | exprOperands.push_back(Elt: mapOperands[numDims + symExpr.getPosition()]); |
| 347 | }); |
| 348 | // Check access invariance of each operand in 'exprOperands'. |
| 349 | for (Value exprOperand : exprOperands) { |
| 350 | if (!isAccessIndexInvariant(iv, index: exprOperand)) { |
| 351 | if (uniqueVaryingIndexAlongIv != -1) { |
| 352 | // 2+ varying indices -> do not vectorize along iv. |
| 353 | return false; |
| 354 | } |
| 355 | uniqueVaryingIndexAlongIv = i; |
| 356 | } |
| 357 | } |
| 358 | } |
| 359 | |
| 360 | if (uniqueVaryingIndexAlongIv == -1) |
| 361 | *memRefDim = -1; |
| 362 | else |
| 363 | *memRefDim = memRefType.getRank() - (uniqueVaryingIndexAlongIv + 1); |
| 364 | return true; |
| 365 | } |
| 366 | |
| 367 | template bool mlir::affine::isContiguousAccess(Value iv, |
| 368 | AffineReadOpInterface loadOp, |
| 369 | int *memRefDim); |
| 370 | template bool mlir::affine::isContiguousAccess(Value iv, |
| 371 | AffineWriteOpInterface loadOp, |
| 372 | int *memRefDim); |
| 373 | |
| 374 | template <typename LoadOrStoreOp> |
| 375 | static bool isVectorElement(LoadOrStoreOp memoryOp) { |
| 376 | auto memRefType = memoryOp.getMemRefType(); |
| 377 | return isa<VectorType>(memRefType.getElementType()); |
| 378 | } |
| 379 | |
| 380 | using VectorizableOpFun = std::function<bool(AffineForOp, Operation &)>; |
| 381 | |
| 382 | static bool |
| 383 | isVectorizableLoopBodyWithOpCond(AffineForOp loop, |
| 384 | const VectorizableOpFun &isVectorizableOp, |
| 385 | NestedPattern &vectorTransferMatcher) { |
| 386 | auto *forOp = loop.getOperation(); |
| 387 | |
| 388 | // No vectorization across conditionals for now. |
| 389 | auto conditionals = matcher::If(); |
| 390 | SmallVector<NestedMatch, 8> conditionalsMatched; |
| 391 | conditionals.match(op: forOp, matches: &conditionalsMatched); |
| 392 | if (!conditionalsMatched.empty()) { |
| 393 | return false; |
| 394 | } |
| 395 | |
| 396 | // No vectorization for ops with operand or result types that are not |
| 397 | // vectorizable. |
| 398 | auto types = matcher::Op(filter: [](Operation &op) -> bool { |
| 399 | if (llvm::any_of(Range: op.getOperandTypes(), P: [](Type type) { |
| 400 | if (MemRefType t = dyn_cast<MemRefType>(type)) |
| 401 | return !VectorType::isValidElementType(t.getElementType()); |
| 402 | return !VectorType::isValidElementType(type); |
| 403 | })) |
| 404 | return true; |
| 405 | return !llvm::all_of(op.getResultTypes(), VectorType::isValidElementType); |
| 406 | }); |
| 407 | SmallVector<NestedMatch, 8> opsMatched; |
| 408 | types.match(op: forOp, matches: &opsMatched); |
| 409 | if (!opsMatched.empty()) { |
| 410 | return false; |
| 411 | } |
| 412 | |
| 413 | // No vectorization across unknown regions. |
| 414 | auto regions = matcher::Op(filter: [](Operation &op) -> bool { |
| 415 | return op.getNumRegions() != 0 && !isa<AffineIfOp, AffineForOp>(Val: op); |
| 416 | }); |
| 417 | SmallVector<NestedMatch, 8> regionsMatched; |
| 418 | regions.match(op: forOp, matches: ®ionsMatched); |
| 419 | if (!regionsMatched.empty()) { |
| 420 | return false; |
| 421 | } |
| 422 | |
| 423 | SmallVector<NestedMatch, 8> vectorTransfersMatched; |
| 424 | vectorTransferMatcher.match(op: forOp, matches: &vectorTransfersMatched); |
| 425 | if (!vectorTransfersMatched.empty()) { |
| 426 | return false; |
| 427 | } |
| 428 | |
| 429 | auto loadAndStores = matcher::Op(filter: matcher::isLoadOrStore); |
| 430 | SmallVector<NestedMatch, 8> loadAndStoresMatched; |
| 431 | loadAndStores.match(op: forOp, matches: &loadAndStoresMatched); |
| 432 | for (auto ls : loadAndStoresMatched) { |
| 433 | auto *op = ls.getMatchedOperation(); |
| 434 | auto load = dyn_cast<AffineLoadOp>(op); |
| 435 | auto store = dyn_cast<AffineStoreOp>(op); |
| 436 | // Only scalar types are considered vectorizable, all load/store must be |
| 437 | // vectorizable for a loop to qualify as vectorizable. |
| 438 | // TODO: ponder whether we want to be more general here. |
| 439 | bool vector = load ? isVectorElement(load) : isVectorElement(store); |
| 440 | if (vector) { |
| 441 | return false; |
| 442 | } |
| 443 | if (isVectorizableOp && !isVectorizableOp(loop, *op)) { |
| 444 | return false; |
| 445 | } |
| 446 | } |
| 447 | return true; |
| 448 | } |
| 449 | |
| 450 | bool mlir::affine::isVectorizableLoopBody( |
| 451 | AffineForOp loop, int *memRefDim, NestedPattern &vectorTransferMatcher) { |
| 452 | *memRefDim = -1; |
| 453 | VectorizableOpFun fun([memRefDim](AffineForOp loop, Operation &op) { |
| 454 | auto load = dyn_cast<AffineLoadOp>(op); |
| 455 | auto store = dyn_cast<AffineStoreOp>(op); |
| 456 | int thisOpMemRefDim = -1; |
| 457 | bool isContiguous = |
| 458 | load ? isContiguousAccess(loop.getInductionVar(), |
| 459 | cast<AffineReadOpInterface>(*load), |
| 460 | &thisOpMemRefDim) |
| 461 | : isContiguousAccess(loop.getInductionVar(), |
| 462 | cast<AffineWriteOpInterface>(*store), |
| 463 | &thisOpMemRefDim); |
| 464 | if (thisOpMemRefDim != -1) { |
| 465 | // If memory accesses vary across different dimensions then the loop is |
| 466 | // not vectorizable. |
| 467 | if (*memRefDim != -1 && *memRefDim != thisOpMemRefDim) |
| 468 | return false; |
| 469 | *memRefDim = thisOpMemRefDim; |
| 470 | } |
| 471 | return isContiguous; |
| 472 | }); |
| 473 | return isVectorizableLoopBodyWithOpCond(loop, fun, vectorTransferMatcher); |
| 474 | } |
| 475 | |
| 476 | bool mlir::affine::isVectorizableLoopBody( |
| 477 | AffineForOp loop, NestedPattern &vectorTransferMatcher) { |
| 478 | return isVectorizableLoopBodyWithOpCond(loop, nullptr, vectorTransferMatcher); |
| 479 | } |
| 480 | |
| 481 | /// Checks whether SSA dominance would be violated if a for op's body |
| 482 | /// operations are shifted by the specified shifts. This method checks if a |
| 483 | /// 'def' and all its uses have the same shift factor. |
| 484 | // TODO: extend this to check for memory-based dependence violation when we have |
| 485 | // the support. |
| 486 | bool mlir::affine::isOpwiseShiftValid(AffineForOp forOp, |
| 487 | ArrayRef<uint64_t> shifts) { |
| 488 | auto *forBody = forOp.getBody(); |
| 489 | assert(shifts.size() == forBody->getOperations().size()); |
| 490 | |
| 491 | // Work backwards over the body of the block so that the shift of a use's |
| 492 | // ancestor operation in the block gets recorded before it's looked up. |
| 493 | DenseMap<Operation *, uint64_t> forBodyShift; |
| 494 | for (const auto &it : |
| 495 | llvm::enumerate(llvm::reverse(forBody->getOperations()))) { |
| 496 | auto &op = it.value(); |
| 497 | |
| 498 | // Get the index of the current operation, note that we are iterating in |
| 499 | // reverse so we need to fix it up. |
| 500 | size_t index = shifts.size() - it.index() - 1; |
| 501 | |
| 502 | // Remember the shift of this operation. |
| 503 | uint64_t shift = shifts[index]; |
| 504 | forBodyShift.try_emplace(&op, shift); |
| 505 | |
| 506 | // Validate the results of this operation if it were to be shifted. |
| 507 | for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { |
| 508 | Value result = op.getResult(i); |
| 509 | for (auto *user : result.getUsers()) { |
| 510 | // If an ancestor operation doesn't lie in the block of forOp, |
| 511 | // there is no shift to check. |
| 512 | if (auto *ancOp = forBody->findAncestorOpInBlock(*user)) { |
| 513 | assert(forBodyShift.count(ancOp) > 0 && "ancestor expected in map" ); |
| 514 | if (shift != forBodyShift[ancOp]) |
| 515 | return false; |
| 516 | } |
| 517 | } |
| 518 | } |
| 519 | } |
| 520 | return true; |
| 521 | } |
| 522 | |
| 523 | bool mlir::affine::isTilingValid(ArrayRef<AffineForOp> loops) { |
| 524 | assert(!loops.empty() && "no original loops provided" ); |
| 525 | |
| 526 | // We first find out all dependences we intend to check. |
| 527 | SmallVector<Operation *, 8> loadAndStoreOps; |
| 528 | loops[0]->walk([&](Operation *op) { |
| 529 | if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) |
| 530 | loadAndStoreOps.push_back(op); |
| 531 | }); |
| 532 | |
| 533 | unsigned numOps = loadAndStoreOps.size(); |
| 534 | unsigned numLoops = loops.size(); |
| 535 | for (unsigned d = 1; d <= numLoops + 1; ++d) { |
| 536 | for (unsigned i = 0; i < numOps; ++i) { |
| 537 | Operation *srcOp = loadAndStoreOps[i]; |
| 538 | MemRefAccess srcAccess(srcOp); |
| 539 | for (unsigned j = 0; j < numOps; ++j) { |
| 540 | Operation *dstOp = loadAndStoreOps[j]; |
| 541 | MemRefAccess dstAccess(dstOp); |
| 542 | |
| 543 | SmallVector<DependenceComponent, 2> depComps; |
| 544 | DependenceResult result = checkMemrefAccessDependence( |
| 545 | srcAccess, dstAccess, loopDepth: d, /*dependenceConstraints=*/nullptr, |
| 546 | dependenceComponents: &depComps); |
| 547 | |
| 548 | // Skip if there is no dependence in this case. |
| 549 | if (!hasDependence(result)) |
| 550 | continue; |
| 551 | |
| 552 | // Check whether there is any negative direction vector in the |
| 553 | // dependence components found above, which means that dependence is |
| 554 | // violated by the default hyper-rect tiling method. |
| 555 | LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated " |
| 556 | "for dependence at depth: " |
| 557 | << Twine(d) << " between:\n" ;); |
| 558 | LLVM_DEBUG(srcAccess.opInst->dump()); |
| 559 | LLVM_DEBUG(dstAccess.opInst->dump()); |
| 560 | for (const DependenceComponent &depComp : depComps) { |
| 561 | if (depComp.lb.has_value() && depComp.ub.has_value() && |
| 562 | *depComp.lb < *depComp.ub && *depComp.ub < 0) { |
| 563 | LLVM_DEBUG(llvm::dbgs() |
| 564 | << "Dependence component lb = " << Twine(*depComp.lb) |
| 565 | << " ub = " << Twine(*depComp.ub) |
| 566 | << " is negative at depth: " << Twine(d) |
| 567 | << " and thus violates the legality rule.\n" ); |
| 568 | return false; |
| 569 | } |
| 570 | } |
| 571 | } |
| 572 | } |
| 573 | } |
| 574 | |
| 575 | return true; |
| 576 | } |
| 577 | |
| 578 | bool mlir::affine::hasCyclicDependence(AffineForOp root) { |
| 579 | // Collect all the memory accesses in the source nest grouped by their |
| 580 | // immediate parent block. |
| 581 | DirectedOpGraph graph; |
| 582 | SmallVector<MemRefAccess> accesses; |
| 583 | root->walk([&](Operation *op) { |
| 584 | if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) { |
| 585 | accesses.emplace_back(Args&: op); |
| 586 | graph.addNode(op); |
| 587 | } |
| 588 | }); |
| 589 | |
| 590 | // Construct the dependence graph for all the collected acccesses. |
| 591 | unsigned rootDepth = getNestingDepth(root); |
| 592 | for (const auto &accA : accesses) { |
| 593 | for (const auto &accB : accesses) { |
| 594 | if (accA.memref != accB.memref) |
| 595 | continue; |
| 596 | // Perform the dependence on all surrounding loops + the body. |
| 597 | unsigned numCommonLoops = |
| 598 | getNumCommonSurroundingLoops(a&: *accA.opInst, b&: *accB.opInst); |
| 599 | for (unsigned d = rootDepth + 1; d <= numCommonLoops + 1; ++d) { |
| 600 | if (!noDependence(result: checkMemrefAccessDependence(srcAccess: accA, dstAccess: accB, loopDepth: d))) |
| 601 | graph.addEdge(src: accA.opInst, dest: accB.opInst); |
| 602 | } |
| 603 | } |
| 604 | } |
| 605 | return graph.hasCycle(); |
| 606 | } |
| 607 | |