1//===- PresburgerSpace.cpp - MLIR PresburgerSpace Class -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/Presburger/PresburgerSpace.h"
10#include "llvm/Support/ErrorHandling.h"
11#include "llvm/Support/raw_ostream.h"
12#include <algorithm>
13#include <cassert>
14
15using namespace mlir;
16using namespace presburger;
17
18bool Identifier::isEqual(const Identifier &other) const {
19 if (value == nullptr || other.value == nullptr)
20 return false;
21 assert(value != other.value ||
22 (value == other.value && idType == other.idType &&
23 "Values of Identifiers are equal but their types do not match."));
24 return value == other.value;
25}
26
27void Identifier::print(llvm::raw_ostream &os) const {
28 os << "Id<" << value << ">";
29}
30
31void Identifier::dump() const {
32 print(os&: llvm::errs());
33 llvm::errs() << "\n";
34}
35
36PresburgerSpace PresburgerSpace::getDomainSpace() const {
37 PresburgerSpace newSpace = *this;
38 newSpace.removeVarRange(kind: VarKind::Range, varStart: 0, varLimit: getNumRangeVars());
39 newSpace.convertVarKind(srcKind: VarKind::Domain, srcPos: 0, num: getNumDomainVars(),
40 dstKind: VarKind::SetDim, dstPos: 0);
41 return newSpace;
42}
43
44PresburgerSpace PresburgerSpace::getRangeSpace() const {
45 PresburgerSpace newSpace = *this;
46 newSpace.removeVarRange(kind: VarKind::Domain, varStart: 0, varLimit: getNumDomainVars());
47 return newSpace;
48}
49
50PresburgerSpace PresburgerSpace::getSpaceWithoutLocals() const {
51 PresburgerSpace space = *this;
52 space.removeVarRange(kind: VarKind::Local, varStart: 0, varLimit: getNumLocalVars());
53 return space;
54}
55
56unsigned PresburgerSpace::getNumVarKind(VarKind kind) const {
57 if (kind == VarKind::Domain)
58 return getNumDomainVars();
59 if (kind == VarKind::Range)
60 return getNumRangeVars();
61 if (kind == VarKind::Symbol)
62 return getNumSymbolVars();
63 if (kind == VarKind::Local)
64 return getNumLocalVars();
65 llvm_unreachable("VarKind does not exist!");
66}
67
68unsigned PresburgerSpace::getVarKindOffset(VarKind kind) const {
69 if (kind == VarKind::Domain)
70 return 0;
71 if (kind == VarKind::Range)
72 return getNumDomainVars();
73 if (kind == VarKind::Symbol)
74 return getNumDimVars();
75 if (kind == VarKind::Local)
76 return getNumDimAndSymbolVars();
77 llvm_unreachable("VarKind does not exist!");
78}
79
80unsigned PresburgerSpace::getVarKindEnd(VarKind kind) const {
81 return getVarKindOffset(kind) + getNumVarKind(kind);
82}
83
84unsigned PresburgerSpace::getVarKindOverlap(VarKind kind, unsigned varStart,
85 unsigned varLimit) const {
86 unsigned varRangeStart = getVarKindOffset(kind);
87 unsigned varRangeEnd = getVarKindEnd(kind);
88
89 // Compute number of elements in intersection of the ranges [varStart,
90 // varLimit) and [varRangeStart, varRangeEnd).
91 unsigned overlapStart = std::max(a: varStart, b: varRangeStart);
92 unsigned overlapEnd = std::min(a: varLimit, b: varRangeEnd);
93
94 if (overlapStart > overlapEnd)
95 return 0;
96 return overlapEnd - overlapStart;
97}
98
99VarKind PresburgerSpace::getVarKindAt(unsigned pos) const {
100 assert(pos < getNumVars() && "`pos` should represent a valid var position");
101 if (pos < getVarKindEnd(kind: VarKind::Domain))
102 return VarKind::Domain;
103 if (pos < getVarKindEnd(kind: VarKind::Range))
104 return VarKind::Range;
105 if (pos < getVarKindEnd(kind: VarKind::Symbol))
106 return VarKind::Symbol;
107 if (pos < getVarKindEnd(kind: VarKind::Local))
108 return VarKind::Local;
109 llvm_unreachable("`pos` should represent a valid var position");
110}
111
112unsigned PresburgerSpace::insertVar(VarKind kind, unsigned pos, unsigned num) {
113 assert(pos <= getNumVarKind(kind));
114
115 unsigned absolutePos = getVarKindOffset(kind) + pos;
116
117 if (kind == VarKind::Domain)
118 numDomain += num;
119 else if (kind == VarKind::Range)
120 numRange += num;
121 else if (kind == VarKind::Symbol)
122 numSymbols += num;
123 else
124 numLocals += num;
125
126 // Insert NULL identifiers if `usingIds` and variables inserted are
127 // not locals.
128 if (usingIds && kind != VarKind::Local)
129 identifiers.insert(I: identifiers.begin() + absolutePos, NumToInsert: num, Elt: Identifier());
130
131 return absolutePos;
132}
133
134void PresburgerSpace::removeVarRange(VarKind kind, unsigned varStart,
135 unsigned varLimit) {
136 assert(varLimit <= getNumVarKind(kind) && "invalid var limit");
137
138 if (varStart >= varLimit)
139 return;
140
141 unsigned numVarsEliminated = varLimit - varStart;
142 if (kind == VarKind::Domain)
143 numDomain -= numVarsEliminated;
144 else if (kind == VarKind::Range)
145 numRange -= numVarsEliminated;
146 else if (kind == VarKind::Symbol)
147 numSymbols -= numVarsEliminated;
148 else
149 numLocals -= numVarsEliminated;
150
151 // Remove identifiers if `usingIds` and variables removed are not
152 // locals.
153 if (usingIds && kind != VarKind::Local)
154 identifiers.erase(CS: identifiers.begin() + getVarKindOffset(kind) + varStart,
155 CE: identifiers.begin() + getVarKindOffset(kind) + varLimit);
156}
157
158void PresburgerSpace::convertVarKind(VarKind srcKind, unsigned srcPos,
159 unsigned num, VarKind dstKind,
160 unsigned dstPos) {
161 assert(srcKind != dstKind && "cannot convert variables to the same kind");
162 assert(srcPos + num <= getNumVarKind(srcKind) &&
163 "invalid range for source variables");
164 assert(dstPos <= getNumVarKind(dstKind) &&
165 "invalid position for destination variables");
166
167 // Move identifiers if `usingIds` and variables moved are not locals.
168 unsigned srcOffset = getVarKindOffset(kind: srcKind) + srcPos;
169 unsigned dstOffset = getVarKindOffset(kind: dstKind) + dstPos;
170 if (isUsingIds() && srcKind != VarKind::Local && dstKind != VarKind::Local) {
171 identifiers.insert(I: identifiers.begin() + dstOffset, NumToInsert: num, Elt: Identifier());
172 // Update srcOffset if insertion of new elements invalidates it.
173 if (dstOffset < srcOffset)
174 srcOffset += num;
175 std::move(first: identifiers.begin() + srcOffset,
176 last: identifiers.begin() + srcOffset + num,
177 result: identifiers.begin() + dstOffset);
178 identifiers.erase(CS: identifiers.begin() + srcOffset,
179 CE: identifiers.begin() + srcOffset + num);
180 } else if (isUsingIds() && srcKind != VarKind::Local) {
181 identifiers.erase(CS: identifiers.begin() + srcOffset,
182 CE: identifiers.begin() + srcOffset + num);
183 } else if (isUsingIds() && dstKind != VarKind::Local) {
184 identifiers.insert(I: identifiers.begin() + dstOffset, NumToInsert: num, Elt: Identifier());
185 }
186
187 auto addVars = [&](VarKind kind, int num) {
188 switch (kind) {
189 case VarKind::Domain:
190 numDomain += num;
191 break;
192 case VarKind::Range:
193 numRange += num;
194 break;
195 case VarKind::Symbol:
196 numSymbols += num;
197 break;
198 case VarKind::Local:
199 numLocals += num;
200 break;
201 }
202 };
203
204 addVars(srcKind, -(signed)num);
205 addVars(dstKind, num);
206}
207
208void PresburgerSpace::swapVar(VarKind kindA, VarKind kindB, unsigned posA,
209 unsigned posB) {
210 if (!isUsingIds())
211 return;
212
213 if (kindA == VarKind::Local && kindB == VarKind::Local)
214 return;
215
216 if (kindA == VarKind::Local) {
217 setId(kind: kindB, pos: posB, id: Identifier());
218 return;
219 }
220
221 if (kindB == VarKind::Local) {
222 setId(kind: kindA, pos: posA, id: Identifier());
223 return;
224 }
225
226 std::swap(a&: identifiers[getVarKindOffset(kind: kindA) + posA],
227 b&: identifiers[getVarKindOffset(kind: kindB) + posB]);
228}
229
230bool PresburgerSpace::isCompatible(const PresburgerSpace &other) const {
231 return getNumDomainVars() == other.getNumDomainVars() &&
232 getNumRangeVars() == other.getNumRangeVars() &&
233 getNumSymbolVars() == other.getNumSymbolVars();
234}
235
236bool PresburgerSpace::isEqual(const PresburgerSpace &other) const {
237 return isCompatible(other) && getNumLocalVars() == other.getNumLocalVars();
238}
239
240/// Checks if the number of ids of the given kind in the two spaces are
241/// equal and if the ids are equal. Assumes that both spaces are using
242/// ids.
243static bool areIdsEqual(const PresburgerSpace &spaceA,
244 const PresburgerSpace &spaceB, VarKind kind) {
245 assert(spaceA.isUsingIds() && spaceB.isUsingIds() &&
246 "Both spaces should be using ids");
247 if (spaceA.getNumVarKind(kind) != spaceB.getNumVarKind(kind))
248 return false;
249 if (kind == VarKind::Local)
250 return true; // No ids.
251 return spaceA.getIds(kind) == spaceB.getIds(kind);
252}
253
254bool PresburgerSpace::isAligned(const PresburgerSpace &other) const {
255 // If only one of the spaces is using identifiers, then they are
256 // not aligned.
257 if (isUsingIds() != other.isUsingIds())
258 return false;
259 // If both spaces are using identifiers, then they are aligned if
260 // their identifiers are equal. Identifiers being equal implies
261 // that the number of variables of each kind is same, which implies
262 // compatiblity, so we do not check for that.
263 if (isUsingIds())
264 return areIdsEqual(spaceA: *this, spaceB: other, kind: VarKind::Domain) &&
265 areIdsEqual(spaceA: *this, spaceB: other, kind: VarKind::Range) &&
266 areIdsEqual(spaceA: *this, spaceB: other, kind: VarKind::Symbol);
267 // If neither space is using identifiers, then they are aligned if
268 // they are compatible.
269 return isCompatible(other);
270}
271
272bool PresburgerSpace::isAligned(const PresburgerSpace &other,
273 VarKind kind) const {
274 // If only one of the spaces is using identifiers, then they are
275 // not aligned.
276 if (isUsingIds() != other.isUsingIds())
277 return false;
278 // If both spaces are using identifiers, then they are aligned if
279 // their identifiers are equal. Identifiers being equal implies
280 // that the number of variables of each kind is same, which implies
281 // compatiblity, so we do not check for that
282 if (isUsingIds())
283 return areIdsEqual(spaceA: *this, spaceB: other, kind);
284 // If neither space is using identifiers, then they are aligned if
285 // the number of variable kind is equal.
286 return getNumVarKind(kind) == other.getNumVarKind(kind);
287}
288
289void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) {
290 assert(newSymbolCount <= getNumDimAndSymbolVars() &&
291 "invalid separation position");
292 numRange = numRange + numSymbols - newSymbolCount;
293 numSymbols = newSymbolCount;
294 // We do not need to change `identifiers` since the ordering of
295 // `identifiers` remains same.
296}
297
298void PresburgerSpace::mergeAndAlignSymbols(PresburgerSpace &other) {
299 assert(usingIds && other.usingIds &&
300 "Both spaces need to have identifers to merge & align");
301
302 // First merge & align identifiers into `other` from `this`.
303 unsigned i = 0;
304 for (const Identifier identifier : getIds(kind: VarKind::Symbol)) {
305 // If the identifier exists in `other`, then align it; otherwise insert it
306 // assuming it is a new identifier. Search in `other` starting at position
307 // `i` since the left of `i` is aligned.
308 auto *findBegin = other.getIds(kind: VarKind::Symbol).begin() + i;
309 auto *findEnd = other.getIds(kind: VarKind::Symbol).end();
310 auto *itr = std::find(first: findBegin, last: findEnd, val: identifier);
311 if (itr != findEnd) {
312 std::swap(a&: findBegin, b&: itr);
313 } else {
314 other.insertVar(kind: VarKind::Symbol, pos: i);
315 other.setId(kind: VarKind::Symbol, pos: i, id: identifier);
316 }
317 ++i;
318 }
319
320 // Finally add identifiers that are in `other`, but not in `this` to `this`.
321 for (unsigned e = other.getNumVarKind(kind: VarKind::Symbol); i < e; ++i) {
322 insertVar(kind: VarKind::Symbol, pos: i);
323 setId(kind: VarKind::Symbol, pos: i, id: other.getId(kind: VarKind::Symbol, pos: i));
324 }
325}
326
327void PresburgerSpace::print(llvm::raw_ostream &os) const {
328 os << "Domain: " << getNumDomainVars() << ", "
329 << "Range: " << getNumRangeVars() << ", "
330 << "Symbols: " << getNumSymbolVars() << ", "
331 << "Locals: " << getNumLocalVars() << "\n";
332
333 if (isUsingIds()) {
334 auto printIds = [&](VarKind kind) {
335 os << " ";
336 for (Identifier id : getIds(kind)) {
337 if (id.hasValue())
338 id.print(os);
339 else
340 os << "None";
341 os << " ";
342 }
343 };
344
345 os << "(";
346 printIds(VarKind::Domain);
347 os << ") -> (";
348 printIds(VarKind::Range);
349 os << ") : [";
350 printIds(VarKind::Symbol);
351 os << "]";
352 }
353}
354
355void PresburgerSpace::dump() const {
356 print(os&: llvm::errs());
357 llvm::errs() << "\n";
358}
359

source code of mlir/lib/Analysis/Presburger/PresburgerSpace.cpp