1//===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous analysis routines for non-loop IR
10// structures.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/Analysis/Utils.h"
15
16#include "mlir/Analysis/Presburger/PresburgerRelation.h"
17#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
18#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
21#include "mlir/Dialect/Utils/StaticValueUtils.h"
22#include "mlir/IR/IntegerSet.h"
23#include "llvm/ADT/SetVector.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/raw_ostream.h"
26#include <optional>
27
28#define DEBUG_TYPE "analysis-utils"
29
30using namespace mlir;
31using namespace affine;
32using namespace presburger;
33
34using llvm::SmallDenseMap;
35
36using Node = MemRefDependenceGraph::Node;
37
38// LoopNestStateCollector walks loop nests and collects load and store
39// operations, and whether or not a region holding op other than ForOp and IfOp
40// was encountered in the loop nest.
41void LoopNestStateCollector::collect(Operation *opToWalk) {
42 opToWalk->walk(callback: [&](Operation *op) {
43 if (auto forOp = dyn_cast<AffineForOp>(Val: op)) {
44 forOps.push_back(Elt: forOp);
45 } else if (isa<AffineReadOpInterface>(Val: op)) {
46 loadOpInsts.push_back(Elt: op);
47 } else if (isa<AffineWriteOpInterface>(Val: op)) {
48 storeOpInsts.push_back(Elt: op);
49 } else {
50 auto memInterface = dyn_cast<MemoryEffectOpInterface>(Val: op);
51 if (!memInterface) {
52 if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
53 // This op itself is memory-effect free.
54 return;
55 // Check operands. Eg. ops like the `call` op are handled here.
56 for (Value v : op->getOperands()) {
57 if (!isa<MemRefType>(Val: v.getType()))
58 continue;
59 // Conservatively, we assume the memref is read and written to.
60 memrefLoads.push_back(Elt: op);
61 memrefStores.push_back(Elt: op);
62 }
63 } else {
64 // Non-affine loads and stores.
65 if (hasEffect<MemoryEffects::Read>(op))
66 memrefLoads.push_back(Elt: op);
67 if (hasEffect<MemoryEffects::Write>(op))
68 memrefStores.push_back(Elt: op);
69 if (hasEffect<MemoryEffects::Free>(op))
70 memrefFrees.push_back(Elt: op);
71 }
72 }
73 });
74}
75
76unsigned Node::getLoadOpCount(Value memref) const {
77 unsigned loadOpCount = 0;
78 for (Operation *loadOp : loads) {
79 // Common case: affine reads.
80 if (auto affineLoad = dyn_cast<AffineReadOpInterface>(Val: loadOp)) {
81 if (memref == affineLoad.getMemRef())
82 ++loadOpCount;
83 } else if (hasEffect<MemoryEffects::Read>(op: loadOp, value: memref)) {
84 ++loadOpCount;
85 }
86 }
87 return loadOpCount;
88}
89
90// Returns the store op count for 'memref'.
91unsigned Node::getStoreOpCount(Value memref) const {
92 unsigned storeOpCount = 0;
93 for (auto *storeOp : llvm::concat<Operation *const>(Ranges: stores, Ranges: memrefStores)) {
94 // Common case: affine writes.
95 if (auto affineStore = dyn_cast<AffineWriteOpInterface>(Val: storeOp)) {
96 if (memref == affineStore.getMemRef())
97 ++storeOpCount;
98 } else if (hasEffect<MemoryEffects::Write>(op: const_cast<Operation *>(storeOp),
99 value: memref)) {
100 ++storeOpCount;
101 }
102 }
103 return storeOpCount;
104}
105
106// Returns the store op count for 'memref'.
107unsigned Node::hasStore(Value memref) const {
108 return llvm::any_of(
109 Range: llvm::concat<Operation *const>(Ranges: stores, Ranges: memrefStores),
110 P: [&](Operation *storeOp) {
111 if (auto affineStore = dyn_cast<AffineWriteOpInterface>(Val: storeOp)) {
112 if (memref == affineStore.getMemRef())
113 return true;
114 } else if (hasEffect<MemoryEffects::Write>(op: storeOp, value: memref)) {
115 return true;
116 }
117 return false;
118 });
119}
120
121unsigned Node::hasFree(Value memref) const {
122 return llvm::any_of(Range: memrefFrees, P: [&](Operation *freeOp) {
123 return hasEffect<MemoryEffects::Free>(op: freeOp, value: memref);
124 });
125}
126
127// Returns all store ops in 'storeOps' which access 'memref'.
128void Node::getStoreOpsForMemref(Value memref,
129 SmallVectorImpl<Operation *> *storeOps) const {
130 for (Operation *storeOp : stores) {
131 if (memref == cast<AffineWriteOpInterface>(Val: storeOp).getMemRef())
132 storeOps->push_back(Elt: storeOp);
133 }
134}
135
136// Returns all load ops in 'loadOps' which access 'memref'.
137void Node::getLoadOpsForMemref(Value memref,
138 SmallVectorImpl<Operation *> *loadOps) const {
139 for (Operation *loadOp : loads) {
140 if (memref == cast<AffineReadOpInterface>(Val: loadOp).getMemRef())
141 loadOps->push_back(Elt: loadOp);
142 }
143}
144
145// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
146// has at least one load and store operation.
147void Node::getLoadAndStoreMemrefSet(
148 DenseSet<Value> *loadAndStoreMemrefSet) const {
149 llvm::SmallDenseSet<Value, 2> loadMemrefs;
150 for (Operation *loadOp : loads) {
151 loadMemrefs.insert(V: cast<AffineReadOpInterface>(Val: loadOp).getMemRef());
152 }
153 for (Operation *storeOp : stores) {
154 auto memref = cast<AffineWriteOpInterface>(Val: storeOp).getMemRef();
155 if (loadMemrefs.count(V: memref) > 0)
156 loadAndStoreMemrefSet->insert(V: memref);
157 }
158}
159
160/// Returns the values that this op has a memref effect of type `EffectTys` on,
161/// not considering recursive effects.
162template <typename... EffectTys>
163static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
164 auto memOp = dyn_cast<MemoryEffectOpInterface>(Val: op);
165 if (!memOp) {
166 if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
167 // No effects.
168 return;
169 // Memref operands have to be considered as being affected.
170 for (Value operand : op->getOperands()) {
171 if (isa<MemRefType>(Val: operand.getType()))
172 values.push_back(Elt: operand);
173 }
174 return;
175 }
176 SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
177 memOp.getEffects(effects);
178 for (auto &effect : effects) {
179 Value effectVal = effect.getValue();
180 if (isa<EffectTys...>(effect.getEffect()) && effectVal &&
181 isa<MemRefType>(Val: effectVal.getType()))
182 values.push_back(Elt: effectVal);
183 };
184}
185
186/// Add `op` to MDG creating a new node and adding its memory accesses (affine
187/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
188static Node *
189addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
190 DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
191 auto &nodes = mdg.nodes;
192 // Create graph node 'id' to represent top-level 'forOp' and record
193 // all loads and store accesses it contains.
194 LoopNestStateCollector collector;
195 collector.collect(opToWalk: nodeOp);
196 unsigned newNodeId = mdg.nextNodeId++;
197 Node &node = nodes.insert(KV: {newNodeId, Node(newNodeId, nodeOp)}).first->second;
198 for (Operation *op : collector.loadOpInsts) {
199 node.loads.push_back(Elt: op);
200 auto memref = cast<AffineReadOpInterface>(Val: op).getMemRef();
201 memrefAccesses[memref].insert(X: node.id);
202 }
203 for (Operation *op : collector.storeOpInsts) {
204 node.stores.push_back(Elt: op);
205 auto memref = cast<AffineWriteOpInterface>(Val: op).getMemRef();
206 memrefAccesses[memref].insert(X: node.id);
207 }
208 for (Operation *op : collector.memrefLoads) {
209 SmallVector<Value> effectedValues;
210 getEffectedValues<MemoryEffects::Read>(op, values&: effectedValues);
211 if (llvm::any_of(Range: ((ValueRange)effectedValues).getTypes(),
212 P: [](Type type) { return !isa<MemRefType>(Val: type); }))
213 // We do not know the interaction here.
214 return nullptr;
215 for (Value memref : effectedValues)
216 memrefAccesses[memref].insert(X: node.id);
217 node.memrefLoads.push_back(Elt: op);
218 }
219 for (Operation *op : collector.memrefStores) {
220 SmallVector<Value> effectedValues;
221 getEffectedValues<MemoryEffects::Write>(op, values&: effectedValues);
222 if (llvm::any_of(Range: (ValueRange(effectedValues)).getTypes(),
223 P: [](Type type) { return !isa<MemRefType>(Val: type); }))
224 return nullptr;
225 for (Value memref : effectedValues)
226 memrefAccesses[memref].insert(X: node.id);
227 node.memrefStores.push_back(Elt: op);
228 }
229 for (Operation *op : collector.memrefFrees) {
230 SmallVector<Value> effectedValues;
231 getEffectedValues<MemoryEffects::Free>(op, values&: effectedValues);
232 if (llvm::any_of(Range: (ValueRange(effectedValues)).getTypes(),
233 P: [](Type type) { return !isa<MemRefType>(Val: type); }))
234 return nullptr;
235 for (Value memref : effectedValues)
236 memrefAccesses[memref].insert(X: node.id);
237 node.memrefFrees.push_back(Elt: op);
238 }
239
240 return &node;
241}
242
243bool MemRefDependenceGraph::init() {
244 LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
245 // Map from a memref to the set of ids of the nodes that have ops accessing
246 // the memref.
247 DenseMap<Value, SetVector<unsigned>> memrefAccesses;
248
249 // Create graph nodes.
250 DenseMap<Operation *, unsigned> forToNodeMap;
251 for (Operation &op : block) {
252 if (auto forOp = dyn_cast<AffineForOp>(Val&: op)) {
253 Node *node = addNodeToMDG(nodeOp: &op, mdg&: *this, memrefAccesses);
254 if (!node)
255 return false;
256 forToNodeMap[&op] = node->id;
257 } else if (isa<AffineReadOpInterface>(Val: op)) {
258 // Create graph node for top-level load op.
259 Node node(nextNodeId++, &op);
260 node.loads.push_back(Elt: &op);
261 auto memref = cast<AffineReadOpInterface>(Val&: op).getMemRef();
262 memrefAccesses[memref].insert(X: node.id);
263 nodes.insert(KV: {node.id, node});
264 } else if (isa<AffineWriteOpInterface>(Val: op)) {
265 // Create graph node for top-level store op.
266 Node node(nextNodeId++, &op);
267 node.stores.push_back(Elt: &op);
268 auto memref = cast<AffineWriteOpInterface>(Val&: op).getMemRef();
269 memrefAccesses[memref].insert(X: node.id);
270 nodes.insert(KV: {node.id, node});
271 } else if (op.getNumResults() > 0 && !op.use_empty()) {
272 // Create graph node for top-level producer of SSA values, which
273 // could be used by loop nest nodes.
274 Node *node = addNodeToMDG(nodeOp: &op, mdg&: *this, memrefAccesses);
275 if (!node)
276 return false;
277 } else if (!isMemoryEffectFree(op: &op) &&
278 (op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(Val: op))) {
279 // Create graph node for top-level op unless it is known to be
280 // memory-effect free. This covers all unknown/unregistered ops,
281 // non-affine ops with memory effects, and region-holding ops with a
282 // well-defined control flow. During the fusion validity checks, edges
283 // to/from these ops get looked at.
284 Node *node = addNodeToMDG(nodeOp: &op, mdg&: *this, memrefAccesses);
285 if (!node)
286 return false;
287 } else if (op.getNumRegions() != 0 && !isa<RegionBranchOpInterface>(Val: op)) {
288 // Return false if non-handled/unknown region-holding ops are found. We
289 // won't know what such ops do or what its regions mean; for e.g., it may
290 // not be an imperative op.
291 LLVM_DEBUG(llvm::dbgs()
292 << "MDG init failed; unknown region-holding op found!\n");
293 return false;
294 }
295 // We aren't creating nodes for memory-effect free ops either with no
296 // regions (unless it has results being used) or those with branch op
297 // interface.
298 }
299
300 LLVM_DEBUG(llvm::dbgs() << "Created " << nodes.size() << " nodes\n");
301
302 // Add dependence edges between nodes which produce SSA values and their
303 // users. Load ops can be considered as the ones producing SSA values.
304 for (auto &idAndNode : nodes) {
305 const Node &node = idAndNode.second;
306 // Stores don't define SSA values, skip them.
307 if (!node.stores.empty())
308 continue;
309 Operation *opInst = node.op;
310 for (Value value : opInst->getResults()) {
311 for (Operation *user : value.getUsers()) {
312 // Ignore users outside of the block.
313 if (block.getParent()->findAncestorOpInRegion(op&: *user)->getBlock() !=
314 &block)
315 continue;
316 SmallVector<AffineForOp, 4> loops;
317 getAffineForIVs(op&: *user, loops: &loops);
318 // Find the surrounding affine.for nested immediately within the
319 // block.
320 auto *it = llvm::find_if(Range&: loops, P: [&](AffineForOp loop) {
321 return loop->getBlock() == &block;
322 });
323 if (it == loops.end())
324 continue;
325 assert(forToNodeMap.count(*it) > 0 && "missing mapping");
326 unsigned userLoopNestId = forToNodeMap[*it];
327 addEdge(srcId: node.id, dstId: userLoopNestId, value);
328 }
329 }
330 }
331
332 // Walk memref access lists and add graph edges between dependent nodes.
333 for (auto &memrefAndList : memrefAccesses) {
334 unsigned n = memrefAndList.second.size();
335 Value srcMemRef = memrefAndList.first;
336 // Add edges between all dependent pairs among the node IDs on this memref.
337 for (unsigned i = 0; i < n; ++i) {
338 unsigned srcId = memrefAndList.second[i];
339 Node *srcNode = getNode(id: srcId);
340 bool srcHasStoreOrFree =
341 srcNode->hasStore(memref: srcMemRef) || srcNode->hasFree(memref: srcMemRef);
342 for (unsigned j = i + 1; j < n; ++j) {
343 unsigned dstId = memrefAndList.second[j];
344 Node *dstNode = getNode(id: dstId);
345 bool dstHasStoreOrFree =
346 dstNode->hasStore(memref: srcMemRef) || dstNode->hasFree(memref: srcMemRef);
347 if (srcHasStoreOrFree || dstHasStoreOrFree)
348 addEdge(srcId, dstId, value: srcMemRef);
349 }
350 }
351 }
352 return true;
353}
354
355// Returns the graph node for 'id'.
356const Node *MemRefDependenceGraph::getNode(unsigned id) const {
357 auto it = nodes.find(Val: id);
358 assert(it != nodes.end());
359 return &it->second;
360}
361
362// Returns the graph node for 'forOp'.
363const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
364 for (auto &idAndNode : nodes)
365 if (idAndNode.second.op == forOp)
366 return &idAndNode.second;
367 return nullptr;
368}
369
370// Adds a node with 'op' to the graph and returns its unique identifier.
371unsigned MemRefDependenceGraph::addNode(Operation *op) {
372 Node node(nextNodeId++, op);
373 nodes.insert(KV: {node.id, node});
374 return node.id;
375}
376
377// Remove node 'id' (and its associated edges) from graph.
378void MemRefDependenceGraph::removeNode(unsigned id) {
379 // Remove each edge in 'inEdges[id]'.
380 if (inEdges.count(Val: id) > 0) {
381 SmallVector<Edge, 2> oldInEdges = inEdges[id];
382 for (auto &inEdge : oldInEdges) {
383 removeEdge(srcId: inEdge.id, dstId: id, value: inEdge.value);
384 }
385 }
386 // Remove each edge in 'outEdges[id]'.
387 if (outEdges.contains(Val: id)) {
388 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
389 for (auto &outEdge : oldOutEdges) {
390 removeEdge(srcId: id, dstId: outEdge.id, value: outEdge.value);
391 }
392 }
393 // Erase remaining node state.
394 inEdges.erase(Val: id);
395 outEdges.erase(Val: id);
396 nodes.erase(Val: id);
397}
398
399// Returns true if node 'id' writes to any memref which escapes (or is an
400// argument to) the block. Returns false otherwise.
401bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
402 const Node *node = getNode(id);
403 for (auto *storeOpInst : node->stores) {
404 auto memref = cast<AffineWriteOpInterface>(Val: storeOpInst).getMemRef();
405 auto *op = memref.getDefiningOp();
406 // Return true if 'memref' is a block argument.
407 if (!op)
408 return true;
409 // Return true if any use of 'memref' does not deference it in an affine
410 // way.
411 for (auto *user : memref.getUsers())
412 if (!isa<AffineMapAccessInterface>(Val: *user))
413 return true;
414 }
415 return false;
416}
417
418// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
419// is for 'value' if non-null, or for any value otherwise. Returns false
420// otherwise.
421bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
422 Value value) const {
423 if (!outEdges.contains(Val: srcId) || !inEdges.contains(Val: dstId)) {
424 return false;
425 }
426 bool hasOutEdge = llvm::any_of(Range: outEdges.lookup(Val: srcId), P: [=](const Edge &edge) {
427 return edge.id == dstId && (!value || edge.value == value);
428 });
429 bool hasInEdge = llvm::any_of(Range: inEdges.lookup(Val: dstId), P: [=](const Edge &edge) {
430 return edge.id == srcId && (!value || edge.value == value);
431 });
432 return hasOutEdge && hasInEdge;
433}
434
435// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
436void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
437 Value value) {
438 if (!hasEdge(srcId, dstId, value)) {
439 outEdges[srcId].push_back(Elt: {.id: dstId, .value: value});
440 inEdges[dstId].push_back(Elt: {.id: srcId, .value: value});
441 if (isa<MemRefType>(Val: value.getType()))
442 memrefEdgeCount[value]++;
443 }
444}
445
446// Removes an edge from node 'srcId' to node 'dstId' for 'value'.
447void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
448 Value value) {
449 assert(inEdges.count(dstId) > 0);
450 assert(outEdges.count(srcId) > 0);
451 if (isa<MemRefType>(Val: value.getType())) {
452 assert(memrefEdgeCount.count(value) > 0);
453 memrefEdgeCount[value]--;
454 }
455 // Remove 'srcId' from 'inEdges[dstId]'.
456 for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
457 if ((*it).id == srcId && (*it).value == value) {
458 inEdges[dstId].erase(CI: it);
459 break;
460 }
461 }
462 // Remove 'dstId' from 'outEdges[srcId]'.
463 for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
464 if ((*it).id == dstId && (*it).value == value) {
465 outEdges[srcId].erase(CI: it);
466 break;
467 }
468 }
469}
470
471// Returns true if there is a path in the dependence graph from node 'srcId'
472// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
473// operations that the edges connected are expected to be from the same block.
474bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
475 unsigned dstId) const {
476 // Worklist state is: <node-id, next-output-edge-index-to-visit>
477 SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
478 worklist.push_back(Elt: {srcId, 0});
479 Operation *dstOp = getNode(id: dstId)->op;
480 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
481 while (!worklist.empty()) {
482 auto &idAndIndex = worklist.back();
483 // Return true if we have reached 'dstId'.
484 if (idAndIndex.first == dstId)
485 return true;
486 // Pop and continue if node has no out edges, or if all out edges have
487 // already been visited.
488 if (!outEdges.contains(Val: idAndIndex.first) ||
489 idAndIndex.second == outEdges.lookup(Val: idAndIndex.first).size()) {
490 worklist.pop_back();
491 continue;
492 }
493 // Get graph edge to traverse.
494 const Edge edge = outEdges.lookup(Val: idAndIndex.first)[idAndIndex.second];
495 // Increment next output edge index for 'idAndIndex'.
496 ++idAndIndex.second;
497 // Add node at 'edge.id' to the worklist. We don't need to consider
498 // nodes that are "after" dstId in the containing block; one can't have a
499 // path to `dstId` from any of those nodes.
500 bool afterDst = dstOp->isBeforeInBlock(other: getNode(id: edge.id)->op);
501 if (!afterDst && edge.id != idAndIndex.first)
502 worklist.push_back(Elt: {edge.id, 0});
503 }
504 return false;
505}
506
507// Returns the input edge count for node 'id' and 'memref' from src nodes
508// which access 'memref' with a store operation.
509unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
510 Value memref) const {
511 unsigned inEdgeCount = 0;
512 for (const Edge &inEdge : inEdges.lookup(Val: id)) {
513 if (inEdge.value == memref) {
514 const Node *srcNode = getNode(id: inEdge.id);
515 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
516 if (srcNode->getStoreOpCount(memref) > 0)
517 ++inEdgeCount;
518 }
519 }
520 return inEdgeCount;
521}
522
523// Returns the output edge count for node 'id' and 'memref' (if non-null),
524// otherwise returns the total output edge count from node 'id'.
525unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
526 Value memref) const {
527 unsigned outEdgeCount = 0;
528 for (const auto &outEdge : outEdges.lookup(Val: id))
529 if (!memref || outEdge.value == memref)
530 ++outEdgeCount;
531 return outEdgeCount;
532}
533
534/// Return all nodes which define SSA values used in node 'id'.
535void MemRefDependenceGraph::gatherDefiningNodes(
536 unsigned id, DenseSet<unsigned> &definingNodes) const {
537 for (const Edge &edge : inEdges.lookup(Val: id))
538 // By definition of edge, if the edge value is a non-memref value,
539 // then the dependence is between a graph node which defines an SSA value
540 // and another graph node which uses the SSA value.
541 if (!isa<MemRefType>(Val: edge.value.getType()))
542 definingNodes.insert(V: edge.id);
543}
544
545// Computes and returns an insertion point operation, before which the
546// the fused <srcId, dstId> loop nest can be inserted while preserving
547// dependences. Returns nullptr if no such insertion point is found.
548Operation *
549MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
550 unsigned dstId) const {
551 if (!outEdges.contains(Val: srcId))
552 return getNode(id: dstId)->op;
553
554 // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
555 DenseSet<unsigned> definingNodes;
556 gatherDefiningNodes(id: dstId, definingNodes);
557 if (llvm::any_of(Range&: definingNodes,
558 P: [&](unsigned id) { return hasDependencePath(srcId, dstId: id); })) {
559 LLVM_DEBUG(llvm::dbgs()
560 << "Can't fuse: a defining op with a user in the dst "
561 "loop has dependence from the src loop\n");
562 return nullptr;
563 }
564
565 // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
566 SmallPtrSet<Operation *, 2> srcDepInsts;
567 for (auto &outEdge : outEdges.lookup(Val: srcId))
568 if (outEdge.id != dstId)
569 srcDepInsts.insert(Ptr: getNode(id: outEdge.id)->op);
570
571 // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
572 SmallPtrSet<Operation *, 2> dstDepInsts;
573 for (auto &inEdge : inEdges.lookup(Val: dstId))
574 if (inEdge.id != srcId)
575 dstDepInsts.insert(Ptr: getNode(id: inEdge.id)->op);
576
577 Operation *srcNodeInst = getNode(id: srcId)->op;
578 Operation *dstNodeInst = getNode(id: dstId)->op;
579
580 // Computing insertion point:
581 // *) Walk all operation positions in Block operation list in the
582 // range (src, dst). For each operation 'op' visited in this search:
583 // *) Store in 'firstSrcDepPos' the first position where 'op' has a
584 // dependence edge from 'srcNode'.
585 // *) Store in 'lastDstDepPost' the last position where 'op' has a
586 // dependence edge to 'dstNode'.
587 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
588 // operation insertion point (or return null pointer if no such
589 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
590 SmallVector<Operation *, 2> depInsts;
591 std::optional<unsigned> firstSrcDepPos;
592 std::optional<unsigned> lastDstDepPos;
593 unsigned pos = 0;
594 for (Block::iterator it = std::next(x: Block::iterator(srcNodeInst));
595 it != Block::iterator(dstNodeInst); ++it) {
596 Operation *op = &(*it);
597 if (srcDepInsts.count(Ptr: op) > 0 && firstSrcDepPos == std::nullopt)
598 firstSrcDepPos = pos;
599 if (dstDepInsts.count(Ptr: op) > 0)
600 lastDstDepPos = pos;
601 depInsts.push_back(Elt: op);
602 ++pos;
603 }
604
605 if (firstSrcDepPos.has_value()) {
606 if (lastDstDepPos.has_value()) {
607 if (*firstSrcDepPos <= *lastDstDepPos) {
608 // No valid insertion point exists which preserves dependences.
609 return nullptr;
610 }
611 }
612 // Return the insertion point at 'firstSrcDepPos'.
613 return depInsts[*firstSrcDepPos];
614 }
615 // No dependence targets in range (or only dst deps in range), return
616 // 'dstNodInst' insertion point.
617 return dstNodeInst;
618}
619
620// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
621// taking into account that:
622// *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
623// *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
624// private memref.
625void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
626 const DenseSet<Value> &privateMemRefs,
627 bool removeSrcId) {
628 // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
629 if (inEdges.count(Val: srcId) > 0) {
630 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
631 for (auto &inEdge : oldInEdges) {
632 // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
633 if (!privateMemRefs.contains(V: inEdge.value))
634 addEdge(srcId: inEdge.id, dstId, value: inEdge.value);
635 }
636 }
637 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
638 // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
639 if (outEdges.count(Val: srcId) > 0) {
640 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
641 for (auto &outEdge : oldOutEdges) {
642 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
643 if (outEdge.id == dstId)
644 removeEdge(srcId, dstId: outEdge.id, value: outEdge.value);
645 else if (removeSrcId) {
646 addEdge(srcId: dstId, dstId: outEdge.id, value: outEdge.value);
647 removeEdge(srcId, dstId: outEdge.id, value: outEdge.value);
648 }
649 }
650 }
651 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
652 // replaced by a private memref). These edges could come from nodes
653 // other than 'srcId' which were removed in the previous step.
654 if (inEdges.count(Val: dstId) > 0 && !privateMemRefs.empty()) {
655 SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
656 for (auto &inEdge : oldInEdges)
657 if (privateMemRefs.count(V: inEdge.value) > 0)
658 removeEdge(srcId: inEdge.id, dstId, value: inEdge.value);
659 }
660}
661
662// Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
663// of sibling node 'sibId' into node 'dstId'.
664void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) {
665 // For each edge in 'inEdges[sibId]':
666 // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
667 // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
668 if (inEdges.count(Val: sibId) > 0) {
669 SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
670 for (auto &inEdge : oldInEdges) {
671 addEdge(srcId: inEdge.id, dstId, value: inEdge.value);
672 removeEdge(srcId: inEdge.id, dstId: sibId, value: inEdge.value);
673 }
674 }
675
676 // For each edge in 'outEdges[sibId]' to node 'id'
677 // *) Add new edge from 'dstId' to 'outEdge.id'.
678 // *) Remove edge from 'sibId' to 'outEdge.id'.
679 if (outEdges.count(Val: sibId) > 0) {
680 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
681 for (auto &outEdge : oldOutEdges) {
682 addEdge(srcId: dstId, dstId: outEdge.id, value: outEdge.value);
683 removeEdge(srcId: sibId, dstId: outEdge.id, value: outEdge.value);
684 }
685 }
686}
687
688// Adds ops in 'loads' and 'stores' to node at 'id'.
689void MemRefDependenceGraph::addToNode(unsigned id, ArrayRef<Operation *> loads,
690 ArrayRef<Operation *> stores,
691 ArrayRef<Operation *> memrefLoads,
692 ArrayRef<Operation *> memrefStores,
693 ArrayRef<Operation *> memrefFrees) {
694 Node *node = getNode(id);
695 llvm::append_range(C&: node->loads, R&: loads);
696 llvm::append_range(C&: node->stores, R&: stores);
697 llvm::append_range(C&: node->memrefLoads, R&: memrefLoads);
698 llvm::append_range(C&: node->memrefStores, R&: memrefStores);
699 llvm::append_range(C&: node->memrefFrees, R&: memrefFrees);
700}
701
702void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) {
703 Node *node = getNode(id);
704 node->loads.clear();
705 node->stores.clear();
706}
707
708// Calls 'callback' for each input edge incident to node 'id' which carries a
709// memref dependence.
710void MemRefDependenceGraph::forEachMemRefInputEdge(
711 unsigned id, const std::function<void(Edge)> &callback) {
712 if (inEdges.count(Val: id) > 0)
713 forEachMemRefEdge(edges: inEdges[id], callback);
714}
715
716// Calls 'callback' for each output edge from node 'id' which carries a
717// memref dependence.
718void MemRefDependenceGraph::forEachMemRefOutputEdge(
719 unsigned id, const std::function<void(Edge)> &callback) {
720 if (outEdges.count(Val: id) > 0)
721 forEachMemRefEdge(edges: outEdges[id], callback);
722}
723
724// Calls 'callback' for each edge in 'edges' which carries a memref
725// dependence.
726void MemRefDependenceGraph::forEachMemRefEdge(
727 ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
728 for (const auto &edge : edges) {
729 // Skip if 'edge' is not a memref dependence edge.
730 if (!isa<MemRefType>(Val: edge.value.getType()))
731 continue;
732 assert(nodes.count(edge.id) > 0);
733 // Skip if 'edge.id' is not a loop nest.
734 if (!isa<AffineForOp>(Val: getNode(id: edge.id)->op))
735 continue;
736 // Visit current input edge 'edge'.
737 callback(edge);
738 }
739}
740
741void MemRefDependenceGraph::print(raw_ostream &os) const {
742 os << "\nMemRefDependenceGraph\n";
743 os << "\nNodes:\n";
744 for (const auto &idAndNode : nodes) {
745 os << "Node: " << idAndNode.first << "\n";
746 auto it = inEdges.find(Val: idAndNode.first);
747 if (it != inEdges.end()) {
748 for (const auto &e : it->second)
749 os << " InEdge: " << e.id << " " << e.value << "\n";
750 }
751 it = outEdges.find(Val: idAndNode.first);
752 if (it != outEdges.end()) {
753 for (const auto &e : it->second)
754 os << " OutEdge: " << e.id << " " << e.value << "\n";
755 }
756 }
757}
758
759void mlir::affine::getAffineForIVs(Operation &op,
760 SmallVectorImpl<AffineForOp> *loops) {
761 auto *currOp = op.getParentOp();
762 AffineForOp currAffineForOp;
763 // Traverse up the hierarchy collecting all 'affine.for' operation while
764 // skipping over 'affine.if' operations.
765 while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
766 if (auto currAffineForOp = dyn_cast<AffineForOp>(Val: currOp))
767 loops->push_back(Elt: currAffineForOp);
768 currOp = currOp->getParentOp();
769 }
770 std::reverse(first: loops->begin(), last: loops->end());
771}
772
773void mlir::affine::getEnclosingAffineOps(Operation &op,
774 SmallVectorImpl<Operation *> *ops) {
775 ops->clear();
776 Operation *currOp = op.getParentOp();
777
778 // Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and
779 // affine.parallel operations.
780 while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
781 if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(Val: currOp))
782 ops->push_back(Elt: currOp);
783 currOp = currOp->getParentOp();
784 }
785 std::reverse(first: ops->begin(), last: ops->end());
786}
787
788// Populates 'cst' with FlatAffineValueConstraints which represent original
789// domain of the loop bounds that define 'ivs'.
790LogicalResult ComputationSliceState::getSourceAsConstraints(
791 FlatAffineValueConstraints &cst) const {
792 assert(!ivs.empty() && "Cannot have a slice without its IVs");
793 cst = FlatAffineValueConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0,
794 /*numLocals=*/0, ivs);
795 for (Value iv : ivs) {
796 AffineForOp loop = getForInductionVarOwner(val: iv);
797 assert(loop && "Expected affine for");
798 if (failed(Result: cst.addAffineForOpDomain(forOp: loop)))
799 return failure();
800 }
801 return success();
802}
803
804// Populates 'cst' with FlatAffineValueConstraints which represent slice bounds.
805LogicalResult
806ComputationSliceState::getAsConstraints(FlatAffineValueConstraints *cst) const {
807 assert(!lbOperands.empty());
808 // Adds src 'ivs' as dimension variables in 'cst'.
809 unsigned numDims = ivs.size();
810 // Adds operands (dst ivs and symbols) as symbols in 'cst'.
811 unsigned numSymbols = lbOperands[0].size();
812
813 SmallVector<Value, 4> values(ivs);
814 // Append 'ivs' then 'operands' to 'values'.
815 values.append(in_start: lbOperands[0].begin(), in_end: lbOperands[0].end());
816 *cst = FlatAffineValueConstraints(numDims, numSymbols, 0, values);
817
818 // Add loop bound constraints for values which are loop IVs of the destination
819 // of fusion and equality constraints for symbols which are constants.
820 for (unsigned i = numDims, end = values.size(); i < end; ++i) {
821 Value value = values[i];
822 assert(cst->containsVar(value) && "value expected to be present");
823 if (isValidSymbol(value)) {
824 // Check if the symbol is a constant.
825 if (std::optional<int64_t> cOp = getConstantIntValue(ofr: value))
826 cst->addBound(type: BoundType::EQ, val: value, value: cOp.value());
827 } else if (auto loop = getForInductionVarOwner(val: value)) {
828 if (failed(Result: cst->addAffineForOpDomain(forOp: loop)))
829 return failure();
830 }
831 }
832
833 // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
834 LogicalResult ret = cst->addSliceBounds(values: ivs, lbMaps: lbs, ubMaps: ubs, operands: lbOperands[0]);
835 assert(succeeded(ret) &&
836 "should not fail as we never have semi-affine slice maps");
837 (void)ret;
838 return success();
839}
840
841// Clears state bounds and operand state.
842void ComputationSliceState::clearBounds() {
843 lbs.clear();
844 ubs.clear();
845 lbOperands.clear();
846 ubOperands.clear();
847}
848
849void ComputationSliceState::dump() const {
850 llvm::errs() << "\tIVs:\n";
851 for (Value iv : ivs)
852 llvm::errs() << "\t\t" << iv << "\n";
853
854 llvm::errs() << "\tLBs:\n";
855 for (auto en : llvm::enumerate(First: lbs)) {
856 llvm::errs() << "\t\t" << en.value() << "\n";
857 llvm::errs() << "\t\tOperands:\n";
858 for (Value lbOp : lbOperands[en.index()])
859 llvm::errs() << "\t\t\t" << lbOp << "\n";
860 }
861
862 llvm::errs() << "\tUBs:\n";
863 for (auto en : llvm::enumerate(First: ubs)) {
864 llvm::errs() << "\t\t" << en.value() << "\n";
865 llvm::errs() << "\t\tOperands:\n";
866 for (Value ubOp : ubOperands[en.index()])
867 llvm::errs() << "\t\t\t" << ubOp << "\n";
868 }
869}
870
871/// Fast check to determine if the computation slice is maximal. Returns true if
872/// each slice dimension maps to an existing dst dimension and both the src
873/// and the dst loops for those dimensions have the same bounds. Returns false
874/// if both the src and the dst loops don't have the same bounds. Returns
875/// std::nullopt if none of the above can be proven.
876std::optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
877 assert(lbs.size() == ubs.size() && !lbs.empty() && !ivs.empty() &&
878 "Unexpected number of lbs, ubs and ivs in slice");
879
880 for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
881 AffineMap lbMap = lbs[i];
882 AffineMap ubMap = ubs[i];
883
884 // Check if this slice is just an equality along this dimension.
885 if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
886 ubMap.getNumResults() != 1 ||
887 lbMap.getResult(idx: 0) + 1 != ubMap.getResult(idx: 0) ||
888 // The condition above will be true for maps describing a single
889 // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
890 // Make sure we skip those cases by checking that the lb result is not
891 // just a constant.
892 isa<AffineConstantExpr>(Val: lbMap.getResult(idx: 0)))
893 return std::nullopt;
894
895 // Limited support: we expect the lb result to be just a loop dimension for
896 // now.
897 AffineDimExpr result = dyn_cast<AffineDimExpr>(Val: lbMap.getResult(idx: 0));
898 if (!result)
899 return std::nullopt;
900
901 // Retrieve dst loop bounds.
902 AffineForOp dstLoop =
903 getForInductionVarOwner(val: lbOperands[i][result.getPosition()]);
904 if (!dstLoop)
905 return std::nullopt;
906 AffineMap dstLbMap = dstLoop.getLowerBoundMap();
907 AffineMap dstUbMap = dstLoop.getUpperBoundMap();
908
909 // Retrieve src loop bounds.
910 AffineForOp srcLoop = getForInductionVarOwner(val: ivs[i]);
911 assert(srcLoop && "Expected affine for");
912 AffineMap srcLbMap = srcLoop.getLowerBoundMap();
913 AffineMap srcUbMap = srcLoop.getUpperBoundMap();
914
915 // Limited support: we expect simple src and dst loops with a single
916 // constant component per bound for now.
917 if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
918 dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
919 return std::nullopt;
920
921 AffineExpr srcLbResult = srcLbMap.getResult(idx: 0);
922 AffineExpr dstLbResult = dstLbMap.getResult(idx: 0);
923 AffineExpr srcUbResult = srcUbMap.getResult(idx: 0);
924 AffineExpr dstUbResult = dstUbMap.getResult(idx: 0);
925 if (!isa<AffineConstantExpr>(Val: srcLbResult) ||
926 !isa<AffineConstantExpr>(Val: srcUbResult) ||
927 !isa<AffineConstantExpr>(Val: dstLbResult) ||
928 !isa<AffineConstantExpr>(Val: dstUbResult))
929 return std::nullopt;
930
931 // Check if src and dst loop bounds are the same. If not, we can guarantee
932 // that the slice is not maximal.
933 if (srcLbResult != dstLbResult || srcUbResult != dstUbResult ||
934 srcLoop.getStep() != dstLoop.getStep())
935 return false;
936 }
937
938 return true;
939}
940
941/// Returns true if it is deterministically verified that the original iteration
942/// space of the slice is contained within the new iteration space that is
943/// created after fusing 'this' slice into its destination.
944std::optional<bool> ComputationSliceState::isSliceValid() const {
945 // Fast check to determine if the slice is valid. If the following conditions
946 // are verified to be true, slice is declared valid by the fast check:
947 // 1. Each slice loop is a single iteration loop bound in terms of a single
948 // destination loop IV.
949 // 2. Loop bounds of the destination loop IV (from above) and those of the
950 // source loop IV are exactly the same.
951 // If the fast check is inconclusive or false, we proceed with a more
952 // expensive analysis.
953 // TODO: Store the result of the fast check, as it might be used again in
954 // `canRemoveSrcNodeAfterFusion`.
955 std::optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
956 if (isValidFastCheck && *isValidFastCheck)
957 return true;
958
959 // Create constraints for the source loop nest using which slice is computed.
960 FlatAffineValueConstraints srcConstraints;
961 // TODO: Store the source's domain to avoid computation at each depth.
962 if (failed(Result: getSourceAsConstraints(cst&: srcConstraints))) {
963 LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
964 return std::nullopt;
965 }
966 // As the set difference utility currently cannot handle symbols in its
967 // operands, validity of the slice cannot be determined.
968 if (srcConstraints.getNumSymbolVars() > 0) {
969 LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
970 return std::nullopt;
971 }
972 // TODO: Handle local vars in the source domains while using the 'projectOut'
973 // utility below. Currently, aligning is not done assuming that there will be
974 // no local vars in the source domain.
975 if (srcConstraints.getNumLocalVars() != 0) {
976 LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
977 return std::nullopt;
978 }
979
980 // Create constraints for the slice loop nest that would be created if the
981 // fusion succeeds.
982 FlatAffineValueConstraints sliceConstraints;
983 if (failed(Result: getAsConstraints(cst: &sliceConstraints))) {
984 LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
985 return std::nullopt;
986 }
987
988 // Projecting out every dimension other than the 'ivs' to express slice's
989 // domain completely in terms of source's IVs.
990 sliceConstraints.projectOut(pos: ivs.size(),
991 num: sliceConstraints.getNumVars() - ivs.size());
992
993 LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
994 LLVM_DEBUG(srcConstraints.dump());
995 LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
996 "(expressed in terms of its source's IVs):\n");
997 LLVM_DEBUG(sliceConstraints.dump());
998
999 // TODO: Store 'srcSet' to avoid recalculating for each depth.
1000 PresburgerSet srcSet(srcConstraints);
1001 PresburgerSet sliceSet(sliceConstraints);
1002 PresburgerSet diffSet = sliceSet.subtract(set: srcSet);
1003
1004 if (!diffSet.isIntegerEmpty()) {
1005 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
1006 return false;
1007 }
1008 return true;
1009}
1010
1011/// Returns true if the computation slice encloses all the iterations of the
1012/// sliced loop nest. Returns false if it does not. Returns std::nullopt if it
1013/// cannot determine if the slice is maximal or not.
1014std::optional<bool> ComputationSliceState::isMaximal() const {
1015 // Fast check to determine if the computation slice is maximal. If the result
1016 // is inconclusive, we proceed with a more expensive analysis.
1017 std::optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
1018 if (isMaximalFastCheck)
1019 return isMaximalFastCheck;
1020
1021 // Create constraints for the src loop nest being sliced.
1022 FlatAffineValueConstraints srcConstraints(/*numDims=*/ivs.size(),
1023 /*numSymbols=*/0,
1024 /*numLocals=*/0, ivs);
1025 for (Value iv : ivs) {
1026 AffineForOp loop = getForInductionVarOwner(val: iv);
1027 assert(loop && "Expected affine for");
1028 if (failed(Result: srcConstraints.addAffineForOpDomain(forOp: loop)))
1029 return std::nullopt;
1030 }
1031
1032 // Create constraints for the slice using the dst loop nest information. We
1033 // retrieve existing dst loops from the lbOperands.
1034 SmallVector<Value> consumerIVs;
1035 for (Value lbOp : lbOperands[0])
1036 if (getForInductionVarOwner(val: lbOp))
1037 consumerIVs.push_back(Elt: lbOp);
1038
1039 // Add empty IV Values for those new loops that are not equalities and,
1040 // therefore, are not yet materialized in the IR.
1041 for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
1042 consumerIVs.push_back(Elt: Value());
1043
1044 FlatAffineValueConstraints sliceConstraints(/*numDims=*/consumerIVs.size(),
1045 /*numSymbols=*/0,
1046 /*numLocals=*/0, consumerIVs);
1047
1048 if (failed(Result: sliceConstraints.addDomainFromSliceMaps(lbMaps: lbs, ubMaps: ubs, operands: lbOperands[0])))
1049 return std::nullopt;
1050
1051 if (srcConstraints.getNumDimVars() != sliceConstraints.getNumDimVars())
1052 // Constraint dims are different. The integer set difference can't be
1053 // computed so we don't know if the slice is maximal.
1054 return std::nullopt;
1055
1056 // Compute the difference between the src loop nest and the slice integer
1057 // sets.
1058 PresburgerSet srcSet(srcConstraints);
1059 PresburgerSet sliceSet(sliceConstraints);
1060 PresburgerSet diffSet = srcSet.subtract(set: sliceSet);
1061 return diffSet.isIntegerEmpty();
1062}
1063
1064unsigned MemRefRegion::getRank() const {
1065 return cast<MemRefType>(Val: memref.getType()).getRank();
1066}
1067
1068std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
1069 SmallVectorImpl<int64_t> *shape, SmallVectorImpl<AffineMap> *lbs) const {
1070 auto memRefType = cast<MemRefType>(Val: memref.getType());
1071 MLIRContext *context = memref.getContext();
1072 unsigned rank = memRefType.getRank();
1073 if (shape)
1074 shape->reserve(N: rank);
1075
1076 assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1077
1078 // Use a copy of the region constraints that has upper/lower bounds for each
1079 // memref dimension with static size added to guard against potential
1080 // over-approximation from projection or union bounding box. We may not add
1081 // this on the region itself since they might just be redundant constraints
1082 // that will need non-trivials means to eliminate.
1083 FlatLinearValueConstraints cstWithShapeBounds(cst);
1084 for (unsigned r = 0; r < rank; r++) {
1085 cstWithShapeBounds.addBound(type: BoundType::LB, pos: r, value: 0);
1086 int64_t dimSize = memRefType.getDimSize(idx: r);
1087 if (ShapedType::isDynamic(dValue: dimSize))
1088 continue;
1089 cstWithShapeBounds.addBound(type: BoundType::UB, pos: r, value: dimSize - 1);
1090 }
1091
1092 // Find a constant upper bound on the extent of this memref region along
1093 // each dimension.
1094 int64_t numElements = 1;
1095 int64_t diffConstant;
1096 for (unsigned d = 0; d < rank; d++) {
1097 AffineMap lb;
1098 std::optional<int64_t> diff =
1099 cstWithShapeBounds.getConstantBoundOnDimSize(context, pos: d, lb: &lb);
1100 if (diff.has_value()) {
1101 diffConstant = *diff;
1102 assert(diffConstant >= 0 && "dim size bound cannot be negative");
1103 } else {
1104 // If no constant bound is found, then it can always be bound by the
1105 // memref's dim size if the latter has a constant size along this dim.
1106 auto dimSize = memRefType.getDimSize(idx: d);
1107 if (ShapedType::isDynamic(dValue: dimSize))
1108 return std::nullopt;
1109 diffConstant = dimSize;
1110 // Lower bound becomes 0.
1111 lb = AffineMap::get(/*dimCount=*/0, symbolCount: cstWithShapeBounds.getNumSymbolVars(),
1112 /*result=*/getAffineConstantExpr(constant: 0, context));
1113 }
1114 numElements *= diffConstant;
1115 // Populate outputs if available.
1116 if (lbs)
1117 lbs->push_back(Elt: lb);
1118 if (shape)
1119 shape->push_back(Elt: diffConstant);
1120 }
1121 return numElements;
1122}
1123
1124void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
1125 AffineMap &ubMap) const {
1126 assert(pos < cst.getNumDimVars() && "invalid position");
1127 auto memRefType = cast<MemRefType>(Val: memref.getType());
1128 unsigned rank = memRefType.getRank();
1129
1130 assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1131
1132 auto boundPairs = cst.getLowerAndUpperBound(
1133 pos, /*offset=*/0, /*num=*/rank, symStartPos: cst.getNumDimAndSymbolVars(),
1134 /*localExprs=*/{}, context: memRefType.getContext());
1135 lbMap = boundPairs.first;
1136 ubMap = boundPairs.second;
1137 assert(lbMap && "lower bound for a region must exist");
1138 assert(ubMap && "upper bound for a region must exist");
1139 assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1140 assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1141}
1142
1143LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
1144 assert(memref == other.memref);
1145 return cst.unionBoundingBox(other: *other.getConstraints());
1146}
1147
1148/// Computes the memory region accessed by this memref with the region
1149/// represented as constraints symbolic/parametric in 'loopDepth' loops
1150/// surrounding opInst and any additional Function symbols.
1151// For example, the memref region for this load operation at loopDepth = 1 will
1152// be as below:
1153//
1154// affine.for %i = 0 to 32 {
1155// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
1156// load %A[%ii]
1157// }
1158// }
1159//
1160// region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
1161// The last field is a 2-d FlatAffineValueConstraints symbolic in %i.
1162//
1163// TODO: extend this to any other memref dereferencing ops
1164// (dma_start, dma_wait).
1165LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
1166 const ComputationSliceState *sliceState,
1167 bool addMemRefDimBounds, bool dropLocalVars,
1168 bool dropOuterIvs) {
1169 assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
1170 "affine read/write op expected");
1171
1172 MemRefAccess access(op);
1173 memref = access.memref;
1174 write = access.isStore();
1175
1176 unsigned rank = access.getRank();
1177
1178 LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
1179 << "\ndepth: " << loopDepth << "\n";);
1180
1181 // 0-d memrefs.
1182 if (rank == 0) {
1183 SmallVector<Value, 4> ivs;
1184 getAffineIVs(op&: *op, ivs);
1185 assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
1186 // The first 'loopDepth' IVs are symbols for this region.
1187 ivs.resize(N: loopDepth);
1188 // A 0-d memref has a 0-d region.
1189 cst = FlatAffineValueConstraints(rank, loopDepth, /*numLocals=*/0, ivs);
1190 return success();
1191 }
1192
1193 // Build the constraints for this region.
1194 AffineValueMap accessValueMap;
1195 access.getAccessMap(accessMap: &accessValueMap);
1196 AffineMap accessMap = accessValueMap.getAffineMap();
1197
1198 unsigned numDims = accessMap.getNumDims();
1199 unsigned numSymbols = accessMap.getNumSymbols();
1200 unsigned numOperands = accessValueMap.getNumOperands();
1201 // Merge operands with slice operands.
1202 SmallVector<Value, 4> operands;
1203 operands.resize(N: numOperands);
1204 for (unsigned i = 0; i < numOperands; ++i)
1205 operands[i] = accessValueMap.getOperand(i);
1206
1207 if (sliceState != nullptr) {
1208 operands.reserve(N: operands.size() + sliceState->lbOperands[0].size());
1209 // Append slice operands to 'operands' as symbols.
1210 for (auto extraOperand : sliceState->lbOperands[0]) {
1211 if (!llvm::is_contained(Range&: operands, Element: extraOperand)) {
1212 operands.push_back(Elt: extraOperand);
1213 numSymbols++;
1214 }
1215 }
1216 }
1217 // We'll first associate the dims and symbols of the access map to the dims
1218 // and symbols resp. of cst. This will change below once cst is
1219 // fully constructed out.
1220 cst = FlatAffineValueConstraints(numDims, numSymbols, 0, operands);
1221
1222 // Add equality constraints.
1223 // Add inequalities for loop lower/upper bounds.
1224 for (unsigned i = 0; i < numDims + numSymbols; ++i) {
1225 auto operand = operands[i];
1226 if (auto affineFor = getForInductionVarOwner(val: operand)) {
1227 // Note that cst can now have more dimensions than accessMap if the
1228 // bounds expressions involve outer loops or other symbols.
1229 // TODO: rewrite this to use getInstIndexSet; this way
1230 // conditionals will be handled when the latter supports it.
1231 if (failed(Result: cst.addAffineForOpDomain(forOp: affineFor)))
1232 return failure();
1233 } else if (auto parallelOp = getAffineParallelInductionVarOwner(val: operand)) {
1234 if (failed(Result: cst.addAffineParallelOpDomain(parallelOp)))
1235 return failure();
1236 } else if (isValidSymbol(value: operand)) {
1237 // Check if the symbol is a constant.
1238 Value symbol = operand;
1239 if (auto constVal = getConstantIntValue(ofr: symbol))
1240 cst.addBound(type: BoundType::EQ, val: symbol, value: constVal.value());
1241 } else {
1242 LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value");
1243 return failure();
1244 }
1245 }
1246
1247 // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
1248 if (sliceState != nullptr) {
1249 // Add dim and symbol slice operands.
1250 for (auto operand : sliceState->lbOperands[0]) {
1251 if (failed(Result: cst.addInductionVarOrTerminalSymbol(val: operand)))
1252 return failure();
1253 }
1254 // Add upper/lower bounds from 'sliceState' to 'cst'.
1255 LogicalResult ret =
1256 cst.addSliceBounds(values: sliceState->ivs, lbMaps: sliceState->lbs, ubMaps: sliceState->ubs,
1257 operands: sliceState->lbOperands[0]);
1258 assert(succeeded(ret) &&
1259 "should not fail as we never have semi-affine slice maps");
1260 (void)ret;
1261 }
1262
1263 // Add access function equalities to connect loop IVs to data dimensions.
1264 if (failed(Result: cst.composeMap(vMap: &accessValueMap))) {
1265 op->emitError(message: "getMemRefRegion: compose affine map failed");
1266 LLVM_DEBUG(accessValueMap.getAffineMap().dump());
1267 return failure();
1268 }
1269
1270 // Set all variables appearing after the first 'rank' variables as
1271 // symbolic variables - so that the ones corresponding to the memref
1272 // dimensions are the dimensional variables for the memref region.
1273 cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - rank);
1274
1275 // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
1276 // this memref region is symbolic.
1277 SmallVector<Value, 4> enclosingIVs;
1278 getAffineIVs(op&: *op, ivs&: enclosingIVs);
1279 assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
1280 enclosingIVs.resize(N: loopDepth);
1281 SmallVector<Value, 4> vars;
1282 cst.getValues(start: cst.getNumDimVars(), end: cst.getNumDimAndSymbolVars(), values: &vars);
1283 for (auto en : llvm::enumerate(First&: vars)) {
1284 if ((isAffineInductionVar(val: en.value())) &&
1285 !llvm::is_contained(Range&: enclosingIVs, Element: en.value())) {
1286 if (dropOuterIvs) {
1287 cst.projectOut(val: en.value());
1288 } else {
1289 unsigned varPosition;
1290 cst.findVar(val: en.value(), pos: &varPosition);
1291 auto varKind = cst.getVarKindAt(pos: varPosition);
1292 varPosition -= cst.getNumDimVars();
1293 cst.convertToLocal(kind: varKind, varStart: varPosition, varLimit: varPosition + 1);
1294 }
1295 }
1296 }
1297
1298 // Project out any local variables (these would have been added for any
1299 // mod/divs) if specified.
1300 if (dropLocalVars)
1301 cst.projectOut(pos: cst.getNumDimAndSymbolVars(), num: cst.getNumLocalVars());
1302
1303 // Constant fold any symbolic variables.
1304 cst.constantFoldVarRange(/*pos=*/cst.getNumDimVars(),
1305 /*num=*/cst.getNumSymbolVars());
1306
1307 assert(cst.getNumDimVars() == rank && "unexpected MemRefRegion format");
1308
1309 // Add upper/lower bounds for each memref dimension with static size
1310 // to guard against potential over-approximation from projection.
1311 // TODO: Support dynamic memref dimensions.
1312 if (addMemRefDimBounds) {
1313 auto memRefType = cast<MemRefType>(Val: memref.getType());
1314 for (unsigned r = 0; r < rank; r++) {
1315 cst.addBound(type: BoundType::LB, /*pos=*/r, /*value=*/0);
1316 if (memRefType.isDynamicDim(idx: r))
1317 continue;
1318 cst.addBound(type: BoundType::UB, /*pos=*/r, value: memRefType.getDimSize(idx: r) - 1);
1319 }
1320 }
1321 cst.removeTrivialRedundancy();
1322
1323 LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
1324 LLVM_DEBUG(cst.dump());
1325 return success();
1326}
1327
1328std::optional<int64_t>
1329mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
1330 auto elementType = memRefType.getElementType();
1331
1332 unsigned sizeInBits;
1333 if (elementType.isIntOrFloat()) {
1334 sizeInBits = elementType.getIntOrFloatBitWidth();
1335 } else if (auto vectorType = dyn_cast<VectorType>(Val&: elementType)) {
1336 if (vectorType.getElementType().isIntOrFloat())
1337 sizeInBits =
1338 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
1339 else
1340 return std::nullopt;
1341 } else {
1342 return std::nullopt;
1343 }
1344 return llvm::divideCeil(Numerator: sizeInBits, Denominator: 8);
1345}
1346
1347// Returns the size of the region.
1348std::optional<int64_t> MemRefRegion::getRegionSize() {
1349 auto memRefType = cast<MemRefType>(Val: memref.getType());
1350
1351 if (!memRefType.getLayout().isIdentity()) {
1352 LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
1353 return false;
1354 }
1355
1356 // Compute the extents of the buffer.
1357 std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
1358 if (!numElements) {
1359 LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
1360 return std::nullopt;
1361 }
1362 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1363 if (!eltSize)
1364 return std::nullopt;
1365 return *eltSize * *numElements;
1366}
1367
1368/// Returns the size of memref data in bytes if it's statically shaped,
1369/// std::nullopt otherwise. If the element of the memref has vector type, takes
1370/// into account size of the vector as well.
1371// TODO: improve/complete this when we have target data.
1372std::optional<uint64_t>
1373mlir::affine::getIntOrFloatMemRefSizeInBytes(MemRefType memRefType) {
1374 if (!memRefType.hasStaticShape())
1375 return std::nullopt;
1376 auto elementType = memRefType.getElementType();
1377 if (!elementType.isIntOrFloat() && !isa<VectorType>(Val: elementType))
1378 return std::nullopt;
1379
1380 auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1381 if (!sizeInBytes)
1382 return std::nullopt;
1383 for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
1384 sizeInBytes = *sizeInBytes * memRefType.getDimSize(idx: i);
1385 }
1386 return sizeInBytes;
1387}
1388
1389template <typename LoadOrStoreOp>
1390LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
1391 bool emitError) {
1392 static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
1393 AffineWriteOpInterface>::value,
1394 "argument should be either a AffineReadOpInterface or a "
1395 "AffineWriteOpInterface");
1396
1397 Operation *op = loadOrStoreOp.getOperation();
1398 MemRefRegion region(op->getLoc());
1399 if (failed(Result: region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
1400 /*addMemRefDimBounds=*/false)))
1401 return success();
1402
1403 LLVM_DEBUG(llvm::dbgs() << "Memory region");
1404 LLVM_DEBUG(region.getConstraints()->dump());
1405
1406 bool outOfBounds = false;
1407 unsigned rank = loadOrStoreOp.getMemRefType().getRank();
1408
1409 // For each dimension, check for out of bounds.
1410 for (unsigned r = 0; r < rank; r++) {
1411 FlatAffineValueConstraints ucst(*region.getConstraints());
1412
1413 // Intersect memory region with constraint capturing out of bounds (both out
1414 // of upper and out of lower), and check if the constraint system is
1415 // feasible. If it is, there is at least one point out of bounds.
1416 SmallVector<int64_t, 4> ineq(rank + 1, 0);
1417 int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
1418 // TODO: handle dynamic dim sizes.
1419 if (dimSize == -1)
1420 continue;
1421
1422 // Check for overflow: d_i >= memref dim size.
1423 ucst.addBound(type: BoundType::LB, pos: r, value: dimSize);
1424 outOfBounds = !ucst.isEmpty();
1425 if (outOfBounds && emitError) {
1426 loadOrStoreOp.emitOpError()
1427 << "memref out of upper bound access along dimension #" << (r + 1);
1428 }
1429
1430 // Check for a negative index.
1431 FlatAffineValueConstraints lcst(*region.getConstraints());
1432 llvm::fill(Range&: ineq, Value: 0);
1433 // d_i <= -1;
1434 lcst.addBound(type: BoundType::UB, pos: r, value: -1);
1435 outOfBounds = !lcst.isEmpty();
1436 if (outOfBounds && emitError) {
1437 loadOrStoreOp.emitOpError()
1438 << "memref out of lower bound access along dimension #" << (r + 1);
1439 }
1440 }
1441 return failure(IsFailure: outOfBounds);
1442}
1443
1444// Explicitly instantiate the template so that the compiler knows we need them!
1445template LogicalResult
1446mlir::affine::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp,
1447 bool emitError);
1448template LogicalResult
1449mlir::affine::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp,
1450 bool emitError);
1451
1452// Returns in 'positions' the Block positions of 'op' in each ancestor
1453// Block from the Block containing operation, stopping at 'limitBlock'.
1454static void findInstPosition(Operation *op, Block *limitBlock,
1455 SmallVectorImpl<unsigned> *positions) {
1456 Block *block = op->getBlock();
1457 while (block != limitBlock) {
1458 // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
1459 // rely on linear scans.
1460 int instPosInBlock = std::distance(first: block->begin(), last: op->getIterator());
1461 positions->push_back(Elt: instPosInBlock);
1462 op = block->getParentOp();
1463 block = op->getBlock();
1464 }
1465 std::reverse(first: positions->begin(), last: positions->end());
1466}
1467
1468// Returns the Operation in a possibly nested set of Blocks, where the
1469// position of the operation is represented by 'positions', which has a
1470// Block position for each level of nesting.
1471static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
1472 unsigned level, Block *block) {
1473 unsigned i = 0;
1474 for (auto &op : *block) {
1475 if (i != positions[level]) {
1476 ++i;
1477 continue;
1478 }
1479 if (level == positions.size() - 1)
1480 return &op;
1481 if (auto childAffineForOp = dyn_cast<AffineForOp>(Val&: op))
1482 return getInstAtPosition(positions, level: level + 1,
1483 block: childAffineForOp.getBody());
1484
1485 for (auto &region : op.getRegions()) {
1486 for (auto &b : region)
1487 if (auto *ret = getInstAtPosition(positions, level: level + 1, block: &b))
1488 return ret;
1489 }
1490 return nullptr;
1491 }
1492 return nullptr;
1493}
1494
1495// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
1496static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
1497 FlatAffineValueConstraints *cst) {
1498 for (unsigned i = 0, e = cst->getNumDimVars(); i < e; ++i) {
1499 auto value = cst->getValue(pos: i);
1500 if (ivs.count(Ptr: value) == 0) {
1501 assert(isAffineForInductionVar(value));
1502 auto loop = getForInductionVarOwner(val: value);
1503 if (failed(Result: cst->addAffineForOpDomain(forOp: loop)))
1504 return failure();
1505 }
1506 }
1507 return success();
1508}
1509
1510/// Returns the innermost common loop depth for the set of operations in 'ops'.
1511// TODO: Move this to LoopUtils.
1512unsigned mlir::affine::getInnermostCommonLoopDepth(
1513 ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
1514 unsigned numOps = ops.size();
1515 assert(numOps > 0 && "Expected at least one operation");
1516
1517 std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
1518 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
1519 for (unsigned i = 0; i < numOps; ++i) {
1520 getAffineForIVs(op&: *ops[i], loops: &loops[i]);
1521 loopDepthLimit =
1522 std::min(a: loopDepthLimit, b: static_cast<unsigned>(loops[i].size()));
1523 }
1524
1525 unsigned loopDepth = 0;
1526 for (unsigned d = 0; d < loopDepthLimit; ++d) {
1527 unsigned i;
1528 for (i = 1; i < numOps; ++i) {
1529 if (loops[i - 1][d] != loops[i][d])
1530 return loopDepth;
1531 }
1532 if (surroundingLoops)
1533 surroundingLoops->push_back(Elt: loops[i - 1][d]);
1534 ++loopDepth;
1535 }
1536 return loopDepth;
1537}
1538
1539/// Computes in 'sliceUnion' the union of all slice bounds computed at
1540/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
1541/// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
1542/// union was computed correctly, an appropriate failure otherwise.
1543SliceComputationResult
1544mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
1545 ArrayRef<Operation *> opsB, unsigned loopDepth,
1546 unsigned numCommonLoops, bool isBackwardSlice,
1547 ComputationSliceState *sliceUnion) {
1548 // Compute the union of slice bounds between all pairs in 'opsA' and
1549 // 'opsB' in 'sliceUnionCst'.
1550 FlatAffineValueConstraints sliceUnionCst;
1551 assert(sliceUnionCst.getNumDimAndSymbolVars() == 0);
1552 std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
1553 for (Operation *i : opsA) {
1554 MemRefAccess srcAccess(i);
1555 for (Operation *j : opsB) {
1556 MemRefAccess dstAccess(j);
1557 if (srcAccess.memref != dstAccess.memref)
1558 continue;
1559 // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
1560 if ((!isBackwardSlice && loopDepth > getNestingDepth(op: i)) ||
1561 (isBackwardSlice && loopDepth > getNestingDepth(op: j))) {
1562 LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
1563 return SliceComputationResult::GenericFailure;
1564 }
1565
1566 bool readReadAccesses = isa<AffineReadOpInterface>(Val: srcAccess.opInst) &&
1567 isa<AffineReadOpInterface>(Val: dstAccess.opInst);
1568 FlatAffineValueConstraints dependenceConstraints;
1569 // Check dependence between 'srcAccess' and 'dstAccess'.
1570 DependenceResult result = checkMemrefAccessDependence(
1571 srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
1572 dependenceConstraints: &dependenceConstraints, /*dependenceComponents=*/nullptr,
1573 /*allowRAR=*/readReadAccesses);
1574 if (result.value == DependenceResult::Failure) {
1575 LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
1576 return SliceComputationResult::GenericFailure;
1577 }
1578 if (result.value == DependenceResult::NoDependence)
1579 continue;
1580 dependentOpPairs.emplace_back(args&: i, args&: j);
1581
1582 // Compute slice bounds for 'srcAccess' and 'dstAccess'.
1583 ComputationSliceState tmpSliceState;
1584 mlir::affine::getComputationSliceState(depSourceOp: i, depSinkOp: j, dependenceConstraints,
1585 loopDepth, isBackwardSlice,
1586 sliceState: &tmpSliceState);
1587
1588 if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1589 // Initialize 'sliceUnionCst' with the bounds computed in previous step.
1590 if (failed(Result: tmpSliceState.getAsConstraints(cst: &sliceUnionCst))) {
1591 LLVM_DEBUG(llvm::dbgs()
1592 << "Unable to compute slice bound constraints\n");
1593 return SliceComputationResult::GenericFailure;
1594 }
1595 assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
1596 continue;
1597 }
1598
1599 // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
1600 FlatAffineValueConstraints tmpSliceCst;
1601 if (failed(Result: tmpSliceState.getAsConstraints(cst: &tmpSliceCst))) {
1602 LLVM_DEBUG(llvm::dbgs()
1603 << "Unable to compute slice bound constraints\n");
1604 return SliceComputationResult::GenericFailure;
1605 }
1606
1607 // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
1608 if (!sliceUnionCst.areVarsAlignedWithOther(other: tmpSliceCst)) {
1609
1610 // Pre-constraint var alignment: record loop IVs used in each constraint
1611 // system.
1612 SmallPtrSet<Value, 8> sliceUnionIVs;
1613 for (unsigned k = 0, l = sliceUnionCst.getNumDimVars(); k < l; ++k)
1614 sliceUnionIVs.insert(Ptr: sliceUnionCst.getValue(pos: k));
1615 SmallPtrSet<Value, 8> tmpSliceIVs;
1616 for (unsigned k = 0, l = tmpSliceCst.getNumDimVars(); k < l; ++k)
1617 tmpSliceIVs.insert(Ptr: tmpSliceCst.getValue(pos: k));
1618
1619 sliceUnionCst.mergeAndAlignVarsWithOther(/*offset=*/0, other: &tmpSliceCst);
1620
1621 // Post-constraint var alignment: add loop IV bounds missing after
1622 // var alignment to constraint systems. This can occur if one constraint
1623 // system uses an loop IV that is not used by the other. The call
1624 // to unionBoundingBox below expects constraints for each Loop IV, even
1625 // if they are the unsliced full loop bounds added here.
1626 if (failed(Result: addMissingLoopIVBounds(ivs&: sliceUnionIVs, cst: &sliceUnionCst)))
1627 return SliceComputationResult::GenericFailure;
1628 if (failed(Result: addMissingLoopIVBounds(ivs&: tmpSliceIVs, cst: &tmpSliceCst)))
1629 return SliceComputationResult::GenericFailure;
1630 }
1631 // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
1632 if (sliceUnionCst.getNumLocalVars() > 0 ||
1633 tmpSliceCst.getNumLocalVars() > 0 ||
1634 failed(Result: sliceUnionCst.unionBoundingBox(other: tmpSliceCst))) {
1635 LLVM_DEBUG(llvm::dbgs()
1636 << "Unable to compute union bounding box of slice bounds\n");
1637 return SliceComputationResult::GenericFailure;
1638 }
1639 }
1640 }
1641
1642 // Empty union.
1643 if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1644 LLVM_DEBUG(llvm::dbgs() << "empty slice union - unexpected\n");
1645 return SliceComputationResult::GenericFailure;
1646 }
1647
1648 // Gather loops surrounding ops from loop nest where slice will be inserted.
1649 SmallVector<Operation *, 4> ops;
1650 for (auto &dep : dependentOpPairs) {
1651 ops.push_back(Elt: isBackwardSlice ? dep.second : dep.first);
1652 }
1653 SmallVector<AffineForOp, 4> surroundingLoops;
1654 unsigned innermostCommonLoopDepth =
1655 getInnermostCommonLoopDepth(ops, surroundingLoops: &surroundingLoops);
1656 if (loopDepth > innermostCommonLoopDepth) {
1657 LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
1658 return SliceComputationResult::GenericFailure;
1659 }
1660
1661 // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
1662 unsigned numSliceLoopIVs = sliceUnionCst.getNumDimVars();
1663
1664 // Convert any dst loop IVs which are symbol variables to dim variables.
1665 sliceUnionCst.convertLoopIVSymbolsToDims();
1666 sliceUnion->clearBounds();
1667 sliceUnion->lbs.resize(N: numSliceLoopIVs, NV: AffineMap());
1668 sliceUnion->ubs.resize(N: numSliceLoopIVs, NV: AffineMap());
1669
1670 // Get slice bounds from slice union constraints 'sliceUnionCst'.
1671 sliceUnionCst.getSliceBounds(/*offset=*/0, num: numSliceLoopIVs,
1672 context: opsA[0]->getContext(), lbMaps: &sliceUnion->lbs,
1673 ubMaps: &sliceUnion->ubs);
1674
1675 // Add slice bound operands of union.
1676 SmallVector<Value, 4> sliceBoundOperands;
1677 sliceUnionCst.getValues(start: numSliceLoopIVs,
1678 end: sliceUnionCst.getNumDimAndSymbolVars(),
1679 values: &sliceBoundOperands);
1680
1681 // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
1682 sliceUnion->ivs.clear();
1683 sliceUnionCst.getValues(start: 0, end: numSliceLoopIVs, values: &sliceUnion->ivs);
1684
1685 // Set loop nest insertion point to block start at 'loopDepth' for forward
1686 // slices, while at the end for backward slices.
1687 sliceUnion->insertPoint =
1688 isBackwardSlice
1689 ? surroundingLoops[loopDepth - 1].getBody()->begin()
1690 : std::prev(x: surroundingLoops[loopDepth - 1].getBody()->end());
1691
1692 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1693 // canonicalization.
1694 sliceUnion->lbOperands.resize(new_size: numSliceLoopIVs, x: sliceBoundOperands);
1695 sliceUnion->ubOperands.resize(new_size: numSliceLoopIVs, x: sliceBoundOperands);
1696
1697 // Check if the slice computed is valid. Return success only if it is verified
1698 // that the slice is valid, otherwise return appropriate failure status.
1699 std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
1700 if (!isSliceValid) {
1701 LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
1702 return SliceComputationResult::GenericFailure;
1703 }
1704 if (!*isSliceValid)
1705 return SliceComputationResult::IncorrectSliceFailure;
1706
1707 return SliceComputationResult::Success;
1708}
1709
1710// TODO: extend this to handle multiple result maps.
1711static std::optional<uint64_t> getConstDifference(AffineMap lbMap,
1712 AffineMap ubMap) {
1713 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
1714 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
1715 assert(lbMap.getNumDims() == ubMap.getNumDims());
1716 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
1717 AffineExpr lbExpr(lbMap.getResult(idx: 0));
1718 AffineExpr ubExpr(ubMap.getResult(idx: 0));
1719 auto loopSpanExpr = simplifyAffineExpr(expr: ubExpr - lbExpr, numDims: lbMap.getNumDims(),
1720 numSymbols: lbMap.getNumSymbols());
1721 auto cExpr = dyn_cast<AffineConstantExpr>(Val&: loopSpanExpr);
1722 if (!cExpr)
1723 return std::nullopt;
1724 return cExpr.getValue();
1725}
1726
1727// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
1728// nest surrounding represented by slice loop bounds in 'slice'. Returns true
1729// on success, false otherwise (if a non-constant trip count was encountered).
1730// TODO: Make this work with non-unit step loops.
1731bool mlir::affine::buildSliceTripCountMap(
1732 const ComputationSliceState &slice,
1733 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
1734 unsigned numSrcLoopIVs = slice.ivs.size();
1735 // Populate map from AffineForOp -> trip count
1736 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1737 AffineForOp forOp = getForInductionVarOwner(val: slice.ivs[i]);
1738 auto *op = forOp.getOperation();
1739 AffineMap lbMap = slice.lbs[i];
1740 AffineMap ubMap = slice.ubs[i];
1741 // If lower or upper bound maps are null or provide no results, it implies
1742 // that source loop was not at all sliced, and the entire loop will be a
1743 // part of the slice.
1744 if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
1745 ubMap.getNumResults() == 0) {
1746 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
1747 if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
1748 (*tripCountMap)[op] =
1749 forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
1750 continue;
1751 }
1752 std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
1753 if (maybeConstTripCount.has_value()) {
1754 (*tripCountMap)[op] = *maybeConstTripCount;
1755 continue;
1756 }
1757 return false;
1758 }
1759 std::optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
1760 // Slice bounds are created with a constant ub - lb difference.
1761 if (!tripCount.has_value())
1762 return false;
1763 (*tripCountMap)[op] = *tripCount;
1764 }
1765 return true;
1766}
1767
1768// Return the number of iterations in the given slice.
1769uint64_t mlir::affine::getSliceIterationCount(
1770 const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
1771 uint64_t iterCount = 1;
1772 for (const auto &count : sliceTripCountMap) {
1773 iterCount *= count.second;
1774 }
1775 return iterCount;
1776}
1777
1778const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
1779// Computes slice bounds by projecting out any loop IVs from
1780// 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
1781// bounds in 'sliceState' which represent the one loop nest's IVs in terms of
1782// the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
1783void mlir::affine::getComputationSliceState(
1784 Operation *depSourceOp, Operation *depSinkOp,
1785 const FlatAffineValueConstraints &dependenceConstraints, unsigned loopDepth,
1786 bool isBackwardSlice, ComputationSliceState *sliceState) {
1787 // Get loop nest surrounding src operation.
1788 SmallVector<AffineForOp, 4> srcLoopIVs;
1789 getAffineForIVs(op&: *depSourceOp, loops: &srcLoopIVs);
1790 unsigned numSrcLoopIVs = srcLoopIVs.size();
1791
1792 // Get loop nest surrounding dst operation.
1793 SmallVector<AffineForOp, 4> dstLoopIVs;
1794 getAffineForIVs(op&: *depSinkOp, loops: &dstLoopIVs);
1795 unsigned numDstLoopIVs = dstLoopIVs.size();
1796
1797 assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
1798 (isBackwardSlice && loopDepth <= numDstLoopIVs));
1799
1800 // Project out dimensions other than those up to 'loopDepth'.
1801 unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
1802 unsigned num =
1803 isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
1804 FlatAffineValueConstraints sliceCst(dependenceConstraints);
1805 sliceCst.projectOut(pos, num);
1806
1807 // Add slice loop IV values to 'sliceState'.
1808 unsigned offset = isBackwardSlice ? 0 : loopDepth;
1809 unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
1810 sliceCst.getValues(start: offset, end: offset + numSliceLoopIVs, values: &sliceState->ivs);
1811
1812 // Set up lower/upper bound affine maps for the slice.
1813 sliceState->lbs.resize(N: numSliceLoopIVs, NV: AffineMap());
1814 sliceState->ubs.resize(N: numSliceLoopIVs, NV: AffineMap());
1815
1816 // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
1817 sliceCst.getSliceBounds(offset, num: numSliceLoopIVs, context: depSourceOp->getContext(),
1818 lbMaps: &sliceState->lbs, ubMaps: &sliceState->ubs);
1819
1820 // Set up bound operands for the slice's lower and upper bounds.
1821 SmallVector<Value, 4> sliceBoundOperands;
1822 unsigned numDimsAndSymbols = sliceCst.getNumDimAndSymbolVars();
1823 for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
1824 if (i < offset || i >= offset + numSliceLoopIVs)
1825 sliceBoundOperands.push_back(Elt: sliceCst.getValue(pos: i));
1826 }
1827
1828 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1829 // canonicalization.
1830 sliceState->lbOperands.resize(new_size: numSliceLoopIVs, x: sliceBoundOperands);
1831 sliceState->ubOperands.resize(new_size: numSliceLoopIVs, x: sliceBoundOperands);
1832
1833 // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
1834 sliceState->insertPoint =
1835 isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
1836 : std::prev(x: srcLoopIVs[loopDepth - 1].getBody()->end());
1837
1838 llvm::SmallDenseSet<Value, 8> sequentialLoops;
1839 if (isa<AffineReadOpInterface>(Val: depSourceOp) &&
1840 isa<AffineReadOpInterface>(Val: depSinkOp)) {
1841 // For read-read access pairs, clear any slice bounds on sequential loops.
1842 // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
1843 getSequentialLoops(forOp: isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
1844 sequentialLoops: &sequentialLoops);
1845 }
1846 auto getSliceLoop = [&](unsigned i) {
1847 return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
1848 };
1849 auto isInnermostInsertion = [&]() {
1850 return (isBackwardSlice ? loopDepth >= srcLoopIVs.size()
1851 : loopDepth >= dstLoopIVs.size());
1852 };
1853 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
1854 auto srcIsUnitSlice = [&]() {
1855 return (buildSliceTripCountMap(slice: *sliceState, tripCountMap: &sliceTripCountMap) &&
1856 (getSliceIterationCount(sliceTripCountMap) == 1));
1857 };
1858 // Clear all sliced loop bounds beginning at the first sequential loop, or
1859 // first loop with a slice fusion barrier attribute..
1860
1861 for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
1862 Value iv = getSliceLoop(i).getInductionVar();
1863 if (sequentialLoops.count(V: iv) == 0 &&
1864 getSliceLoop(i)->getAttr(name: kSliceFusionBarrierAttrName) == nullptr)
1865 continue;
1866 // Skip reset of bounds of reduction loop inserted in the destination loop
1867 // that meets the following conditions:
1868 // 1. Slice is single trip count.
1869 // 2. Loop bounds of the source and destination match.
1870 // 3. Is being inserted at the innermost insertion point.
1871 std::optional<bool> isMaximal = sliceState->isMaximal();
1872 if (isLoopParallelAndContainsReduction(forOp: getSliceLoop(i)) &&
1873 isInnermostInsertion() && srcIsUnitSlice() && isMaximal && *isMaximal)
1874 continue;
1875 for (unsigned j = i; j < numSliceLoopIVs; ++j) {
1876 sliceState->lbs[j] = AffineMap();
1877 sliceState->ubs[j] = AffineMap();
1878 }
1879 break;
1880 }
1881}
1882
1883/// Creates a computation slice of the loop nest surrounding 'srcOpInst',
1884/// updates the slice loop bounds with any non-null bound maps specified in
1885/// 'sliceState', and inserts this slice into the loop nest surrounding
1886/// 'dstOpInst' at loop depth 'dstLoopDepth'.
1887// TODO: extend the slicing utility to compute slices that
1888// aren't necessarily a one-to-one relation b/w the source and destination. The
1889// relation between the source and destination could be many-to-many in general.
1890// TODO: the slice computation is incorrect in the cases
1891// where the dependence from the source to the destination does not cover the
1892// entire destination index set. Subtract out the dependent destination
1893// iterations from destination index set and check for emptiness --- this is one
1894// solution.
1895AffineForOp mlir::affine::insertBackwardComputationSlice(
1896 Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth,
1897 ComputationSliceState *sliceState) {
1898 // Get loop nest surrounding src operation.
1899 SmallVector<AffineForOp, 4> srcLoopIVs;
1900 getAffineForIVs(op&: *srcOpInst, loops: &srcLoopIVs);
1901 unsigned numSrcLoopIVs = srcLoopIVs.size();
1902
1903 // Get loop nest surrounding dst operation.
1904 SmallVector<AffineForOp, 4> dstLoopIVs;
1905 getAffineForIVs(op&: *dstOpInst, loops: &dstLoopIVs);
1906 unsigned dstLoopIVsSize = dstLoopIVs.size();
1907 if (dstLoopDepth > dstLoopIVsSize) {
1908 dstOpInst->emitError(message: "invalid destination loop depth");
1909 return AffineForOp();
1910 }
1911
1912 // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
1913 SmallVector<unsigned, 4> positions;
1914 // TODO: This code is incorrect since srcLoopIVs can be 0-d.
1915 findInstPosition(op: srcOpInst, limitBlock: srcLoopIVs[0]->getBlock(), positions: &positions);
1916
1917 // Clone src loop nest and insert it a the beginning of the operation block
1918 // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
1919 auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
1920 OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
1921 auto sliceLoopNest =
1922 cast<AffineForOp>(Val: b.clone(op&: *srcLoopIVs[0].getOperation()));
1923
1924 Operation *sliceInst =
1925 getInstAtPosition(positions, /*level=*/0, block: sliceLoopNest.getBody());
1926 // Get loop nest surrounding 'sliceInst'.
1927 SmallVector<AffineForOp, 4> sliceSurroundingLoops;
1928 getAffineForIVs(op&: *sliceInst, loops: &sliceSurroundingLoops);
1929
1930 // Sanity check.
1931 unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
1932 (void)sliceSurroundingLoopsSize;
1933 assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
1934 unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
1935 (void)sliceLoopLimit;
1936 assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
1937
1938 // Update loop bounds for loops in 'sliceLoopNest'.
1939 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1940 auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
1941 if (AffineMap lbMap = sliceState->lbs[i])
1942 forOp.setLowerBound(operands: sliceState->lbOperands[i], map: lbMap);
1943 if (AffineMap ubMap = sliceState->ubs[i])
1944 forOp.setUpperBound(operands: sliceState->ubOperands[i], map: ubMap);
1945 }
1946 return sliceLoopNest;
1947}
1948
1949// Constructs MemRefAccess populating it with the memref, its indices and
1950// opinst from 'loadOrStoreOpInst'.
1951MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
1952 if (auto loadOp = dyn_cast<AffineReadOpInterface>(Val: loadOrStoreOpInst)) {
1953 memref = loadOp.getMemRef();
1954 opInst = loadOrStoreOpInst;
1955 llvm::append_range(C&: indices, R: loadOp.getMapOperands());
1956 } else {
1957 assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
1958 "Affine read/write op expected");
1959 auto storeOp = cast<AffineWriteOpInterface>(Val: loadOrStoreOpInst);
1960 opInst = loadOrStoreOpInst;
1961 memref = storeOp.getMemRef();
1962 llvm::append_range(C&: indices, R: storeOp.getMapOperands());
1963 }
1964}
1965
1966unsigned MemRefAccess::getRank() const {
1967 return cast<MemRefType>(Val: memref.getType()).getRank();
1968}
1969
1970bool MemRefAccess::isStore() const {
1971 return isa<AffineWriteOpInterface>(Val: opInst);
1972}
1973
1974/// Returns the nesting depth of this statement, i.e., the number of loops
1975/// surrounding this statement.
1976unsigned mlir::affine::getNestingDepth(Operation *op) {
1977 Operation *currOp = op;
1978 unsigned depth = 0;
1979 while ((currOp = currOp->getParentOp())) {
1980 if (isa<AffineForOp>(Val: currOp))
1981 depth++;
1982 if (auto parOp = dyn_cast<AffineParallelOp>(Val: currOp))
1983 depth += parOp.getNumDims();
1984 }
1985 return depth;
1986}
1987
1988/// Equal if both affine accesses are provably equivalent (at compile
1989/// time) when considering the memref, the affine maps and their respective
1990/// operands. The equality of access functions + operands is checked by
1991/// subtracting fully composed value maps, and then simplifying the difference
1992/// using the expression flattener.
1993/// TODO: this does not account for aliasing of memrefs.
1994bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
1995 if (memref != rhs.memref)
1996 return false;
1997
1998 AffineValueMap diff, thisMap, rhsMap;
1999 getAccessMap(accessMap: &thisMap);
2000 rhs.getAccessMap(accessMap: &rhsMap);
2001 return thisMap == rhsMap;
2002}
2003
2004void mlir::affine::getAffineIVs(Operation &op, SmallVectorImpl<Value> &ivs) {
2005 auto *currOp = op.getParentOp();
2006 AffineForOp currAffineForOp;
2007 // Traverse up the hierarchy collecting all 'affine.for' and affine.parallel
2008 // operation while skipping over 'affine.if' operations.
2009 while (currOp) {
2010 if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(Val: currOp))
2011 ivs.push_back(Elt: currAffineForOp.getInductionVar());
2012 else if (auto parOp = dyn_cast<AffineParallelOp>(Val: currOp))
2013 llvm::append_range(C&: ivs, R: parOp.getIVs());
2014 currOp = currOp->getParentOp();
2015 }
2016 std::reverse(first: ivs.begin(), last: ivs.end());
2017}
2018
2019/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
2020/// where each lists loops from outer-most to inner-most in loop nest.
2021unsigned mlir::affine::getNumCommonSurroundingLoops(Operation &a,
2022 Operation &b) {
2023 SmallVector<Value, 4> loopsA, loopsB;
2024 getAffineIVs(op&: a, ivs&: loopsA);
2025 getAffineIVs(op&: b, ivs&: loopsB);
2026
2027 unsigned minNumLoops = std::min(a: loopsA.size(), b: loopsB.size());
2028 unsigned numCommonLoops = 0;
2029 for (unsigned i = 0; i < minNumLoops; ++i) {
2030 if (loopsA[i] != loopsB[i])
2031 break;
2032 ++numCommonLoops;
2033 }
2034 return numCommonLoops;
2035}
2036
2037static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
2038 Block::iterator start,
2039 Block::iterator end,
2040 int memorySpace) {
2041 SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
2042
2043 // Walk this 'affine.for' operation to gather all memory regions.
2044 auto result = block.walk(begin: start, end, callback: [&](Operation *opInst) -> WalkResult {
2045 if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(Val: opInst)) {
2046 // Neither load nor a store op.
2047 return WalkResult::advance();
2048 }
2049
2050 // Compute the memref region symbolic in any IVs enclosing this block.
2051 auto region = std::make_unique<MemRefRegion>(args: opInst->getLoc());
2052 if (failed(
2053 Result: region->compute(op: opInst,
2054 /*loopDepth=*/getNestingDepth(op: &*block.begin())))) {
2055 LLVM_DEBUG(opInst->emitError("error obtaining memory region"));
2056 return failure();
2057 }
2058
2059 auto [it, inserted] = regions.try_emplace(Key: region->memref);
2060 if (inserted) {
2061 it->second = std::move(region);
2062 } else if (failed(Result: it->second->unionBoundingBox(other: *region))) {
2063 LLVM_DEBUG(opInst->emitWarning(
2064 "getMemoryFootprintBytes: unable to perform a union on a memory "
2065 "region"));
2066 return failure();
2067 }
2068 return WalkResult::advance();
2069 });
2070 if (result.wasInterrupted())
2071 return std::nullopt;
2072
2073 int64_t totalSizeInBytes = 0;
2074 for (const auto &region : regions) {
2075 std::optional<int64_t> size = region.second->getRegionSize();
2076 if (!size.has_value())
2077 return std::nullopt;
2078 totalSizeInBytes += *size;
2079 }
2080 return totalSizeInBytes;
2081}
2082
2083std::optional<int64_t> mlir::affine::getMemoryFootprintBytes(AffineForOp forOp,
2084 int memorySpace) {
2085 auto *forInst = forOp.getOperation();
2086 return ::getMemoryFootprintBytes(
2087 block&: *forInst->getBlock(), start: Block::iterator(forInst),
2088 end: std::next(x: Block::iterator(forInst)), memorySpace);
2089}
2090
2091/// Returns whether a loop is parallel and contains a reduction loop.
2092bool mlir::affine::isLoopParallelAndContainsReduction(AffineForOp forOp) {
2093 SmallVector<LoopReduction> reductions;
2094 if (!isLoopParallel(forOp, parallelReductions: &reductions))
2095 return false;
2096 return !reductions.empty();
2097}
2098
2099/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
2100/// at 'forOp'.
2101void mlir::affine::getSequentialLoops(
2102 AffineForOp forOp, llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
2103 forOp->walk(callback: [&](Operation *op) {
2104 if (auto innerFor = dyn_cast<AffineForOp>(Val: op))
2105 if (!isLoopParallel(forOp: innerFor))
2106 sequentialLoops->insert(V: innerFor.getInductionVar());
2107 });
2108}
2109
2110IntegerSet mlir::affine::simplifyIntegerSet(IntegerSet set) {
2111 FlatAffineValueConstraints fac(set);
2112 if (fac.isEmpty())
2113 return IntegerSet::getEmptySet(numDims: set.getNumDims(), numSymbols: set.getNumSymbols(),
2114 context: set.getContext());
2115 fac.removeTrivialRedundancy();
2116
2117 auto simplifiedSet = fac.getAsIntegerSet(context: set.getContext());
2118 assert(simplifiedSet && "guaranteed to succeed while roundtripping");
2119 return simplifiedSet;
2120}
2121
2122static void unpackOptionalValues(ArrayRef<std::optional<Value>> source,
2123 SmallVector<Value> &target) {
2124 target =
2125 llvm::to_vector<4>(Range: llvm::map_range(C&: source, F: [](std::optional<Value> val) {
2126 return val.has_value() ? *val : Value();
2127 }));
2128}
2129
2130/// Bound an identifier `pos` in a given FlatAffineValueConstraints with
2131/// constraints drawn from an affine map. Before adding the constraint, the
2132/// dimensions/symbols of the affine map are aligned with `constraints`.
2133/// `operands` are the SSA Value operands used with the affine map.
2134/// Note: This function adds a new symbol column to the `constraints` for each
2135/// dimension/symbol that exists in the affine map but not in `constraints`.
2136static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints,
2137 BoundType type, unsigned pos,
2138 AffineMap map, ValueRange operands) {
2139 SmallVector<Value> dims, syms, newSyms;
2140 unpackOptionalValues(source: constraints.getMaybeValues(kind: VarKind::SetDim), target&: dims);
2141 unpackOptionalValues(source: constraints.getMaybeValues(kind: VarKind::Symbol), target&: syms);
2142
2143 AffineMap alignedMap =
2144 alignAffineMapWithValues(map, operands, dims, syms, newSyms: &newSyms);
2145 for (unsigned i = syms.size(); i < newSyms.size(); ++i)
2146 constraints.appendSymbolVar(vals: newSyms[i]);
2147 return constraints.addBound(type, pos, boundMap: alignedMap);
2148}
2149
2150/// Add `val` to each result of `map`.
2151static AffineMap addConstToResults(AffineMap map, int64_t val) {
2152 SmallVector<AffineExpr> newResults;
2153 for (AffineExpr r : map.getResults())
2154 newResults.push_back(Elt: r + val);
2155 return AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: newResults,
2156 context: map.getContext());
2157}
2158
2159// Attempt to simplify the given min/max operation by proving that its value is
2160// bounded by the same lower and upper bound.
2161//
2162// Bounds are computed by FlatAffineValueConstraints. Invariants required for
2163// finding/proving bounds should be supplied via `constraints`.
2164//
2165// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
2166// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
2167// case of `!isMin`) and bind it to `opBound`. SSA values that are used in
2168// `op` but are not part of `constraints`, are added as extra symbols.
2169// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
2170// * If `isMin`: r_i >= opBound
2171// * If `isMax`: r_i <= opBound
2172// If this is the case, ub(op) == lb(op).
2173// 4. Replace `op` with `opBound`.
2174//
2175// In summary, the following constraints are added throughout this function.
2176// Note: `invar` are dimensions added by the caller to express the invariants.
2177// (Showing only the case where `isMin`.)
2178//
2179// invar | op | opBound | r_i | extra syms... | const | eq/ineq
2180// ------+-------+---------+-----+---------------+-------+-------------------
2181// (various eq./ineq. constraining `invar`, added by the caller)
2182// ... | 0 | 0 | 0 | 0 | ... | ...
2183// ------+-------+---------+-----+---------------+-------+-------------------
2184// (various ineq. constraining `op` in terms of `op` operands (`invar` and
2185// extra `op` operands "extra syms" that are not in `invar`)).
2186// ... | -1 | 0 | 0 | ... | ... | >= 0
2187// ------+-------+---------+-----+---------------+-------+-------------------
2188// (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
2189// ... | 0 | -1 | 0 | ... | ... | = 0
2190// ------+-------+---------+-----+---------------+-------+-------------------
2191// (for each `op` map result r_i: set r_i to corresponding map result,
2192// prove that r_i >= minOpUb via contradiction)
2193// ... | 0 | 0 | -1 | ... | ... | = 0
2194// 0 | 0 | 1 | -1 | 0 | -1 | >= 0
2195//
2196FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
2197 Operation *op, FlatAffineValueConstraints constraints) {
2198 bool isMin = isa<AffineMinOp>(Val: op);
2199 assert((isMin || isa<AffineMaxOp>(op)) && "expect AffineMin/MaxOp");
2200 MLIRContext *ctx = op->getContext();
2201 Builder builder(ctx);
2202 AffineMap map =
2203 isMin ? cast<AffineMinOp>(Val: op).getMap() : cast<AffineMaxOp>(Val: op).getMap();
2204 ValueRange operands = op->getOperands();
2205 unsigned numResults = map.getNumResults();
2206
2207 // Add a few extra dimensions.
2208 unsigned dimOp = constraints.appendDimVar(); // `op`
2209 unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound
2210 unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults);
2211
2212 // Add an inequality for each result expr_i of map:
2213 // isMin: op <= expr_i, !isMin: op >= expr_i
2214 auto boundType = isMin ? BoundType::UB : BoundType::LB;
2215 // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.)
2216 AffineMap mapLbUb = isMin ? addConstToResults(map, val: 1) : map;
2217 if (failed(
2218 Result: alignAndAddBound(constraints, type: boundType, pos: dimOp, map: mapLbUb, operands)))
2219 return failure();
2220
2221 // Try to compute a lower/upper bound for op, expressed in terms of the other
2222 // `dims` and extra symbols.
2223 SmallVector<AffineMap> opLb(1), opUb(1);
2224 constraints.getSliceBounds(offset: dimOp, num: 1, context: ctx, lbMaps: &opLb, ubMaps: &opUb);
2225 AffineMap sliceBound = isMin ? opUb[0] : opLb[0];
2226 // TODO: `getSliceBounds` may return multiple bounds at the moment. This is
2227 // a TODO of `getSliceBounds` and not handled here.
2228 if (!sliceBound || sliceBound.getNumResults() != 1)
2229 return failure(); // No or multiple bounds found.
2230 // Recover the inclusive UB in the case of an `affine.min`.
2231 AffineMap boundMap = isMin ? addConstToResults(map: sliceBound, val: -1) : sliceBound;
2232
2233 // Add an equality: Set dimOpBound to computed bound.
2234 // Add back dimension for op. (Was removed by `getSliceBounds`.)
2235 AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp);
2236 if (failed(Result: constraints.addBound(type: BoundType::EQ, pos: dimOpBound, boundMap: alignedBoundMap)))
2237 return failure();
2238
2239 // If the constraint system is empty, there is an inconsistency. (E.g., this
2240 // can happen if loop lb > ub.)
2241 if (constraints.isEmpty())
2242 return failure();
2243
2244 // In the case of `isMin` (`!isMin` is inversed):
2245 // Prove that each result of `map` has a lower bound that is equal to (or
2246 // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op`
2247 // can be replaced with the bound. I.e., prove that for each result
2248 // expr_i (represented by dimension r_i):
2249 //
2250 // r_i >= opBound
2251 //
2252 // To prove this inequality, add its negation to the constraint set and prove
2253 // that the constraint set is empty.
2254 for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) {
2255 FlatAffineValueConstraints newConstr(constraints);
2256
2257 // Add an equality: r_i = expr_i
2258 // Note: These equalities could have been added earlier and used to express
2259 // minOp <= expr_i. However, then we run the risk that `getSliceBounds`
2260 // computes minOpUb in terms of r_i dims, which is not desired.
2261 if (failed(Result: alignAndAddBound(constraints&: newConstr, type: BoundType::EQ, pos: i,
2262 map: map.getSubMap(resultPos: {i - resultDimStart}), operands)))
2263 return failure();
2264
2265 // If `isMin`: Add inequality: r_i < opBound
2266 // equiv.: opBound - r_i - 1 >= 0
2267 // If `!isMin`: Add inequality: r_i > opBound
2268 // equiv.: -opBound + r_i - 1 >= 0
2269 SmallVector<int64_t> ineq(newConstr.getNumCols(), 0);
2270 ineq[dimOpBound] = isMin ? 1 : -1;
2271 ineq[i] = isMin ? -1 : 1;
2272 ineq[newConstr.getNumCols() - 1] = -1;
2273 newConstr.addInequality(inEq: ineq);
2274 if (!newConstr.isEmpty())
2275 return failure();
2276 }
2277
2278 // Lower and upper bound of `op` are equal. Replace `minOp` with its bound.
2279 AffineMap newMap = alignedBoundMap;
2280 SmallVector<Value> newOperands;
2281 unpackOptionalValues(source: constraints.getMaybeValues(), target&: newOperands);
2282 // If dims/symbols have known constant values, use those in order to simplify
2283 // the affine map further.
2284 for (int64_t i = 0, e = constraints.getNumDimAndSymbolVars(); i < e; ++i) {
2285 // Skip unused operands and operands that are already constants.
2286 if (!newOperands[i] || getConstantIntValue(ofr: newOperands[i]))
2287 continue;
2288 if (auto bound = constraints.getConstantBound64(type: BoundType::EQ, pos: i)) {
2289 AffineExpr expr =
2290 i < newMap.getNumDims()
2291 ? builder.getAffineDimExpr(position: i)
2292 : builder.getAffineSymbolExpr(position: i - newMap.getNumDims());
2293 newMap = newMap.replace(expr, replacement: builder.getAffineConstantExpr(constant: *bound),
2294 numResultDims: newMap.getNumDims(), numResultSyms: newMap.getNumSymbols());
2295 }
2296 }
2297 affine::canonicalizeMapAndOperands(map: &newMap, operands: &newOperands);
2298 return AffineValueMap(newMap, newOperands);
2299}
2300
2301Block *mlir::affine::findInnermostCommonBlockInScope(Operation *a,
2302 Operation *b) {
2303 Region *aScope = getAffineAnalysisScope(op: a);
2304 Region *bScope = getAffineAnalysisScope(op: b);
2305 if (aScope != bScope)
2306 return nullptr;
2307
2308 // Get the block ancestry of `op` while stopping at the affine scope `aScope`
2309 // and store them in `ancestry`.
2310 auto getBlockAncestry = [&](Operation *op,
2311 SmallVectorImpl<Block *> &ancestry) {
2312 Operation *curOp = op;
2313 do {
2314 ancestry.push_back(Elt: curOp->getBlock());
2315 if (curOp->getParentRegion() == aScope)
2316 break;
2317 curOp = curOp->getParentOp();
2318 } while (curOp);
2319 assert(curOp && "can't reach root op without passing through affine scope");
2320 std::reverse(first: ancestry.begin(), last: ancestry.end());
2321 };
2322
2323 SmallVector<Block *, 4> aAncestors, bAncestors;
2324 getBlockAncestry(a, aAncestors);
2325 getBlockAncestry(b, bAncestors);
2326 assert(!aAncestors.empty() && !bAncestors.empty() &&
2327 "at least one Block ancestor expected");
2328
2329 Block *innermostCommonBlock = nullptr;
2330 for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
2331 a < e && b < f; ++a, ++b) {
2332 if (aAncestors[a] != bAncestors[b])
2333 break;
2334 innermostCommonBlock = aAncestors[a];
2335 }
2336 return innermostCommonBlock;
2337}
2338

source code of mlir/lib/Dialect/Affine/Analysis/Utils.cpp