1//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/Presburger/PWMAFunction.h"
10#include "mlir/Analysis/Presburger/IntegerRelation.h"
11#include "mlir/Analysis/Presburger/MPInt.h"
12#include "mlir/Analysis/Presburger/PresburgerRelation.h"
13#include "mlir/Analysis/Presburger/PresburgerSpace.h"
14#include "mlir/Analysis/Presburger/Utils.h"
15#include "mlir/Support/LLVM.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/STLFunctionalExtras.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/Support/raw_ostream.h"
20#include <algorithm>
21#include <cassert>
22#include <optional>
23
24using namespace mlir;
25using namespace presburger;
26
27void MultiAffineFunction::assertIsConsistent() const {
28 assert(space.getNumVars() - space.getNumRangeVars() + 1 ==
29 output.getNumColumns() &&
30 "Inconsistent number of output columns");
31 assert(space.getNumDomainVars() + space.getNumSymbolVars() ==
32 divs.getNumNonDivs() &&
33 "Inconsistent number of non-division variables in divs");
34 assert(space.getNumRangeVars() == output.getNumRows() &&
35 "Inconsistent number of output rows");
36 assert(space.getNumLocalVars() == divs.getNumDivs() &&
37 "Inconsistent number of divisions.");
38 assert(divs.hasAllReprs() && "All divisions should have a representation");
39}
40
41// Return the result of subtracting the two given vectors pointwise.
42// The vectors must be of the same size.
43// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
44static SmallVector<MPInt, 8> subtractExprs(ArrayRef<MPInt> vecA,
45 ArrayRef<MPInt> vecB) {
46 assert(vecA.size() == vecB.size() &&
47 "Cannot subtract vectors of differing lengths!");
48 SmallVector<MPInt, 8> result;
49 result.reserve(N: vecA.size());
50 for (unsigned i = 0, e = vecA.size(); i < e; ++i)
51 result.push_back(Elt: vecA[i] - vecB[i]);
52 return result;
53}
54
55PresburgerSet PWMAFunction::getDomain() const {
56 PresburgerSet domain = PresburgerSet::getEmpty(space: getDomainSpace());
57 for (const Piece &piece : pieces)
58 domain.unionInPlace(set: piece.domain);
59 return domain;
60}
61
62void MultiAffineFunction::print(raw_ostream &os) const {
63 space.print(os);
64 os << "Division Representation:\n";
65 divs.print(os);
66 os << "Output:\n";
67 output.print(os);
68}
69
70SmallVector<MPInt, 8>
71MultiAffineFunction::valueAt(ArrayRef<MPInt> point) const {
72 assert(point.size() == getNumDomainVars() + getNumSymbolVars() &&
73 "Point has incorrect dimensionality!");
74
75 SmallVector<MPInt, 8> pointHomogenous{llvm::to_vector(Range&: point)};
76 // Get the division values at this point.
77 SmallVector<std::optional<MPInt>, 8> divValues = divs.divValuesAt(point);
78 // The given point didn't include the values of the divs which the output is a
79 // function of; we have computed one possible set of values and use them here.
80 pointHomogenous.reserve(N: pointHomogenous.size() + divValues.size());
81 for (const std::optional<MPInt> &divVal : divValues)
82 pointHomogenous.push_back(Elt: *divVal);
83 // The matrix `output` has an affine expression in the ith row, corresponding
84 // to the expression for the ith value in the output vector. The last column
85 // of the matrix contains the constant term. Let v be the input point with
86 // a 1 appended at the end. We can see that output * v gives the desired
87 // output vector.
88 pointHomogenous.emplace_back(Args: 1);
89 SmallVector<MPInt, 8> result = output.postMultiplyWithColumn(colVec: pointHomogenous);
90 assert(result.size() == getNumOutputs());
91 return result;
92}
93
94bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
95 assert(space.isCompatible(other.space) &&
96 "Spaces should be compatible for equality check.");
97 return getAsRelation().isEqual(other: other.getAsRelation());
98}
99
100bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
101 const IntegerPolyhedron &domain) const {
102 assert(space.isCompatible(other.space) &&
103 "Spaces should be compatible for equality check.");
104 IntegerRelation restrictedThis = getAsRelation();
105 restrictedThis.intersectDomain(poly: domain);
106
107 IntegerRelation restrictedOther = other.getAsRelation();
108 restrictedOther.intersectDomain(poly: domain);
109
110 return restrictedThis.isEqual(other: restrictedOther);
111}
112
113bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
114 const PresburgerSet &domain) const {
115 assert(space.isCompatible(other.space) &&
116 "Spaces should be compatible for equality check.");
117 return llvm::all_of(Range: domain.getAllDisjuncts(),
118 P: [&](const IntegerRelation &disjunct) {
119 return isEqual(other, domain: IntegerPolyhedron(disjunct));
120 });
121}
122
123void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) {
124 assert(end <= getNumOutputs() && "Invalid range");
125
126 if (start >= end)
127 return;
128
129 space.removeVarRange(kind: VarKind::Range, varStart: start, varLimit: end);
130 output.removeRows(pos: start, count: end - start);
131}
132
133void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) {
134 assert(space.isCompatible(other.space) && "Functions should be compatible");
135
136 unsigned nDivs = getNumDivs();
137 unsigned divOffset = divs.getDivOffset();
138
139 other.divs.insertDiv(pos: 0, num: nDivs);
140
141 SmallVector<MPInt, 8> div(other.divs.getNumVars() + 1);
142 for (unsigned i = 0; i < nDivs; ++i) {
143 // Zero fill.
144 std::fill(first: div.begin(), last: div.end(), value: 0);
145 // Fill div with dividend from `divs`. Do not fill the constant.
146 std::copy(first: divs.getDividend(i).begin(), last: divs.getDividend(i).end() - 1,
147 result: div.begin());
148 // Fill constant.
149 div.back() = divs.getDividend(i).back();
150 other.divs.setDiv(i, dividend: div, divisor: divs.getDenom(i));
151 }
152
153 other.space.insertVar(kind: VarKind::Local, pos: 0, num: nDivs);
154 other.output.insertColumns(pos: divOffset, count: nDivs);
155
156 auto merge = [&](unsigned i, unsigned j) {
157 // We only merge from local at pos j to local at pos i, where j > i.
158 if (i >= j)
159 return false;
160
161 // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we
162 // do not want to merge duplicates in `this`, we ignore this call.
163 if (j < nDivs)
164 return false;
165
166 // Merge things in space and output.
167 other.space.removeVarRange(kind: VarKind::Local, varStart: j, varLimit: j + 1);
168 other.output.addToColumn(sourceColumn: divOffset + i, targetColumn: divOffset + j, scale: 1);
169 other.output.removeColumn(pos: divOffset + j);
170 return true;
171 };
172
173 other.divs.removeDuplicateDivs(merge);
174
175 unsigned newDivs = other.divs.getNumDivs() - nDivs;
176
177 space.insertVar(kind: VarKind::Local, pos: nDivs, num: newDivs);
178 output.insertColumns(pos: divOffset + nDivs, count: newDivs);
179 divs = other.divs;
180
181 // Check consistency.
182 assertIsConsistent();
183 other.assertIsConsistent();
184}
185
186PresburgerSet
187MultiAffineFunction::getLexSet(OrderingKind comp,
188 const MultiAffineFunction &other) const {
189 assert(getSpace().isCompatible(other.getSpace()) &&
190 "Output space of funcs should be compatible");
191
192 // Create copies of functions and merge their local space.
193 MultiAffineFunction funcA = *this;
194 MultiAffineFunction funcB = other;
195 funcA.mergeDivs(other&: funcB);
196
197 // We first create the set `result`, corresponding to the set where output
198 // of funcA is lexicographically larger/smaller than funcB. This is done by
199 // creating a PresburgerSet with the following constraints:
200 //
201 // (outA[0] > outB[0]) U
202 // (outA[0] = outB[0], outA[1] > outA[1]) U
203 // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
204 // ...
205 // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
206 //
207 // where `n` is the number of outputs.
208 // If `lexMin` is set, the complement inequality is used:
209 //
210 // (outA[0] < outB[0]) U
211 // (outA[0] = outB[0], outA[1] < outA[1]) U
212 // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
213 // ...
214 // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
215 PresburgerSpace resultSpace = funcA.getDomainSpace();
216 PresburgerSet result =
217 PresburgerSet::getEmpty(space: resultSpace.getSpaceWithoutLocals());
218 IntegerPolyhedron levelSet(
219 /*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(),
220 /*numReservedEqualities=*/funcA.getNumOutputs(),
221 /*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace);
222
223 // Add division inequalities to `levelSet`.
224 for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) {
225 levelSet.addInequality(inEq: getDivUpperBound(dividend: funcA.divs.getDividend(i),
226 divisor: funcA.divs.getDenom(i),
227 localVarIdx: funcA.divs.getDivOffset() + i));
228 levelSet.addInequality(inEq: getDivLowerBound(dividend: funcA.divs.getDividend(i),
229 divisor: funcA.divs.getDenom(i),
230 localVarIdx: funcA.divs.getDivOffset() + i));
231 }
232
233 for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) {
234 // Create the expression `outA - outB` for this level.
235 SmallVector<MPInt, 8> subExpr =
236 subtractExprs(vecA: funcA.getOutputExpr(i: level), vecB: funcB.getOutputExpr(i: level));
237
238 // TODO: Implement all comparison cases.
239 switch (comp) {
240 case OrderingKind::LT:
241 // For less than, we add an upper bound of -1:
242 // outA - outB <= -1
243 // outA <= outB - 1
244 // outA < outB
245 levelSet.addBound(type: BoundType::UB, expr: subExpr, value: MPInt(-1));
246 break;
247 case OrderingKind::GT:
248 // For greater than, we add a lower bound of 1:
249 // outA - outB >= 1
250 // outA > outB + 1
251 // outA > outB
252 levelSet.addBound(type: BoundType::LB, expr: subExpr, value: MPInt(1));
253 break;
254 case OrderingKind::GE:
255 case OrderingKind::LE:
256 case OrderingKind::EQ:
257 case OrderingKind::NE:
258 assert(false && "Not implemented case");
259 }
260
261 // Union the set with the result.
262 result.unionInPlace(disjunct: levelSet);
263 // The last inequality in `levelSet` is the bound we inserted. We remove
264 // that for next iteration.
265 levelSet.removeInequality(pos: levelSet.getNumInequalities() - 1);
266 // Add equality `outA - outB == 0` for this level for next iteration.
267 levelSet.addEquality(eq: subExpr);
268 }
269
270 return result;
271}
272
273/// Two PWMAFunctions are equal if they have the same dimensionalities,
274/// the same domain, and take the same value at every point in the domain.
275bool PWMAFunction::isEqual(const PWMAFunction &other) const {
276 if (!space.isCompatible(other: other.space))
277 return false;
278
279 if (!this->getDomain().isEqual(set: other.getDomain()))
280 return false;
281
282 // Check if, whenever the domains of a piece of `this` and a piece of `other`
283 // overlap, they take the same output value. If `this` and `other` have the
284 // same domain (checked above), then this check passes iff the two functions
285 // have the same output at every point in the domain.
286 return llvm::all_of(Range: this->pieces, P: [&other](const Piece &pieceA) {
287 return llvm::all_of(Range: other.pieces, P: [&pieceA](const Piece &pieceB) {
288 PresburgerSet commonDomain = pieceA.domain.intersect(set: pieceB.domain);
289 return pieceA.output.isEqual(other: pieceB.output, domain: commonDomain);
290 });
291 });
292}
293
294void PWMAFunction::addPiece(const Piece &piece) {
295 assert(piece.isConsistent() && "Piece should be consistent");
296 assert(piece.domain.intersect(getDomain()).isIntegerEmpty() &&
297 "Piece should be disjoint from the function");
298 pieces.push_back(Elt: piece);
299}
300
301void PWMAFunction::print(raw_ostream &os) const {
302 space.print(os);
303 os << getNumPieces() << " pieces:\n";
304 for (const Piece &piece : pieces) {
305 os << "Domain of piece:\n";
306 piece.domain.print(os);
307 os << "Output of piece\n";
308 piece.output.print(os);
309 }
310}
311
312void PWMAFunction::dump() const { print(os&: llvm::errs()); }
313
314PWMAFunction PWMAFunction::unionFunction(
315 const PWMAFunction &func,
316 llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const {
317 assert(getNumOutputs() == func.getNumOutputs() &&
318 "Ranges of functions should be same.");
319 assert(getSpace().isCompatible(func.getSpace()) &&
320 "Space is not compatible.");
321
322 // The algorithm used here is as follows:
323 // - Add the output of pieceB for the part of the domain where both pieceA and
324 // pieceB are defined, and `tiebreak` chooses the output of pieceB.
325 // - Add the output of pieceA, where pieceB is not defined or `tiebreak`
326 // chooses
327 // pieceA over pieceB.
328 // - Add the output of pieceB, where pieceA is not defined.
329
330 // Add parts of the common domain where pieceB's output is used. Also
331 // add all the parts where pieceA's output is used, both common and
332 // non-common.
333 PWMAFunction result(getSpace());
334 for (const Piece &pieceA : pieces) {
335 PresburgerSet dom(pieceA.domain);
336 for (const Piece &pieceB : func.pieces) {
337 PresburgerSet better = tiebreak(pieceB, pieceA);
338 // Add the output of pieceB, where it is better than output of pieceA.
339 // The disjuncts in "better" will be disjoint as tiebreak should gurantee
340 // that.
341 result.addPiece(piece: {.domain: better, .output: pieceB.output});
342 dom = dom.subtract(set: better);
343 }
344 // Add output of pieceA, where it is better than pieceB, or pieceB is not
345 // defined.
346 //
347 // `dom` here is guranteed to be disjoint from already added pieces
348 // because the pieces added before are either:
349 // - Subsets of the domain of other MAFs in `this`, which are guranteed
350 // to be disjoint from `dom`, or
351 // - They are one of the pieces added for `pieceB`, and we have been
352 // subtracting all such pieces from `dom`, so `dom` is disjoint from those
353 // pieces as well.
354 result.addPiece(piece: {.domain: dom, .output: pieceA.output});
355 }
356
357 // Add parts of pieceB which are not shared with pieceA.
358 PresburgerSet dom = getDomain();
359 for (const Piece &pieceB : func.pieces)
360 result.addPiece(piece: {.domain: pieceB.domain.subtract(set: dom), .output: pieceB.output});
361
362 return result;
363}
364
365/// A tiebreak function which breaks ties by comparing the outputs
366/// lexicographically based on the given comparison operator.
367/// This is templated since it is passed as a lambda.
368template <OrderingKind comp>
369static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
370 const PWMAFunction::Piece &pieceB) {
371 PresburgerSet result = pieceA.output.getLexSet(comp, other: pieceB.output);
372 result = result.intersect(set: pieceA.domain).intersect(set: pieceB.domain);
373
374 return result;
375}
376
377PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
378 return unionFunction(func, tiebreak: tiebreakLex</*comp=*/OrderingKind::LT>);
379}
380
381PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
382 return unionFunction(func, tiebreak: tiebreakLex</*comp=*/OrderingKind::GT>);
383}
384
385void MultiAffineFunction::subtract(const MultiAffineFunction &other) {
386 assert(space.isCompatible(other.space) &&
387 "Spaces should be compatible for subtraction.");
388
389 MultiAffineFunction copyOther = other;
390 mergeDivs(other&: copyOther);
391 for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
392 output.addToRow(row: i, rowVec: copyOther.getOutputExpr(i), scale: MPInt(-1));
393
394 // Check consistency.
395 assertIsConsistent();
396}
397
398/// Adds division constraints corresponding to local variables, given a
399/// relation and division representations of the local variables in the
400/// relation.
401static void addDivisionConstraints(IntegerRelation &rel,
402 const DivisionRepr &divs) {
403 assert(divs.hasAllReprs() &&
404 "All divisions in divs should have a representation");
405 assert(rel.getNumVars() == divs.getNumVars() &&
406 "Relation and divs should have the same number of vars");
407 assert(rel.getNumLocalVars() == divs.getNumDivs() &&
408 "Relation and divs should have the same number of local vars");
409
410 for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) {
411 rel.addInequality(inEq: getDivUpperBound(dividend: divs.getDividend(i), divisor: divs.getDenom(i),
412 localVarIdx: divs.getDivOffset() + i));
413 rel.addInequality(inEq: getDivLowerBound(dividend: divs.getDividend(i), divisor: divs.getDenom(i),
414 localVarIdx: divs.getDivOffset() + i));
415 }
416}
417
418IntegerRelation MultiAffineFunction::getAsRelation() const {
419 // Create a relation corressponding to the input space plus the divisions
420 // used in outputs.
421 IntegerRelation result(PresburgerSpace::getRelationSpace(
422 numDomain: space.getNumDomainVars(), numRange: 0, numSymbols: space.getNumSymbolVars(),
423 numLocals: space.getNumLocalVars()));
424 // Add division constraints corresponding to divisions used in outputs.
425 addDivisionConstraints(rel&: result, divs);
426 // The outputs are represented as range variables in the relation. We add
427 // range variables for the outputs.
428 result.insertVar(kind: VarKind::Range, pos: 0, num: getNumOutputs());
429
430 // Add equalities such that the i^th range variable is equal to the i^th
431 // output expression.
432 SmallVector<MPInt, 8> eq(result.getNumCols());
433 for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
434 // TODO: Add functions to get VarKind offsets in output in MAF and use them
435 // here.
436 // The output expression does not contain range variables, while the
437 // equality does. So, we need to copy all variables and mark all range
438 // variables as 0 in the equality.
439 ArrayRef<MPInt> expr = getOutputExpr(i);
440 // Copy domain variables in `expr` to domain variables in `eq`.
441 std::copy(first: expr.begin(), last: expr.begin() + getNumDomainVars(), result: eq.begin());
442 // Fill the range variables in `eq` as zero.
443 std::fill(first: eq.begin() + result.getVarKindOffset(kind: VarKind::Range),
444 last: eq.begin() + result.getVarKindEnd(kind: VarKind::Range), value: 0);
445 // Copy remaining variables in `expr` to the remaining variables in `eq`.
446 std::copy(first: expr.begin() + getNumDomainVars(), last: expr.end(),
447 result: eq.begin() + result.getVarKindEnd(kind: VarKind::Range));
448
449 // Set the i^th range var to -1 in `eq` to equate the output expression to
450 // this range var.
451 eq[result.getVarKindOffset(kind: VarKind::Range) + i] = -1;
452 // Add the equality `rangeVar_i = output[i]`.
453 result.addEquality(eq);
454 }
455
456 return result;
457}
458
459void PWMAFunction::removeOutputs(unsigned start, unsigned end) {
460 space.removeVarRange(kind: VarKind::Range, varStart: start, varLimit: end);
461 for (Piece &piece : pieces)
462 piece.output.removeOutputs(start, end);
463}
464
465std::optional<SmallVector<MPInt, 8>>
466PWMAFunction::valueAt(ArrayRef<MPInt> point) const {
467 assert(point.size() == getNumDomainVars() + getNumSymbolVars());
468
469 for (const Piece &piece : pieces)
470 if (piece.domain.containsPoint(point))
471 return piece.output.valueAt(point);
472 return std::nullopt;
473}
474

source code of mlir/lib/Analysis/Presburger/PWMAFunction.cpp