1//===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous analysis routines for non-loop IR
10// structures.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/Analysis/Utils.h"
15
16#include "mlir/Analysis/Presburger/PresburgerRelation.h"
17#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
18#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/Utils/StaticValueUtils.h"
23#include "mlir/IR/IntegerSet.h"
24#include "mlir/Interfaces/CallInterfaces.h"
25#include "llvm/ADT/SetVector.h"
26#include "llvm/ADT/SmallPtrSet.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/raw_ostream.h"
29#include <optional>
30
31#define DEBUG_TYPE "analysis-utils"
32
33using namespace mlir;
34using namespace affine;
35using namespace presburger;
36
37using llvm::SmallDenseMap;
38
39using Node = MemRefDependenceGraph::Node;
40
41// LoopNestStateCollector walks loop nests and collects load and store
42// operations, and whether or not a region holding op other than ForOp and IfOp
43// was encountered in the loop nest.
44void LoopNestStateCollector::collect(Operation *opToWalk) {
45 opToWalk->walk(callback: [&](Operation *op) {
46 if (auto forOp = dyn_cast<AffineForOp>(op)) {
47 forOps.push_back(forOp);
48 } else if (isa<AffineReadOpInterface>(op)) {
49 loadOpInsts.push_back(Elt: op);
50 } else if (isa<AffineWriteOpInterface>(op)) {
51 storeOpInsts.push_back(Elt: op);
52 } else {
53 auto memInterface = dyn_cast<MemoryEffectOpInterface>(op);
54 if (!memInterface) {
55 if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
56 // This op itself is memory-effect free.
57 return;
58 // Check operands. Eg. ops like the `call` op are handled here.
59 for (Value v : op->getOperands()) {
60 if (!isa<MemRefType>(Val: v.getType()))
61 continue;
62 // Conservatively, we assume the memref is read and written to.
63 memrefLoads.push_back(Elt: op);
64 memrefStores.push_back(Elt: op);
65 }
66 } else {
67 // Non-affine loads and stores.
68 if (hasEffect<MemoryEffects::Read>(op))
69 memrefLoads.push_back(Elt: op);
70 if (hasEffect<MemoryEffects::Write>(op))
71 memrefStores.push_back(Elt: op);
72 if (hasEffect<MemoryEffects::Free>(op))
73 memrefFrees.push_back(Elt: op);
74 }
75 }
76 });
77}
78
79unsigned Node::getLoadOpCount(Value memref) const {
80 unsigned loadOpCount = 0;
81 for (Operation *loadOp : loads) {
82 // Common case: affine reads.
83 if (auto affineLoad = dyn_cast<AffineReadOpInterface>(loadOp)) {
84 if (memref == affineLoad.getMemRef())
85 ++loadOpCount;
86 } else if (hasEffect<MemoryEffects::Read>(op: loadOp, value: memref)) {
87 ++loadOpCount;
88 }
89 }
90 return loadOpCount;
91}
92
93// Returns the store op count for 'memref'.
94unsigned Node::getStoreOpCount(Value memref) const {
95 unsigned storeOpCount = 0;
96 for (auto *storeOp : llvm::concat<Operation *const>(Ranges: stores, Ranges: memrefStores)) {
97 // Common case: affine writes.
98 if (auto affineStore = dyn_cast<AffineWriteOpInterface>(storeOp)) {
99 if (memref == affineStore.getMemRef())
100 ++storeOpCount;
101 } else if (hasEffect<MemoryEffects::Write>(op: const_cast<Operation *>(storeOp),
102 value: memref)) {
103 ++storeOpCount;
104 }
105 }
106 return storeOpCount;
107}
108
109// Returns the store op count for 'memref'.
110unsigned Node::hasStore(Value memref) const {
111 return llvm::any_of(
112 Range: llvm::concat<Operation *const>(Ranges: stores, Ranges: memrefStores),
113 P: [&](Operation *storeOp) {
114 if (auto affineStore = dyn_cast<AffineWriteOpInterface>(storeOp)) {
115 if (memref == affineStore.getMemRef())
116 return true;
117 } else if (hasEffect<MemoryEffects::Write>(op: storeOp, value: memref)) {
118 return true;
119 }
120 return false;
121 });
122}
123
124unsigned Node::hasFree(Value memref) const {
125 return llvm::any_of(Range: memrefFrees, P: [&](Operation *freeOp) {
126 return hasEffect<MemoryEffects::Free>(op: freeOp, value: memref);
127 });
128}
129
130// Returns all store ops in 'storeOps' which access 'memref'.
131void Node::getStoreOpsForMemref(Value memref,
132 SmallVectorImpl<Operation *> *storeOps) const {
133 for (Operation *storeOp : stores) {
134 if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
135 storeOps->push_back(Elt: storeOp);
136 }
137}
138
139// Returns all load ops in 'loadOps' which access 'memref'.
140void Node::getLoadOpsForMemref(Value memref,
141 SmallVectorImpl<Operation *> *loadOps) const {
142 for (Operation *loadOp : loads) {
143 if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
144 loadOps->push_back(Elt: loadOp);
145 }
146}
147
148// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
149// has at least one load and store operation.
150void Node::getLoadAndStoreMemrefSet(
151 DenseSet<Value> *loadAndStoreMemrefSet) const {
152 llvm::SmallDenseSet<Value, 2> loadMemrefs;
153 for (Operation *loadOp : loads) {
154 loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
155 }
156 for (Operation *storeOp : stores) {
157 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
158 if (loadMemrefs.count(V: memref) > 0)
159 loadAndStoreMemrefSet->insert(memref);
160 }
161}
162
163/// Returns the values that this op has a memref effect of type `EffectTys` on,
164/// not considering recursive effects.
165template <typename... EffectTys>
166static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
167 auto memOp = dyn_cast<MemoryEffectOpInterface>(op);
168 if (!memOp) {
169 if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
170 // No effects.
171 return;
172 // Memref operands have to be considered as being affected.
173 for (Value operand : op->getOperands()) {
174 if (isa<MemRefType>(Val: operand.getType()))
175 values.push_back(Elt: operand);
176 }
177 return;
178 }
179 SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
180 memOp.getEffects(effects);
181 for (auto &effect : effects) {
182 Value effectVal = effect.getValue();
183 if (isa<EffectTys...>(effect.getEffect()) && effectVal &&
184 isa<MemRefType>(Val: effectVal.getType()))
185 values.push_back(Elt: effectVal);
186 };
187}
188
189/// Add `op` to MDG creating a new node and adding its memory accesses (affine
190/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
191static Node *
192addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
193 DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
194 auto &nodes = mdg.nodes;
195 // Create graph node 'id' to represent top-level 'forOp' and record
196 // all loads and store accesses it contains.
197 LoopNestStateCollector collector;
198 collector.collect(opToWalk: nodeOp);
199 unsigned newNodeId = mdg.nextNodeId++;
200 Node &node = nodes.insert(KV: {newNodeId, Node(newNodeId, nodeOp)}).first->second;
201 for (Operation *op : collector.loadOpInsts) {
202 node.loads.push_back(Elt: op);
203 auto memref = cast<AffineReadOpInterface>(op).getMemRef();
204 memrefAccesses[memref].insert(node.id);
205 }
206 for (Operation *op : collector.storeOpInsts) {
207 node.stores.push_back(Elt: op);
208 auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
209 memrefAccesses[memref].insert(node.id);
210 }
211 for (Operation *op : collector.memrefLoads) {
212 SmallVector<Value> effectedValues;
213 getEffectedValues<MemoryEffects::Read>(op, values&: effectedValues);
214 if (llvm::any_of(Range: ((ValueRange)effectedValues).getTypes(),
215 P: [](Type type) { return !isa<MemRefType>(Val: type); }))
216 // We do not know the interaction here.
217 return nullptr;
218 for (Value memref : effectedValues)
219 memrefAccesses[memref].insert(X: node.id);
220 node.memrefLoads.push_back(Elt: op);
221 }
222 for (Operation *op : collector.memrefStores) {
223 SmallVector<Value> effectedValues;
224 getEffectedValues<MemoryEffects::Write>(op, values&: effectedValues);
225 if (llvm::any_of(Range: (ValueRange(effectedValues)).getTypes(),
226 P: [](Type type) { return !isa<MemRefType>(Val: type); }))
227 return nullptr;
228 for (Value memref : effectedValues)
229 memrefAccesses[memref].insert(X: node.id);
230 node.memrefStores.push_back(Elt: op);
231 }
232 for (Operation *op : collector.memrefFrees) {
233 SmallVector<Value> effectedValues;
234 getEffectedValues<MemoryEffects::Free>(op, values&: effectedValues);
235 if (llvm::any_of(Range: (ValueRange(effectedValues)).getTypes(),
236 P: [](Type type) { return !isa<MemRefType>(Val: type); }))
237 return nullptr;
238 for (Value memref : effectedValues)
239 memrefAccesses[memref].insert(X: node.id);
240 node.memrefFrees.push_back(Elt: op);
241 }
242
243 return &node;
244}
245
246bool MemRefDependenceGraph::init() {
247 LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
248 // Map from a memref to the set of ids of the nodes that have ops accessing
249 // the memref.
250 DenseMap<Value, SetVector<unsigned>> memrefAccesses;
251
252 // Create graph nodes.
253 DenseMap<Operation *, unsigned> forToNodeMap;
254 for (Operation &op : block) {
255 if (auto forOp = dyn_cast<AffineForOp>(op)) {
256 Node *node = addNodeToMDG(nodeOp: &op, mdg&: *this, memrefAccesses);
257 if (!node)
258 return false;
259 forToNodeMap[&op] = node->id;
260 } else if (isa<AffineReadOpInterface>(op)) {
261 // Create graph node for top-level load op.
262 Node node(nextNodeId++, &op);
263 node.loads.push_back(Elt: &op);
264 auto memref = cast<AffineReadOpInterface>(op).getMemRef();
265 memrefAccesses[memref].insert(node.id);
266 nodes.insert(KV: {node.id, node});
267 } else if (isa<AffineWriteOpInterface>(op)) {
268 // Create graph node for top-level store op.
269 Node node(nextNodeId++, &op);
270 node.stores.push_back(Elt: &op);
271 auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
272 memrefAccesses[memref].insert(node.id);
273 nodes.insert(KV: {node.id, node});
274 } else if (op.getNumResults() > 0 && !op.use_empty()) {
275 // Create graph node for top-level producer of SSA values, which
276 // could be used by loop nest nodes.
277 Node *node = addNodeToMDG(nodeOp: &op, mdg&: *this, memrefAccesses);
278 if (!node)
279 return false;
280 } else if (!isMemoryEffectFree(op: &op) &&
281 (op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(Val: op))) {
282 // Create graph node for top-level op unless it is known to be
283 // memory-effect free. This covers all unknown/unregistered ops,
284 // non-affine ops with memory effects, and region-holding ops with a
285 // well-defined control flow. During the fusion validity checks, edges
286 // to/from these ops get looked at.
287 Node *node = addNodeToMDG(nodeOp: &op, mdg&: *this, memrefAccesses);
288 if (!node)
289 return false;
290 } else if (op.getNumRegions() != 0 && !isa<RegionBranchOpInterface>(Val: op)) {
291 // Return false if non-handled/unknown region-holding ops are found. We
292 // won't know what such ops do or what its regions mean; for e.g., it may
293 // not be an imperative op.
294 LLVM_DEBUG(llvm::dbgs()
295 << "MDG init failed; unknown region-holding op found!\n");
296 return false;
297 }
298 // We aren't creating nodes for memory-effect free ops either with no
299 // regions (unless it has results being used) or those with branch op
300 // interface.
301 }
302
303 LLVM_DEBUG(llvm::dbgs() << "Created " << nodes.size() << " nodes\n");
304
305 // Add dependence edges between nodes which produce SSA values and their
306 // users. Load ops can be considered as the ones producing SSA values.
307 for (auto &idAndNode : nodes) {
308 const Node &node = idAndNode.second;
309 // Stores don't define SSA values, skip them.
310 if (!node.stores.empty())
311 continue;
312 Operation *opInst = node.op;
313 for (Value value : opInst->getResults()) {
314 for (Operation *user : value.getUsers()) {
315 // Ignore users outside of the block.
316 if (block.getParent()->findAncestorOpInRegion(op&: *user)->getBlock() !=
317 &block)
318 continue;
319 SmallVector<AffineForOp, 4> loops;
320 getAffineForIVs(*user, &loops);
321 // Find the surrounding affine.for nested immediately within the
322 // block.
323 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
324 return loop->getBlock() == &block;
325 });
326 if (it == loops.end())
327 continue;
328 assert(forToNodeMap.count(*it) > 0 && "missing mapping");
329 unsigned userLoopNestId = forToNodeMap[*it];
330 addEdge(srcId: node.id, dstId: userLoopNestId, value);
331 }
332 }
333 }
334
335 // Walk memref access lists and add graph edges between dependent nodes.
336 for (auto &memrefAndList : memrefAccesses) {
337 unsigned n = memrefAndList.second.size();
338 Value srcMemRef = memrefAndList.first;
339 // Add edges between all dependent pairs among the node IDs on this memref.
340 for (unsigned i = 0; i < n; ++i) {
341 unsigned srcId = memrefAndList.second[i];
342 Node *srcNode = getNode(id: srcId);
343 bool srcHasStoreOrFree =
344 srcNode->hasStore(memref: srcMemRef) || srcNode->hasFree(memref: srcMemRef);
345 for (unsigned j = i + 1; j < n; ++j) {
346 unsigned dstId = memrefAndList.second[j];
347 Node *dstNode = getNode(id: dstId);
348 bool dstHasStoreOrFree =
349 dstNode->hasStore(memref: srcMemRef) || dstNode->hasFree(memref: srcMemRef);
350 if (srcHasStoreOrFree || dstHasStoreOrFree)
351 addEdge(srcId, dstId, value: srcMemRef);
352 }
353 }
354 }
355 return true;
356}
357
358// Returns the graph node for 'id'.
359const Node *MemRefDependenceGraph::getNode(unsigned id) const {
360 auto it = nodes.find(Val: id);
361 assert(it != nodes.end());
362 return &it->second;
363}
364
365// Returns the graph node for 'forOp'.
366const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
367 for (auto &idAndNode : nodes)
368 if (idAndNode.second.op == forOp)
369 return &idAndNode.second;
370 return nullptr;
371}
372
373// Adds a node with 'op' to the graph and returns its unique identifier.
374unsigned MemRefDependenceGraph::addNode(Operation *op) {
375 Node node(nextNodeId++, op);
376 nodes.insert(KV: {node.id, node});
377 return node.id;
378}
379
380// Remove node 'id' (and its associated edges) from graph.
381void MemRefDependenceGraph::removeNode(unsigned id) {
382 // Remove each edge in 'inEdges[id]'.
383 if (inEdges.count(Val: id) > 0) {
384 SmallVector<Edge, 2> oldInEdges = inEdges[id];
385 for (auto &inEdge : oldInEdges) {
386 removeEdge(srcId: inEdge.id, dstId: id, value: inEdge.value);
387 }
388 }
389 // Remove each edge in 'outEdges[id]'.
390 if (outEdges.contains(Val: id)) {
391 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
392 for (auto &outEdge : oldOutEdges) {
393 removeEdge(srcId: id, dstId: outEdge.id, value: outEdge.value);
394 }
395 }
396 // Erase remaining node state.
397 inEdges.erase(Val: id);
398 outEdges.erase(Val: id);
399 nodes.erase(Val: id);
400}
401
402// Returns true if node 'id' writes to any memref which escapes (or is an
403// argument to) the block. Returns false otherwise.
404bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
405 const Node *node = getNode(id);
406 for (auto *storeOpInst : node->stores) {
407 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
408 auto *op = memref.getDefiningOp();
409 // Return true if 'memref' is a block argument.
410 if (!op)
411 return true;
412 // Return true if any use of 'memref' does not deference it in an affine
413 // way.
414 for (auto *user : memref.getUsers())
415 if (!isa<AffineMapAccessInterface>(*user))
416 return true;
417 }
418 return false;
419}
420
421// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
422// is for 'value' if non-null, or for any value otherwise. Returns false
423// otherwise.
424bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
425 Value value) const {
426 if (!outEdges.contains(Val: srcId) || !inEdges.contains(Val: dstId)) {
427 return false;
428 }
429 bool hasOutEdge = llvm::any_of(Range: outEdges.lookup(Val: srcId), P: [=](const Edge &edge) {
430 return edge.id == dstId && (!value || edge.value == value);
431 });
432 bool hasInEdge = llvm::any_of(Range: inEdges.lookup(Val: dstId), P: [=](const Edge &edge) {
433 return edge.id == srcId && (!value || edge.value == value);
434 });
435 return hasOutEdge && hasInEdge;
436}
437
438// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
439void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
440 Value value) {
441 if (!hasEdge(srcId, dstId, value)) {
442 outEdges[srcId].push_back(Elt: {.id: dstId, .value: value});
443 inEdges[dstId].push_back(Elt: {.id: srcId, .value: value});
444 if (isa<MemRefType>(Val: value.getType()))
445 memrefEdgeCount[value]++;
446 }
447}
448
449// Removes an edge from node 'srcId' to node 'dstId' for 'value'.
450void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
451 Value value) {
452 assert(inEdges.count(dstId) > 0);
453 assert(outEdges.count(srcId) > 0);
454 if (isa<MemRefType>(Val: value.getType())) {
455 assert(memrefEdgeCount.count(value) > 0);
456 memrefEdgeCount[value]--;
457 }
458 // Remove 'srcId' from 'inEdges[dstId]'.
459 for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
460 if ((*it).id == srcId && (*it).value == value) {
461 inEdges[dstId].erase(CI: it);
462 break;
463 }
464 }
465 // Remove 'dstId' from 'outEdges[srcId]'.
466 for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
467 if ((*it).id == dstId && (*it).value == value) {
468 outEdges[srcId].erase(CI: it);
469 break;
470 }
471 }
472}
473
474// Returns true if there is a path in the dependence graph from node 'srcId'
475// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
476// operations that the edges connected are expected to be from the same block.
477bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
478 unsigned dstId) const {
479 // Worklist state is: <node-id, next-output-edge-index-to-visit>
480 SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
481 worklist.push_back(Elt: {srcId, 0});
482 Operation *dstOp = getNode(id: dstId)->op;
483 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
484 while (!worklist.empty()) {
485 auto &idAndIndex = worklist.back();
486 // Return true if we have reached 'dstId'.
487 if (idAndIndex.first == dstId)
488 return true;
489 // Pop and continue if node has no out edges, or if all out edges have
490 // already been visited.
491 if (!outEdges.contains(Val: idAndIndex.first) ||
492 idAndIndex.second == outEdges.lookup(Val: idAndIndex.first).size()) {
493 worklist.pop_back();
494 continue;
495 }
496 // Get graph edge to traverse.
497 const Edge edge = outEdges.lookup(Val: idAndIndex.first)[idAndIndex.second];
498 // Increment next output edge index for 'idAndIndex'.
499 ++idAndIndex.second;
500 // Add node at 'edge.id' to the worklist. We don't need to consider
501 // nodes that are "after" dstId in the containing block; one can't have a
502 // path to `dstId` from any of those nodes.
503 bool afterDst = dstOp->isBeforeInBlock(other: getNode(id: edge.id)->op);
504 if (!afterDst && edge.id != idAndIndex.first)
505 worklist.push_back(Elt: {edge.id, 0});
506 }
507 return false;
508}
509
510// Returns the input edge count for node 'id' and 'memref' from src nodes
511// which access 'memref' with a store operation.
512unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
513 Value memref) const {
514 unsigned inEdgeCount = 0;
515 for (const Edge &inEdge : inEdges.lookup(Val: id)) {
516 if (inEdge.value == memref) {
517 const Node *srcNode = getNode(id: inEdge.id);
518 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
519 if (srcNode->getStoreOpCount(memref) > 0)
520 ++inEdgeCount;
521 }
522 }
523 return inEdgeCount;
524}
525
526// Returns the output edge count for node 'id' and 'memref' (if non-null),
527// otherwise returns the total output edge count from node 'id'.
528unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
529 Value memref) const {
530 unsigned outEdgeCount = 0;
531 for (const auto &outEdge : outEdges.lookup(Val: id))
532 if (!memref || outEdge.value == memref)
533 ++outEdgeCount;
534 return outEdgeCount;
535}
536
537/// Return all nodes which define SSA values used in node 'id'.
538void MemRefDependenceGraph::gatherDefiningNodes(
539 unsigned id, DenseSet<unsigned> &definingNodes) const {
540 for (const Edge &edge : inEdges.lookup(Val: id))
541 // By definition of edge, if the edge value is a non-memref value,
542 // then the dependence is between a graph node which defines an SSA value
543 // and another graph node which uses the SSA value.
544 if (!isa<MemRefType>(Val: edge.value.getType()))
545 definingNodes.insert(V: edge.id);
546}
547
548// Computes and returns an insertion point operation, before which the
549// the fused <srcId, dstId> loop nest can be inserted while preserving
550// dependences. Returns nullptr if no such insertion point is found.
551Operation *
552MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
553 unsigned dstId) const {
554 if (!outEdges.contains(Val: srcId))
555 return getNode(id: dstId)->op;
556
557 // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
558 DenseSet<unsigned> definingNodes;
559 gatherDefiningNodes(id: dstId, definingNodes);
560 if (llvm::any_of(Range&: definingNodes,
561 P: [&](unsigned id) { return hasDependencePath(srcId, dstId: id); })) {
562 LLVM_DEBUG(llvm::dbgs()
563 << "Can't fuse: a defining op with a user in the dst "
564 "loop has dependence from the src loop\n");
565 return nullptr;
566 }
567
568 // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
569 SmallPtrSet<Operation *, 2> srcDepInsts;
570 for (auto &outEdge : outEdges.lookup(Val: srcId))
571 if (outEdge.id != dstId)
572 srcDepInsts.insert(Ptr: getNode(id: outEdge.id)->op);
573
574 // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
575 SmallPtrSet<Operation *, 2> dstDepInsts;
576 for (auto &inEdge : inEdges.lookup(Val: dstId))
577 if (inEdge.id != srcId)
578 dstDepInsts.insert(Ptr: getNode(id: inEdge.id)->op);
579
580 Operation *srcNodeInst = getNode(id: srcId)->op;
581 Operation *dstNodeInst = getNode(id: dstId)->op;
582
583 // Computing insertion point:
584 // *) Walk all operation positions in Block operation list in the
585 // range (src, dst). For each operation 'op' visited in this search:
586 // *) Store in 'firstSrcDepPos' the first position where 'op' has a
587 // dependence edge from 'srcNode'.
588 // *) Store in 'lastDstDepPost' the last position where 'op' has a
589 // dependence edge to 'dstNode'.
590 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
591 // operation insertion point (or return null pointer if no such
592 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
593 SmallVector<Operation *, 2> depInsts;
594 std::optional<unsigned> firstSrcDepPos;
595 std::optional<unsigned> lastDstDepPos;
596 unsigned pos = 0;
597 for (Block::iterator it = std::next(x: Block::iterator(srcNodeInst));
598 it != Block::iterator(dstNodeInst); ++it) {
599 Operation *op = &(*it);
600 if (srcDepInsts.count(Ptr: op) > 0 && firstSrcDepPos == std::nullopt)
601 firstSrcDepPos = pos;
602 if (dstDepInsts.count(Ptr: op) > 0)
603 lastDstDepPos = pos;
604 depInsts.push_back(Elt: op);
605 ++pos;
606 }
607
608 if (firstSrcDepPos.has_value()) {
609 if (lastDstDepPos.has_value()) {
610 if (*firstSrcDepPos <= *lastDstDepPos) {
611 // No valid insertion point exists which preserves dependences.
612 return nullptr;
613 }
614 }
615 // Return the insertion point at 'firstSrcDepPos'.
616 return depInsts[*firstSrcDepPos];
617 }
618 // No dependence targets in range (or only dst deps in range), return
619 // 'dstNodInst' insertion point.
620 return dstNodeInst;
621}
622
623// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
624// taking into account that:
625// *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
626// *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
627// private memref.
628void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
629 const DenseSet<Value> &privateMemRefs,
630 bool removeSrcId) {
631 // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
632 if (inEdges.count(Val: srcId) > 0) {
633 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
634 for (auto &inEdge : oldInEdges) {
635 // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
636 if (!privateMemRefs.contains(V: inEdge.value))
637 addEdge(srcId: inEdge.id, dstId, value: inEdge.value);
638 }
639 }
640 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
641 // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
642 if (outEdges.count(Val: srcId) > 0) {
643 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
644 for (auto &outEdge : oldOutEdges) {
645 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
646 if (outEdge.id == dstId)
647 removeEdge(srcId, dstId: outEdge.id, value: outEdge.value);
648 else if (removeSrcId) {
649 addEdge(srcId: dstId, dstId: outEdge.id, value: outEdge.value);
650 removeEdge(srcId, dstId: outEdge.id, value: outEdge.value);
651 }
652 }
653 }
654 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
655 // replaced by a private memref). These edges could come from nodes
656 // other than 'srcId' which were removed in the previous step.
657 if (inEdges.count(Val: dstId) > 0 && !privateMemRefs.empty()) {
658 SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
659 for (auto &inEdge : oldInEdges)
660 if (privateMemRefs.count(V: inEdge.value) > 0)
661 removeEdge(srcId: inEdge.id, dstId, value: inEdge.value);
662 }
663}
664
665// Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
666// of sibling node 'sibId' into node 'dstId'.
667void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) {
668 // For each edge in 'inEdges[sibId]':
669 // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
670 // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
671 if (inEdges.count(Val: sibId) > 0) {
672 SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
673 for (auto &inEdge : oldInEdges) {
674 addEdge(srcId: inEdge.id, dstId, value: inEdge.value);
675 removeEdge(srcId: inEdge.id, dstId: sibId, value: inEdge.value);
676 }
677 }
678
679 // For each edge in 'outEdges[sibId]' to node 'id'
680 // *) Add new edge from 'dstId' to 'outEdge.id'.
681 // *) Remove edge from 'sibId' to 'outEdge.id'.
682 if (outEdges.count(Val: sibId) > 0) {
683 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
684 for (auto &outEdge : oldOutEdges) {
685 addEdge(srcId: dstId, dstId: outEdge.id, value: outEdge.value);
686 removeEdge(srcId: sibId, dstId: outEdge.id, value: outEdge.value);
687 }
688 }
689}
690
691// Adds ops in 'loads' and 'stores' to node at 'id'.
692void MemRefDependenceGraph::addToNode(unsigned id, ArrayRef<Operation *> loads,
693 ArrayRef<Operation *> stores,
694 ArrayRef<Operation *> memrefLoads,
695 ArrayRef<Operation *> memrefStores,
696 ArrayRef<Operation *> memrefFrees) {
697 Node *node = getNode(id);
698 llvm::append_range(C&: node->loads, R&: loads);
699 llvm::append_range(C&: node->stores, R&: stores);
700 llvm::append_range(C&: node->memrefLoads, R&: memrefLoads);
701 llvm::append_range(C&: node->memrefStores, R&: memrefStores);
702 llvm::append_range(C&: node->memrefFrees, R&: memrefFrees);
703}
704
705void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) {
706 Node *node = getNode(id);
707 node->loads.clear();
708 node->stores.clear();
709}
710
711// Calls 'callback' for each input edge incident to node 'id' which carries a
712// memref dependence.
713void MemRefDependenceGraph::forEachMemRefInputEdge(
714 unsigned id, const std::function<void(Edge)> &callback) {
715 if (inEdges.count(Val: id) > 0)
716 forEachMemRefEdge(edges: inEdges[id], callback);
717}
718
719// Calls 'callback' for each output edge from node 'id' which carries a
720// memref dependence.
721void MemRefDependenceGraph::forEachMemRefOutputEdge(
722 unsigned id, const std::function<void(Edge)> &callback) {
723 if (outEdges.count(Val: id) > 0)
724 forEachMemRefEdge(edges: outEdges[id], callback);
725}
726
727// Calls 'callback' for each edge in 'edges' which carries a memref
728// dependence.
729void MemRefDependenceGraph::forEachMemRefEdge(
730 ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
731 for (const auto &edge : edges) {
732 // Skip if 'edge' is not a memref dependence edge.
733 if (!isa<MemRefType>(Val: edge.value.getType()))
734 continue;
735 assert(nodes.count(edge.id) > 0);
736 // Skip if 'edge.id' is not a loop nest.
737 if (!isa<AffineForOp>(Val: getNode(id: edge.id)->op))
738 continue;
739 // Visit current input edge 'edge'.
740 callback(edge);
741 }
742}
743
744void MemRefDependenceGraph::print(raw_ostream &os) const {
745 os << "\nMemRefDependenceGraph\n";
746 os << "\nNodes:\n";
747 for (const auto &idAndNode : nodes) {
748 os << "Node: " << idAndNode.first << "\n";
749 auto it = inEdges.find(Val: idAndNode.first);
750 if (it != inEdges.end()) {
751 for (const auto &e : it->second)
752 os << " InEdge: " << e.id << " " << e.value << "\n";
753 }
754 it = outEdges.find(Val: idAndNode.first);
755 if (it != outEdges.end()) {
756 for (const auto &e : it->second)
757 os << " OutEdge: " << e.id << " " << e.value << "\n";
758 }
759 }
760}
761
762void mlir::affine::getAffineForIVs(Operation &op,
763 SmallVectorImpl<AffineForOp> *loops) {
764 auto *currOp = op.getParentOp();
765 AffineForOp currAffineForOp;
766 // Traverse up the hierarchy collecting all 'affine.for' operation while
767 // skipping over 'affine.if' operations.
768 while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
769 if (auto currAffineForOp = dyn_cast<AffineForOp>(currOp))
770 loops->push_back(currAffineForOp);
771 currOp = currOp->getParentOp();
772 }
773 std::reverse(loops->begin(), loops->end());
774}
775
776void mlir::affine::getEnclosingAffineOps(Operation &op,
777 SmallVectorImpl<Operation *> *ops) {
778 ops->clear();
779 Operation *currOp = op.getParentOp();
780
781 // Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and
782 // affine.parallel operations.
783 while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
784 if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(Val: currOp))
785 ops->push_back(Elt: currOp);
786 currOp = currOp->getParentOp();
787 }
788 std::reverse(first: ops->begin(), last: ops->end());
789}
790
791// Populates 'cst' with FlatAffineValueConstraints which represent original
792// domain of the loop bounds that define 'ivs'.
793LogicalResult ComputationSliceState::getSourceAsConstraints(
794 FlatAffineValueConstraints &cst) const {
795 assert(!ivs.empty() && "Cannot have a slice without its IVs");
796 cst = FlatAffineValueConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0,
797 /*numLocals=*/0, ivs);
798 for (Value iv : ivs) {
799 AffineForOp loop = getForInductionVarOwner(iv);
800 assert(loop && "Expected affine for");
801 if (failed(cst.addAffineForOpDomain(forOp: loop)))
802 return failure();
803 }
804 return success();
805}
806
807// Populates 'cst' with FlatAffineValueConstraints which represent slice bounds.
808LogicalResult
809ComputationSliceState::getAsConstraints(FlatAffineValueConstraints *cst) const {
810 assert(!lbOperands.empty());
811 // Adds src 'ivs' as dimension variables in 'cst'.
812 unsigned numDims = ivs.size();
813 // Adds operands (dst ivs and symbols) as symbols in 'cst'.
814 unsigned numSymbols = lbOperands[0].size();
815
816 SmallVector<Value, 4> values(ivs);
817 // Append 'ivs' then 'operands' to 'values'.
818 values.append(in_start: lbOperands[0].begin(), in_end: lbOperands[0].end());
819 *cst = FlatAffineValueConstraints(numDims, numSymbols, 0, values);
820
821 // Add loop bound constraints for values which are loop IVs of the destination
822 // of fusion and equality constraints for symbols which are constants.
823 for (unsigned i = numDims, end = values.size(); i < end; ++i) {
824 Value value = values[i];
825 assert(cst->containsVar(value) && "value expected to be present");
826 if (isValidSymbol(value)) {
827 // Check if the symbol is a constant.
828 if (std::optional<int64_t> cOp = getConstantIntValue(ofr: value))
829 cst->addBound(type: BoundType::EQ, val: value, value: cOp.value());
830 } else if (auto loop = getForInductionVarOwner(value)) {
831 if (failed(cst->addAffineForOpDomain(forOp: loop)))
832 return failure();
833 }
834 }
835
836 // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
837 LogicalResult ret = cst->addSliceBounds(values: ivs, lbMaps: lbs, ubMaps: ubs, operands: lbOperands[0]);
838 assert(succeeded(ret) &&
839 "should not fail as we never have semi-affine slice maps");
840 (void)ret;
841 return success();
842}
843
844// Clears state bounds and operand state.
845void ComputationSliceState::clearBounds() {
846 lbs.clear();
847 ubs.clear();
848 lbOperands.clear();
849 ubOperands.clear();
850}
851
852void ComputationSliceState::dump() const {
853 llvm::errs() << "\tIVs:\n";
854 for (Value iv : ivs)
855 llvm::errs() << "\t\t" << iv << "\n";
856
857 llvm::errs() << "\tLBs:\n";
858 for (auto en : llvm::enumerate(First: lbs)) {
859 llvm::errs() << "\t\t" << en.value() << "\n";
860 llvm::errs() << "\t\tOperands:\n";
861 for (Value lbOp : lbOperands[en.index()])
862 llvm::errs() << "\t\t\t" << lbOp << "\n";
863 }
864
865 llvm::errs() << "\tUBs:\n";
866 for (auto en : llvm::enumerate(First: ubs)) {
867 llvm::errs() << "\t\t" << en.value() << "\n";
868 llvm::errs() << "\t\tOperands:\n";
869 for (Value ubOp : ubOperands[en.index()])
870 llvm::errs() << "\t\t\t" << ubOp << "\n";
871 }
872}
873
874/// Fast check to determine if the computation slice is maximal. Returns true if
875/// each slice dimension maps to an existing dst dimension and both the src
876/// and the dst loops for those dimensions have the same bounds. Returns false
877/// if both the src and the dst loops don't have the same bounds. Returns
878/// std::nullopt if none of the above can be proven.
879std::optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
880 assert(lbs.size() == ubs.size() && !lbs.empty() && !ivs.empty() &&
881 "Unexpected number of lbs, ubs and ivs in slice");
882
883 for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
884 AffineMap lbMap = lbs[i];
885 AffineMap ubMap = ubs[i];
886
887 // Check if this slice is just an equality along this dimension.
888 if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
889 ubMap.getNumResults() != 1 ||
890 lbMap.getResult(idx: 0) + 1 != ubMap.getResult(idx: 0) ||
891 // The condition above will be true for maps describing a single
892 // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
893 // Make sure we skip those cases by checking that the lb result is not
894 // just a constant.
895 isa<AffineConstantExpr>(Val: lbMap.getResult(idx: 0)))
896 return std::nullopt;
897
898 // Limited support: we expect the lb result to be just a loop dimension for
899 // now.
900 AffineDimExpr result = dyn_cast<AffineDimExpr>(Val: lbMap.getResult(idx: 0));
901 if (!result)
902 return std::nullopt;
903
904 // Retrieve dst loop bounds.
905 AffineForOp dstLoop =
906 getForInductionVarOwner(lbOperands[i][result.getPosition()]);
907 if (!dstLoop)
908 return std::nullopt;
909 AffineMap dstLbMap = dstLoop.getLowerBoundMap();
910 AffineMap dstUbMap = dstLoop.getUpperBoundMap();
911
912 // Retrieve src loop bounds.
913 AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
914 assert(srcLoop && "Expected affine for");
915 AffineMap srcLbMap = srcLoop.getLowerBoundMap();
916 AffineMap srcUbMap = srcLoop.getUpperBoundMap();
917
918 // Limited support: we expect simple src and dst loops with a single
919 // constant component per bound for now.
920 if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
921 dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
922 return std::nullopt;
923
924 AffineExpr srcLbResult = srcLbMap.getResult(idx: 0);
925 AffineExpr dstLbResult = dstLbMap.getResult(idx: 0);
926 AffineExpr srcUbResult = srcUbMap.getResult(idx: 0);
927 AffineExpr dstUbResult = dstUbMap.getResult(idx: 0);
928 if (!isa<AffineConstantExpr>(Val: srcLbResult) ||
929 !isa<AffineConstantExpr>(Val: srcUbResult) ||
930 !isa<AffineConstantExpr>(Val: dstLbResult) ||
931 !isa<AffineConstantExpr>(Val: dstUbResult))
932 return std::nullopt;
933
934 // Check if src and dst loop bounds are the same. If not, we can guarantee
935 // that the slice is not maximal.
936 if (srcLbResult != dstLbResult || srcUbResult != dstUbResult ||
937 srcLoop.getStep() != dstLoop.getStep())
938 return false;
939 }
940
941 return true;
942}
943
944/// Returns true if it is deterministically verified that the original iteration
945/// space of the slice is contained within the new iteration space that is
946/// created after fusing 'this' slice into its destination.
947std::optional<bool> ComputationSliceState::isSliceValid() const {
948 // Fast check to determine if the slice is valid. If the following conditions
949 // are verified to be true, slice is declared valid by the fast check:
950 // 1. Each slice loop is a single iteration loop bound in terms of a single
951 // destination loop IV.
952 // 2. Loop bounds of the destination loop IV (from above) and those of the
953 // source loop IV are exactly the same.
954 // If the fast check is inconclusive or false, we proceed with a more
955 // expensive analysis.
956 // TODO: Store the result of the fast check, as it might be used again in
957 // `canRemoveSrcNodeAfterFusion`.
958 std::optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
959 if (isValidFastCheck && *isValidFastCheck)
960 return true;
961
962 // Create constraints for the source loop nest using which slice is computed.
963 FlatAffineValueConstraints srcConstraints;
964 // TODO: Store the source's domain to avoid computation at each depth.
965 if (failed(Result: getSourceAsConstraints(cst&: srcConstraints))) {
966 LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
967 return std::nullopt;
968 }
969 // As the set difference utility currently cannot handle symbols in its
970 // operands, validity of the slice cannot be determined.
971 if (srcConstraints.getNumSymbolVars() > 0) {
972 LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
973 return std::nullopt;
974 }
975 // TODO: Handle local vars in the source domains while using the 'projectOut'
976 // utility below. Currently, aligning is not done assuming that there will be
977 // no local vars in the source domain.
978 if (srcConstraints.getNumLocalVars() != 0) {
979 LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
980 return std::nullopt;
981 }
982
983 // Create constraints for the slice loop nest that would be created if the
984 // fusion succeeds.
985 FlatAffineValueConstraints sliceConstraints;
986 if (failed(Result: getAsConstraints(cst: &sliceConstraints))) {
987 LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
988 return std::nullopt;
989 }
990
991 // Projecting out every dimension other than the 'ivs' to express slice's
992 // domain completely in terms of source's IVs.
993 sliceConstraints.projectOut(pos: ivs.size(),
994 num: sliceConstraints.getNumVars() - ivs.size());
995
996 LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
997 LLVM_DEBUG(srcConstraints.dump());
998 LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
999 "(expressed in terms of its source's IVs):\n");
1000 LLVM_DEBUG(sliceConstraints.dump());
1001
1002 // TODO: Store 'srcSet' to avoid recalculating for each depth.
1003 PresburgerSet srcSet(srcConstraints);
1004 PresburgerSet sliceSet(sliceConstraints);
1005 PresburgerSet diffSet = sliceSet.subtract(set: srcSet);
1006
1007 if (!diffSet.isIntegerEmpty()) {
1008 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
1009 return false;
1010 }
1011 return true;
1012}
1013
1014/// Returns true if the computation slice encloses all the iterations of the
1015/// sliced loop nest. Returns false if it does not. Returns std::nullopt if it
1016/// cannot determine if the slice is maximal or not.
1017std::optional<bool> ComputationSliceState::isMaximal() const {
1018 // Fast check to determine if the computation slice is maximal. If the result
1019 // is inconclusive, we proceed with a more expensive analysis.
1020 std::optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
1021 if (isMaximalFastCheck)
1022 return isMaximalFastCheck;
1023
1024 // Create constraints for the src loop nest being sliced.
1025 FlatAffineValueConstraints srcConstraints(/*numDims=*/ivs.size(),
1026 /*numSymbols=*/0,
1027 /*numLocals=*/0, ivs);
1028 for (Value iv : ivs) {
1029 AffineForOp loop = getForInductionVarOwner(iv);
1030 assert(loop && "Expected affine for");
1031 if (failed(srcConstraints.addAffineForOpDomain(forOp: loop)))
1032 return std::nullopt;
1033 }
1034
1035 // Create constraints for the slice using the dst loop nest information. We
1036 // retrieve existing dst loops from the lbOperands.
1037 SmallVector<Value> consumerIVs;
1038 for (Value lbOp : lbOperands[0])
1039 if (getForInductionVarOwner(lbOp))
1040 consumerIVs.push_back(Elt: lbOp);
1041
1042 // Add empty IV Values for those new loops that are not equalities and,
1043 // therefore, are not yet materialized in the IR.
1044 for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
1045 consumerIVs.push_back(Elt: Value());
1046
1047 FlatAffineValueConstraints sliceConstraints(/*numDims=*/consumerIVs.size(),
1048 /*numSymbols=*/0,
1049 /*numLocals=*/0, consumerIVs);
1050
1051 if (failed(Result: sliceConstraints.addDomainFromSliceMaps(lbMaps: lbs, ubMaps: ubs, operands: lbOperands[0])))
1052 return std::nullopt;
1053
1054 if (srcConstraints.getNumDimVars() != sliceConstraints.getNumDimVars())
1055 // Constraint dims are different. The integer set difference can't be
1056 // computed so we don't know if the slice is maximal.
1057 return std::nullopt;
1058
1059 // Compute the difference between the src loop nest and the slice integer
1060 // sets.
1061 PresburgerSet srcSet(srcConstraints);
1062 PresburgerSet sliceSet(sliceConstraints);
1063 PresburgerSet diffSet = srcSet.subtract(set: sliceSet);
1064 return diffSet.isIntegerEmpty();
1065}
1066
1067unsigned MemRefRegion::getRank() const {
1068 return cast<MemRefType>(memref.getType()).getRank();
1069}
1070
1071std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
1072 SmallVectorImpl<int64_t> *shape, SmallVectorImpl<AffineMap> *lbs) const {
1073 auto memRefType = cast<MemRefType>(memref.getType());
1074 MLIRContext *context = memref.getContext();
1075 unsigned rank = memRefType.getRank();
1076 if (shape)
1077 shape->reserve(N: rank);
1078
1079 assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1080
1081 // Use a copy of the region constraints that has upper/lower bounds for each
1082 // memref dimension with static size added to guard against potential
1083 // over-approximation from projection or union bounding box. We may not add
1084 // this on the region itself since they might just be redundant constraints
1085 // that will need non-trivials means to eliminate.
1086 FlatLinearValueConstraints cstWithShapeBounds(cst);
1087 for (unsigned r = 0; r < rank; r++) {
1088 cstWithShapeBounds.addBound(type: BoundType::LB, pos: r, value: 0);
1089 int64_t dimSize = memRefType.getDimSize(r);
1090 if (ShapedType::isDynamic(dimSize))
1091 continue;
1092 cstWithShapeBounds.addBound(type: BoundType::UB, pos: r, value: dimSize - 1);
1093 }
1094
1095 // Find a constant upper bound on the extent of this memref region along
1096 // each dimension.
1097 int64_t numElements = 1;
1098 int64_t diffConstant;
1099 for (unsigned d = 0; d < rank; d++) {
1100 AffineMap lb;
1101 std::optional<int64_t> diff =
1102 cstWithShapeBounds.getConstantBoundOnDimSize(context, pos: d, lb: &lb);
1103 if (diff.has_value()) {
1104 diffConstant = *diff;
1105 assert(diffConstant >= 0 && "dim size bound cannot be negative");
1106 } else {
1107 // If no constant bound is found, then it can always be bound by the
1108 // memref's dim size if the latter has a constant size along this dim.
1109 auto dimSize = memRefType.getDimSize(d);
1110 if (ShapedType::isDynamic(dimSize))
1111 return std::nullopt;
1112 diffConstant = dimSize;
1113 // Lower bound becomes 0.
1114 lb = AffineMap::get(/*dimCount=*/0, symbolCount: cstWithShapeBounds.getNumSymbolVars(),
1115 /*result=*/getAffineConstantExpr(constant: 0, context));
1116 }
1117 numElements *= diffConstant;
1118 // Populate outputs if available.
1119 if (lbs)
1120 lbs->push_back(Elt: lb);
1121 if (shape)
1122 shape->push_back(Elt: diffConstant);
1123 }
1124 return numElements;
1125}
1126
1127void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
1128 AffineMap &ubMap) const {
1129 assert(pos < cst.getNumDimVars() && "invalid position");
1130 auto memRefType = cast<MemRefType>(memref.getType());
1131 unsigned rank = memRefType.getRank();
1132
1133 assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1134
1135 auto boundPairs = cst.getLowerAndUpperBound(
1136 pos, /*offset=*/0, /*num=*/rank, symStartPos: cst.getNumDimAndSymbolVars(),
1137 /*localExprs=*/{}, context: memRefType.getContext());
1138 lbMap = boundPairs.first;
1139 ubMap = boundPairs.second;
1140 assert(lbMap && "lower bound for a region must exist");
1141 assert(ubMap && "upper bound for a region must exist");
1142 assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1143 assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1144}
1145
1146LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
1147 assert(memref == other.memref);
1148 return cst.unionBoundingBox(other: *other.getConstraints());
1149}
1150
1151/// Computes the memory region accessed by this memref with the region
1152/// represented as constraints symbolic/parametric in 'loopDepth' loops
1153/// surrounding opInst and any additional Function symbols.
1154// For example, the memref region for this load operation at loopDepth = 1 will
1155// be as below:
1156//
1157// affine.for %i = 0 to 32 {
1158// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
1159// load %A[%ii]
1160// }
1161// }
1162//
1163// region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
1164// The last field is a 2-d FlatAffineValueConstraints symbolic in %i.
1165//
1166// TODO: extend this to any other memref dereferencing ops
1167// (dma_start, dma_wait).
1168LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
1169 const ComputationSliceState *sliceState,
1170 bool addMemRefDimBounds, bool dropLocalVars,
1171 bool dropOuterIvs) {
1172 assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
1173 "affine read/write op expected");
1174
1175 MemRefAccess access(op);
1176 memref = access.memref;
1177 write = access.isStore();
1178
1179 unsigned rank = access.getRank();
1180
1181 LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
1182 << "\ndepth: " << loopDepth << "\n";);
1183
1184 // 0-d memrefs.
1185 if (rank == 0) {
1186 SmallVector<Value, 4> ivs;
1187 getAffineIVs(op&: *op, ivs);
1188 assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
1189 // The first 'loopDepth' IVs are symbols for this region.
1190 ivs.resize(N: loopDepth);
1191 // A 0-d memref has a 0-d region.
1192 cst = FlatAffineValueConstraints(rank, loopDepth, /*numLocals=*/0, ivs);
1193 return success();
1194 }
1195
1196 // Build the constraints for this region.
1197 AffineValueMap accessValueMap;
1198 access.getAccessMap(accessMap: &accessValueMap);
1199 AffineMap accessMap = accessValueMap.getAffineMap();
1200
1201 unsigned numDims = accessMap.getNumDims();
1202 unsigned numSymbols = accessMap.getNumSymbols();
1203 unsigned numOperands = accessValueMap.getNumOperands();
1204 // Merge operands with slice operands.
1205 SmallVector<Value, 4> operands;
1206 operands.resize(N: numOperands);
1207 for (unsigned i = 0; i < numOperands; ++i)
1208 operands[i] = accessValueMap.getOperand(i);
1209
1210 if (sliceState != nullptr) {
1211 operands.reserve(N: operands.size() + sliceState->lbOperands[0].size());
1212 // Append slice operands to 'operands' as symbols.
1213 for (auto extraOperand : sliceState->lbOperands[0]) {
1214 if (!llvm::is_contained(Range&: operands, Element: extraOperand)) {
1215 operands.push_back(Elt: extraOperand);
1216 numSymbols++;
1217 }
1218 }
1219 }
1220 // We'll first associate the dims and symbols of the access map to the dims
1221 // and symbols resp. of cst. This will change below once cst is
1222 // fully constructed out.
1223 cst = FlatAffineValueConstraints(numDims, numSymbols, 0, operands);
1224
1225 // Add equality constraints.
1226 // Add inequalities for loop lower/upper bounds.
1227 for (unsigned i = 0; i < numDims + numSymbols; ++i) {
1228 auto operand = operands[i];
1229 if (auto affineFor = getForInductionVarOwner(operand)) {
1230 // Note that cst can now have more dimensions than accessMap if the
1231 // bounds expressions involve outer loops or other symbols.
1232 // TODO: rewrite this to use getInstIndexSet; this way
1233 // conditionals will be handled when the latter supports it.
1234 if (failed(cst.addAffineForOpDomain(forOp: affineFor)))
1235 return failure();
1236 } else if (auto parallelOp = getAffineParallelInductionVarOwner(operand)) {
1237 if (failed(cst.addAffineParallelOpDomain(parallelOp: parallelOp)))
1238 return failure();
1239 } else if (isValidSymbol(value: operand)) {
1240 // Check if the symbol is a constant.
1241 Value symbol = operand;
1242 if (auto constVal = getConstantIntValue(ofr: symbol))
1243 cst.addBound(type: BoundType::EQ, val: symbol, value: constVal.value());
1244 } else {
1245 LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value");
1246 return failure();
1247 }
1248 }
1249
1250 // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
1251 if (sliceState != nullptr) {
1252 // Add dim and symbol slice operands.
1253 for (auto operand : sliceState->lbOperands[0]) {
1254 if (failed(Result: cst.addInductionVarOrTerminalSymbol(val: operand)))
1255 return failure();
1256 }
1257 // Add upper/lower bounds from 'sliceState' to 'cst'.
1258 LogicalResult ret =
1259 cst.addSliceBounds(values: sliceState->ivs, lbMaps: sliceState->lbs, ubMaps: sliceState->ubs,
1260 operands: sliceState->lbOperands[0]);
1261 assert(succeeded(ret) &&
1262 "should not fail as we never have semi-affine slice maps");
1263 (void)ret;
1264 }
1265
1266 // Add access function equalities to connect loop IVs to data dimensions.
1267 if (failed(Result: cst.composeMap(vMap: &accessValueMap))) {
1268 op->emitError(message: "getMemRefRegion: compose affine map failed");
1269 LLVM_DEBUG(accessValueMap.getAffineMap().dump());
1270 return failure();
1271 }
1272
1273 // Set all variables appearing after the first 'rank' variables as
1274 // symbolic variables - so that the ones corresponding to the memref
1275 // dimensions are the dimensional variables for the memref region.
1276 cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - rank);
1277
1278 // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
1279 // this memref region is symbolic.
1280 SmallVector<Value, 4> enclosingIVs;
1281 getAffineIVs(op&: *op, ivs&: enclosingIVs);
1282 assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
1283 enclosingIVs.resize(N: loopDepth);
1284 SmallVector<Value, 4> vars;
1285 cst.getValues(start: cst.getNumDimVars(), end: cst.getNumDimAndSymbolVars(), values: &vars);
1286 for (auto en : llvm::enumerate(First&: vars)) {
1287 if ((isAffineInductionVar(val: en.value())) &&
1288 !llvm::is_contained(Range&: enclosingIVs, Element: en.value())) {
1289 if (dropOuterIvs) {
1290 cst.projectOut(val: en.value());
1291 } else {
1292 unsigned varPosition;
1293 cst.findVar(val: en.value(), pos: &varPosition);
1294 auto varKind = cst.getVarKindAt(pos: varPosition);
1295 varPosition -= cst.getNumDimVars();
1296 cst.convertToLocal(kind: varKind, varStart: varPosition, varLimit: varPosition + 1);
1297 }
1298 }
1299 }
1300
1301 // Project out any local variables (these would have been added for any
1302 // mod/divs) if specified.
1303 if (dropLocalVars)
1304 cst.projectOut(pos: cst.getNumDimAndSymbolVars(), num: cst.getNumLocalVars());
1305
1306 // Constant fold any symbolic variables.
1307 cst.constantFoldVarRange(/*pos=*/cst.getNumDimVars(),
1308 /*num=*/cst.getNumSymbolVars());
1309
1310 assert(cst.getNumDimVars() == rank && "unexpected MemRefRegion format");
1311
1312 // Add upper/lower bounds for each memref dimension with static size
1313 // to guard against potential over-approximation from projection.
1314 // TODO: Support dynamic memref dimensions.
1315 if (addMemRefDimBounds) {
1316 auto memRefType = cast<MemRefType>(memref.getType());
1317 for (unsigned r = 0; r < rank; r++) {
1318 cst.addBound(type: BoundType::LB, /*pos=*/r, /*value=*/0);
1319 if (memRefType.isDynamicDim(r))
1320 continue;
1321 cst.addBound(BoundType::UB, /*pos=*/r, memRefType.getDimSize(r) - 1);
1322 }
1323 }
1324 cst.removeTrivialRedundancy();
1325
1326 LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
1327 LLVM_DEBUG(cst.dump());
1328 return success();
1329}
1330
1331std::optional<int64_t>
1332mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
1333 auto elementType = memRefType.getElementType();
1334
1335 unsigned sizeInBits;
1336 if (elementType.isIntOrFloat()) {
1337 sizeInBits = elementType.getIntOrFloatBitWidth();
1338 } else if (auto vectorType = dyn_cast<VectorType>(elementType)) {
1339 if (vectorType.getElementType().isIntOrFloat())
1340 sizeInBits =
1341 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
1342 else
1343 return std::nullopt;
1344 } else {
1345 return std::nullopt;
1346 }
1347 return llvm::divideCeil(Numerator: sizeInBits, Denominator: 8);
1348}
1349
1350// Returns the size of the region.
1351std::optional<int64_t> MemRefRegion::getRegionSize() {
1352 auto memRefType = cast<MemRefType>(memref.getType());
1353
1354 if (!memRefType.getLayout().isIdentity()) {
1355 LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
1356 return false;
1357 }
1358
1359 // Compute the extents of the buffer.
1360 std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
1361 if (!numElements) {
1362 LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
1363 return std::nullopt;
1364 }
1365 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1366 if (!eltSize)
1367 return std::nullopt;
1368 return *eltSize * *numElements;
1369}
1370
1371/// Returns the size of memref data in bytes if it's statically shaped,
1372/// std::nullopt otherwise. If the element of the memref has vector type, takes
1373/// into account size of the vector as well.
1374// TODO: improve/complete this when we have target data.
1375std::optional<uint64_t>
1376mlir::affine::getIntOrFloatMemRefSizeInBytes(MemRefType memRefType) {
1377 if (!memRefType.hasStaticShape())
1378 return std::nullopt;
1379 auto elementType = memRefType.getElementType();
1380 if (!elementType.isIntOrFloat() && !isa<VectorType>(elementType))
1381 return std::nullopt;
1382
1383 auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1384 if (!sizeInBytes)
1385 return std::nullopt;
1386 for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
1387 sizeInBytes = *sizeInBytes * memRefType.getDimSize(i);
1388 }
1389 return sizeInBytes;
1390}
1391
1392template <typename LoadOrStoreOp>
1393LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
1394 bool emitError) {
1395 static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
1396 AffineWriteOpInterface>::value,
1397 "argument should be either a AffineReadOpInterface or a "
1398 "AffineWriteOpInterface");
1399
1400 Operation *op = loadOrStoreOp.getOperation();
1401 MemRefRegion region(op->getLoc());
1402 if (failed(Result: region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
1403 /*addMemRefDimBounds=*/false)))
1404 return success();
1405
1406 LLVM_DEBUG(llvm::dbgs() << "Memory region");
1407 LLVM_DEBUG(region.getConstraints()->dump());
1408
1409 bool outOfBounds = false;
1410 unsigned rank = loadOrStoreOp.getMemRefType().getRank();
1411
1412 // For each dimension, check for out of bounds.
1413 for (unsigned r = 0; r < rank; r++) {
1414 FlatAffineValueConstraints ucst(*region.getConstraints());
1415
1416 // Intersect memory region with constraint capturing out of bounds (both out
1417 // of upper and out of lower), and check if the constraint system is
1418 // feasible. If it is, there is at least one point out of bounds.
1419 SmallVector<int64_t, 4> ineq(rank + 1, 0);
1420 int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
1421 // TODO: handle dynamic dim sizes.
1422 if (dimSize == -1)
1423 continue;
1424
1425 // Check for overflow: d_i >= memref dim size.
1426 ucst.addBound(type: BoundType::LB, pos: r, value: dimSize);
1427 outOfBounds = !ucst.isEmpty();
1428 if (outOfBounds && emitError) {
1429 loadOrStoreOp.emitOpError()
1430 << "memref out of upper bound access along dimension #" << (r + 1);
1431 }
1432
1433 // Check for a negative index.
1434 FlatAffineValueConstraints lcst(*region.getConstraints());
1435 std::fill(first: ineq.begin(), last: ineq.end(), value: 0);
1436 // d_i <= -1;
1437 lcst.addBound(type: BoundType::UB, pos: r, value: -1);
1438 outOfBounds = !lcst.isEmpty();
1439 if (outOfBounds && emitError) {
1440 loadOrStoreOp.emitOpError()
1441 << "memref out of lower bound access along dimension #" << (r + 1);
1442 }
1443 }
1444 return failure(IsFailure: outOfBounds);
1445}
1446
1447// Explicitly instantiate the template so that the compiler knows we need them!
1448template LogicalResult
1449mlir::affine::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp,
1450 bool emitError);
1451template LogicalResult
1452mlir::affine::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp,
1453 bool emitError);
1454
1455// Returns in 'positions' the Block positions of 'op' in each ancestor
1456// Block from the Block containing operation, stopping at 'limitBlock'.
1457static void findInstPosition(Operation *op, Block *limitBlock,
1458 SmallVectorImpl<unsigned> *positions) {
1459 Block *block = op->getBlock();
1460 while (block != limitBlock) {
1461 // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
1462 // rely on linear scans.
1463 int instPosInBlock = std::distance(block->begin(), op->getIterator());
1464 positions->push_back(Elt: instPosInBlock);
1465 op = block->getParentOp();
1466 block = op->getBlock();
1467 }
1468 std::reverse(first: positions->begin(), last: positions->end());
1469}
1470
1471// Returns the Operation in a possibly nested set of Blocks, where the
1472// position of the operation is represented by 'positions', which has a
1473// Block position for each level of nesting.
1474static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
1475 unsigned level, Block *block) {
1476 unsigned i = 0;
1477 for (auto &op : *block) {
1478 if (i != positions[level]) {
1479 ++i;
1480 continue;
1481 }
1482 if (level == positions.size() - 1)
1483 return &op;
1484 if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
1485 return getInstAtPosition(positions, level + 1,
1486 childAffineForOp.getBody());
1487
1488 for (auto &region : op.getRegions()) {
1489 for (auto &b : region)
1490 if (auto *ret = getInstAtPosition(positions, level: level + 1, block: &b))
1491 return ret;
1492 }
1493 return nullptr;
1494 }
1495 return nullptr;
1496}
1497
1498// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
1499static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
1500 FlatAffineValueConstraints *cst) {
1501 for (unsigned i = 0, e = cst->getNumDimVars(); i < e; ++i) {
1502 auto value = cst->getValue(pos: i);
1503 if (ivs.count(Ptr: value) == 0) {
1504 assert(isAffineForInductionVar(value));
1505 auto loop = getForInductionVarOwner(value);
1506 if (failed(cst->addAffineForOpDomain(forOp: loop)))
1507 return failure();
1508 }
1509 }
1510 return success();
1511}
1512
1513/// Returns the innermost common loop depth for the set of operations in 'ops'.
1514// TODO: Move this to LoopUtils.
1515unsigned mlir::affine::getInnermostCommonLoopDepth(
1516 ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
1517 unsigned numOps = ops.size();
1518 assert(numOps > 0 && "Expected at least one operation");
1519
1520 std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
1521 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
1522 for (unsigned i = 0; i < numOps; ++i) {
1523 getAffineForIVs(*ops[i], &loops[i]);
1524 loopDepthLimit =
1525 std::min(a: loopDepthLimit, b: static_cast<unsigned>(loops[i].size()));
1526 }
1527
1528 unsigned loopDepth = 0;
1529 for (unsigned d = 0; d < loopDepthLimit; ++d) {
1530 unsigned i;
1531 for (i = 1; i < numOps; ++i) {
1532 if (loops[i - 1][d] != loops[i][d])
1533 return loopDepth;
1534 }
1535 if (surroundingLoops)
1536 surroundingLoops->push_back(loops[i - 1][d]);
1537 ++loopDepth;
1538 }
1539 return loopDepth;
1540}
1541
1542/// Computes in 'sliceUnion' the union of all slice bounds computed at
1543/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
1544/// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
1545/// union was computed correctly, an appropriate failure otherwise.
1546SliceComputationResult
1547mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
1548 ArrayRef<Operation *> opsB, unsigned loopDepth,
1549 unsigned numCommonLoops, bool isBackwardSlice,
1550 ComputationSliceState *sliceUnion) {
1551 // Compute the union of slice bounds between all pairs in 'opsA' and
1552 // 'opsB' in 'sliceUnionCst'.
1553 FlatAffineValueConstraints sliceUnionCst;
1554 assert(sliceUnionCst.getNumDimAndSymbolVars() == 0);
1555 std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
1556 for (Operation *i : opsA) {
1557 MemRefAccess srcAccess(i);
1558 for (Operation *j : opsB) {
1559 MemRefAccess dstAccess(j);
1560 if (srcAccess.memref != dstAccess.memref)
1561 continue;
1562 // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
1563 if ((!isBackwardSlice && loopDepth > getNestingDepth(op: i)) ||
1564 (isBackwardSlice && loopDepth > getNestingDepth(op: j))) {
1565 LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
1566 return SliceComputationResult::GenericFailure;
1567 }
1568
1569 bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
1570 isa<AffineReadOpInterface>(dstAccess.opInst);
1571 FlatAffineValueConstraints dependenceConstraints;
1572 // Check dependence between 'srcAccess' and 'dstAccess'.
1573 DependenceResult result = checkMemrefAccessDependence(
1574 srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
1575 dependenceConstraints: &dependenceConstraints, /*dependenceComponents=*/nullptr,
1576 /*allowRAR=*/readReadAccesses);
1577 if (result.value == DependenceResult::Failure) {
1578 LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
1579 return SliceComputationResult::GenericFailure;
1580 }
1581 if (result.value == DependenceResult::NoDependence)
1582 continue;
1583 dependentOpPairs.emplace_back(args&: i, args&: j);
1584
1585 // Compute slice bounds for 'srcAccess' and 'dstAccess'.
1586 ComputationSliceState tmpSliceState;
1587 mlir::affine::getComputationSliceState(depSourceOp: i, depSinkOp: j, dependenceConstraints,
1588 loopDepth, isBackwardSlice,
1589 sliceState: &tmpSliceState);
1590
1591 if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1592 // Initialize 'sliceUnionCst' with the bounds computed in previous step.
1593 if (failed(Result: tmpSliceState.getAsConstraints(cst: &sliceUnionCst))) {
1594 LLVM_DEBUG(llvm::dbgs()
1595 << "Unable to compute slice bound constraints\n");
1596 return SliceComputationResult::GenericFailure;
1597 }
1598 assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
1599 continue;
1600 }
1601
1602 // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
1603 FlatAffineValueConstraints tmpSliceCst;
1604 if (failed(Result: tmpSliceState.getAsConstraints(cst: &tmpSliceCst))) {
1605 LLVM_DEBUG(llvm::dbgs()
1606 << "Unable to compute slice bound constraints\n");
1607 return SliceComputationResult::GenericFailure;
1608 }
1609
1610 // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
1611 if (!sliceUnionCst.areVarsAlignedWithOther(other: tmpSliceCst)) {
1612
1613 // Pre-constraint var alignment: record loop IVs used in each constraint
1614 // system.
1615 SmallPtrSet<Value, 8> sliceUnionIVs;
1616 for (unsigned k = 0, l = sliceUnionCst.getNumDimVars(); k < l; ++k)
1617 sliceUnionIVs.insert(Ptr: sliceUnionCst.getValue(pos: k));
1618 SmallPtrSet<Value, 8> tmpSliceIVs;
1619 for (unsigned k = 0, l = tmpSliceCst.getNumDimVars(); k < l; ++k)
1620 tmpSliceIVs.insert(Ptr: tmpSliceCst.getValue(pos: k));
1621
1622 sliceUnionCst.mergeAndAlignVarsWithOther(/*offset=*/0, other: &tmpSliceCst);
1623
1624 // Post-constraint var alignment: add loop IV bounds missing after
1625 // var alignment to constraint systems. This can occur if one constraint
1626 // system uses an loop IV that is not used by the other. The call
1627 // to unionBoundingBox below expects constraints for each Loop IV, even
1628 // if they are the unsliced full loop bounds added here.
1629 if (failed(Result: addMissingLoopIVBounds(ivs&: sliceUnionIVs, cst: &sliceUnionCst)))
1630 return SliceComputationResult::GenericFailure;
1631 if (failed(Result: addMissingLoopIVBounds(ivs&: tmpSliceIVs, cst: &tmpSliceCst)))
1632 return SliceComputationResult::GenericFailure;
1633 }
1634 // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
1635 if (sliceUnionCst.getNumLocalVars() > 0 ||
1636 tmpSliceCst.getNumLocalVars() > 0 ||
1637 failed(Result: sliceUnionCst.unionBoundingBox(other: tmpSliceCst))) {
1638 LLVM_DEBUG(llvm::dbgs()
1639 << "Unable to compute union bounding box of slice bounds\n");
1640 return SliceComputationResult::GenericFailure;
1641 }
1642 }
1643 }
1644
1645 // Empty union.
1646 if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1647 LLVM_DEBUG(llvm::dbgs() << "empty slice union - unexpected\n");
1648 return SliceComputationResult::GenericFailure;
1649 }
1650
1651 // Gather loops surrounding ops from loop nest where slice will be inserted.
1652 SmallVector<Operation *, 4> ops;
1653 for (auto &dep : dependentOpPairs) {
1654 ops.push_back(Elt: isBackwardSlice ? dep.second : dep.first);
1655 }
1656 SmallVector<AffineForOp, 4> surroundingLoops;
1657 unsigned innermostCommonLoopDepth =
1658 getInnermostCommonLoopDepth(ops, &surroundingLoops);
1659 if (loopDepth > innermostCommonLoopDepth) {
1660 LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
1661 return SliceComputationResult::GenericFailure;
1662 }
1663
1664 // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
1665 unsigned numSliceLoopIVs = sliceUnionCst.getNumDimVars();
1666
1667 // Convert any dst loop IVs which are symbol variables to dim variables.
1668 sliceUnionCst.convertLoopIVSymbolsToDims();
1669 sliceUnion->clearBounds();
1670 sliceUnion->lbs.resize(N: numSliceLoopIVs, NV: AffineMap());
1671 sliceUnion->ubs.resize(N: numSliceLoopIVs, NV: AffineMap());
1672
1673 // Get slice bounds from slice union constraints 'sliceUnionCst'.
1674 sliceUnionCst.getSliceBounds(/*offset=*/0, num: numSliceLoopIVs,
1675 context: opsA[0]->getContext(), lbMaps: &sliceUnion->lbs,
1676 ubMaps: &sliceUnion->ubs);
1677
1678 // Add slice bound operands of union.
1679 SmallVector<Value, 4> sliceBoundOperands;
1680 sliceUnionCst.getValues(start: numSliceLoopIVs,
1681 end: sliceUnionCst.getNumDimAndSymbolVars(),
1682 values: &sliceBoundOperands);
1683
1684 // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
1685 sliceUnion->ivs.clear();
1686 sliceUnionCst.getValues(start: 0, end: numSliceLoopIVs, values: &sliceUnion->ivs);
1687
1688 // Set loop nest insertion point to block start at 'loopDepth' for forward
1689 // slices, while at the end for backward slices.
1690 sliceUnion->insertPoint =
1691 isBackwardSlice
1692 ? surroundingLoops[loopDepth - 1].getBody()->begin()
1693 : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
1694
1695 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1696 // canonicalization.
1697 sliceUnion->lbOperands.resize(new_size: numSliceLoopIVs, x: sliceBoundOperands);
1698 sliceUnion->ubOperands.resize(new_size: numSliceLoopIVs, x: sliceBoundOperands);
1699
1700 // Check if the slice computed is valid. Return success only if it is verified
1701 // that the slice is valid, otherwise return appropriate failure status.
1702 std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
1703 if (!isSliceValid) {
1704 LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
1705 return SliceComputationResult::GenericFailure;
1706 }
1707 if (!*isSliceValid)
1708 return SliceComputationResult::IncorrectSliceFailure;
1709
1710 return SliceComputationResult::Success;
1711}
1712
1713// TODO: extend this to handle multiple result maps.
1714static std::optional<uint64_t> getConstDifference(AffineMap lbMap,
1715 AffineMap ubMap) {
1716 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
1717 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
1718 assert(lbMap.getNumDims() == ubMap.getNumDims());
1719 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
1720 AffineExpr lbExpr(lbMap.getResult(idx: 0));
1721 AffineExpr ubExpr(ubMap.getResult(idx: 0));
1722 auto loopSpanExpr = simplifyAffineExpr(expr: ubExpr - lbExpr, numDims: lbMap.getNumDims(),
1723 numSymbols: lbMap.getNumSymbols());
1724 auto cExpr = dyn_cast<AffineConstantExpr>(Val&: loopSpanExpr);
1725 if (!cExpr)
1726 return std::nullopt;
1727 return cExpr.getValue();
1728}
1729
1730// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
1731// nest surrounding represented by slice loop bounds in 'slice'. Returns true
1732// on success, false otherwise (if a non-constant trip count was encountered).
1733// TODO: Make this work with non-unit step loops.
1734bool mlir::affine::buildSliceTripCountMap(
1735 const ComputationSliceState &slice,
1736 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
1737 unsigned numSrcLoopIVs = slice.ivs.size();
1738 // Populate map from AffineForOp -> trip count
1739 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1740 AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
1741 auto *op = forOp.getOperation();
1742 AffineMap lbMap = slice.lbs[i];
1743 AffineMap ubMap = slice.ubs[i];
1744 // If lower or upper bound maps are null or provide no results, it implies
1745 // that source loop was not at all sliced, and the entire loop will be a
1746 // part of the slice.
1747 if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
1748 ubMap.getNumResults() == 0) {
1749 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
1750 if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
1751 (*tripCountMap)[op] =
1752 forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
1753 continue;
1754 }
1755 std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
1756 if (maybeConstTripCount.has_value()) {
1757 (*tripCountMap)[op] = *maybeConstTripCount;
1758 continue;
1759 }
1760 return false;
1761 }
1762 std::optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
1763 // Slice bounds are created with a constant ub - lb difference.
1764 if (!tripCount.has_value())
1765 return false;
1766 (*tripCountMap)[op] = *tripCount;
1767 }
1768 return true;
1769}
1770
1771// Return the number of iterations in the given slice.
1772uint64_t mlir::affine::getSliceIterationCount(
1773 const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
1774 uint64_t iterCount = 1;
1775 for (const auto &count : sliceTripCountMap) {
1776 iterCount *= count.second;
1777 }
1778 return iterCount;
1779}
1780
1781const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
1782// Computes slice bounds by projecting out any loop IVs from
1783// 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
1784// bounds in 'sliceState' which represent the one loop nest's IVs in terms of
1785// the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
1786void mlir::affine::getComputationSliceState(
1787 Operation *depSourceOp, Operation *depSinkOp,
1788 const FlatAffineValueConstraints &dependenceConstraints, unsigned loopDepth,
1789 bool isBackwardSlice, ComputationSliceState *sliceState) {
1790 // Get loop nest surrounding src operation.
1791 SmallVector<AffineForOp, 4> srcLoopIVs;
1792 getAffineForIVs(*depSourceOp, &srcLoopIVs);
1793 unsigned numSrcLoopIVs = srcLoopIVs.size();
1794
1795 // Get loop nest surrounding dst operation.
1796 SmallVector<AffineForOp, 4> dstLoopIVs;
1797 getAffineForIVs(*depSinkOp, &dstLoopIVs);
1798 unsigned numDstLoopIVs = dstLoopIVs.size();
1799
1800 assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
1801 (isBackwardSlice && loopDepth <= numDstLoopIVs));
1802
1803 // Project out dimensions other than those up to 'loopDepth'.
1804 unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
1805 unsigned num =
1806 isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
1807 FlatAffineValueConstraints sliceCst(dependenceConstraints);
1808 sliceCst.projectOut(pos, num);
1809
1810 // Add slice loop IV values to 'sliceState'.
1811 unsigned offset = isBackwardSlice ? 0 : loopDepth;
1812 unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
1813 sliceCst.getValues(start: offset, end: offset + numSliceLoopIVs, values: &sliceState->ivs);
1814
1815 // Set up lower/upper bound affine maps for the slice.
1816 sliceState->lbs.resize(N: numSliceLoopIVs, NV: AffineMap());
1817 sliceState->ubs.resize(N: numSliceLoopIVs, NV: AffineMap());
1818
1819 // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
1820 sliceCst.getSliceBounds(offset, num: numSliceLoopIVs, context: depSourceOp->getContext(),
1821 lbMaps: &sliceState->lbs, ubMaps: &sliceState->ubs);
1822
1823 // Set up bound operands for the slice's lower and upper bounds.
1824 SmallVector<Value, 4> sliceBoundOperands;
1825 unsigned numDimsAndSymbols = sliceCst.getNumDimAndSymbolVars();
1826 for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
1827 if (i < offset || i >= offset + numSliceLoopIVs)
1828 sliceBoundOperands.push_back(Elt: sliceCst.getValue(pos: i));
1829 }
1830
1831 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1832 // canonicalization.
1833 sliceState->lbOperands.resize(new_size: numSliceLoopIVs, x: sliceBoundOperands);
1834 sliceState->ubOperands.resize(new_size: numSliceLoopIVs, x: sliceBoundOperands);
1835
1836 // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
1837 sliceState->insertPoint =
1838 isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
1839 : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
1840
1841 llvm::SmallDenseSet<Value, 8> sequentialLoops;
1842 if (isa<AffineReadOpInterface>(depSourceOp) &&
1843 isa<AffineReadOpInterface>(depSinkOp)) {
1844 // For read-read access pairs, clear any slice bounds on sequential loops.
1845 // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
1846 getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
1847 &sequentialLoops);
1848 }
1849 auto getSliceLoop = [&](unsigned i) {
1850 return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
1851 };
1852 auto isInnermostInsertion = [&]() {
1853 return (isBackwardSlice ? loopDepth >= srcLoopIVs.size()
1854 : loopDepth >= dstLoopIVs.size());
1855 };
1856 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
1857 auto srcIsUnitSlice = [&]() {
1858 return (buildSliceTripCountMap(slice: *sliceState, tripCountMap: &sliceTripCountMap) &&
1859 (getSliceIterationCount(sliceTripCountMap) == 1));
1860 };
1861 // Clear all sliced loop bounds beginning at the first sequential loop, or
1862 // first loop with a slice fusion barrier attribute..
1863
1864 for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
1865 Value iv = getSliceLoop(i).getInductionVar();
1866 if (sequentialLoops.count(V: iv) == 0 &&
1867 getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr)
1868 continue;
1869 // Skip reset of bounds of reduction loop inserted in the destination loop
1870 // that meets the following conditions:
1871 // 1. Slice is single trip count.
1872 // 2. Loop bounds of the source and destination match.
1873 // 3. Is being inserted at the innermost insertion point.
1874 std::optional<bool> isMaximal = sliceState->isMaximal();
1875 if (isLoopParallelAndContainsReduction(getSliceLoop(i)) &&
1876 isInnermostInsertion() && srcIsUnitSlice() && isMaximal && *isMaximal)
1877 continue;
1878 for (unsigned j = i; j < numSliceLoopIVs; ++j) {
1879 sliceState->lbs[j] = AffineMap();
1880 sliceState->ubs[j] = AffineMap();
1881 }
1882 break;
1883 }
1884}
1885
1886/// Creates a computation slice of the loop nest surrounding 'srcOpInst',
1887/// updates the slice loop bounds with any non-null bound maps specified in
1888/// 'sliceState', and inserts this slice into the loop nest surrounding
1889/// 'dstOpInst' at loop depth 'dstLoopDepth'.
1890// TODO: extend the slicing utility to compute slices that
1891// aren't necessarily a one-to-one relation b/w the source and destination. The
1892// relation between the source and destination could be many-to-many in general.
1893// TODO: the slice computation is incorrect in the cases
1894// where the dependence from the source to the destination does not cover the
1895// entire destination index set. Subtract out the dependent destination
1896// iterations from destination index set and check for emptiness --- this is one
1897// solution.
1898AffineForOp mlir::affine::insertBackwardComputationSlice(
1899 Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth,
1900 ComputationSliceState *sliceState) {
1901 // Get loop nest surrounding src operation.
1902 SmallVector<AffineForOp, 4> srcLoopIVs;
1903 getAffineForIVs(*srcOpInst, &srcLoopIVs);
1904 unsigned numSrcLoopIVs = srcLoopIVs.size();
1905
1906 // Get loop nest surrounding dst operation.
1907 SmallVector<AffineForOp, 4> dstLoopIVs;
1908 getAffineForIVs(*dstOpInst, &dstLoopIVs);
1909 unsigned dstLoopIVsSize = dstLoopIVs.size();
1910 if (dstLoopDepth > dstLoopIVsSize) {
1911 dstOpInst->emitError(message: "invalid destination loop depth");
1912 return AffineForOp();
1913 }
1914
1915 // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
1916 SmallVector<unsigned, 4> positions;
1917 // TODO: This code is incorrect since srcLoopIVs can be 0-d.
1918 findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
1919
1920 // Clone src loop nest and insert it a the beginning of the operation block
1921 // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
1922 auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
1923 OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
1924 auto sliceLoopNest =
1925 cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
1926
1927 Operation *sliceInst =
1928 getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
1929 // Get loop nest surrounding 'sliceInst'.
1930 SmallVector<AffineForOp, 4> sliceSurroundingLoops;
1931 getAffineForIVs(*sliceInst, &sliceSurroundingLoops);
1932
1933 // Sanity check.
1934 unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
1935 (void)sliceSurroundingLoopsSize;
1936 assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
1937 unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
1938 (void)sliceLoopLimit;
1939 assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
1940
1941 // Update loop bounds for loops in 'sliceLoopNest'.
1942 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1943 auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
1944 if (AffineMap lbMap = sliceState->lbs[i])
1945 forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
1946 if (AffineMap ubMap = sliceState->ubs[i])
1947 forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
1948 }
1949 return sliceLoopNest;
1950}
1951
1952// Constructs MemRefAccess populating it with the memref, its indices and
1953// opinst from 'loadOrStoreOpInst'.
1954MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
1955 if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) {
1956 memref = loadOp.getMemRef();
1957 opInst = loadOrStoreOpInst;
1958 llvm::append_range(indices, loadOp.getMapOperands());
1959 } else {
1960 assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
1961 "Affine read/write op expected");
1962 auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst);
1963 opInst = loadOrStoreOpInst;
1964 memref = storeOp.getMemRef();
1965 llvm::append_range(indices, storeOp.getMapOperands());
1966 }
1967}
1968
1969unsigned MemRefAccess::getRank() const {
1970 return cast<MemRefType>(memref.getType()).getRank();
1971}
1972
1973bool MemRefAccess::isStore() const {
1974 return isa<AffineWriteOpInterface>(opInst);
1975}
1976
1977/// Returns the nesting depth of this statement, i.e., the number of loops
1978/// surrounding this statement.
1979unsigned mlir::affine::getNestingDepth(Operation *op) {
1980 Operation *currOp = op;
1981 unsigned depth = 0;
1982 while ((currOp = currOp->getParentOp())) {
1983 if (isa<AffineForOp>(Val: currOp))
1984 depth++;
1985 if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
1986 depth += parOp.getNumDims();
1987 }
1988 return depth;
1989}
1990
1991/// Equal if both affine accesses are provably equivalent (at compile
1992/// time) when considering the memref, the affine maps and their respective
1993/// operands. The equality of access functions + operands is checked by
1994/// subtracting fully composed value maps, and then simplifying the difference
1995/// using the expression flattener.
1996/// TODO: this does not account for aliasing of memrefs.
1997bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
1998 if (memref != rhs.memref)
1999 return false;
2000
2001 AffineValueMap diff, thisMap, rhsMap;
2002 getAccessMap(accessMap: &thisMap);
2003 rhs.getAccessMap(accessMap: &rhsMap);
2004 return thisMap == rhsMap;
2005}
2006
2007void mlir::affine::getAffineIVs(Operation &op, SmallVectorImpl<Value> &ivs) {
2008 auto *currOp = op.getParentOp();
2009 AffineForOp currAffineForOp;
2010 // Traverse up the hierarchy collecting all 'affine.for' and affine.parallel
2011 // operation while skipping over 'affine.if' operations.
2012 while (currOp) {
2013 if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
2014 ivs.push_back(Elt: currAffineForOp.getInductionVar());
2015 else if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
2016 llvm::append_range(ivs, parOp.getIVs());
2017 currOp = currOp->getParentOp();
2018 }
2019 std::reverse(first: ivs.begin(), last: ivs.end());
2020}
2021
2022/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
2023/// where each lists loops from outer-most to inner-most in loop nest.
2024unsigned mlir::affine::getNumCommonSurroundingLoops(Operation &a,
2025 Operation &b) {
2026 SmallVector<Value, 4> loopsA, loopsB;
2027 getAffineIVs(op&: a, ivs&: loopsA);
2028 getAffineIVs(op&: b, ivs&: loopsB);
2029
2030 unsigned minNumLoops = std::min(a: loopsA.size(), b: loopsB.size());
2031 unsigned numCommonLoops = 0;
2032 for (unsigned i = 0; i < minNumLoops; ++i) {
2033 if (loopsA[i] != loopsB[i])
2034 break;
2035 ++numCommonLoops;
2036 }
2037 return numCommonLoops;
2038}
2039
2040static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
2041 Block::iterator start,
2042 Block::iterator end,
2043 int memorySpace) {
2044 SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
2045
2046 // Walk this 'affine.for' operation to gather all memory regions.
2047 auto result = block.walk(begin: start, end, callback: [&](Operation *opInst) -> WalkResult {
2048 if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
2049 // Neither load nor a store op.
2050 return WalkResult::advance();
2051 }
2052
2053 // Compute the memref region symbolic in any IVs enclosing this block.
2054 auto region = std::make_unique<MemRefRegion>(args: opInst->getLoc());
2055 if (failed(
2056 Result: region->compute(op: opInst,
2057 /*loopDepth=*/getNestingDepth(op: &*block.begin())))) {
2058 LLVM_DEBUG(opInst->emitError("error obtaining memory region"));
2059 return failure();
2060 }
2061
2062 auto [it, inserted] = regions.try_emplace(Key: region->memref);
2063 if (inserted) {
2064 it->second = std::move(region);
2065 } else if (failed(Result: it->second->unionBoundingBox(other: *region))) {
2066 LLVM_DEBUG(opInst->emitWarning(
2067 "getMemoryFootprintBytes: unable to perform a union on a memory "
2068 "region"));
2069 return failure();
2070 }
2071 return WalkResult::advance();
2072 });
2073 if (result.wasInterrupted())
2074 return std::nullopt;
2075
2076 int64_t totalSizeInBytes = 0;
2077 for (const auto &region : regions) {
2078 std::optional<int64_t> size = region.second->getRegionSize();
2079 if (!size.has_value())
2080 return std::nullopt;
2081 totalSizeInBytes += *size;
2082 }
2083 return totalSizeInBytes;
2084}
2085
2086std::optional<int64_t> mlir::affine::getMemoryFootprintBytes(AffineForOp forOp,
2087 int memorySpace) {
2088 auto *forInst = forOp.getOperation();
2089 return ::getMemoryFootprintBytes(
2090 block&: *forInst->getBlock(), start: Block::iterator(forInst),
2091 end: std::next(x: Block::iterator(forInst)), memorySpace);
2092}
2093
2094/// Returns whether a loop is parallel and contains a reduction loop.
2095bool mlir::affine::isLoopParallelAndContainsReduction(AffineForOp forOp) {
2096 SmallVector<LoopReduction> reductions;
2097 if (!isLoopParallel(forOp, &reductions))
2098 return false;
2099 return !reductions.empty();
2100}
2101
2102/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
2103/// at 'forOp'.
2104void mlir::affine::getSequentialLoops(
2105 AffineForOp forOp, llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
2106 forOp->walk([&](Operation *op) {
2107 if (auto innerFor = dyn_cast<AffineForOp>(op))
2108 if (!isLoopParallel(innerFor))
2109 sequentialLoops->insert(innerFor.getInductionVar());
2110 });
2111}
2112
2113IntegerSet mlir::affine::simplifyIntegerSet(IntegerSet set) {
2114 FlatAffineValueConstraints fac(set);
2115 if (fac.isEmpty())
2116 return IntegerSet::getEmptySet(numDims: set.getNumDims(), numSymbols: set.getNumSymbols(),
2117 context: set.getContext());
2118 fac.removeTrivialRedundancy();
2119
2120 auto simplifiedSet = fac.getAsIntegerSet(context: set.getContext());
2121 assert(simplifiedSet && "guaranteed to succeed while roundtripping");
2122 return simplifiedSet;
2123}
2124
2125static void unpackOptionalValues(ArrayRef<std::optional<Value>> source,
2126 SmallVector<Value> &target) {
2127 target =
2128 llvm::to_vector<4>(Range: llvm::map_range(C&: source, F: [](std::optional<Value> val) {
2129 return val.has_value() ? *val : Value();
2130 }));
2131}
2132
2133/// Bound an identifier `pos` in a given FlatAffineValueConstraints with
2134/// constraints drawn from an affine map. Before adding the constraint, the
2135/// dimensions/symbols of the affine map are aligned with `constraints`.
2136/// `operands` are the SSA Value operands used with the affine map.
2137/// Note: This function adds a new symbol column to the `constraints` for each
2138/// dimension/symbol that exists in the affine map but not in `constraints`.
2139static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints,
2140 BoundType type, unsigned pos,
2141 AffineMap map, ValueRange operands) {
2142 SmallVector<Value> dims, syms, newSyms;
2143 unpackOptionalValues(source: constraints.getMaybeValues(kind: VarKind::SetDim), target&: dims);
2144 unpackOptionalValues(source: constraints.getMaybeValues(kind: VarKind::Symbol), target&: syms);
2145
2146 AffineMap alignedMap =
2147 alignAffineMapWithValues(map, operands, dims, syms, newSyms: &newSyms);
2148 for (unsigned i = syms.size(); i < newSyms.size(); ++i)
2149 constraints.appendSymbolVar(vals: newSyms[i]);
2150 return constraints.addBound(type, pos, boundMap: alignedMap);
2151}
2152
2153/// Add `val` to each result of `map`.
2154static AffineMap addConstToResults(AffineMap map, int64_t val) {
2155 SmallVector<AffineExpr> newResults;
2156 for (AffineExpr r : map.getResults())
2157 newResults.push_back(Elt: r + val);
2158 return AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: newResults,
2159 context: map.getContext());
2160}
2161
2162// Attempt to simplify the given min/max operation by proving that its value is
2163// bounded by the same lower and upper bound.
2164//
2165// Bounds are computed by FlatAffineValueConstraints. Invariants required for
2166// finding/proving bounds should be supplied via `constraints`.
2167//
2168// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
2169// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
2170// case of `!isMin`) and bind it to `opBound`. SSA values that are used in
2171// `op` but are not part of `constraints`, are added as extra symbols.
2172// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
2173// * If `isMin`: r_i >= opBound
2174// * If `isMax`: r_i <= opBound
2175// If this is the case, ub(op) == lb(op).
2176// 4. Replace `op` with `opBound`.
2177//
2178// In summary, the following constraints are added throughout this function.
2179// Note: `invar` are dimensions added by the caller to express the invariants.
2180// (Showing only the case where `isMin`.)
2181//
2182// invar | op | opBound | r_i | extra syms... | const | eq/ineq
2183// ------+-------+---------+-----+---------------+-------+-------------------
2184// (various eq./ineq. constraining `invar`, added by the caller)
2185// ... | 0 | 0 | 0 | 0 | ... | ...
2186// ------+-------+---------+-----+---------------+-------+-------------------
2187// (various ineq. constraining `op` in terms of `op` operands (`invar` and
2188// extra `op` operands "extra syms" that are not in `invar`)).
2189// ... | -1 | 0 | 0 | ... | ... | >= 0
2190// ------+-------+---------+-----+---------------+-------+-------------------
2191// (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
2192// ... | 0 | -1 | 0 | ... | ... | = 0
2193// ------+-------+---------+-----+---------------+-------+-------------------
2194// (for each `op` map result r_i: set r_i to corresponding map result,
2195// prove that r_i >= minOpUb via contradiction)
2196// ... | 0 | 0 | -1 | ... | ... | = 0
2197// 0 | 0 | 1 | -1 | 0 | -1 | >= 0
2198//
2199FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
2200 Operation *op, FlatAffineValueConstraints constraints) {
2201 bool isMin = isa<AffineMinOp>(Val: op);
2202 assert((isMin || isa<AffineMaxOp>(op)) && "expect AffineMin/MaxOp");
2203 MLIRContext *ctx = op->getContext();
2204 Builder builder(ctx);
2205 AffineMap map =
2206 isMin ? cast<AffineMinOp>(op).getMap() : cast<AffineMaxOp>(op).getMap();
2207 ValueRange operands = op->getOperands();
2208 unsigned numResults = map.getNumResults();
2209
2210 // Add a few extra dimensions.
2211 unsigned dimOp = constraints.appendDimVar(); // `op`
2212 unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound
2213 unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults);
2214
2215 // Add an inequality for each result expr_i of map:
2216 // isMin: op <= expr_i, !isMin: op >= expr_i
2217 auto boundType = isMin ? BoundType::UB : BoundType::LB;
2218 // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.)
2219 AffineMap mapLbUb = isMin ? addConstToResults(map, val: 1) : map;
2220 if (failed(
2221 Result: alignAndAddBound(constraints, type: boundType, pos: dimOp, map: mapLbUb, operands)))
2222 return failure();
2223
2224 // Try to compute a lower/upper bound for op, expressed in terms of the other
2225 // `dims` and extra symbols.
2226 SmallVector<AffineMap> opLb(1), opUb(1);
2227 constraints.getSliceBounds(offset: dimOp, num: 1, context: ctx, lbMaps: &opLb, ubMaps: &opUb);
2228 AffineMap sliceBound = isMin ? opUb[0] : opLb[0];
2229 // TODO: `getSliceBounds` may return multiple bounds at the moment. This is
2230 // a TODO of `getSliceBounds` and not handled here.
2231 if (!sliceBound || sliceBound.getNumResults() != 1)
2232 return failure(); // No or multiple bounds found.
2233 // Recover the inclusive UB in the case of an `affine.min`.
2234 AffineMap boundMap = isMin ? addConstToResults(map: sliceBound, val: -1) : sliceBound;
2235
2236 // Add an equality: Set dimOpBound to computed bound.
2237 // Add back dimension for op. (Was removed by `getSliceBounds`.)
2238 AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp);
2239 if (failed(Result: constraints.addBound(type: BoundType::EQ, pos: dimOpBound, boundMap: alignedBoundMap)))
2240 return failure();
2241
2242 // If the constraint system is empty, there is an inconsistency. (E.g., this
2243 // can happen if loop lb > ub.)
2244 if (constraints.isEmpty())
2245 return failure();
2246
2247 // In the case of `isMin` (`!isMin` is inversed):
2248 // Prove that each result of `map` has a lower bound that is equal to (or
2249 // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op`
2250 // can be replaced with the bound. I.e., prove that for each result
2251 // expr_i (represented by dimension r_i):
2252 //
2253 // r_i >= opBound
2254 //
2255 // To prove this inequality, add its negation to the constraint set and prove
2256 // that the constraint set is empty.
2257 for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) {
2258 FlatAffineValueConstraints newConstr(constraints);
2259
2260 // Add an equality: r_i = expr_i
2261 // Note: These equalities could have been added earlier and used to express
2262 // minOp <= expr_i. However, then we run the risk that `getSliceBounds`
2263 // computes minOpUb in terms of r_i dims, which is not desired.
2264 if (failed(Result: alignAndAddBound(constraints&: newConstr, type: BoundType::EQ, pos: i,
2265 map: map.getSubMap(resultPos: {i - resultDimStart}), operands)))
2266 return failure();
2267
2268 // If `isMin`: Add inequality: r_i < opBound
2269 // equiv.: opBound - r_i - 1 >= 0
2270 // If `!isMin`: Add inequality: r_i > opBound
2271 // equiv.: -opBound + r_i - 1 >= 0
2272 SmallVector<int64_t> ineq(newConstr.getNumCols(), 0);
2273 ineq[dimOpBound] = isMin ? 1 : -1;
2274 ineq[i] = isMin ? -1 : 1;
2275 ineq[newConstr.getNumCols() - 1] = -1;
2276 newConstr.addInequality(inEq: ineq);
2277 if (!newConstr.isEmpty())
2278 return failure();
2279 }
2280
2281 // Lower and upper bound of `op` are equal. Replace `minOp` with its bound.
2282 AffineMap newMap = alignedBoundMap;
2283 SmallVector<Value> newOperands;
2284 unpackOptionalValues(source: constraints.getMaybeValues(), target&: newOperands);
2285 // If dims/symbols have known constant values, use those in order to simplify
2286 // the affine map further.
2287 for (int64_t i = 0, e = constraints.getNumDimAndSymbolVars(); i < e; ++i) {
2288 // Skip unused operands and operands that are already constants.
2289 if (!newOperands[i] || getConstantIntValue(ofr: newOperands[i]))
2290 continue;
2291 if (auto bound = constraints.getConstantBound64(type: BoundType::EQ, pos: i)) {
2292 AffineExpr expr =
2293 i < newMap.getNumDims()
2294 ? builder.getAffineDimExpr(position: i)
2295 : builder.getAffineSymbolExpr(position: i - newMap.getNumDims());
2296 newMap = newMap.replace(expr, replacement: builder.getAffineConstantExpr(constant: *bound),
2297 numResultDims: newMap.getNumDims(), numResultSyms: newMap.getNumSymbols());
2298 }
2299 }
2300 affine::canonicalizeMapAndOperands(map: &newMap, operands: &newOperands);
2301 return AffineValueMap(newMap, newOperands);
2302}
2303
2304Block *mlir::affine::findInnermostCommonBlockInScope(Operation *a,
2305 Operation *b) {
2306 Region *aScope = getAffineAnalysisScope(op: a);
2307 Region *bScope = getAffineAnalysisScope(op: b);
2308 if (aScope != bScope)
2309 return nullptr;
2310
2311 // Get the block ancestry of `op` while stopping at the affine scope `aScope`
2312 // and store them in `ancestry`.
2313 auto getBlockAncestry = [&](Operation *op,
2314 SmallVectorImpl<Block *> &ancestry) {
2315 Operation *curOp = op;
2316 do {
2317 ancestry.push_back(Elt: curOp->getBlock());
2318 if (curOp->getParentRegion() == aScope)
2319 break;
2320 curOp = curOp->getParentOp();
2321 } while (curOp);
2322 assert(curOp && "can't reach root op without passing through affine scope");
2323 std::reverse(first: ancestry.begin(), last: ancestry.end());
2324 };
2325
2326 SmallVector<Block *, 4> aAncestors, bAncestors;
2327 getBlockAncestry(a, aAncestors);
2328 getBlockAncestry(b, bAncestors);
2329 assert(!aAncestors.empty() && !bAncestors.empty() &&
2330 "at least one Block ancestor expected");
2331
2332 Block *innermostCommonBlock = nullptr;
2333 for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
2334 a < e && b < f; ++a, ++b) {
2335 if (aAncestors[a] != bAncestors[b])
2336 break;
2337 innermostCommonBlock = aAncestors[a];
2338 }
2339 return innermostCommonBlock;
2340}
2341

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Affine/Analysis/Utils.cpp