1 | //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Analysis/Presburger/PWMAFunction.h" |
10 | #include "mlir/Analysis/Presburger/IntegerRelation.h" |
11 | #include "mlir/Analysis/Presburger/PresburgerRelation.h" |
12 | #include "mlir/Analysis/Presburger/PresburgerSpace.h" |
13 | #include "mlir/Analysis/Presburger/Utils.h" |
14 | #include "llvm/ADT/STLExtras.h" |
15 | #include "llvm/ADT/STLFunctionalExtras.h" |
16 | #include "llvm/ADT/SmallVector.h" |
17 | #include "llvm/Support/raw_ostream.h" |
18 | #include <algorithm> |
19 | #include <cassert> |
20 | #include <optional> |
21 | |
22 | using namespace mlir; |
23 | using namespace presburger; |
24 | |
25 | void MultiAffineFunction::assertIsConsistent() const { |
26 | assert(space.getNumVars() - space.getNumRangeVars() + 1 == |
27 | output.getNumColumns() && |
28 | "Inconsistent number of output columns" ); |
29 | assert(space.getNumDomainVars() + space.getNumSymbolVars() == |
30 | divs.getNumNonDivs() && |
31 | "Inconsistent number of non-division variables in divs" ); |
32 | assert(space.getNumRangeVars() == output.getNumRows() && |
33 | "Inconsistent number of output rows" ); |
34 | assert(space.getNumLocalVars() == divs.getNumDivs() && |
35 | "Inconsistent number of divisions." ); |
36 | assert(divs.hasAllReprs() && "All divisions should have a representation" ); |
37 | } |
38 | |
39 | // Return the result of subtracting the two given vectors pointwise. |
40 | // The vectors must be of the same size. |
41 | // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5]. |
42 | static SmallVector<DynamicAPInt, 8> subtractExprs(ArrayRef<DynamicAPInt> vecA, |
43 | ArrayRef<DynamicAPInt> vecB) { |
44 | assert(vecA.size() == vecB.size() && |
45 | "Cannot subtract vectors of differing lengths!" ); |
46 | SmallVector<DynamicAPInt, 8> result; |
47 | result.reserve(N: vecA.size()); |
48 | for (unsigned i = 0, e = vecA.size(); i < e; ++i) |
49 | result.emplace_back(Args: vecA[i] - vecB[i]); |
50 | return result; |
51 | } |
52 | |
53 | PresburgerSet PWMAFunction::getDomain() const { |
54 | PresburgerSet domain = PresburgerSet::getEmpty(space: getDomainSpace()); |
55 | for (const Piece &piece : pieces) |
56 | domain.unionInPlace(set: piece.domain); |
57 | return domain; |
58 | } |
59 | |
60 | void MultiAffineFunction::print(raw_ostream &os) const { |
61 | space.print(os); |
62 | os << "Division Representation:\n" ; |
63 | divs.print(os); |
64 | os << "Output:\n" ; |
65 | output.print(os); |
66 | } |
67 | |
68 | void MultiAffineFunction::dump() const { print(os&: llvm::errs()); } |
69 | |
70 | SmallVector<DynamicAPInt, 8> |
71 | MultiAffineFunction::valueAt(ArrayRef<DynamicAPInt> point) const { |
72 | assert(point.size() == getNumDomainVars() + getNumSymbolVars() && |
73 | "Point has incorrect dimensionality!" ); |
74 | |
75 | SmallVector<DynamicAPInt, 8> pointHomogenous{llvm::to_vector(Range&: point)}; |
76 | // Get the division values at this point. |
77 | SmallVector<std::optional<DynamicAPInt>, 8> divValues = |
78 | divs.divValuesAt(point); |
79 | // The given point didn't include the values of the divs which the output is a |
80 | // function of; we have computed one possible set of values and use them here. |
81 | pointHomogenous.reserve(N: pointHomogenous.size() + divValues.size()); |
82 | for (const std::optional<DynamicAPInt> &divVal : divValues) |
83 | pointHomogenous.emplace_back(Args: *divVal); |
84 | // The matrix `output` has an affine expression in the ith row, corresponding |
85 | // to the expression for the ith value in the output vector. The last column |
86 | // of the matrix contains the constant term. Let v be the input point with |
87 | // a 1 appended at the end. We can see that output * v gives the desired |
88 | // output vector. |
89 | pointHomogenous.emplace_back(Args: 1); |
90 | SmallVector<DynamicAPInt, 8> result = |
91 | output.postMultiplyWithColumn(colVec: pointHomogenous); |
92 | assert(result.size() == getNumOutputs()); |
93 | return result; |
94 | } |
95 | |
96 | bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { |
97 | assert(space.isCompatible(other.space) && |
98 | "Spaces should be compatible for equality check." ); |
99 | return getAsRelation().isEqual(other: other.getAsRelation()); |
100 | } |
101 | |
102 | bool MultiAffineFunction::isEqual(const MultiAffineFunction &other, |
103 | const IntegerPolyhedron &domain) const { |
104 | assert(space.isCompatible(other.space) && |
105 | "Spaces should be compatible for equality check." ); |
106 | IntegerRelation restrictedThis = getAsRelation(); |
107 | restrictedThis.intersectDomain(poly: domain); |
108 | |
109 | IntegerRelation restrictedOther = other.getAsRelation(); |
110 | restrictedOther.intersectDomain(poly: domain); |
111 | |
112 | return restrictedThis.isEqual(other: restrictedOther); |
113 | } |
114 | |
115 | bool MultiAffineFunction::isEqual(const MultiAffineFunction &other, |
116 | const PresburgerSet &domain) const { |
117 | assert(space.isCompatible(other.space) && |
118 | "Spaces should be compatible for equality check." ); |
119 | return llvm::all_of(Range: domain.getAllDisjuncts(), |
120 | P: [&](const IntegerRelation &disjunct) { |
121 | return isEqual(other, domain: IntegerPolyhedron(disjunct)); |
122 | }); |
123 | } |
124 | |
125 | void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) { |
126 | assert(end <= getNumOutputs() && "Invalid range" ); |
127 | |
128 | if (start >= end) |
129 | return; |
130 | |
131 | space.removeVarRange(kind: VarKind::Range, varStart: start, varLimit: end); |
132 | output.removeRows(pos: start, count: end - start); |
133 | } |
134 | |
135 | void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) { |
136 | assert(space.isCompatible(other.space) && "Functions should be compatible" ); |
137 | |
138 | unsigned nDivs = getNumDivs(); |
139 | unsigned divOffset = divs.getDivOffset(); |
140 | |
141 | other.divs.insertDiv(pos: 0, num: nDivs); |
142 | |
143 | SmallVector<DynamicAPInt, 8> div(other.divs.getNumVars() + 1); |
144 | for (unsigned i = 0; i < nDivs; ++i) { |
145 | // Zero fill. |
146 | std::fill(first: div.begin(), last: div.end(), value: 0); |
147 | // Fill div with dividend from `divs`. Do not fill the constant. |
148 | std::copy(first: divs.getDividend(i).begin(), last: divs.getDividend(i).end() - 1, |
149 | result: div.begin()); |
150 | // Fill constant. |
151 | div.back() = divs.getDividend(i).back(); |
152 | other.divs.setDiv(i, dividend: div, divisor: divs.getDenom(i)); |
153 | } |
154 | |
155 | other.space.insertVar(kind: VarKind::Local, pos: 0, num: nDivs); |
156 | other.output.insertColumns(pos: divOffset, count: nDivs); |
157 | |
158 | auto merge = [&](unsigned i, unsigned j) { |
159 | // We only merge from local at pos j to local at pos i, where j > i. |
160 | if (i >= j) |
161 | return false; |
162 | |
163 | // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we |
164 | // do not want to merge duplicates in `this`, we ignore this call. |
165 | if (j < nDivs) |
166 | return false; |
167 | |
168 | // Merge things in space and output. |
169 | other.space.removeVarRange(kind: VarKind::Local, varStart: j, varLimit: j + 1); |
170 | other.output.addToColumn(sourceColumn: divOffset + i, targetColumn: divOffset + j, scale: 1); |
171 | other.output.removeColumn(pos: divOffset + j); |
172 | return true; |
173 | }; |
174 | |
175 | other.divs.removeDuplicateDivs(merge); |
176 | |
177 | unsigned newDivs = other.divs.getNumDivs() - nDivs; |
178 | |
179 | space.insertVar(kind: VarKind::Local, pos: nDivs, num: newDivs); |
180 | output.insertColumns(pos: divOffset + nDivs, count: newDivs); |
181 | divs = other.divs; |
182 | |
183 | // Check consistency. |
184 | assertIsConsistent(); |
185 | other.assertIsConsistent(); |
186 | } |
187 | |
188 | PresburgerSet |
189 | MultiAffineFunction::getLexSet(OrderingKind comp, |
190 | const MultiAffineFunction &other) const { |
191 | assert(getSpace().isCompatible(other.getSpace()) && |
192 | "Output space of funcs should be compatible" ); |
193 | |
194 | // Create copies of functions and merge their local space. |
195 | MultiAffineFunction funcA = *this; |
196 | MultiAffineFunction funcB = other; |
197 | funcA.mergeDivs(other&: funcB); |
198 | |
199 | // We first create the set `result`, corresponding to the set where output |
200 | // of funcA is lexicographically larger/smaller than funcB. This is done by |
201 | // creating a PresburgerSet with the following constraints: |
202 | // |
203 | // (outA[0] > outB[0]) U |
204 | // (outA[0] = outB[0], outA[1] > outA[1]) U |
205 | // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U |
206 | // ... |
207 | // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1]) |
208 | // |
209 | // where `n` is the number of outputs. |
210 | // If `lexMin` is set, the complement inequality is used: |
211 | // |
212 | // (outA[0] < outB[0]) U |
213 | // (outA[0] = outB[0], outA[1] < outA[1]) U |
214 | // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U |
215 | // ... |
216 | // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1]) |
217 | PresburgerSpace resultSpace = funcA.getDomainSpace(); |
218 | PresburgerSet result = |
219 | PresburgerSet::getEmpty(space: resultSpace.getSpaceWithoutLocals()); |
220 | IntegerPolyhedron levelSet( |
221 | /*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(), |
222 | /*numReservedEqualities=*/funcA.getNumOutputs(), |
223 | /*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace); |
224 | |
225 | // Add division inequalities to `levelSet`. |
226 | for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) { |
227 | levelSet.addInequality(inEq: getDivUpperBound(dividend: funcA.divs.getDividend(i), |
228 | divisor: funcA.divs.getDenom(i), |
229 | localVarIdx: funcA.divs.getDivOffset() + i)); |
230 | levelSet.addInequality(inEq: getDivLowerBound(dividend: funcA.divs.getDividend(i), |
231 | divisor: funcA.divs.getDenom(i), |
232 | localVarIdx: funcA.divs.getDivOffset() + i)); |
233 | } |
234 | |
235 | for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) { |
236 | // Create the expression `outA - outB` for this level. |
237 | SmallVector<DynamicAPInt, 8> subExpr = |
238 | subtractExprs(vecA: funcA.getOutputExpr(i: level), vecB: funcB.getOutputExpr(i: level)); |
239 | |
240 | // TODO: Implement all comparison cases. |
241 | switch (comp) { |
242 | case OrderingKind::LT: |
243 | // For less than, we add an upper bound of -1: |
244 | // outA - outB <= -1 |
245 | // outA <= outB - 1 |
246 | // outA < outB |
247 | levelSet.addBound(type: BoundType::UB, expr: subExpr, value: DynamicAPInt(-1)); |
248 | break; |
249 | case OrderingKind::GT: |
250 | // For greater than, we add a lower bound of 1: |
251 | // outA - outB >= 1 |
252 | // outA > outB + 1 |
253 | // outA > outB |
254 | levelSet.addBound(type: BoundType::LB, expr: subExpr, value: DynamicAPInt(1)); |
255 | break; |
256 | case OrderingKind::GE: |
257 | case OrderingKind::LE: |
258 | case OrderingKind::EQ: |
259 | case OrderingKind::NE: |
260 | assert(false && "Not implemented case" ); |
261 | } |
262 | |
263 | // Union the set with the result. |
264 | result.unionInPlace(disjunct: levelSet); |
265 | // The last inequality in `levelSet` is the bound we inserted. We remove |
266 | // that for next iteration. |
267 | levelSet.removeInequality(pos: levelSet.getNumInequalities() - 1); |
268 | // Add equality `outA - outB == 0` for this level for next iteration. |
269 | levelSet.addEquality(eq: subExpr); |
270 | } |
271 | |
272 | return result; |
273 | } |
274 | |
275 | /// Two PWMAFunctions are equal if they have the same dimensionalities, |
276 | /// the same domain, and take the same value at every point in the domain. |
277 | bool PWMAFunction::isEqual(const PWMAFunction &other) const { |
278 | if (!space.isCompatible(other: other.space)) |
279 | return false; |
280 | |
281 | if (!this->getDomain().isEqual(set: other.getDomain())) |
282 | return false; |
283 | |
284 | // Check if, whenever the domains of a piece of `this` and a piece of `other` |
285 | // overlap, they take the same output value. If `this` and `other` have the |
286 | // same domain (checked above), then this check passes iff the two functions |
287 | // have the same output at every point in the domain. |
288 | return llvm::all_of(Range: this->pieces, P: [&other](const Piece &pieceA) { |
289 | return llvm::all_of(Range: other.pieces, P: [&pieceA](const Piece &pieceB) { |
290 | PresburgerSet commonDomain = pieceA.domain.intersect(set: pieceB.domain); |
291 | return pieceA.output.isEqual(other: pieceB.output, domain: commonDomain); |
292 | }); |
293 | }); |
294 | } |
295 | |
296 | void PWMAFunction::addPiece(const Piece &piece) { |
297 | assert(piece.isConsistent() && "Piece should be consistent" ); |
298 | assert(piece.domain.intersect(getDomain()).isIntegerEmpty() && |
299 | "Piece should be disjoint from the function" ); |
300 | pieces.emplace_back(Args: piece); |
301 | } |
302 | |
303 | void PWMAFunction::print(raw_ostream &os) const { |
304 | space.print(os); |
305 | os << getNumPieces() << " pieces:\n" ; |
306 | for (const Piece &piece : pieces) { |
307 | os << "Domain of piece:\n" ; |
308 | piece.domain.print(os); |
309 | os << "Output of piece\n" ; |
310 | piece.output.print(os); |
311 | } |
312 | } |
313 | |
314 | void PWMAFunction::dump() const { print(os&: llvm::errs()); } |
315 | |
316 | PWMAFunction PWMAFunction::unionFunction( |
317 | const PWMAFunction &func, |
318 | llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const { |
319 | assert(getNumOutputs() == func.getNumOutputs() && |
320 | "Ranges of functions should be same." ); |
321 | assert(getSpace().isCompatible(func.getSpace()) && |
322 | "Space is not compatible." ); |
323 | |
324 | // The algorithm used here is as follows: |
325 | // - Add the output of pieceB for the part of the domain where both pieceA and |
326 | // pieceB are defined, and `tiebreak` chooses the output of pieceB. |
327 | // - Add the output of pieceA, where pieceB is not defined or `tiebreak` |
328 | // chooses |
329 | // pieceA over pieceB. |
330 | // - Add the output of pieceB, where pieceA is not defined. |
331 | |
332 | // Add parts of the common domain where pieceB's output is used. Also |
333 | // add all the parts where pieceA's output is used, both common and |
334 | // non-common. |
335 | PWMAFunction result(getSpace()); |
336 | for (const Piece &pieceA : pieces) { |
337 | PresburgerSet dom(pieceA.domain); |
338 | for (const Piece &pieceB : func.pieces) { |
339 | PresburgerSet better = tiebreak(pieceB, pieceA); |
340 | // Add the output of pieceB, where it is better than output of pieceA. |
341 | // The disjuncts in "better" will be disjoint as tiebreak should gurantee |
342 | // that. |
343 | result.addPiece(piece: {.domain: better, .output: pieceB.output}); |
344 | dom = dom.subtract(set: better); |
345 | } |
346 | // Add output of pieceA, where it is better than pieceB, or pieceB is not |
347 | // defined. |
348 | // |
349 | // `dom` here is guranteed to be disjoint from already added pieces |
350 | // because the pieces added before are either: |
351 | // - Subsets of the domain of other MAFs in `this`, which are guranteed |
352 | // to be disjoint from `dom`, or |
353 | // - They are one of the pieces added for `pieceB`, and we have been |
354 | // subtracting all such pieces from `dom`, so `dom` is disjoint from those |
355 | // pieces as well. |
356 | result.addPiece(piece: {.domain: dom, .output: pieceA.output}); |
357 | } |
358 | |
359 | // Add parts of pieceB which are not shared with pieceA. |
360 | PresburgerSet dom = getDomain(); |
361 | for (const Piece &pieceB : func.pieces) |
362 | result.addPiece(piece: {.domain: pieceB.domain.subtract(set: dom), .output: pieceB.output}); |
363 | |
364 | return result; |
365 | } |
366 | |
367 | /// A tiebreak function which breaks ties by comparing the outputs |
368 | /// lexicographically based on the given comparison operator. |
369 | /// This is templated since it is passed as a lambda. |
370 | template <OrderingKind comp> |
371 | static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA, |
372 | const PWMAFunction::Piece &pieceB) { |
373 | PresburgerSet result = pieceA.output.getLexSet(comp, other: pieceB.output); |
374 | result = result.intersect(set: pieceA.domain).intersect(set: pieceB.domain); |
375 | |
376 | return result; |
377 | } |
378 | |
379 | PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) { |
380 | return unionFunction(func, tiebreak: tiebreakLex</*comp=*/OrderingKind::LT>); |
381 | } |
382 | |
383 | PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) { |
384 | return unionFunction(func, tiebreak: tiebreakLex</*comp=*/OrderingKind::GT>); |
385 | } |
386 | |
387 | void MultiAffineFunction::subtract(const MultiAffineFunction &other) { |
388 | assert(space.isCompatible(other.space) && |
389 | "Spaces should be compatible for subtraction." ); |
390 | |
391 | MultiAffineFunction copyOther = other; |
392 | mergeDivs(other&: copyOther); |
393 | for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) |
394 | output.addToRow(row: i, rowVec: copyOther.getOutputExpr(i), scale: DynamicAPInt(-1)); |
395 | |
396 | // Check consistency. |
397 | assertIsConsistent(); |
398 | } |
399 | |
400 | /// Adds division constraints corresponding to local variables, given a |
401 | /// relation and division representations of the local variables in the |
402 | /// relation. |
403 | static void addDivisionConstraints(IntegerRelation &rel, |
404 | const DivisionRepr &divs) { |
405 | assert(divs.hasAllReprs() && |
406 | "All divisions in divs should have a representation" ); |
407 | assert(rel.getNumVars() == divs.getNumVars() && |
408 | "Relation and divs should have the same number of vars" ); |
409 | assert(rel.getNumLocalVars() == divs.getNumDivs() && |
410 | "Relation and divs should have the same number of local vars" ); |
411 | |
412 | for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) { |
413 | rel.addInequality(inEq: getDivUpperBound(dividend: divs.getDividend(i), divisor: divs.getDenom(i), |
414 | localVarIdx: divs.getDivOffset() + i)); |
415 | rel.addInequality(inEq: getDivLowerBound(dividend: divs.getDividend(i), divisor: divs.getDenom(i), |
416 | localVarIdx: divs.getDivOffset() + i)); |
417 | } |
418 | } |
419 | |
420 | IntegerRelation MultiAffineFunction::getAsRelation() const { |
421 | // Create a relation corressponding to the input space plus the divisions |
422 | // used in outputs. |
423 | IntegerRelation result(PresburgerSpace::getRelationSpace( |
424 | numDomain: space.getNumDomainVars(), numRange: 0, numSymbols: space.getNumSymbolVars(), |
425 | numLocals: space.getNumLocalVars())); |
426 | // Add division constraints corresponding to divisions used in outputs. |
427 | addDivisionConstraints(rel&: result, divs); |
428 | // The outputs are represented as range variables in the relation. We add |
429 | // range variables for the outputs. |
430 | result.insertVar(kind: VarKind::Range, pos: 0, num: getNumOutputs()); |
431 | |
432 | // Add equalities such that the i^th range variable is equal to the i^th |
433 | // output expression. |
434 | SmallVector<DynamicAPInt, 8> eq(result.getNumCols()); |
435 | for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) { |
436 | // TODO: Add functions to get VarKind offsets in output in MAF and use them |
437 | // here. |
438 | // The output expression does not contain range variables, while the |
439 | // equality does. So, we need to copy all variables and mark all range |
440 | // variables as 0 in the equality. |
441 | ArrayRef<DynamicAPInt> expr = getOutputExpr(i); |
442 | // Copy domain variables in `expr` to domain variables in `eq`. |
443 | std::copy(first: expr.begin(), last: expr.begin() + getNumDomainVars(), result: eq.begin()); |
444 | // Fill the range variables in `eq` as zero. |
445 | std::fill(first: eq.begin() + result.getVarKindOffset(kind: VarKind::Range), |
446 | last: eq.begin() + result.getVarKindEnd(kind: VarKind::Range), value: 0); |
447 | // Copy remaining variables in `expr` to the remaining variables in `eq`. |
448 | std::copy(first: expr.begin() + getNumDomainVars(), last: expr.end(), |
449 | result: eq.begin() + result.getVarKindEnd(kind: VarKind::Range)); |
450 | |
451 | // Set the i^th range var to -1 in `eq` to equate the output expression to |
452 | // this range var. |
453 | eq[result.getVarKindOffset(kind: VarKind::Range) + i] = -1; |
454 | // Add the equality `rangeVar_i = output[i]`. |
455 | result.addEquality(eq); |
456 | } |
457 | |
458 | return result; |
459 | } |
460 | |
461 | void PWMAFunction::removeOutputs(unsigned start, unsigned end) { |
462 | space.removeVarRange(kind: VarKind::Range, varStart: start, varLimit: end); |
463 | for (Piece &piece : pieces) |
464 | piece.output.removeOutputs(start, end); |
465 | } |
466 | |
467 | std::optional<SmallVector<DynamicAPInt, 8>> |
468 | PWMAFunction::valueAt(ArrayRef<DynamicAPInt> point) const { |
469 | assert(point.size() == getNumDomainVars() + getNumSymbolVars()); |
470 | |
471 | for (const Piece &piece : pieces) |
472 | if (piece.domain.containsPoint(point)) |
473 | return piece.output.valueAt(point); |
474 | return std::nullopt; |
475 | } |
476 | |