1 | //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/Presburger/PWMAFunction.h" |
10 | #include "mlir/Analysis/Presburger/IntegerRelation.h" |
11 | #include "mlir/Analysis/Presburger/MPInt.h" |
12 | #include "mlir/Analysis/Presburger/PresburgerRelation.h" |
13 | #include "mlir/Analysis/Presburger/PresburgerSpace.h" |
14 | #include "mlir/Analysis/Presburger/Utils.h" |
15 | #include "mlir/Support/LLVM.h" |
16 | #include "llvm/ADT/STLExtras.h" |
17 | #include "llvm/ADT/STLFunctionalExtras.h" |
18 | #include "llvm/ADT/SmallVector.h" |
19 | #include "llvm/Support/raw_ostream.h" |
20 | #include <algorithm> |
21 | #include <cassert> |
22 | #include <optional> |
23 | |
24 | using namespace mlir; |
25 | using namespace presburger; |
26 | |
27 | void MultiAffineFunction::assertIsConsistent() const { |
28 | assert(space.getNumVars() - space.getNumRangeVars() + 1 == |
29 | output.getNumColumns() && |
30 | "Inconsistent number of output columns" ); |
31 | assert(space.getNumDomainVars() + space.getNumSymbolVars() == |
32 | divs.getNumNonDivs() && |
33 | "Inconsistent number of non-division variables in divs" ); |
34 | assert(space.getNumRangeVars() == output.getNumRows() && |
35 | "Inconsistent number of output rows" ); |
36 | assert(space.getNumLocalVars() == divs.getNumDivs() && |
37 | "Inconsistent number of divisions." ); |
38 | assert(divs.hasAllReprs() && "All divisions should have a representation" ); |
39 | } |
40 | |
41 | // Return the result of subtracting the two given vectors pointwise. |
42 | // The vectors must be of the same size. |
43 | // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5]. |
44 | static SmallVector<MPInt, 8> subtractExprs(ArrayRef<MPInt> vecA, |
45 | ArrayRef<MPInt> vecB) { |
46 | assert(vecA.size() == vecB.size() && |
47 | "Cannot subtract vectors of differing lengths!" ); |
48 | SmallVector<MPInt, 8> result; |
49 | result.reserve(N: vecA.size()); |
50 | for (unsigned i = 0, e = vecA.size(); i < e; ++i) |
51 | result.push_back(Elt: vecA[i] - vecB[i]); |
52 | return result; |
53 | } |
54 | |
55 | PresburgerSet PWMAFunction::getDomain() const { |
56 | PresburgerSet domain = PresburgerSet::getEmpty(space: getDomainSpace()); |
57 | for (const Piece &piece : pieces) |
58 | domain.unionInPlace(set: piece.domain); |
59 | return domain; |
60 | } |
61 | |
62 | void MultiAffineFunction::print(raw_ostream &os) const { |
63 | space.print(os); |
64 | os << "Division Representation:\n" ; |
65 | divs.print(os); |
66 | os << "Output:\n" ; |
67 | output.print(os); |
68 | } |
69 | |
70 | SmallVector<MPInt, 8> |
71 | MultiAffineFunction::valueAt(ArrayRef<MPInt> point) const { |
72 | assert(point.size() == getNumDomainVars() + getNumSymbolVars() && |
73 | "Point has incorrect dimensionality!" ); |
74 | |
75 | SmallVector<MPInt, 8> pointHomogenous{llvm::to_vector(Range&: point)}; |
76 | // Get the division values at this point. |
77 | SmallVector<std::optional<MPInt>, 8> divValues = divs.divValuesAt(point); |
78 | // The given point didn't include the values of the divs which the output is a |
79 | // function of; we have computed one possible set of values and use them here. |
80 | pointHomogenous.reserve(N: pointHomogenous.size() + divValues.size()); |
81 | for (const std::optional<MPInt> &divVal : divValues) |
82 | pointHomogenous.push_back(Elt: *divVal); |
83 | // The matrix `output` has an affine expression in the ith row, corresponding |
84 | // to the expression for the ith value in the output vector. The last column |
85 | // of the matrix contains the constant term. Let v be the input point with |
86 | // a 1 appended at the end. We can see that output * v gives the desired |
87 | // output vector. |
88 | pointHomogenous.emplace_back(Args: 1); |
89 | SmallVector<MPInt, 8> result = output.postMultiplyWithColumn(colVec: pointHomogenous); |
90 | assert(result.size() == getNumOutputs()); |
91 | return result; |
92 | } |
93 | |
94 | bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { |
95 | assert(space.isCompatible(other.space) && |
96 | "Spaces should be compatible for equality check." ); |
97 | return getAsRelation().isEqual(other: other.getAsRelation()); |
98 | } |
99 | |
100 | bool MultiAffineFunction::isEqual(const MultiAffineFunction &other, |
101 | const IntegerPolyhedron &domain) const { |
102 | assert(space.isCompatible(other.space) && |
103 | "Spaces should be compatible for equality check." ); |
104 | IntegerRelation restrictedThis = getAsRelation(); |
105 | restrictedThis.intersectDomain(poly: domain); |
106 | |
107 | IntegerRelation restrictedOther = other.getAsRelation(); |
108 | restrictedOther.intersectDomain(poly: domain); |
109 | |
110 | return restrictedThis.isEqual(other: restrictedOther); |
111 | } |
112 | |
113 | bool MultiAffineFunction::isEqual(const MultiAffineFunction &other, |
114 | const PresburgerSet &domain) const { |
115 | assert(space.isCompatible(other.space) && |
116 | "Spaces should be compatible for equality check." ); |
117 | return llvm::all_of(Range: domain.getAllDisjuncts(), |
118 | P: [&](const IntegerRelation &disjunct) { |
119 | return isEqual(other, domain: IntegerPolyhedron(disjunct)); |
120 | }); |
121 | } |
122 | |
123 | void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) { |
124 | assert(end <= getNumOutputs() && "Invalid range" ); |
125 | |
126 | if (start >= end) |
127 | return; |
128 | |
129 | space.removeVarRange(kind: VarKind::Range, varStart: start, varLimit: end); |
130 | output.removeRows(pos: start, count: end - start); |
131 | } |
132 | |
133 | void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) { |
134 | assert(space.isCompatible(other.space) && "Functions should be compatible" ); |
135 | |
136 | unsigned nDivs = getNumDivs(); |
137 | unsigned divOffset = divs.getDivOffset(); |
138 | |
139 | other.divs.insertDiv(pos: 0, num: nDivs); |
140 | |
141 | SmallVector<MPInt, 8> div(other.divs.getNumVars() + 1); |
142 | for (unsigned i = 0; i < nDivs; ++i) { |
143 | // Zero fill. |
144 | std::fill(first: div.begin(), last: div.end(), value: 0); |
145 | // Fill div with dividend from `divs`. Do not fill the constant. |
146 | std::copy(first: divs.getDividend(i).begin(), last: divs.getDividend(i).end() - 1, |
147 | result: div.begin()); |
148 | // Fill constant. |
149 | div.back() = divs.getDividend(i).back(); |
150 | other.divs.setDiv(i, dividend: div, divisor: divs.getDenom(i)); |
151 | } |
152 | |
153 | other.space.insertVar(kind: VarKind::Local, pos: 0, num: nDivs); |
154 | other.output.insertColumns(pos: divOffset, count: nDivs); |
155 | |
156 | auto merge = [&](unsigned i, unsigned j) { |
157 | // We only merge from local at pos j to local at pos i, where j > i. |
158 | if (i >= j) |
159 | return false; |
160 | |
161 | // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we |
162 | // do not want to merge duplicates in `this`, we ignore this call. |
163 | if (j < nDivs) |
164 | return false; |
165 | |
166 | // Merge things in space and output. |
167 | other.space.removeVarRange(kind: VarKind::Local, varStart: j, varLimit: j + 1); |
168 | other.output.addToColumn(sourceColumn: divOffset + i, targetColumn: divOffset + j, scale: 1); |
169 | other.output.removeColumn(pos: divOffset + j); |
170 | return true; |
171 | }; |
172 | |
173 | other.divs.removeDuplicateDivs(merge); |
174 | |
175 | unsigned newDivs = other.divs.getNumDivs() - nDivs; |
176 | |
177 | space.insertVar(kind: VarKind::Local, pos: nDivs, num: newDivs); |
178 | output.insertColumns(pos: divOffset + nDivs, count: newDivs); |
179 | divs = other.divs; |
180 | |
181 | // Check consistency. |
182 | assertIsConsistent(); |
183 | other.assertIsConsistent(); |
184 | } |
185 | |
186 | PresburgerSet |
187 | MultiAffineFunction::getLexSet(OrderingKind comp, |
188 | const MultiAffineFunction &other) const { |
189 | assert(getSpace().isCompatible(other.getSpace()) && |
190 | "Output space of funcs should be compatible" ); |
191 | |
192 | // Create copies of functions and merge their local space. |
193 | MultiAffineFunction funcA = *this; |
194 | MultiAffineFunction funcB = other; |
195 | funcA.mergeDivs(other&: funcB); |
196 | |
197 | // We first create the set `result`, corresponding to the set where output |
198 | // of funcA is lexicographically larger/smaller than funcB. This is done by |
199 | // creating a PresburgerSet with the following constraints: |
200 | // |
201 | // (outA[0] > outB[0]) U |
202 | // (outA[0] = outB[0], outA[1] > outA[1]) U |
203 | // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U |
204 | // ... |
205 | // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1]) |
206 | // |
207 | // where `n` is the number of outputs. |
208 | // If `lexMin` is set, the complement inequality is used: |
209 | // |
210 | // (outA[0] < outB[0]) U |
211 | // (outA[0] = outB[0], outA[1] < outA[1]) U |
212 | // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U |
213 | // ... |
214 | // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1]) |
215 | PresburgerSpace resultSpace = funcA.getDomainSpace(); |
216 | PresburgerSet result = |
217 | PresburgerSet::getEmpty(space: resultSpace.getSpaceWithoutLocals()); |
218 | IntegerPolyhedron levelSet( |
219 | /*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(), |
220 | /*numReservedEqualities=*/funcA.getNumOutputs(), |
221 | /*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace); |
222 | |
223 | // Add division inequalities to `levelSet`. |
224 | for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) { |
225 | levelSet.addInequality(inEq: getDivUpperBound(dividend: funcA.divs.getDividend(i), |
226 | divisor: funcA.divs.getDenom(i), |
227 | localVarIdx: funcA.divs.getDivOffset() + i)); |
228 | levelSet.addInequality(inEq: getDivLowerBound(dividend: funcA.divs.getDividend(i), |
229 | divisor: funcA.divs.getDenom(i), |
230 | localVarIdx: funcA.divs.getDivOffset() + i)); |
231 | } |
232 | |
233 | for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) { |
234 | // Create the expression `outA - outB` for this level. |
235 | SmallVector<MPInt, 8> subExpr = |
236 | subtractExprs(vecA: funcA.getOutputExpr(i: level), vecB: funcB.getOutputExpr(i: level)); |
237 | |
238 | // TODO: Implement all comparison cases. |
239 | switch (comp) { |
240 | case OrderingKind::LT: |
241 | // For less than, we add an upper bound of -1: |
242 | // outA - outB <= -1 |
243 | // outA <= outB - 1 |
244 | // outA < outB |
245 | levelSet.addBound(type: BoundType::UB, expr: subExpr, value: MPInt(-1)); |
246 | break; |
247 | case OrderingKind::GT: |
248 | // For greater than, we add a lower bound of 1: |
249 | // outA - outB >= 1 |
250 | // outA > outB + 1 |
251 | // outA > outB |
252 | levelSet.addBound(type: BoundType::LB, expr: subExpr, value: MPInt(1)); |
253 | break; |
254 | case OrderingKind::GE: |
255 | case OrderingKind::LE: |
256 | case OrderingKind::EQ: |
257 | case OrderingKind::NE: |
258 | assert(false && "Not implemented case" ); |
259 | } |
260 | |
261 | // Union the set with the result. |
262 | result.unionInPlace(disjunct: levelSet); |
263 | // The last inequality in `levelSet` is the bound we inserted. We remove |
264 | // that for next iteration. |
265 | levelSet.removeInequality(pos: levelSet.getNumInequalities() - 1); |
266 | // Add equality `outA - outB == 0` for this level for next iteration. |
267 | levelSet.addEquality(eq: subExpr); |
268 | } |
269 | |
270 | return result; |
271 | } |
272 | |
273 | /// Two PWMAFunctions are equal if they have the same dimensionalities, |
274 | /// the same domain, and take the same value at every point in the domain. |
275 | bool PWMAFunction::isEqual(const PWMAFunction &other) const { |
276 | if (!space.isCompatible(other: other.space)) |
277 | return false; |
278 | |
279 | if (!this->getDomain().isEqual(set: other.getDomain())) |
280 | return false; |
281 | |
282 | // Check if, whenever the domains of a piece of `this` and a piece of `other` |
283 | // overlap, they take the same output value. If `this` and `other` have the |
284 | // same domain (checked above), then this check passes iff the two functions |
285 | // have the same output at every point in the domain. |
286 | return llvm::all_of(Range: this->pieces, P: [&other](const Piece &pieceA) { |
287 | return llvm::all_of(Range: other.pieces, P: [&pieceA](const Piece &pieceB) { |
288 | PresburgerSet commonDomain = pieceA.domain.intersect(set: pieceB.domain); |
289 | return pieceA.output.isEqual(other: pieceB.output, domain: commonDomain); |
290 | }); |
291 | }); |
292 | } |
293 | |
294 | void PWMAFunction::addPiece(const Piece &piece) { |
295 | assert(piece.isConsistent() && "Piece should be consistent" ); |
296 | assert(piece.domain.intersect(getDomain()).isIntegerEmpty() && |
297 | "Piece should be disjoint from the function" ); |
298 | pieces.push_back(Elt: piece); |
299 | } |
300 | |
301 | void PWMAFunction::print(raw_ostream &os) const { |
302 | space.print(os); |
303 | os << getNumPieces() << " pieces:\n" ; |
304 | for (const Piece &piece : pieces) { |
305 | os << "Domain of piece:\n" ; |
306 | piece.domain.print(os); |
307 | os << "Output of piece\n" ; |
308 | piece.output.print(os); |
309 | } |
310 | } |
311 | |
312 | void PWMAFunction::dump() const { print(os&: llvm::errs()); } |
313 | |
314 | PWMAFunction PWMAFunction::unionFunction( |
315 | const PWMAFunction &func, |
316 | llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const { |
317 | assert(getNumOutputs() == func.getNumOutputs() && |
318 | "Ranges of functions should be same." ); |
319 | assert(getSpace().isCompatible(func.getSpace()) && |
320 | "Space is not compatible." ); |
321 | |
322 | // The algorithm used here is as follows: |
323 | // - Add the output of pieceB for the part of the domain where both pieceA and |
324 | // pieceB are defined, and `tiebreak` chooses the output of pieceB. |
325 | // - Add the output of pieceA, where pieceB is not defined or `tiebreak` |
326 | // chooses |
327 | // pieceA over pieceB. |
328 | // - Add the output of pieceB, where pieceA is not defined. |
329 | |
330 | // Add parts of the common domain where pieceB's output is used. Also |
331 | // add all the parts where pieceA's output is used, both common and |
332 | // non-common. |
333 | PWMAFunction result(getSpace()); |
334 | for (const Piece &pieceA : pieces) { |
335 | PresburgerSet dom(pieceA.domain); |
336 | for (const Piece &pieceB : func.pieces) { |
337 | PresburgerSet better = tiebreak(pieceB, pieceA); |
338 | // Add the output of pieceB, where it is better than output of pieceA. |
339 | // The disjuncts in "better" will be disjoint as tiebreak should gurantee |
340 | // that. |
341 | result.addPiece(piece: {.domain: better, .output: pieceB.output}); |
342 | dom = dom.subtract(set: better); |
343 | } |
344 | // Add output of pieceA, where it is better than pieceB, or pieceB is not |
345 | // defined. |
346 | // |
347 | // `dom` here is guranteed to be disjoint from already added pieces |
348 | // because the pieces added before are either: |
349 | // - Subsets of the domain of other MAFs in `this`, which are guranteed |
350 | // to be disjoint from `dom`, or |
351 | // - They are one of the pieces added for `pieceB`, and we have been |
352 | // subtracting all such pieces from `dom`, so `dom` is disjoint from those |
353 | // pieces as well. |
354 | result.addPiece(piece: {.domain: dom, .output: pieceA.output}); |
355 | } |
356 | |
357 | // Add parts of pieceB which are not shared with pieceA. |
358 | PresburgerSet dom = getDomain(); |
359 | for (const Piece &pieceB : func.pieces) |
360 | result.addPiece(piece: {.domain: pieceB.domain.subtract(set: dom), .output: pieceB.output}); |
361 | |
362 | return result; |
363 | } |
364 | |
365 | /// A tiebreak function which breaks ties by comparing the outputs |
366 | /// lexicographically based on the given comparison operator. |
367 | /// This is templated since it is passed as a lambda. |
368 | template <OrderingKind comp> |
369 | static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA, |
370 | const PWMAFunction::Piece &pieceB) { |
371 | PresburgerSet result = pieceA.output.getLexSet(comp, other: pieceB.output); |
372 | result = result.intersect(set: pieceA.domain).intersect(set: pieceB.domain); |
373 | |
374 | return result; |
375 | } |
376 | |
377 | PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) { |
378 | return unionFunction(func, tiebreak: tiebreakLex</*comp=*/OrderingKind::LT>); |
379 | } |
380 | |
381 | PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) { |
382 | return unionFunction(func, tiebreak: tiebreakLex</*comp=*/OrderingKind::GT>); |
383 | } |
384 | |
385 | void MultiAffineFunction::subtract(const MultiAffineFunction &other) { |
386 | assert(space.isCompatible(other.space) && |
387 | "Spaces should be compatible for subtraction." ); |
388 | |
389 | MultiAffineFunction copyOther = other; |
390 | mergeDivs(other&: copyOther); |
391 | for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) |
392 | output.addToRow(row: i, rowVec: copyOther.getOutputExpr(i), scale: MPInt(-1)); |
393 | |
394 | // Check consistency. |
395 | assertIsConsistent(); |
396 | } |
397 | |
398 | /// Adds division constraints corresponding to local variables, given a |
399 | /// relation and division representations of the local variables in the |
400 | /// relation. |
401 | static void addDivisionConstraints(IntegerRelation &rel, |
402 | const DivisionRepr &divs) { |
403 | assert(divs.hasAllReprs() && |
404 | "All divisions in divs should have a representation" ); |
405 | assert(rel.getNumVars() == divs.getNumVars() && |
406 | "Relation and divs should have the same number of vars" ); |
407 | assert(rel.getNumLocalVars() == divs.getNumDivs() && |
408 | "Relation and divs should have the same number of local vars" ); |
409 | |
410 | for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) { |
411 | rel.addInequality(inEq: getDivUpperBound(dividend: divs.getDividend(i), divisor: divs.getDenom(i), |
412 | localVarIdx: divs.getDivOffset() + i)); |
413 | rel.addInequality(inEq: getDivLowerBound(dividend: divs.getDividend(i), divisor: divs.getDenom(i), |
414 | localVarIdx: divs.getDivOffset() + i)); |
415 | } |
416 | } |
417 | |
418 | IntegerRelation MultiAffineFunction::getAsRelation() const { |
419 | // Create a relation corressponding to the input space plus the divisions |
420 | // used in outputs. |
421 | IntegerRelation result(PresburgerSpace::getRelationSpace( |
422 | numDomain: space.getNumDomainVars(), numRange: 0, numSymbols: space.getNumSymbolVars(), |
423 | numLocals: space.getNumLocalVars())); |
424 | // Add division constraints corresponding to divisions used in outputs. |
425 | addDivisionConstraints(rel&: result, divs); |
426 | // The outputs are represented as range variables in the relation. We add |
427 | // range variables for the outputs. |
428 | result.insertVar(kind: VarKind::Range, pos: 0, num: getNumOutputs()); |
429 | |
430 | // Add equalities such that the i^th range variable is equal to the i^th |
431 | // output expression. |
432 | SmallVector<MPInt, 8> eq(result.getNumCols()); |
433 | for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) { |
434 | // TODO: Add functions to get VarKind offsets in output in MAF and use them |
435 | // here. |
436 | // The output expression does not contain range variables, while the |
437 | // equality does. So, we need to copy all variables and mark all range |
438 | // variables as 0 in the equality. |
439 | ArrayRef<MPInt> expr = getOutputExpr(i); |
440 | // Copy domain variables in `expr` to domain variables in `eq`. |
441 | std::copy(first: expr.begin(), last: expr.begin() + getNumDomainVars(), result: eq.begin()); |
442 | // Fill the range variables in `eq` as zero. |
443 | std::fill(first: eq.begin() + result.getVarKindOffset(kind: VarKind::Range), |
444 | last: eq.begin() + result.getVarKindEnd(kind: VarKind::Range), value: 0); |
445 | // Copy remaining variables in `expr` to the remaining variables in `eq`. |
446 | std::copy(first: expr.begin() + getNumDomainVars(), last: expr.end(), |
447 | result: eq.begin() + result.getVarKindEnd(kind: VarKind::Range)); |
448 | |
449 | // Set the i^th range var to -1 in `eq` to equate the output expression to |
450 | // this range var. |
451 | eq[result.getVarKindOffset(kind: VarKind::Range) + i] = -1; |
452 | // Add the equality `rangeVar_i = output[i]`. |
453 | result.addEquality(eq); |
454 | } |
455 | |
456 | return result; |
457 | } |
458 | |
459 | void PWMAFunction::removeOutputs(unsigned start, unsigned end) { |
460 | space.removeVarRange(kind: VarKind::Range, varStart: start, varLimit: end); |
461 | for (Piece &piece : pieces) |
462 | piece.output.removeOutputs(start, end); |
463 | } |
464 | |
465 | std::optional<SmallVector<MPInt, 8>> |
466 | PWMAFunction::valueAt(ArrayRef<MPInt> point) const { |
467 | assert(point.size() == getNumDomainVars() + getNumSymbolVars()); |
468 | |
469 | for (const Piece &piece : pieces) |
470 | if (piece.domain.containsPoint(point)) |
471 | return piece.output.valueAt(point); |
472 | return std::nullopt; |
473 | } |
474 | |