1//===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements miscellaneous analysis routines for non-loop IR
10// structures.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Affine/Analysis/Utils.h"
15#include "mlir/Analysis/Presburger/PresburgerRelation.h"
16#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18#include "mlir/Dialect/Affine/IR/AffineOps.h"
19#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/Utils/StaticValueUtils.h"
22#include "mlir/IR/IntegerSet.h"
23#include "mlir/Interfaces/CallInterfaces.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/ADT/SmallPtrSet.h"
26#include "llvm/Support/Debug.h"
27#include "llvm/Support/raw_ostream.h"
28#include <optional>
29
30#define DEBUG_TYPE "analysis-utils"
31
32using namespace mlir;
33using namespace affine;
34using namespace presburger;
35
36using llvm::SmallDenseMap;
37
38using Node = MemRefDependenceGraph::Node;
39
40// LoopNestStateCollector walks loop nests and collects load and store
41// operations, and whether or not a region holding op other than ForOp and IfOp
42// was encountered in the loop nest.
43void LoopNestStateCollector::collect(Operation *opToWalk) {
44 opToWalk->walk(callback: [&](Operation *op) {
45 if (isa<AffineForOp>(Val: op))
46 forOps.push_back(cast<AffineForOp>(op));
47 else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(Val: op))
48 hasNonAffineRegionOp = true;
49 else if (isa<AffineReadOpInterface>(op))
50 loadOpInsts.push_back(Elt: op);
51 else if (isa<AffineWriteOpInterface>(op))
52 storeOpInsts.push_back(Elt: op);
53 });
54}
55
56// Returns the load op count for 'memref'.
57unsigned Node::getLoadOpCount(Value memref) const {
58 unsigned loadOpCount = 0;
59 for (Operation *loadOp : loads) {
60 if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
61 ++loadOpCount;
62 }
63 return loadOpCount;
64}
65
66// Returns the store op count for 'memref'.
67unsigned Node::getStoreOpCount(Value memref) const {
68 unsigned storeOpCount = 0;
69 for (Operation *storeOp : stores) {
70 if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
71 ++storeOpCount;
72 }
73 return storeOpCount;
74}
75
76// Returns all store ops in 'storeOps' which access 'memref'.
77void Node::getStoreOpsForMemref(Value memref,
78 SmallVectorImpl<Operation *> *storeOps) const {
79 for (Operation *storeOp : stores) {
80 if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
81 storeOps->push_back(Elt: storeOp);
82 }
83}
84
85// Returns all load ops in 'loadOps' which access 'memref'.
86void Node::getLoadOpsForMemref(Value memref,
87 SmallVectorImpl<Operation *> *loadOps) const {
88 for (Operation *loadOp : loads) {
89 if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
90 loadOps->push_back(Elt: loadOp);
91 }
92}
93
94// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
95// has at least one load and store operation.
96void Node::getLoadAndStoreMemrefSet(
97 DenseSet<Value> *loadAndStoreMemrefSet) const {
98 llvm::SmallDenseSet<Value, 2> loadMemrefs;
99 for (Operation *loadOp : loads) {
100 loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
101 }
102 for (Operation *storeOp : stores) {
103 auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
104 if (loadMemrefs.count(V: memref) > 0)
105 loadAndStoreMemrefSet->insert(memref);
106 }
107}
108
109// Initializes the data dependence graph by walking operations in `block`.
110// Assigns each node in the graph a node id based on program order in 'f'.
111bool MemRefDependenceGraph::init() {
112 LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
113 // Map from a memref to the set of ids of the nodes that have ops accessing
114 // the memref.
115 DenseMap<Value, SetVector<unsigned>> memrefAccesses;
116
117 DenseMap<Operation *, unsigned> forToNodeMap;
118 for (Operation &op : block) {
119 if (dyn_cast<AffineForOp>(op)) {
120 // Create graph node 'id' to represent top-level 'forOp' and record
121 // all loads and store accesses it contains.
122 LoopNestStateCollector collector;
123 collector.collect(opToWalk: &op);
124 // Return false if a region holding op other than 'affine.for' and
125 // 'affine.if' was found (not currently supported).
126 if (collector.hasNonAffineRegionOp)
127 return false;
128 Node node(nextNodeId++, &op);
129 for (auto *opInst : collector.loadOpInsts) {
130 node.loads.push_back(Elt: opInst);
131 auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
132 memrefAccesses[memref].insert(node.id);
133 }
134 for (auto *opInst : collector.storeOpInsts) {
135 node.stores.push_back(Elt: opInst);
136 auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
137 memrefAccesses[memref].insert(node.id);
138 }
139 forToNodeMap[&op] = node.id;
140 nodes.insert(KV: {node.id, node});
141 } else if (dyn_cast<AffineReadOpInterface>(op)) {
142 // Create graph node for top-level load op.
143 Node node(nextNodeId++, &op);
144 node.loads.push_back(Elt: &op);
145 auto memref = cast<AffineReadOpInterface>(op).getMemRef();
146 memrefAccesses[memref].insert(node.id);
147 nodes.insert(KV: {node.id, node});
148 } else if (dyn_cast<AffineWriteOpInterface>(op)) {
149 // Create graph node for top-level store op.
150 Node node(nextNodeId++, &op);
151 node.stores.push_back(Elt: &op);
152 auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
153 memrefAccesses[memref].insert(node.id);
154 nodes.insert(KV: {node.id, node});
155 } else if (op.getNumResults() > 0 && !op.use_empty()) {
156 // Create graph node for top-level producer of SSA values, which
157 // could be used by loop nest nodes.
158 Node node(nextNodeId++, &op);
159 nodes.insert(KV: {node.id, node});
160 } else if (!isMemoryEffectFree(op: &op) &&
161 (op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(Val: op))) {
162 // Create graph node for top-level op unless it is known to be
163 // memory-effect free. This covers all unknown/unregistered ops,
164 // non-affine ops with memory effects, and region-holding ops with a
165 // well-defined control flow. During the fusion validity checks, we look
166 // for non-affine ops on the path from source to destination, at which
167 // point we check which memrefs if any are used in the region.
168 Node node(nextNodeId++, &op);
169 nodes.insert(KV: {node.id, node});
170 } else if (op.getNumRegions() != 0) {
171 // Return false if non-handled/unknown region-holding ops are found. We
172 // won't know what such ops do or what its regions mean; for e.g., it may
173 // not be an imperative op.
174 LLVM_DEBUG(llvm::dbgs()
175 << "MDG init failed; unknown region-holding op found!\n");
176 return false;
177 }
178 }
179
180 for (auto &idAndNode : nodes) {
181 LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
182 << *(idAndNode.second.op) << "\n");
183 (void)idAndNode;
184 }
185
186 // Add dependence edges between nodes which produce SSA values and their
187 // users. Load ops can be considered as the ones producing SSA values.
188 for (auto &idAndNode : nodes) {
189 const Node &node = idAndNode.second;
190 // Stores don't define SSA values, skip them.
191 if (!node.stores.empty())
192 continue;
193 Operation *opInst = node.op;
194 for (Value value : opInst->getResults()) {
195 for (Operation *user : value.getUsers()) {
196 // Ignore users outside of the block.
197 if (block.getParent()->findAncestorOpInRegion(op&: *user)->getBlock() !=
198 &block)
199 continue;
200 SmallVector<AffineForOp, 4> loops;
201 getAffineForIVs(*user, &loops);
202 // Find the surrounding affine.for nested immediately within the
203 // block.
204 auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
205 return loop->getBlock() == &block;
206 });
207 if (it == loops.end())
208 continue;
209 assert(forToNodeMap.count(*it) > 0 && "missing mapping");
210 unsigned userLoopNestId = forToNodeMap[*it];
211 addEdge(srcId: node.id, dstId: userLoopNestId, value);
212 }
213 }
214 }
215
216 // Walk memref access lists and add graph edges between dependent nodes.
217 for (auto &memrefAndList : memrefAccesses) {
218 unsigned n = memrefAndList.second.size();
219 for (unsigned i = 0; i < n; ++i) {
220 unsigned srcId = memrefAndList.second[i];
221 bool srcHasStore =
222 getNode(id: srcId)->getStoreOpCount(memref: memrefAndList.first) > 0;
223 for (unsigned j = i + 1; j < n; ++j) {
224 unsigned dstId = memrefAndList.second[j];
225 bool dstHasStore =
226 getNode(id: dstId)->getStoreOpCount(memref: memrefAndList.first) > 0;
227 if (srcHasStore || dstHasStore)
228 addEdge(srcId, dstId, value: memrefAndList.first);
229 }
230 }
231 }
232 return true;
233}
234
235// Returns the graph node for 'id'.
236Node *MemRefDependenceGraph::getNode(unsigned id) {
237 auto it = nodes.find(Val: id);
238 assert(it != nodes.end());
239 return &it->second;
240}
241
242// Returns the graph node for 'forOp'.
243Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
244 for (auto &idAndNode : nodes)
245 if (idAndNode.second.op == forOp)
246 return &idAndNode.second;
247 return nullptr;
248}
249
250// Adds a node with 'op' to the graph and returns its unique identifier.
251unsigned MemRefDependenceGraph::addNode(Operation *op) {
252 Node node(nextNodeId++, op);
253 nodes.insert(KV: {node.id, node});
254 return node.id;
255}
256
257// Remove node 'id' (and its associated edges) from graph.
258void MemRefDependenceGraph::removeNode(unsigned id) {
259 // Remove each edge in 'inEdges[id]'.
260 if (inEdges.count(Val: id) > 0) {
261 SmallVector<Edge, 2> oldInEdges = inEdges[id];
262 for (auto &inEdge : oldInEdges) {
263 removeEdge(srcId: inEdge.id, dstId: id, value: inEdge.value);
264 }
265 }
266 // Remove each edge in 'outEdges[id]'.
267 if (outEdges.count(Val: id) > 0) {
268 SmallVector<Edge, 2> oldOutEdges = outEdges[id];
269 for (auto &outEdge : oldOutEdges) {
270 removeEdge(srcId: id, dstId: outEdge.id, value: outEdge.value);
271 }
272 }
273 // Erase remaining node state.
274 inEdges.erase(Val: id);
275 outEdges.erase(Val: id);
276 nodes.erase(Val: id);
277}
278
279// Returns true if node 'id' writes to any memref which escapes (or is an
280// argument to) the block. Returns false otherwise.
281bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
282 Node *node = getNode(id);
283 for (auto *storeOpInst : node->stores) {
284 auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
285 auto *op = memref.getDefiningOp();
286 // Return true if 'memref' is a block argument.
287 if (!op)
288 return true;
289 // Return true if any use of 'memref' does not deference it in an affine
290 // way.
291 for (auto *user : memref.getUsers())
292 if (!isa<AffineMapAccessInterface>(*user))
293 return true;
294 }
295 return false;
296}
297
298// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
299// is for 'value' if non-null, or for any value otherwise. Returns false
300// otherwise.
301bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
302 Value value) {
303 if (outEdges.count(Val: srcId) == 0 || inEdges.count(Val: dstId) == 0) {
304 return false;
305 }
306 bool hasOutEdge = llvm::any_of(Range&: outEdges[srcId], P: [=](Edge &edge) {
307 return edge.id == dstId && (!value || edge.value == value);
308 });
309 bool hasInEdge = llvm::any_of(Range&: inEdges[dstId], P: [=](Edge &edge) {
310 return edge.id == srcId && (!value || edge.value == value);
311 });
312 return hasOutEdge && hasInEdge;
313}
314
315// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
316void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
317 Value value) {
318 if (!hasEdge(srcId, dstId, value)) {
319 outEdges[srcId].push_back(Elt: {.id: dstId, .value: value});
320 inEdges[dstId].push_back(Elt: {.id: srcId, .value: value});
321 if (isa<MemRefType>(Val: value.getType()))
322 memrefEdgeCount[value]++;
323 }
324}
325
326// Removes an edge from node 'srcId' to node 'dstId' for 'value'.
327void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
328 Value value) {
329 assert(inEdges.count(dstId) > 0);
330 assert(outEdges.count(srcId) > 0);
331 if (isa<MemRefType>(Val: value.getType())) {
332 assert(memrefEdgeCount.count(value) > 0);
333 memrefEdgeCount[value]--;
334 }
335 // Remove 'srcId' from 'inEdges[dstId]'.
336 for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
337 if ((*it).id == srcId && (*it).value == value) {
338 inEdges[dstId].erase(CI: it);
339 break;
340 }
341 }
342 // Remove 'dstId' from 'outEdges[srcId]'.
343 for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
344 if ((*it).id == dstId && (*it).value == value) {
345 outEdges[srcId].erase(CI: it);
346 break;
347 }
348 }
349}
350
351// Returns true if there is a path in the dependence graph from node 'srcId'
352// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
353// operations that the edges connected are expected to be from the same block.
354bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
355 // Worklist state is: <node-id, next-output-edge-index-to-visit>
356 SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
357 worklist.push_back(Elt: {srcId, 0});
358 Operation *dstOp = getNode(id: dstId)->op;
359 // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
360 while (!worklist.empty()) {
361 auto &idAndIndex = worklist.back();
362 // Return true if we have reached 'dstId'.
363 if (idAndIndex.first == dstId)
364 return true;
365 // Pop and continue if node has no out edges, or if all out edges have
366 // already been visited.
367 if (outEdges.count(Val: idAndIndex.first) == 0 ||
368 idAndIndex.second == outEdges[idAndIndex.first].size()) {
369 worklist.pop_back();
370 continue;
371 }
372 // Get graph edge to traverse.
373 Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
374 // Increment next output edge index for 'idAndIndex'.
375 ++idAndIndex.second;
376 // Add node at 'edge.id' to the worklist. We don't need to consider
377 // nodes that are "after" dstId in the containing block; one can't have a
378 // path to `dstId` from any of those nodes.
379 bool afterDst = dstOp->isBeforeInBlock(other: getNode(id: edge.id)->op);
380 if (!afterDst && edge.id != idAndIndex.first)
381 worklist.push_back(Elt: {edge.id, 0});
382 }
383 return false;
384}
385
386// Returns the input edge count for node 'id' and 'memref' from src nodes
387// which access 'memref' with a store operation.
388unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
389 Value memref) {
390 unsigned inEdgeCount = 0;
391 if (inEdges.count(Val: id) > 0)
392 for (auto &inEdge : inEdges[id])
393 if (inEdge.value == memref) {
394 Node *srcNode = getNode(id: inEdge.id);
395 // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
396 if (srcNode->getStoreOpCount(memref) > 0)
397 ++inEdgeCount;
398 }
399 return inEdgeCount;
400}
401
402// Returns the output edge count for node 'id' and 'memref' (if non-null),
403// otherwise returns the total output edge count from node 'id'.
404unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
405 unsigned outEdgeCount = 0;
406 if (outEdges.count(Val: id) > 0)
407 for (auto &outEdge : outEdges[id])
408 if (!memref || outEdge.value == memref)
409 ++outEdgeCount;
410 return outEdgeCount;
411}
412
413/// Return all nodes which define SSA values used in node 'id'.
414void MemRefDependenceGraph::gatherDefiningNodes(
415 unsigned id, DenseSet<unsigned> &definingNodes) {
416 for (MemRefDependenceGraph::Edge edge : inEdges[id])
417 // By definition of edge, if the edge value is a non-memref value,
418 // then the dependence is between a graph node which defines an SSA value
419 // and another graph node which uses the SSA value.
420 if (!isa<MemRefType>(Val: edge.value.getType()))
421 definingNodes.insert(V: edge.id);
422}
423
424// Computes and returns an insertion point operation, before which the
425// the fused <srcId, dstId> loop nest can be inserted while preserving
426// dependences. Returns nullptr if no such insertion point is found.
427Operation *
428MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
429 unsigned dstId) {
430 if (outEdges.count(Val: srcId) == 0)
431 return getNode(id: dstId)->op;
432
433 // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
434 DenseSet<unsigned> definingNodes;
435 gatherDefiningNodes(id: dstId, definingNodes);
436 if (llvm::any_of(Range&: definingNodes,
437 P: [&](unsigned id) { return hasDependencePath(srcId, dstId: id); })) {
438 LLVM_DEBUG(llvm::dbgs()
439 << "Can't fuse: a defining op with a user in the dst "
440 "loop has dependence from the src loop\n");
441 return nullptr;
442 }
443
444 // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
445 SmallPtrSet<Operation *, 2> srcDepInsts;
446 for (auto &outEdge : outEdges[srcId])
447 if (outEdge.id != dstId)
448 srcDepInsts.insert(Ptr: getNode(id: outEdge.id)->op);
449
450 // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
451 SmallPtrSet<Operation *, 2> dstDepInsts;
452 for (auto &inEdge : inEdges[dstId])
453 if (inEdge.id != srcId)
454 dstDepInsts.insert(Ptr: getNode(id: inEdge.id)->op);
455
456 Operation *srcNodeInst = getNode(id: srcId)->op;
457 Operation *dstNodeInst = getNode(id: dstId)->op;
458
459 // Computing insertion point:
460 // *) Walk all operation positions in Block operation list in the
461 // range (src, dst). For each operation 'op' visited in this search:
462 // *) Store in 'firstSrcDepPos' the first position where 'op' has a
463 // dependence edge from 'srcNode'.
464 // *) Store in 'lastDstDepPost' the last position where 'op' has a
465 // dependence edge to 'dstNode'.
466 // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
467 // operation insertion point (or return null pointer if no such
468 // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
469 SmallVector<Operation *, 2> depInsts;
470 std::optional<unsigned> firstSrcDepPos;
471 std::optional<unsigned> lastDstDepPos;
472 unsigned pos = 0;
473 for (Block::iterator it = std::next(x: Block::iterator(srcNodeInst));
474 it != Block::iterator(dstNodeInst); ++it) {
475 Operation *op = &(*it);
476 if (srcDepInsts.count(Ptr: op) > 0 && firstSrcDepPos == std::nullopt)
477 firstSrcDepPos = pos;
478 if (dstDepInsts.count(Ptr: op) > 0)
479 lastDstDepPos = pos;
480 depInsts.push_back(Elt: op);
481 ++pos;
482 }
483
484 if (firstSrcDepPos.has_value()) {
485 if (lastDstDepPos.has_value()) {
486 if (*firstSrcDepPos <= *lastDstDepPos) {
487 // No valid insertion point exists which preserves dependences.
488 return nullptr;
489 }
490 }
491 // Return the insertion point at 'firstSrcDepPos'.
492 return depInsts[*firstSrcDepPos];
493 }
494 // No dependence targets in range (or only dst deps in range), return
495 // 'dstNodInst' insertion point.
496 return dstNodeInst;
497}
498
499// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
500// taking into account that:
501// *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
502// *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
503// private memref.
504void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
505 const DenseSet<Value> &privateMemRefs,
506 bool removeSrcId) {
507 // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
508 if (inEdges.count(Val: srcId) > 0) {
509 SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
510 for (auto &inEdge : oldInEdges) {
511 // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
512 if (privateMemRefs.count(V: inEdge.value) == 0)
513 addEdge(srcId: inEdge.id, dstId, value: inEdge.value);
514 }
515 }
516 // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
517 // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
518 if (outEdges.count(Val: srcId) > 0) {
519 SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
520 for (auto &outEdge : oldOutEdges) {
521 // Remove any out edges from 'srcId' to 'dstId' across memrefs.
522 if (outEdge.id == dstId)
523 removeEdge(srcId, dstId: outEdge.id, value: outEdge.value);
524 else if (removeSrcId) {
525 addEdge(srcId: dstId, dstId: outEdge.id, value: outEdge.value);
526 removeEdge(srcId, dstId: outEdge.id, value: outEdge.value);
527 }
528 }
529 }
530 // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
531 // replaced by a private memref). These edges could come from nodes
532 // other than 'srcId' which were removed in the previous step.
533 if (inEdges.count(Val: dstId) > 0 && !privateMemRefs.empty()) {
534 SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
535 for (auto &inEdge : oldInEdges)
536 if (privateMemRefs.count(V: inEdge.value) > 0)
537 removeEdge(srcId: inEdge.id, dstId, value: inEdge.value);
538 }
539}
540
541// Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
542// of sibling node 'sibId' into node 'dstId'.
543void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) {
544 // For each edge in 'inEdges[sibId]':
545 // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
546 // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
547 if (inEdges.count(Val: sibId) > 0) {
548 SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
549 for (auto &inEdge : oldInEdges) {
550 addEdge(srcId: inEdge.id, dstId, value: inEdge.value);
551 removeEdge(srcId: inEdge.id, dstId: sibId, value: inEdge.value);
552 }
553 }
554
555 // For each edge in 'outEdges[sibId]' to node 'id'
556 // *) Add new edge from 'dstId' to 'outEdge.id'.
557 // *) Remove edge from 'sibId' to 'outEdge.id'.
558 if (outEdges.count(Val: sibId) > 0) {
559 SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
560 for (auto &outEdge : oldOutEdges) {
561 addEdge(srcId: dstId, dstId: outEdge.id, value: outEdge.value);
562 removeEdge(srcId: sibId, dstId: outEdge.id, value: outEdge.value);
563 }
564 }
565}
566
567// Adds ops in 'loads' and 'stores' to node at 'id'.
568void MemRefDependenceGraph::addToNode(
569 unsigned id, const SmallVectorImpl<Operation *> &loads,
570 const SmallVectorImpl<Operation *> &stores) {
571 Node *node = getNode(id);
572 llvm::append_range(C&: node->loads, R: loads);
573 llvm::append_range(C&: node->stores, R: stores);
574}
575
576void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) {
577 Node *node = getNode(id);
578 node->loads.clear();
579 node->stores.clear();
580}
581
582// Calls 'callback' for each input edge incident to node 'id' which carries a
583// memref dependence.
584void MemRefDependenceGraph::forEachMemRefInputEdge(
585 unsigned id, const std::function<void(Edge)> &callback) {
586 if (inEdges.count(Val: id) > 0)
587 forEachMemRefEdge(edges: inEdges[id], callback);
588}
589
590// Calls 'callback' for each output edge from node 'id' which carries a
591// memref dependence.
592void MemRefDependenceGraph::forEachMemRefOutputEdge(
593 unsigned id, const std::function<void(Edge)> &callback) {
594 if (outEdges.count(Val: id) > 0)
595 forEachMemRefEdge(edges: outEdges[id], callback);
596}
597
598// Calls 'callback' for each edge in 'edges' which carries a memref
599// dependence.
600void MemRefDependenceGraph::forEachMemRefEdge(
601 ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
602 for (const auto &edge : edges) {
603 // Skip if 'edge' is not a memref dependence edge.
604 if (!isa<MemRefType>(Val: edge.value.getType()))
605 continue;
606 assert(nodes.count(edge.id) > 0);
607 // Skip if 'edge.id' is not a loop nest.
608 if (!isa<AffineForOp>(Val: getNode(id: edge.id)->op))
609 continue;
610 // Visit current input edge 'edge'.
611 callback(edge);
612 }
613}
614
615void MemRefDependenceGraph::print(raw_ostream &os) const {
616 os << "\nMemRefDependenceGraph\n";
617 os << "\nNodes:\n";
618 for (const auto &idAndNode : nodes) {
619 os << "Node: " << idAndNode.first << "\n";
620 auto it = inEdges.find(Val: idAndNode.first);
621 if (it != inEdges.end()) {
622 for (const auto &e : it->second)
623 os << " InEdge: " << e.id << " " << e.value << "\n";
624 }
625 it = outEdges.find(Val: idAndNode.first);
626 if (it != outEdges.end()) {
627 for (const auto &e : it->second)
628 os << " OutEdge: " << e.id << " " << e.value << "\n";
629 }
630 }
631}
632
633void mlir::affine::getAffineForIVs(Operation &op,
634 SmallVectorImpl<AffineForOp> *loops) {
635 auto *currOp = op.getParentOp();
636 AffineForOp currAffineForOp;
637 // Traverse up the hierarchy collecting all 'affine.for' operation while
638 // skipping over 'affine.if' operations.
639 while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
640 if (auto currAffineForOp = dyn_cast<AffineForOp>(currOp))
641 loops->push_back(currAffineForOp);
642 currOp = currOp->getParentOp();
643 }
644 std::reverse(loops->begin(), loops->end());
645}
646
647void mlir::affine::getEnclosingAffineOps(Operation &op,
648 SmallVectorImpl<Operation *> *ops) {
649 ops->clear();
650 Operation *currOp = op.getParentOp();
651
652 // Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and
653 // affine.parallel operations.
654 while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
655 if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(Val: currOp))
656 ops->push_back(Elt: currOp);
657 currOp = currOp->getParentOp();
658 }
659 std::reverse(first: ops->begin(), last: ops->end());
660}
661
662// Populates 'cst' with FlatAffineValueConstraints which represent original
663// domain of the loop bounds that define 'ivs'.
664LogicalResult ComputationSliceState::getSourceAsConstraints(
665 FlatAffineValueConstraints &cst) const {
666 assert(!ivs.empty() && "Cannot have a slice without its IVs");
667 cst = FlatAffineValueConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0,
668 /*numLocals=*/0, ivs);
669 for (Value iv : ivs) {
670 AffineForOp loop = getForInductionVarOwner(iv);
671 assert(loop && "Expected affine for");
672 if (failed(cst.addAffineForOpDomain(forOp: loop)))
673 return failure();
674 }
675 return success();
676}
677
678// Populates 'cst' with FlatAffineValueConstraints which represent slice bounds.
679LogicalResult
680ComputationSliceState::getAsConstraints(FlatAffineValueConstraints *cst) const {
681 assert(!lbOperands.empty());
682 // Adds src 'ivs' as dimension variables in 'cst'.
683 unsigned numDims = ivs.size();
684 // Adds operands (dst ivs and symbols) as symbols in 'cst'.
685 unsigned numSymbols = lbOperands[0].size();
686
687 SmallVector<Value, 4> values(ivs);
688 // Append 'ivs' then 'operands' to 'values'.
689 values.append(in_start: lbOperands[0].begin(), in_end: lbOperands[0].end());
690 *cst = FlatAffineValueConstraints(numDims, numSymbols, 0, values);
691
692 // Add loop bound constraints for values which are loop IVs of the destination
693 // of fusion and equality constraints for symbols which are constants.
694 for (unsigned i = numDims, end = values.size(); i < end; ++i) {
695 Value value = values[i];
696 assert(cst->containsVar(value) && "value expected to be present");
697 if (isValidSymbol(value)) {
698 // Check if the symbol is a constant.
699 if (std::optional<int64_t> cOp = getConstantIntValue(ofr: value))
700 cst->addBound(type: BoundType::EQ, val: value, value: cOp.value());
701 } else if (auto loop = getForInductionVarOwner(value)) {
702 if (failed(cst->addAffineForOpDomain(forOp: loop)))
703 return failure();
704 }
705 }
706
707 // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
708 LogicalResult ret = cst->addSliceBounds(values: ivs, lbMaps: lbs, ubMaps: ubs, operands: lbOperands[0]);
709 assert(succeeded(ret) &&
710 "should not fail as we never have semi-affine slice maps");
711 (void)ret;
712 return success();
713}
714
715// Clears state bounds and operand state.
716void ComputationSliceState::clearBounds() {
717 lbs.clear();
718 ubs.clear();
719 lbOperands.clear();
720 ubOperands.clear();
721}
722
723void ComputationSliceState::dump() const {
724 llvm::errs() << "\tIVs:\n";
725 for (Value iv : ivs)
726 llvm::errs() << "\t\t" << iv << "\n";
727
728 llvm::errs() << "\tLBs:\n";
729 for (auto en : llvm::enumerate(First: lbs)) {
730 llvm::errs() << "\t\t" << en.value() << "\n";
731 llvm::errs() << "\t\tOperands:\n";
732 for (Value lbOp : lbOperands[en.index()])
733 llvm::errs() << "\t\t\t" << lbOp << "\n";
734 }
735
736 llvm::errs() << "\tUBs:\n";
737 for (auto en : llvm::enumerate(First: ubs)) {
738 llvm::errs() << "\t\t" << en.value() << "\n";
739 llvm::errs() << "\t\tOperands:\n";
740 for (Value ubOp : ubOperands[en.index()])
741 llvm::errs() << "\t\t\t" << ubOp << "\n";
742 }
743}
744
745/// Fast check to determine if the computation slice is maximal. Returns true if
746/// each slice dimension maps to an existing dst dimension and both the src
747/// and the dst loops for those dimensions have the same bounds. Returns false
748/// if both the src and the dst loops don't have the same bounds. Returns
749/// std::nullopt if none of the above can be proven.
750std::optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
751 assert(lbs.size() == ubs.size() && !lbs.empty() && !ivs.empty() &&
752 "Unexpected number of lbs, ubs and ivs in slice");
753
754 for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
755 AffineMap lbMap = lbs[i];
756 AffineMap ubMap = ubs[i];
757
758 // Check if this slice is just an equality along this dimension.
759 if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
760 ubMap.getNumResults() != 1 ||
761 lbMap.getResult(idx: 0) + 1 != ubMap.getResult(idx: 0) ||
762 // The condition above will be true for maps describing a single
763 // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
764 // Make sure we skip those cases by checking that the lb result is not
765 // just a constant.
766 isa<AffineConstantExpr>(Val: lbMap.getResult(idx: 0)))
767 return std::nullopt;
768
769 // Limited support: we expect the lb result to be just a loop dimension for
770 // now.
771 AffineDimExpr result = dyn_cast<AffineDimExpr>(Val: lbMap.getResult(idx: 0));
772 if (!result)
773 return std::nullopt;
774
775 // Retrieve dst loop bounds.
776 AffineForOp dstLoop =
777 getForInductionVarOwner(lbOperands[i][result.getPosition()]);
778 if (!dstLoop)
779 return std::nullopt;
780 AffineMap dstLbMap = dstLoop.getLowerBoundMap();
781 AffineMap dstUbMap = dstLoop.getUpperBoundMap();
782
783 // Retrieve src loop bounds.
784 AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
785 assert(srcLoop && "Expected affine for");
786 AffineMap srcLbMap = srcLoop.getLowerBoundMap();
787 AffineMap srcUbMap = srcLoop.getUpperBoundMap();
788
789 // Limited support: we expect simple src and dst loops with a single
790 // constant component per bound for now.
791 if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
792 dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
793 return std::nullopt;
794
795 AffineExpr srcLbResult = srcLbMap.getResult(idx: 0);
796 AffineExpr dstLbResult = dstLbMap.getResult(idx: 0);
797 AffineExpr srcUbResult = srcUbMap.getResult(idx: 0);
798 AffineExpr dstUbResult = dstUbMap.getResult(idx: 0);
799 if (!isa<AffineConstantExpr>(Val: srcLbResult) ||
800 !isa<AffineConstantExpr>(Val: srcUbResult) ||
801 !isa<AffineConstantExpr>(Val: dstLbResult) ||
802 !isa<AffineConstantExpr>(Val: dstUbResult))
803 return std::nullopt;
804
805 // Check if src and dst loop bounds are the same. If not, we can guarantee
806 // that the slice is not maximal.
807 if (srcLbResult != dstLbResult || srcUbResult != dstUbResult ||
808 srcLoop.getStep() != dstLoop.getStep())
809 return false;
810 }
811
812 return true;
813}
814
815/// Returns true if it is deterministically verified that the original iteration
816/// space of the slice is contained within the new iteration space that is
817/// created after fusing 'this' slice into its destination.
818std::optional<bool> ComputationSliceState::isSliceValid() const {
819 // Fast check to determine if the slice is valid. If the following conditions
820 // are verified to be true, slice is declared valid by the fast check:
821 // 1. Each slice loop is a single iteration loop bound in terms of a single
822 // destination loop IV.
823 // 2. Loop bounds of the destination loop IV (from above) and those of the
824 // source loop IV are exactly the same.
825 // If the fast check is inconclusive or false, we proceed with a more
826 // expensive analysis.
827 // TODO: Store the result of the fast check, as it might be used again in
828 // `canRemoveSrcNodeAfterFusion`.
829 std::optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
830 if (isValidFastCheck && *isValidFastCheck)
831 return true;
832
833 // Create constraints for the source loop nest using which slice is computed.
834 FlatAffineValueConstraints srcConstraints;
835 // TODO: Store the source's domain to avoid computation at each depth.
836 if (failed(result: getSourceAsConstraints(cst&: srcConstraints))) {
837 LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
838 return std::nullopt;
839 }
840 // As the set difference utility currently cannot handle symbols in its
841 // operands, validity of the slice cannot be determined.
842 if (srcConstraints.getNumSymbolVars() > 0) {
843 LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
844 return std::nullopt;
845 }
846 // TODO: Handle local vars in the source domains while using the 'projectOut'
847 // utility below. Currently, aligning is not done assuming that there will be
848 // no local vars in the source domain.
849 if (srcConstraints.getNumLocalVars() != 0) {
850 LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
851 return std::nullopt;
852 }
853
854 // Create constraints for the slice loop nest that would be created if the
855 // fusion succeeds.
856 FlatAffineValueConstraints sliceConstraints;
857 if (failed(result: getAsConstraints(cst: &sliceConstraints))) {
858 LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
859 return std::nullopt;
860 }
861
862 // Projecting out every dimension other than the 'ivs' to express slice's
863 // domain completely in terms of source's IVs.
864 sliceConstraints.projectOut(pos: ivs.size(),
865 num: sliceConstraints.getNumVars() - ivs.size());
866
867 LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
868 LLVM_DEBUG(srcConstraints.dump());
869 LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
870 "(expressed in terms of its source's IVs):\n");
871 LLVM_DEBUG(sliceConstraints.dump());
872
873 // TODO: Store 'srcSet' to avoid recalculating for each depth.
874 PresburgerSet srcSet(srcConstraints);
875 PresburgerSet sliceSet(sliceConstraints);
876 PresburgerSet diffSet = sliceSet.subtract(set: srcSet);
877
878 if (!diffSet.isIntegerEmpty()) {
879 LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
880 return false;
881 }
882 return true;
883}
884
885/// Returns true if the computation slice encloses all the iterations of the
886/// sliced loop nest. Returns false if it does not. Returns std::nullopt if it
887/// cannot determine if the slice is maximal or not.
888std::optional<bool> ComputationSliceState::isMaximal() const {
889 // Fast check to determine if the computation slice is maximal. If the result
890 // is inconclusive, we proceed with a more expensive analysis.
891 std::optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
892 if (isMaximalFastCheck)
893 return isMaximalFastCheck;
894
895 // Create constraints for the src loop nest being sliced.
896 FlatAffineValueConstraints srcConstraints(/*numDims=*/ivs.size(),
897 /*numSymbols=*/0,
898 /*numLocals=*/0, ivs);
899 for (Value iv : ivs) {
900 AffineForOp loop = getForInductionVarOwner(iv);
901 assert(loop && "Expected affine for");
902 if (failed(srcConstraints.addAffineForOpDomain(forOp: loop)))
903 return std::nullopt;
904 }
905
906 // Create constraints for the slice using the dst loop nest information. We
907 // retrieve existing dst loops from the lbOperands.
908 SmallVector<Value> consumerIVs;
909 for (Value lbOp : lbOperands[0])
910 if (getForInductionVarOwner(lbOp))
911 consumerIVs.push_back(Elt: lbOp);
912
913 // Add empty IV Values for those new loops that are not equalities and,
914 // therefore, are not yet materialized in the IR.
915 for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
916 consumerIVs.push_back(Elt: Value());
917
918 FlatAffineValueConstraints sliceConstraints(/*numDims=*/consumerIVs.size(),
919 /*numSymbols=*/0,
920 /*numLocals=*/0, consumerIVs);
921
922 if (failed(result: sliceConstraints.addDomainFromSliceMaps(lbMaps: lbs, ubMaps: ubs, operands: lbOperands[0])))
923 return std::nullopt;
924
925 if (srcConstraints.getNumDimVars() != sliceConstraints.getNumDimVars())
926 // Constraint dims are different. The integer set difference can't be
927 // computed so we don't know if the slice is maximal.
928 return std::nullopt;
929
930 // Compute the difference between the src loop nest and the slice integer
931 // sets.
932 PresburgerSet srcSet(srcConstraints);
933 PresburgerSet sliceSet(sliceConstraints);
934 PresburgerSet diffSet = srcSet.subtract(set: sliceSet);
935 return diffSet.isIntegerEmpty();
936}
937
938unsigned MemRefRegion::getRank() const {
939 return cast<MemRefType>(memref.getType()).getRank();
940}
941
942std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
943 SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
944 SmallVectorImpl<int64_t> *lbDivisors) const {
945 auto memRefType = cast<MemRefType>(memref.getType());
946 unsigned rank = memRefType.getRank();
947 if (shape)
948 shape->reserve(N: rank);
949
950 assert(rank == cst.getNumDimVars() && "inconsistent memref region");
951
952 // Use a copy of the region constraints that has upper/lower bounds for each
953 // memref dimension with static size added to guard against potential
954 // over-approximation from projection or union bounding box. We may not add
955 // this on the region itself since they might just be redundant constraints
956 // that will need non-trivials means to eliminate.
957 FlatAffineValueConstraints cstWithShapeBounds(cst);
958 for (unsigned r = 0; r < rank; r++) {
959 cstWithShapeBounds.addBound(type: BoundType::LB, pos: r, value: 0);
960 int64_t dimSize = memRefType.getDimSize(r);
961 if (ShapedType::isDynamic(dimSize))
962 continue;
963 cstWithShapeBounds.addBound(type: BoundType::UB, pos: r, value: dimSize - 1);
964 }
965
966 // Find a constant upper bound on the extent of this memref region along each
967 // dimension.
968 int64_t numElements = 1;
969 int64_t diffConstant;
970 int64_t lbDivisor;
971 for (unsigned d = 0; d < rank; d++) {
972 SmallVector<int64_t, 4> lb;
973 std::optional<int64_t> diff =
974 cstWithShapeBounds.getConstantBoundOnDimSize64(pos: d, lb: &lb, boundFloorDivisor: &lbDivisor);
975 if (diff.has_value()) {
976 diffConstant = *diff;
977 assert(diffConstant >= 0 && "Dim size bound can't be negative");
978 assert(lbDivisor > 0);
979 } else {
980 // If no constant bound is found, then it can always be bound by the
981 // memref's dim size if the latter has a constant size along this dim.
982 auto dimSize = memRefType.getDimSize(d);
983 if (dimSize == ShapedType::kDynamic)
984 return std::nullopt;
985 diffConstant = dimSize;
986 // Lower bound becomes 0.
987 lb.resize(N: cstWithShapeBounds.getNumSymbolVars() + 1, NV: 0);
988 lbDivisor = 1;
989 }
990 numElements *= diffConstant;
991 if (lbs) {
992 lbs->push_back(x: lb);
993 assert(lbDivisors && "both lbs and lbDivisor or none");
994 lbDivisors->push_back(Elt: lbDivisor);
995 }
996 if (shape) {
997 shape->push_back(Elt: diffConstant);
998 }
999 }
1000 return numElements;
1001}
1002
1003void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
1004 AffineMap &ubMap) const {
1005 assert(pos < cst.getNumDimVars() && "invalid position");
1006 auto memRefType = cast<MemRefType>(memref.getType());
1007 unsigned rank = memRefType.getRank();
1008
1009 assert(rank == cst.getNumDimVars() && "inconsistent memref region");
1010
1011 auto boundPairs = cst.getLowerAndUpperBound(
1012 pos, /*offset=*/0, /*num=*/rank, symStartPos: cst.getNumDimAndSymbolVars(),
1013 /*localExprs=*/{}, context: memRefType.getContext());
1014 lbMap = boundPairs.first;
1015 ubMap = boundPairs.second;
1016 assert(lbMap && "lower bound for a region must exist");
1017 assert(ubMap && "upper bound for a region must exist");
1018 assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1019 assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
1020}
1021
1022LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
1023 assert(memref == other.memref);
1024 return cst.unionBoundingBox(other: *other.getConstraints());
1025}
1026
1027/// Computes the memory region accessed by this memref with the region
1028/// represented as constraints symbolic/parametric in 'loopDepth' loops
1029/// surrounding opInst and any additional Function symbols.
1030// For example, the memref region for this load operation at loopDepth = 1 will
1031// be as below:
1032//
1033// affine.for %i = 0 to 32 {
1034// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
1035// load %A[%ii]
1036// }
1037// }
1038//
1039// region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
1040// The last field is a 2-d FlatAffineValueConstraints symbolic in %i.
1041//
1042// TODO: extend this to any other memref dereferencing ops
1043// (dma_start, dma_wait).
1044LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
1045 const ComputationSliceState *sliceState,
1046 bool addMemRefDimBounds) {
1047 assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
1048 "affine read/write op expected");
1049
1050 MemRefAccess access(op);
1051 memref = access.memref;
1052 write = access.isStore();
1053
1054 unsigned rank = access.getRank();
1055
1056 LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
1057 << "\ndepth: " << loopDepth << "\n";);
1058
1059 // 0-d memrefs.
1060 if (rank == 0) {
1061 SmallVector<Value, 4> ivs;
1062 getAffineIVs(op&: *op, ivs);
1063 assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
1064 // The first 'loopDepth' IVs are symbols for this region.
1065 ivs.resize(N: loopDepth);
1066 // A 0-d memref has a 0-d region.
1067 cst = FlatAffineValueConstraints(rank, loopDepth, /*numLocals=*/0, ivs);
1068 return success();
1069 }
1070
1071 // Build the constraints for this region.
1072 AffineValueMap accessValueMap;
1073 access.getAccessMap(accessMap: &accessValueMap);
1074 AffineMap accessMap = accessValueMap.getAffineMap();
1075
1076 unsigned numDims = accessMap.getNumDims();
1077 unsigned numSymbols = accessMap.getNumSymbols();
1078 unsigned numOperands = accessValueMap.getNumOperands();
1079 // Merge operands with slice operands.
1080 SmallVector<Value, 4> operands;
1081 operands.resize(N: numOperands);
1082 for (unsigned i = 0; i < numOperands; ++i)
1083 operands[i] = accessValueMap.getOperand(i);
1084
1085 if (sliceState != nullptr) {
1086 operands.reserve(N: operands.size() + sliceState->lbOperands[0].size());
1087 // Append slice operands to 'operands' as symbols.
1088 for (auto extraOperand : sliceState->lbOperands[0]) {
1089 if (!llvm::is_contained(Range&: operands, Element: extraOperand)) {
1090 operands.push_back(Elt: extraOperand);
1091 numSymbols++;
1092 }
1093 }
1094 }
1095 // We'll first associate the dims and symbols of the access map to the dims
1096 // and symbols resp. of cst. This will change below once cst is
1097 // fully constructed out.
1098 cst = FlatAffineValueConstraints(numDims, numSymbols, 0, operands);
1099
1100 // Add equality constraints.
1101 // Add inequalities for loop lower/upper bounds.
1102 for (unsigned i = 0; i < numDims + numSymbols; ++i) {
1103 auto operand = operands[i];
1104 if (auto affineFor = getForInductionVarOwner(operand)) {
1105 // Note that cst can now have more dimensions than accessMap if the
1106 // bounds expressions involve outer loops or other symbols.
1107 // TODO: rewrite this to use getInstIndexSet; this way
1108 // conditionals will be handled when the latter supports it.
1109 if (failed(cst.addAffineForOpDomain(forOp: affineFor)))
1110 return failure();
1111 } else if (auto parallelOp = getAffineParallelInductionVarOwner(operand)) {
1112 if (failed(cst.addAffineParallelOpDomain(parallelOp: parallelOp)))
1113 return failure();
1114 } else if (isValidSymbol(value: operand)) {
1115 // Check if the symbol is a constant.
1116 Value symbol = operand;
1117 if (auto constVal = getConstantIntValue(ofr: symbol))
1118 cst.addBound(type: BoundType::EQ, val: symbol, value: constVal.value());
1119 } else {
1120 LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value");
1121 return failure();
1122 }
1123 }
1124
1125 // Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
1126 if (sliceState != nullptr) {
1127 // Add dim and symbol slice operands.
1128 for (auto operand : sliceState->lbOperands[0]) {
1129 cst.addInductionVarOrTerminalSymbol(val: operand);
1130 }
1131 // Add upper/lower bounds from 'sliceState' to 'cst'.
1132 LogicalResult ret =
1133 cst.addSliceBounds(values: sliceState->ivs, lbMaps: sliceState->lbs, ubMaps: sliceState->ubs,
1134 operands: sliceState->lbOperands[0]);
1135 assert(succeeded(ret) &&
1136 "should not fail as we never have semi-affine slice maps");
1137 (void)ret;
1138 }
1139
1140 // Add access function equalities to connect loop IVs to data dimensions.
1141 if (failed(result: cst.composeMap(vMap: &accessValueMap))) {
1142 op->emitError(message: "getMemRefRegion: compose affine map failed");
1143 LLVM_DEBUG(accessValueMap.getAffineMap().dump());
1144 return failure();
1145 }
1146
1147 // Set all variables appearing after the first 'rank' variables as
1148 // symbolic variables - so that the ones corresponding to the memref
1149 // dimensions are the dimensional variables for the memref region.
1150 cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - rank);
1151
1152 // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
1153 // this memref region is symbolic.
1154 SmallVector<Value, 4> enclosingIVs;
1155 getAffineIVs(op&: *op, ivs&: enclosingIVs);
1156 assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
1157 enclosingIVs.resize(N: loopDepth);
1158 SmallVector<Value, 4> vars;
1159 cst.getValues(start: cst.getNumDimVars(), end: cst.getNumDimAndSymbolVars(), values: &vars);
1160 for (Value var : vars) {
1161 if ((isAffineInductionVar(val: var)) && !llvm::is_contained(Range&: enclosingIVs, Element: var)) {
1162 cst.projectOut(val: var);
1163 }
1164 }
1165
1166 // Project out any local variables (these would have been added for any
1167 // mod/divs).
1168 cst.projectOut(pos: cst.getNumDimAndSymbolVars(), num: cst.getNumLocalVars());
1169
1170 // Constant fold any symbolic variables.
1171 cst.constantFoldVarRange(/*pos=*/cst.getNumDimVars(),
1172 /*num=*/cst.getNumSymbolVars());
1173
1174 assert(cst.getNumDimVars() == rank && "unexpected MemRefRegion format");
1175
1176 // Add upper/lower bounds for each memref dimension with static size
1177 // to guard against potential over-approximation from projection.
1178 // TODO: Support dynamic memref dimensions.
1179 if (addMemRefDimBounds) {
1180 auto memRefType = cast<MemRefType>(memref.getType());
1181 for (unsigned r = 0; r < rank; r++) {
1182 cst.addBound(type: BoundType::LB, /*pos=*/r, /*value=*/0);
1183 if (memRefType.isDynamicDim(r))
1184 continue;
1185 cst.addBound(BoundType::UB, /*pos=*/r, memRefType.getDimSize(r) - 1);
1186 }
1187 }
1188 cst.removeTrivialRedundancy();
1189
1190 LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
1191 LLVM_DEBUG(cst.dump());
1192 return success();
1193}
1194
1195std::optional<int64_t>
1196mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
1197 auto elementType = memRefType.getElementType();
1198
1199 unsigned sizeInBits;
1200 if (elementType.isIntOrFloat()) {
1201 sizeInBits = elementType.getIntOrFloatBitWidth();
1202 } else if (auto vectorType = dyn_cast<VectorType>(elementType)) {
1203 if (vectorType.getElementType().isIntOrFloat())
1204 sizeInBits =
1205 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
1206 else
1207 return std::nullopt;
1208 } else {
1209 return std::nullopt;
1210 }
1211 return llvm::divideCeil(Numerator: sizeInBits, Denominator: 8);
1212}
1213
1214// Returns the size of the region.
1215std::optional<int64_t> MemRefRegion::getRegionSize() {
1216 auto memRefType = cast<MemRefType>(memref.getType());
1217
1218 if (!memRefType.getLayout().isIdentity()) {
1219 LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
1220 return false;
1221 }
1222
1223 // Indices to use for the DmaStart op.
1224 // Indices for the original memref being DMAed from/to.
1225 SmallVector<Value, 4> memIndices;
1226 // Indices for the faster buffer being DMAed into/from.
1227 SmallVector<Value, 4> bufIndices;
1228
1229 // Compute the extents of the buffer.
1230 std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
1231 if (!numElements) {
1232 LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
1233 return std::nullopt;
1234 }
1235 auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1236 if (!eltSize)
1237 return std::nullopt;
1238 return *eltSize * *numElements;
1239}
1240
1241/// Returns the size of memref data in bytes if it's statically shaped,
1242/// std::nullopt otherwise. If the element of the memref has vector type, takes
1243/// into account size of the vector as well.
1244// TODO: improve/complete this when we have target data.
1245std::optional<uint64_t>
1246mlir::affine::getIntOrFloatMemRefSizeInBytes(MemRefType memRefType) {
1247 if (!memRefType.hasStaticShape())
1248 return std::nullopt;
1249 auto elementType = memRefType.getElementType();
1250 if (!elementType.isIntOrFloat() && !isa<VectorType>(elementType))
1251 return std::nullopt;
1252
1253 auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType);
1254 if (!sizeInBytes)
1255 return std::nullopt;
1256 for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
1257 sizeInBytes = *sizeInBytes * memRefType.getDimSize(i);
1258 }
1259 return sizeInBytes;
1260}
1261
1262template <typename LoadOrStoreOp>
1263LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
1264 bool emitError) {
1265 static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
1266 AffineWriteOpInterface>::value,
1267 "argument should be either a AffineReadOpInterface or a "
1268 "AffineWriteOpInterface");
1269
1270 Operation *op = loadOrStoreOp.getOperation();
1271 MemRefRegion region(op->getLoc());
1272 if (failed(result: region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
1273 /*addMemRefDimBounds=*/false)))
1274 return success();
1275
1276 LLVM_DEBUG(llvm::dbgs() << "Memory region");
1277 LLVM_DEBUG(region.getConstraints()->dump());
1278
1279 bool outOfBounds = false;
1280 unsigned rank = loadOrStoreOp.getMemRefType().getRank();
1281
1282 // For each dimension, check for out of bounds.
1283 for (unsigned r = 0; r < rank; r++) {
1284 FlatAffineValueConstraints ucst(*region.getConstraints());
1285
1286 // Intersect memory region with constraint capturing out of bounds (both out
1287 // of upper and out of lower), and check if the constraint system is
1288 // feasible. If it is, there is at least one point out of bounds.
1289 SmallVector<int64_t, 4> ineq(rank + 1, 0);
1290 int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
1291 // TODO: handle dynamic dim sizes.
1292 if (dimSize == -1)
1293 continue;
1294
1295 // Check for overflow: d_i >= memref dim size.
1296 ucst.addBound(type: BoundType::LB, pos: r, value: dimSize);
1297 outOfBounds = !ucst.isEmpty();
1298 if (outOfBounds && emitError) {
1299 loadOrStoreOp.emitOpError()
1300 << "memref out of upper bound access along dimension #" << (r + 1);
1301 }
1302
1303 // Check for a negative index.
1304 FlatAffineValueConstraints lcst(*region.getConstraints());
1305 std::fill(first: ineq.begin(), last: ineq.end(), value: 0);
1306 // d_i <= -1;
1307 lcst.addBound(type: BoundType::UB, pos: r, value: -1);
1308 outOfBounds = !lcst.isEmpty();
1309 if (outOfBounds && emitError) {
1310 loadOrStoreOp.emitOpError()
1311 << "memref out of lower bound access along dimension #" << (r + 1);
1312 }
1313 }
1314 return failure(isFailure: outOfBounds);
1315}
1316
1317// Explicitly instantiate the template so that the compiler knows we need them!
1318template LogicalResult
1319mlir::affine::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp,
1320 bool emitError);
1321template LogicalResult
1322mlir::affine::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp,
1323 bool emitError);
1324
1325// Returns in 'positions' the Block positions of 'op' in each ancestor
1326// Block from the Block containing operation, stopping at 'limitBlock'.
1327static void findInstPosition(Operation *op, Block *limitBlock,
1328 SmallVectorImpl<unsigned> *positions) {
1329 Block *block = op->getBlock();
1330 while (block != limitBlock) {
1331 // FIXME: This algorithm is unnecessarily O(n) and should be improved to not
1332 // rely on linear scans.
1333 int instPosInBlock = std::distance(block->begin(), op->getIterator());
1334 positions->push_back(Elt: instPosInBlock);
1335 op = block->getParentOp();
1336 block = op->getBlock();
1337 }
1338 std::reverse(first: positions->begin(), last: positions->end());
1339}
1340
1341// Returns the Operation in a possibly nested set of Blocks, where the
1342// position of the operation is represented by 'positions', which has a
1343// Block position for each level of nesting.
1344static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
1345 unsigned level, Block *block) {
1346 unsigned i = 0;
1347 for (auto &op : *block) {
1348 if (i != positions[level]) {
1349 ++i;
1350 continue;
1351 }
1352 if (level == positions.size() - 1)
1353 return &op;
1354 if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
1355 return getInstAtPosition(positions, level + 1,
1356 childAffineForOp.getBody());
1357
1358 for (auto &region : op.getRegions()) {
1359 for (auto &b : region)
1360 if (auto *ret = getInstAtPosition(positions, level: level + 1, block: &b))
1361 return ret;
1362 }
1363 return nullptr;
1364 }
1365 return nullptr;
1366}
1367
1368// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
1369static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
1370 FlatAffineValueConstraints *cst) {
1371 for (unsigned i = 0, e = cst->getNumDimVars(); i < e; ++i) {
1372 auto value = cst->getValue(pos: i);
1373 if (ivs.count(Ptr: value) == 0) {
1374 assert(isAffineForInductionVar(value));
1375 auto loop = getForInductionVarOwner(value);
1376 if (failed(cst->addAffineForOpDomain(forOp: loop)))
1377 return failure();
1378 }
1379 }
1380 return success();
1381}
1382
1383/// Returns the innermost common loop depth for the set of operations in 'ops'.
1384// TODO: Move this to LoopUtils.
1385unsigned mlir::affine::getInnermostCommonLoopDepth(
1386 ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
1387 unsigned numOps = ops.size();
1388 assert(numOps > 0 && "Expected at least one operation");
1389
1390 std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
1391 unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
1392 for (unsigned i = 0; i < numOps; ++i) {
1393 getAffineForIVs(*ops[i], &loops[i]);
1394 loopDepthLimit =
1395 std::min(a: loopDepthLimit, b: static_cast<unsigned>(loops[i].size()));
1396 }
1397
1398 unsigned loopDepth = 0;
1399 for (unsigned d = 0; d < loopDepthLimit; ++d) {
1400 unsigned i;
1401 for (i = 1; i < numOps; ++i) {
1402 if (loops[i - 1][d] != loops[i][d])
1403 return loopDepth;
1404 }
1405 if (surroundingLoops)
1406 surroundingLoops->push_back(loops[i - 1][d]);
1407 ++loopDepth;
1408 }
1409 return loopDepth;
1410}
1411
1412/// Computes in 'sliceUnion' the union of all slice bounds computed at
1413/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
1414/// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
1415/// union was computed correctly, an appropriate failure otherwise.
1416SliceComputationResult
1417mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
1418 ArrayRef<Operation *> opsB, unsigned loopDepth,
1419 unsigned numCommonLoops, bool isBackwardSlice,
1420 ComputationSliceState *sliceUnion) {
1421 // Compute the union of slice bounds between all pairs in 'opsA' and
1422 // 'opsB' in 'sliceUnionCst'.
1423 FlatAffineValueConstraints sliceUnionCst;
1424 assert(sliceUnionCst.getNumDimAndSymbolVars() == 0);
1425 std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
1426 for (auto *i : opsA) {
1427 MemRefAccess srcAccess(i);
1428 for (auto *j : opsB) {
1429 MemRefAccess dstAccess(j);
1430 if (srcAccess.memref != dstAccess.memref)
1431 continue;
1432 // Check if 'loopDepth' exceeds nesting depth of src/dst ops.
1433 if ((!isBackwardSlice && loopDepth > getNestingDepth(op: i)) ||
1434 (isBackwardSlice && loopDepth > getNestingDepth(op: j))) {
1435 LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
1436 return SliceComputationResult::GenericFailure;
1437 }
1438
1439 bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
1440 isa<AffineReadOpInterface>(dstAccess.opInst);
1441 FlatAffineValueConstraints dependenceConstraints;
1442 // Check dependence between 'srcAccess' and 'dstAccess'.
1443 DependenceResult result = checkMemrefAccessDependence(
1444 srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
1445 dependenceConstraints: &dependenceConstraints, /*dependenceComponents=*/nullptr,
1446 /*allowRAR=*/readReadAccesses);
1447 if (result.value == DependenceResult::Failure) {
1448 LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
1449 return SliceComputationResult::GenericFailure;
1450 }
1451 if (result.value == DependenceResult::NoDependence)
1452 continue;
1453 dependentOpPairs.emplace_back(args&: i, args&: j);
1454
1455 // Compute slice bounds for 'srcAccess' and 'dstAccess'.
1456 ComputationSliceState tmpSliceState;
1457 mlir::affine::getComputationSliceState(depSourceOp: i, depSinkOp: j, dependenceConstraints: &dependenceConstraints,
1458 loopDepth, isBackwardSlice,
1459 sliceState: &tmpSliceState);
1460
1461 if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
1462 // Initialize 'sliceUnionCst' with the bounds computed in previous step.
1463 if (failed(result: tmpSliceState.getAsConstraints(cst: &sliceUnionCst))) {
1464 LLVM_DEBUG(llvm::dbgs()
1465 << "Unable to compute slice bound constraints\n");
1466 return SliceComputationResult::GenericFailure;
1467 }
1468 assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
1469 continue;
1470 }
1471
1472 // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
1473 FlatAffineValueConstraints tmpSliceCst;
1474 if (failed(result: tmpSliceState.getAsConstraints(cst: &tmpSliceCst))) {
1475 LLVM_DEBUG(llvm::dbgs()
1476 << "Unable to compute slice bound constraints\n");
1477 return SliceComputationResult::GenericFailure;
1478 }
1479
1480 // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
1481 if (!sliceUnionCst.areVarsAlignedWithOther(other: tmpSliceCst)) {
1482
1483 // Pre-constraint var alignment: record loop IVs used in each constraint
1484 // system.
1485 SmallPtrSet<Value, 8> sliceUnionIVs;
1486 for (unsigned k = 0, l = sliceUnionCst.getNumDimVars(); k < l; ++k)
1487 sliceUnionIVs.insert(Ptr: sliceUnionCst.getValue(pos: k));
1488 SmallPtrSet<Value, 8> tmpSliceIVs;
1489 for (unsigned k = 0, l = tmpSliceCst.getNumDimVars(); k < l; ++k)
1490 tmpSliceIVs.insert(Ptr: tmpSliceCst.getValue(pos: k));
1491
1492 sliceUnionCst.mergeAndAlignVarsWithOther(/*offset=*/0, other: &tmpSliceCst);
1493
1494 // Post-constraint var alignment: add loop IV bounds missing after
1495 // var alignment to constraint systems. This can occur if one constraint
1496 // system uses an loop IV that is not used by the other. The call
1497 // to unionBoundingBox below expects constraints for each Loop IV, even
1498 // if they are the unsliced full loop bounds added here.
1499 if (failed(result: addMissingLoopIVBounds(ivs&: sliceUnionIVs, cst: &sliceUnionCst)))
1500 return SliceComputationResult::GenericFailure;
1501 if (failed(result: addMissingLoopIVBounds(ivs&: tmpSliceIVs, cst: &tmpSliceCst)))
1502 return SliceComputationResult::GenericFailure;
1503 }
1504 // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
1505 if (sliceUnionCst.getNumLocalVars() > 0 ||
1506 tmpSliceCst.getNumLocalVars() > 0 ||
1507 failed(result: sliceUnionCst.unionBoundingBox(other: tmpSliceCst))) {
1508 LLVM_DEBUG(llvm::dbgs()
1509 << "Unable to compute union bounding box of slice bounds\n");
1510 return SliceComputationResult::GenericFailure;
1511 }
1512 }
1513 }
1514
1515 // Empty union.
1516 if (sliceUnionCst.getNumDimAndSymbolVars() == 0)
1517 return SliceComputationResult::GenericFailure;
1518
1519 // Gather loops surrounding ops from loop nest where slice will be inserted.
1520 SmallVector<Operation *, 4> ops;
1521 for (auto &dep : dependentOpPairs) {
1522 ops.push_back(Elt: isBackwardSlice ? dep.second : dep.first);
1523 }
1524 SmallVector<AffineForOp, 4> surroundingLoops;
1525 unsigned innermostCommonLoopDepth =
1526 getInnermostCommonLoopDepth(ops, &surroundingLoops);
1527 if (loopDepth > innermostCommonLoopDepth) {
1528 LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
1529 return SliceComputationResult::GenericFailure;
1530 }
1531
1532 // Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
1533 unsigned numSliceLoopIVs = sliceUnionCst.getNumDimVars();
1534
1535 // Convert any dst loop IVs which are symbol variables to dim variables.
1536 sliceUnionCst.convertLoopIVSymbolsToDims();
1537 sliceUnion->clearBounds();
1538 sliceUnion->lbs.resize(N: numSliceLoopIVs, NV: AffineMap());
1539 sliceUnion->ubs.resize(N: numSliceLoopIVs, NV: AffineMap());
1540
1541 // Get slice bounds from slice union constraints 'sliceUnionCst'.
1542 sliceUnionCst.getSliceBounds(/*offset=*/0, num: numSliceLoopIVs,
1543 context: opsA[0]->getContext(), lbMaps: &sliceUnion->lbs,
1544 ubMaps: &sliceUnion->ubs);
1545
1546 // Add slice bound operands of union.
1547 SmallVector<Value, 4> sliceBoundOperands;
1548 sliceUnionCst.getValues(start: numSliceLoopIVs,
1549 end: sliceUnionCst.getNumDimAndSymbolVars(),
1550 values: &sliceBoundOperands);
1551
1552 // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
1553 sliceUnion->ivs.clear();
1554 sliceUnionCst.getValues(start: 0, end: numSliceLoopIVs, values: &sliceUnion->ivs);
1555
1556 // Set loop nest insertion point to block start at 'loopDepth'.
1557 sliceUnion->insertPoint =
1558 isBackwardSlice
1559 ? surroundingLoops[loopDepth - 1].getBody()->begin()
1560 : std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
1561
1562 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1563 // canonicalization.
1564 sliceUnion->lbOperands.resize(new_size: numSliceLoopIVs, x: sliceBoundOperands);
1565 sliceUnion->ubOperands.resize(new_size: numSliceLoopIVs, x: sliceBoundOperands);
1566
1567 // Check if the slice computed is valid. Return success only if it is verified
1568 // that the slice is valid, otherwise return appropriate failure status.
1569 std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
1570 if (!isSliceValid) {
1571 LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
1572 return SliceComputationResult::GenericFailure;
1573 }
1574 if (!*isSliceValid)
1575 return SliceComputationResult::IncorrectSliceFailure;
1576
1577 return SliceComputationResult::Success;
1578}
1579
1580// TODO: extend this to handle multiple result maps.
1581static std::optional<uint64_t> getConstDifference(AffineMap lbMap,
1582 AffineMap ubMap) {
1583 assert(lbMap.getNumResults() == 1 && "expected single result bound map");
1584 assert(ubMap.getNumResults() == 1 && "expected single result bound map");
1585 assert(lbMap.getNumDims() == ubMap.getNumDims());
1586 assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
1587 AffineExpr lbExpr(lbMap.getResult(idx: 0));
1588 AffineExpr ubExpr(ubMap.getResult(idx: 0));
1589 auto loopSpanExpr = simplifyAffineExpr(expr: ubExpr - lbExpr, numDims: lbMap.getNumDims(),
1590 numSymbols: lbMap.getNumSymbols());
1591 auto cExpr = dyn_cast<AffineConstantExpr>(Val&: loopSpanExpr);
1592 if (!cExpr)
1593 return std::nullopt;
1594 return cExpr.getValue();
1595}
1596
1597// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
1598// nest surrounding represented by slice loop bounds in 'slice'. Returns true
1599// on success, false otherwise (if a non-constant trip count was encountered).
1600// TODO: Make this work with non-unit step loops.
1601bool mlir::affine::buildSliceTripCountMap(
1602 const ComputationSliceState &slice,
1603 llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
1604 unsigned numSrcLoopIVs = slice.ivs.size();
1605 // Populate map from AffineForOp -> trip count
1606 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1607 AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
1608 auto *op = forOp.getOperation();
1609 AffineMap lbMap = slice.lbs[i];
1610 AffineMap ubMap = slice.ubs[i];
1611 // If lower or upper bound maps are null or provide no results, it implies
1612 // that source loop was not at all sliced, and the entire loop will be a
1613 // part of the slice.
1614 if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
1615 ubMap.getNumResults() == 0) {
1616 // The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
1617 if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
1618 (*tripCountMap)[op] =
1619 forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
1620 continue;
1621 }
1622 std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
1623 if (maybeConstTripCount.has_value()) {
1624 (*tripCountMap)[op] = *maybeConstTripCount;
1625 continue;
1626 }
1627 return false;
1628 }
1629 std::optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
1630 // Slice bounds are created with a constant ub - lb difference.
1631 if (!tripCount.has_value())
1632 return false;
1633 (*tripCountMap)[op] = *tripCount;
1634 }
1635 return true;
1636}
1637
1638// Return the number of iterations in the given slice.
1639uint64_t mlir::affine::getSliceIterationCount(
1640 const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
1641 uint64_t iterCount = 1;
1642 for (const auto &count : sliceTripCountMap) {
1643 iterCount *= count.second;
1644 }
1645 return iterCount;
1646}
1647
1648const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
1649// Computes slice bounds by projecting out any loop IVs from
1650// 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
1651// bounds in 'sliceState' which represent the one loop nest's IVs in terms of
1652// the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
1653void mlir::affine::getComputationSliceState(
1654 Operation *depSourceOp, Operation *depSinkOp,
1655 FlatAffineValueConstraints *dependenceConstraints, unsigned loopDepth,
1656 bool isBackwardSlice, ComputationSliceState *sliceState) {
1657 // Get loop nest surrounding src operation.
1658 SmallVector<AffineForOp, 4> srcLoopIVs;
1659 getAffineForIVs(*depSourceOp, &srcLoopIVs);
1660 unsigned numSrcLoopIVs = srcLoopIVs.size();
1661
1662 // Get loop nest surrounding dst operation.
1663 SmallVector<AffineForOp, 4> dstLoopIVs;
1664 getAffineForIVs(*depSinkOp, &dstLoopIVs);
1665 unsigned numDstLoopIVs = dstLoopIVs.size();
1666
1667 assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
1668 (isBackwardSlice && loopDepth <= numDstLoopIVs));
1669
1670 // Project out dimensions other than those up to 'loopDepth'.
1671 unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
1672 unsigned num =
1673 isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
1674 dependenceConstraints->projectOut(pos, num);
1675
1676 // Add slice loop IV values to 'sliceState'.
1677 unsigned offset = isBackwardSlice ? 0 : loopDepth;
1678 unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
1679 dependenceConstraints->getValues(start: offset, end: offset + numSliceLoopIVs,
1680 values: &sliceState->ivs);
1681
1682 // Set up lower/upper bound affine maps for the slice.
1683 sliceState->lbs.resize(N: numSliceLoopIVs, NV: AffineMap());
1684 sliceState->ubs.resize(N: numSliceLoopIVs, NV: AffineMap());
1685
1686 // Get bounds for slice IVs in terms of other IVs, symbols, and constants.
1687 dependenceConstraints->getSliceBounds(offset, num: numSliceLoopIVs,
1688 context: depSourceOp->getContext(),
1689 lbMaps: &sliceState->lbs, ubMaps: &sliceState->ubs);
1690
1691 // Set up bound operands for the slice's lower and upper bounds.
1692 SmallVector<Value, 4> sliceBoundOperands;
1693 unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolVars();
1694 for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
1695 if (i < offset || i >= offset + numSliceLoopIVs) {
1696 sliceBoundOperands.push_back(Elt: dependenceConstraints->getValue(pos: i));
1697 }
1698 }
1699
1700 // Give each bound its own copy of 'sliceBoundOperands' for subsequent
1701 // canonicalization.
1702 sliceState->lbOperands.resize(new_size: numSliceLoopIVs, x: sliceBoundOperands);
1703 sliceState->ubOperands.resize(new_size: numSliceLoopIVs, x: sliceBoundOperands);
1704
1705 // Set destination loop nest insertion point to block start at 'dstLoopDepth'.
1706 sliceState->insertPoint =
1707 isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
1708 : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
1709
1710 llvm::SmallDenseSet<Value, 8> sequentialLoops;
1711 if (isa<AffineReadOpInterface>(depSourceOp) &&
1712 isa<AffineReadOpInterface>(depSinkOp)) {
1713 // For read-read access pairs, clear any slice bounds on sequential loops.
1714 // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
1715 getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
1716 &sequentialLoops);
1717 }
1718 auto getSliceLoop = [&](unsigned i) {
1719 return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
1720 };
1721 auto isInnermostInsertion = [&]() {
1722 return (isBackwardSlice ? loopDepth >= srcLoopIVs.size()
1723 : loopDepth >= dstLoopIVs.size());
1724 };
1725 llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
1726 auto srcIsUnitSlice = [&]() {
1727 return (buildSliceTripCountMap(slice: *sliceState, tripCountMap: &sliceTripCountMap) &&
1728 (getSliceIterationCount(sliceTripCountMap) == 1));
1729 };
1730 // Clear all sliced loop bounds beginning at the first sequential loop, or
1731 // first loop with a slice fusion barrier attribute..
1732
1733 for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
1734 Value iv = getSliceLoop(i).getInductionVar();
1735 if (sequentialLoops.count(V: iv) == 0 &&
1736 getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr)
1737 continue;
1738 // Skip reset of bounds of reduction loop inserted in the destination loop
1739 // that meets the following conditions:
1740 // 1. Slice is single trip count.
1741 // 2. Loop bounds of the source and destination match.
1742 // 3. Is being inserted at the innermost insertion point.
1743 std::optional<bool> isMaximal = sliceState->isMaximal();
1744 if (isLoopParallelAndContainsReduction(getSliceLoop(i)) &&
1745 isInnermostInsertion() && srcIsUnitSlice() && isMaximal && *isMaximal)
1746 continue;
1747 for (unsigned j = i; j < numSliceLoopIVs; ++j) {
1748 sliceState->lbs[j] = AffineMap();
1749 sliceState->ubs[j] = AffineMap();
1750 }
1751 break;
1752 }
1753}
1754
1755/// Creates a computation slice of the loop nest surrounding 'srcOpInst',
1756/// updates the slice loop bounds with any non-null bound maps specified in
1757/// 'sliceState', and inserts this slice into the loop nest surrounding
1758/// 'dstOpInst' at loop depth 'dstLoopDepth'.
1759// TODO: extend the slicing utility to compute slices that
1760// aren't necessarily a one-to-one relation b/w the source and destination. The
1761// relation between the source and destination could be many-to-many in general.
1762// TODO: the slice computation is incorrect in the cases
1763// where the dependence from the source to the destination does not cover the
1764// entire destination index set. Subtract out the dependent destination
1765// iterations from destination index set and check for emptiness --- this is one
1766// solution.
1767AffineForOp mlir::affine::insertBackwardComputationSlice(
1768 Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth,
1769 ComputationSliceState *sliceState) {
1770 // Get loop nest surrounding src operation.
1771 SmallVector<AffineForOp, 4> srcLoopIVs;
1772 getAffineForIVs(*srcOpInst, &srcLoopIVs);
1773 unsigned numSrcLoopIVs = srcLoopIVs.size();
1774
1775 // Get loop nest surrounding dst operation.
1776 SmallVector<AffineForOp, 4> dstLoopIVs;
1777 getAffineForIVs(*dstOpInst, &dstLoopIVs);
1778 unsigned dstLoopIVsSize = dstLoopIVs.size();
1779 if (dstLoopDepth > dstLoopIVsSize) {
1780 dstOpInst->emitError(message: "invalid destination loop depth");
1781 return AffineForOp();
1782 }
1783
1784 // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
1785 SmallVector<unsigned, 4> positions;
1786 // TODO: This code is incorrect since srcLoopIVs can be 0-d.
1787 findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
1788
1789 // Clone src loop nest and insert it a the beginning of the operation block
1790 // of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
1791 auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
1792 OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
1793 auto sliceLoopNest =
1794 cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
1795
1796 Operation *sliceInst =
1797 getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
1798 // Get loop nest surrounding 'sliceInst'.
1799 SmallVector<AffineForOp, 4> sliceSurroundingLoops;
1800 getAffineForIVs(*sliceInst, &sliceSurroundingLoops);
1801
1802 // Sanity check.
1803 unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
1804 (void)sliceSurroundingLoopsSize;
1805 assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
1806 unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
1807 (void)sliceLoopLimit;
1808 assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
1809
1810 // Update loop bounds for loops in 'sliceLoopNest'.
1811 for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
1812 auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
1813 if (AffineMap lbMap = sliceState->lbs[i])
1814 forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
1815 if (AffineMap ubMap = sliceState->ubs[i])
1816 forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
1817 }
1818 return sliceLoopNest;
1819}
1820
1821// Constructs MemRefAccess populating it with the memref, its indices and
1822// opinst from 'loadOrStoreOpInst'.
1823MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
1824 if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) {
1825 memref = loadOp.getMemRef();
1826 opInst = loadOrStoreOpInst;
1827 llvm::append_range(indices, loadOp.getMapOperands());
1828 } else {
1829 assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
1830 "Affine read/write op expected");
1831 auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst);
1832 opInst = loadOrStoreOpInst;
1833 memref = storeOp.getMemRef();
1834 llvm::append_range(indices, storeOp.getMapOperands());
1835 }
1836}
1837
1838unsigned MemRefAccess::getRank() const {
1839 return cast<MemRefType>(memref.getType()).getRank();
1840}
1841
1842bool MemRefAccess::isStore() const {
1843 return isa<AffineWriteOpInterface>(opInst);
1844}
1845
1846/// Returns the nesting depth of this statement, i.e., the number of loops
1847/// surrounding this statement.
1848unsigned mlir::affine::getNestingDepth(Operation *op) {
1849 Operation *currOp = op;
1850 unsigned depth = 0;
1851 while ((currOp = currOp->getParentOp())) {
1852 if (isa<AffineForOp>(Val: currOp))
1853 depth++;
1854 }
1855 return depth;
1856}
1857
1858/// Equal if both affine accesses are provably equivalent (at compile
1859/// time) when considering the memref, the affine maps and their respective
1860/// operands. The equality of access functions + operands is checked by
1861/// subtracting fully composed value maps, and then simplifying the difference
1862/// using the expression flattener.
1863/// TODO: this does not account for aliasing of memrefs.
1864bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
1865 if (memref != rhs.memref)
1866 return false;
1867
1868 AffineValueMap diff, thisMap, rhsMap;
1869 getAccessMap(accessMap: &thisMap);
1870 rhs.getAccessMap(accessMap: &rhsMap);
1871 AffineValueMap::difference(a: thisMap, b: rhsMap, res: &diff);
1872 return llvm::all_of(Range: diff.getAffineMap().getResults(),
1873 P: [](AffineExpr e) { return e == 0; });
1874}
1875
1876void mlir::affine::getAffineIVs(Operation &op, SmallVectorImpl<Value> &ivs) {
1877 auto *currOp = op.getParentOp();
1878 AffineForOp currAffineForOp;
1879 // Traverse up the hierarchy collecting all 'affine.for' and affine.parallel
1880 // operation while skipping over 'affine.if' operations.
1881 while (currOp) {
1882 if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
1883 ivs.push_back(Elt: currAffineForOp.getInductionVar());
1884 else if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
1885 llvm::append_range(ivs, parOp.getIVs());
1886 currOp = currOp->getParentOp();
1887 }
1888 std::reverse(first: ivs.begin(), last: ivs.end());
1889}
1890
1891/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
1892/// where each lists loops from outer-most to inner-most in loop nest.
1893unsigned mlir::affine::getNumCommonSurroundingLoops(Operation &a,
1894 Operation &b) {
1895 SmallVector<Value, 4> loopsA, loopsB;
1896 getAffineIVs(op&: a, ivs&: loopsA);
1897 getAffineIVs(op&: b, ivs&: loopsB);
1898
1899 unsigned minNumLoops = std::min(a: loopsA.size(), b: loopsB.size());
1900 unsigned numCommonLoops = 0;
1901 for (unsigned i = 0; i < minNumLoops; ++i) {
1902 if (loopsA[i] != loopsB[i])
1903 break;
1904 ++numCommonLoops;
1905 }
1906 return numCommonLoops;
1907}
1908
1909static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
1910 Block::iterator start,
1911 Block::iterator end,
1912 int memorySpace) {
1913 SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
1914
1915 // Walk this 'affine.for' operation to gather all memory regions.
1916 auto result = block.walk(begin: start, end, callback: [&](Operation *opInst) -> WalkResult {
1917 if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
1918 // Neither load nor a store op.
1919 return WalkResult::advance();
1920 }
1921
1922 // Compute the memref region symbolic in any IVs enclosing this block.
1923 auto region = std::make_unique<MemRefRegion>(args: opInst->getLoc());
1924 if (failed(
1925 result: region->compute(op: opInst,
1926 /*loopDepth=*/getNestingDepth(op: &*block.begin())))) {
1927 return opInst->emitError(message: "error obtaining memory region\n");
1928 }
1929
1930 auto it = regions.find(Val: region->memref);
1931 if (it == regions.end()) {
1932 regions[region->memref] = std::move(region);
1933 } else if (failed(result: it->second->unionBoundingBox(other: *region))) {
1934 return opInst->emitWarning(
1935 message: "getMemoryFootprintBytes: unable to perform a union on a memory "
1936 "region");
1937 }
1938 return WalkResult::advance();
1939 });
1940 if (result.wasInterrupted())
1941 return std::nullopt;
1942
1943 int64_t totalSizeInBytes = 0;
1944 for (const auto &region : regions) {
1945 std::optional<int64_t> size = region.second->getRegionSize();
1946 if (!size.has_value())
1947 return std::nullopt;
1948 totalSizeInBytes += *size;
1949 }
1950 return totalSizeInBytes;
1951}
1952
1953std::optional<int64_t> mlir::affine::getMemoryFootprintBytes(AffineForOp forOp,
1954 int memorySpace) {
1955 auto *forInst = forOp.getOperation();
1956 return ::getMemoryFootprintBytes(
1957 block&: *forInst->getBlock(), start: Block::iterator(forInst),
1958 end: std::next(x: Block::iterator(forInst)), memorySpace);
1959}
1960
1961/// Returns whether a loop is parallel and contains a reduction loop.
1962bool mlir::affine::isLoopParallelAndContainsReduction(AffineForOp forOp) {
1963 SmallVector<LoopReduction> reductions;
1964 if (!isLoopParallel(forOp, &reductions))
1965 return false;
1966 return !reductions.empty();
1967}
1968
1969/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
1970/// at 'forOp'.
1971void mlir::affine::getSequentialLoops(
1972 AffineForOp forOp, llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
1973 forOp->walk([&](Operation *op) {
1974 if (auto innerFor = dyn_cast<AffineForOp>(op))
1975 if (!isLoopParallel(innerFor))
1976 sequentialLoops->insert(innerFor.getInductionVar());
1977 });
1978}
1979
1980IntegerSet mlir::affine::simplifyIntegerSet(IntegerSet set) {
1981 FlatAffineValueConstraints fac(set);
1982 if (fac.isEmpty())
1983 return IntegerSet::getEmptySet(numDims: set.getNumDims(), numSymbols: set.getNumSymbols(),
1984 context: set.getContext());
1985 fac.removeTrivialRedundancy();
1986
1987 auto simplifiedSet = fac.getAsIntegerSet(context: set.getContext());
1988 assert(simplifiedSet && "guaranteed to succeed while roundtripping");
1989 return simplifiedSet;
1990}
1991
1992static void unpackOptionalValues(ArrayRef<std::optional<Value>> source,
1993 SmallVector<Value> &target) {
1994 target =
1995 llvm::to_vector<4>(Range: llvm::map_range(C&: source, F: [](std::optional<Value> val) {
1996 return val.has_value() ? *val : Value();
1997 }));
1998}
1999
2000/// Bound an identifier `pos` in a given FlatAffineValueConstraints with
2001/// constraints drawn from an affine map. Before adding the constraint, the
2002/// dimensions/symbols of the affine map are aligned with `constraints`.
2003/// `operands` are the SSA Value operands used with the affine map.
2004/// Note: This function adds a new symbol column to the `constraints` for each
2005/// dimension/symbol that exists in the affine map but not in `constraints`.
2006static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints,
2007 BoundType type, unsigned pos,
2008 AffineMap map, ValueRange operands) {
2009 SmallVector<Value> dims, syms, newSyms;
2010 unpackOptionalValues(source: constraints.getMaybeValues(kind: VarKind::SetDim), target&: dims);
2011 unpackOptionalValues(source: constraints.getMaybeValues(kind: VarKind::Symbol), target&: syms);
2012
2013 AffineMap alignedMap =
2014 alignAffineMapWithValues(map, operands, dims, syms, newSyms: &newSyms);
2015 for (unsigned i = syms.size(); i < newSyms.size(); ++i)
2016 constraints.appendSymbolVar(vals: newSyms[i]);
2017 return constraints.addBound(type, pos, boundMap: alignedMap);
2018}
2019
2020/// Add `val` to each result of `map`.
2021static AffineMap addConstToResults(AffineMap map, int64_t val) {
2022 SmallVector<AffineExpr> newResults;
2023 for (AffineExpr r : map.getResults())
2024 newResults.push_back(Elt: r + val);
2025 return AffineMap::get(dimCount: map.getNumDims(), symbolCount: map.getNumSymbols(), results: newResults,
2026 context: map.getContext());
2027}
2028
2029// Attempt to simplify the given min/max operation by proving that its value is
2030// bounded by the same lower and upper bound.
2031//
2032// Bounds are computed by FlatAffineValueConstraints. Invariants required for
2033// finding/proving bounds should be supplied via `constraints`.
2034//
2035// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
2036// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
2037// case of `!isMin`) and bind it to `opBound`. SSA values that are used in
2038// `op` but are not part of `constraints`, are added as extra symbols.
2039// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
2040// * If `isMin`: r_i >= opBound
2041// * If `isMax`: r_i <= opBound
2042// If this is the case, ub(op) == lb(op).
2043// 4. Replace `op` with `opBound`.
2044//
2045// In summary, the following constraints are added throughout this function.
2046// Note: `invar` are dimensions added by the caller to express the invariants.
2047// (Showing only the case where `isMin`.)
2048//
2049// invar | op | opBound | r_i | extra syms... | const | eq/ineq
2050// ------+-------+---------+-----+---------------+-------+-------------------
2051// (various eq./ineq. constraining `invar`, added by the caller)
2052// ... | 0 | 0 | 0 | 0 | ... | ...
2053// ------+-------+---------+-----+---------------+-------+-------------------
2054// (various ineq. constraining `op` in terms of `op` operands (`invar` and
2055// extra `op` operands "extra syms" that are not in `invar`)).
2056// ... | -1 | 0 | 0 | ... | ... | >= 0
2057// ------+-------+---------+-----+---------------+-------+-------------------
2058// (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
2059// ... | 0 | -1 | 0 | ... | ... | = 0
2060// ------+-------+---------+-----+---------------+-------+-------------------
2061// (for each `op` map result r_i: set r_i to corresponding map result,
2062// prove that r_i >= minOpUb via contradiction)
2063// ... | 0 | 0 | -1 | ... | ... | = 0
2064// 0 | 0 | 1 | -1 | 0 | -1 | >= 0
2065//
2066FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
2067 Operation *op, FlatAffineValueConstraints constraints) {
2068 bool isMin = isa<AffineMinOp>(Val: op);
2069 assert((isMin || isa<AffineMaxOp>(op)) && "expect AffineMin/MaxOp");
2070 MLIRContext *ctx = op->getContext();
2071 Builder builder(ctx);
2072 AffineMap map =
2073 isMin ? cast<AffineMinOp>(op).getMap() : cast<AffineMaxOp>(op).getMap();
2074 ValueRange operands = op->getOperands();
2075 unsigned numResults = map.getNumResults();
2076
2077 // Add a few extra dimensions.
2078 unsigned dimOp = constraints.appendDimVar(); // `op`
2079 unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound
2080 unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults);
2081
2082 // Add an inequality for each result expr_i of map:
2083 // isMin: op <= expr_i, !isMin: op >= expr_i
2084 auto boundType = isMin ? BoundType::UB : BoundType::LB;
2085 // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.)
2086 AffineMap mapLbUb = isMin ? addConstToResults(map, val: 1) : map;
2087 if (failed(
2088 result: alignAndAddBound(constraints, type: boundType, pos: dimOp, map: mapLbUb, operands)))
2089 return failure();
2090
2091 // Try to compute a lower/upper bound for op, expressed in terms of the other
2092 // `dims` and extra symbols.
2093 SmallVector<AffineMap> opLb(1), opUb(1);
2094 constraints.getSliceBounds(offset: dimOp, num: 1, context: ctx, lbMaps: &opLb, ubMaps: &opUb);
2095 AffineMap sliceBound = isMin ? opUb[0] : opLb[0];
2096 // TODO: `getSliceBounds` may return multiple bounds at the moment. This is
2097 // a TODO of `getSliceBounds` and not handled here.
2098 if (!sliceBound || sliceBound.getNumResults() != 1)
2099 return failure(); // No or multiple bounds found.
2100 // Recover the inclusive UB in the case of an `affine.min`.
2101 AffineMap boundMap = isMin ? addConstToResults(map: sliceBound, val: -1) : sliceBound;
2102
2103 // Add an equality: Set dimOpBound to computed bound.
2104 // Add back dimension for op. (Was removed by `getSliceBounds`.)
2105 AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp);
2106 if (failed(result: constraints.addBound(type: BoundType::EQ, pos: dimOpBound, boundMap: alignedBoundMap)))
2107 return failure();
2108
2109 // If the constraint system is empty, there is an inconsistency. (E.g., this
2110 // can happen if loop lb > ub.)
2111 if (constraints.isEmpty())
2112 return failure();
2113
2114 // In the case of `isMin` (`!isMin` is inversed):
2115 // Prove that each result of `map` has a lower bound that is equal to (or
2116 // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op`
2117 // can be replaced with the bound. I.e., prove that for each result
2118 // expr_i (represented by dimension r_i):
2119 //
2120 // r_i >= opBound
2121 //
2122 // To prove this inequality, add its negation to the constraint set and prove
2123 // that the constraint set is empty.
2124 for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) {
2125 FlatAffineValueConstraints newConstr(constraints);
2126
2127 // Add an equality: r_i = expr_i
2128 // Note: These equalities could have been added earlier and used to express
2129 // minOp <= expr_i. However, then we run the risk that `getSliceBounds`
2130 // computes minOpUb in terms of r_i dims, which is not desired.
2131 if (failed(result: alignAndAddBound(constraints&: newConstr, type: BoundType::EQ, pos: i,
2132 map: map.getSubMap(resultPos: {i - resultDimStart}), operands)))
2133 return failure();
2134
2135 // If `isMin`: Add inequality: r_i < opBound
2136 // equiv.: opBound - r_i - 1 >= 0
2137 // If `!isMin`: Add inequality: r_i > opBound
2138 // equiv.: -opBound + r_i - 1 >= 0
2139 SmallVector<int64_t> ineq(newConstr.getNumCols(), 0);
2140 ineq[dimOpBound] = isMin ? 1 : -1;
2141 ineq[i] = isMin ? -1 : 1;
2142 ineq[newConstr.getNumCols() - 1] = -1;
2143 newConstr.addInequality(inEq: ineq);
2144 if (!newConstr.isEmpty())
2145 return failure();
2146 }
2147
2148 // Lower and upper bound of `op` are equal. Replace `minOp` with its bound.
2149 AffineMap newMap = alignedBoundMap;
2150 SmallVector<Value> newOperands;
2151 unpackOptionalValues(source: constraints.getMaybeValues(), target&: newOperands);
2152 // If dims/symbols have known constant values, use those in order to simplify
2153 // the affine map further.
2154 for (int64_t i = 0, e = constraints.getNumDimAndSymbolVars(); i < e; ++i) {
2155 // Skip unused operands and operands that are already constants.
2156 if (!newOperands[i] || getConstantIntValue(ofr: newOperands[i]))
2157 continue;
2158 if (auto bound = constraints.getConstantBound64(type: BoundType::EQ, pos: i)) {
2159 AffineExpr expr =
2160 i < newMap.getNumDims()
2161 ? builder.getAffineDimExpr(position: i)
2162 : builder.getAffineSymbolExpr(position: i - newMap.getNumDims());
2163 newMap = newMap.replace(expr, replacement: builder.getAffineConstantExpr(constant: *bound),
2164 numResultDims: newMap.getNumDims(), numResultSyms: newMap.getNumSymbols());
2165 }
2166 }
2167 affine::canonicalizeMapAndOperands(map: &newMap, operands: &newOperands);
2168 return AffineValueMap(newMap, newOperands);
2169}
2170

source code of mlir/lib/Dialect/Affine/Analysis/Utils.cpp