1 | //===- PWMAFunction.h - MLIR PWMAFunction Class------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Support for piece-wise multi-affine functions. These are functions that are |
10 | // defined on a domain that is a union of IntegerPolyhedrons, and on each domain |
11 | // the value of the function is a tuple of integers, with each value in the |
12 | // tuple being an affine expression in the vars of the IntegerPolyhedron. |
13 | // |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | #ifndef MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H |
17 | #define MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H |
18 | |
19 | #include "mlir/Analysis/Presburger/IntegerRelation.h" |
20 | #include "mlir/Analysis/Presburger/PresburgerRelation.h" |
21 | #include <optional> |
22 | |
23 | namespace mlir { |
24 | namespace presburger { |
25 | |
26 | /// Enum representing a binary comparison operator: equal, not equal, less than, |
27 | /// less than or equal, greater than, greater than or equal. |
28 | enum class OrderingKind { EQ, NE, LT, LE, GT, GE }; |
29 | |
30 | /// This class represents a multi-affine function with the domain as Z^d, where |
31 | /// `d` is the number of domain variables of the function. For example: |
32 | /// |
33 | /// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y). |
34 | /// |
35 | /// The output expressions are represented as a matrix with one row for every |
36 | /// output, one column for each var including division variables, and an extra |
37 | /// column at the end for the constant term. |
38 | /// |
39 | /// Checking equality of two such functions is supported, as well as finding the |
40 | /// value of the function at a specified point. |
41 | class MultiAffineFunction { |
42 | public: |
43 | MultiAffineFunction(const PresburgerSpace &space, const IntMatrix &output) |
44 | : space(space), output(output), |
45 | divs(space.getNumVars() - space.getNumRangeVars()) { |
46 | assertIsConsistent(); |
47 | } |
48 | |
49 | MultiAffineFunction(const PresburgerSpace &space, const IntMatrix &output, |
50 | const DivisionRepr &divs) |
51 | : space(space), output(output), divs(divs) { |
52 | assertIsConsistent(); |
53 | } |
54 | |
55 | unsigned getNumDomainVars() const { return space.getNumDomainVars(); } |
56 | unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); } |
57 | unsigned getNumOutputs() const { return space.getNumRangeVars(); } |
58 | unsigned getNumDivs() const { return space.getNumLocalVars(); } |
59 | |
60 | /// Get the space of this function. |
61 | const PresburgerSpace &getSpace() const { return space; } |
62 | /// Get the domain/output space of the function. The returned space is a set |
63 | /// space. |
64 | PresburgerSpace getDomainSpace() const { return space.getDomainSpace(); } |
65 | PresburgerSpace getOutputSpace() const { return space.getRangeSpace(); } |
66 | |
67 | /// Get a matrix with each row representing row^th output expression. |
68 | const IntMatrix &getOutputMatrix() const { return output; } |
69 | /// Get the `i^th` output expression. |
70 | ArrayRef<MPInt> getOutputExpr(unsigned i) const { return output.getRow(row: i); } |
71 | |
72 | /// Get the divisions used in this function. |
73 | const DivisionRepr &getDivs() const { return divs; } |
74 | |
75 | /// Remove the specified range of outputs. |
76 | void removeOutputs(unsigned start, unsigned end); |
77 | |
78 | /// Given a MAF `other`, merges division variables such that both functions |
79 | /// have the union of the division vars that exist in the functions. |
80 | void mergeDivs(MultiAffineFunction &other); |
81 | |
82 | //// Return the output of the function at the given point. |
83 | SmallVector<MPInt, 8> valueAt(ArrayRef<MPInt> point) const; |
84 | SmallVector<MPInt, 8> valueAt(ArrayRef<int64_t> point) const { |
85 | return valueAt(point: getMPIntVec(range: point)); |
86 | } |
87 | |
88 | /// Return whether the `this` and `other` are equal when the domain is |
89 | /// restricted to `domain`. This is the case if they lie in the same space, |
90 | /// and their outputs are equal for every point in `domain`. |
91 | bool isEqual(const MultiAffineFunction &other) const; |
92 | bool isEqual(const MultiAffineFunction &other, |
93 | const IntegerPolyhedron &domain) const; |
94 | bool isEqual(const MultiAffineFunction &other, |
95 | const PresburgerSet &domain) const; |
96 | |
97 | void subtract(const MultiAffineFunction &other); |
98 | |
99 | /// Return the set of domain points where the output of `this` and `other` |
100 | /// are ordered lexicographically according to the given ordering. |
101 | /// For example, if the given comparison is `LT`, then the returned set |
102 | /// contains all points where the first output of `this` is lexicographically |
103 | /// less than `other`. |
104 | PresburgerSet getLexSet(OrderingKind comp, |
105 | const MultiAffineFunction &other) const; |
106 | |
107 | /// Get this function as a relation. |
108 | IntegerRelation getAsRelation() const; |
109 | |
110 | void print(raw_ostream &os) const; |
111 | void dump() const; |
112 | |
113 | private: |
114 | /// Assert that the MAF is consistent. |
115 | void assertIsConsistent() const; |
116 | |
117 | /// The space of this function. The domain variables are considered as the |
118 | /// input variables of the function. The range variables are considered as |
119 | /// the outputs. The symbols parametrize the function and locals are used to |
120 | /// represent divisions. Each local variable has a corressponding division |
121 | /// representation stored in `divs`. |
122 | PresburgerSpace space; |
123 | |
124 | /// The function's output is a tuple of integers, with the ith element of the |
125 | /// tuple defined by the affine expression given by the ith row of this output |
126 | /// matrix. |
127 | IntMatrix output; |
128 | |
129 | /// Storage for division representation for each local variable in space. |
130 | DivisionRepr divs; |
131 | }; |
132 | |
133 | /// This class represents a piece-wise MultiAffineFunction. This can be thought |
134 | /// of as a list of MultiAffineFunction with disjoint domains, with each having |
135 | /// their own affine expressions for their output tuples. For example, we could |
136 | /// have a function with two input variables (x, y), defined as |
137 | /// |
138 | /// f(x, y) = (2*x + y, y - 4) if x >= 0, y >= 0 |
139 | /// = (-2*x + y, y + 4) if x < 0, y < 0 |
140 | /// = (4, 1) if x < 0, y >= 0 |
141 | /// |
142 | /// Note that the domains all have to be *disjoint*. Otherwise, the behaviour of |
143 | /// this class is undefined. The domains need not cover all possible points; |
144 | /// this represents a partial function and so could be undefined at some points. |
145 | /// |
146 | /// As in PresburgerSets, the input vars are partitioned into dimension vars and |
147 | /// symbolic vars. |
148 | /// |
149 | /// Support is provided to compare equality of two such functions as well as |
150 | /// finding the value of the function at a point. |
151 | class PWMAFunction { |
152 | public: |
153 | struct Piece { |
154 | PresburgerSet domain; |
155 | MultiAffineFunction output; |
156 | |
157 | bool isConsistent() const { |
158 | return domain.getSpace().isCompatible(other: output.getDomainSpace()); |
159 | } |
160 | }; |
161 | |
162 | PWMAFunction(const PresburgerSpace &space) : space(space) { |
163 | assert(space.getNumLocalVars() == 0 && |
164 | "PWMAFunction cannot have local vars." ); |
165 | } |
166 | |
167 | // Get the space of this function. |
168 | const PresburgerSpace &getSpace() const { return space; } |
169 | |
170 | // Add a piece ([domain, output] pair) to this function. |
171 | void addPiece(const Piece &piece); |
172 | |
173 | unsigned getNumPieces() const { return pieces.size(); } |
174 | unsigned getNumVarKind(VarKind kind) const { |
175 | return space.getNumVarKind(kind); |
176 | } |
177 | unsigned getNumDomainVars() const { return space.getNumDomainVars(); } |
178 | unsigned getNumOutputs() const { return space.getNumRangeVars(); } |
179 | unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); } |
180 | |
181 | /// Remove the specified range of outputs. |
182 | void removeOutputs(unsigned start, unsigned end); |
183 | |
184 | /// Get the domain/output space of the function. The returned space is a set |
185 | /// space. |
186 | PresburgerSpace getDomainSpace() const { return space.getDomainSpace(); } |
187 | PresburgerSpace getOutputSpace() const { return space.getDomainSpace(); } |
188 | |
189 | /// Return the domain of this piece-wise MultiAffineFunction. This is the |
190 | /// union of the domains of all the pieces. |
191 | PresburgerSet getDomain() const; |
192 | |
193 | /// Return the output of the function at the given point. |
194 | std::optional<SmallVector<MPInt, 8>> valueAt(ArrayRef<MPInt> point) const; |
195 | std::optional<SmallVector<MPInt, 8>> valueAt(ArrayRef<int64_t> point) const { |
196 | return valueAt(point: getMPIntVec(range: point)); |
197 | } |
198 | |
199 | /// Return all the pieces of this piece-wise function. |
200 | ArrayRef<Piece> getAllPieces() const { return pieces; } |
201 | |
202 | /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether |
203 | /// they have the same dimensions, the same domain and they take the same |
204 | /// value at every point in the domain. |
205 | bool isEqual(const PWMAFunction &other) const; |
206 | |
207 | /// Return a function defined on the union of the domains of this and func, |
208 | /// such that when only one of the functions is defined, it outputs the same |
209 | /// as that function, and if both are defined, it outputs the lexmax/lexmin of |
210 | /// the two outputs. On points where neither function is defined, the returned |
211 | /// function is not defined either. |
212 | /// |
213 | /// Currently this does not support PWMAFunctions which have pieces containing |
214 | /// divisions. |
215 | /// TODO: Support division in pieces. |
216 | PWMAFunction unionLexMin(const PWMAFunction &func); |
217 | PWMAFunction unionLexMax(const PWMAFunction &func); |
218 | |
219 | void print(raw_ostream &os) const; |
220 | void dump() const; |
221 | |
222 | private: |
223 | /// Return a function defined on the union of the domains of `this` and |
224 | /// `func`, such that when only one of the functions is defined, it outputs |
225 | /// the same as that function, and if neither is defined, the returned |
226 | /// function is not defined either. |
227 | /// |
228 | /// The provided `tiebreak` function determines which of the two functions' |
229 | /// output should be used on inputs where both the functions are defined. More |
230 | /// precisely, given two `MultiAffineFunction`s `mafA` and `mafB`, `tiebreak` |
231 | /// returns the subset of the intersection of the two functions' domains where |
232 | /// the output of `mafA` should be used. |
233 | /// |
234 | /// The PresburgerSet returned by `tiebreak` should be disjoint. |
235 | /// TODO: Remove this constraint of returning disjoint set. |
236 | PWMAFunction unionFunction( |
237 | const PWMAFunction &func, |
238 | llvm::function_ref<PresburgerSet(Piece mafA, Piece mafB)> tiebreak) const; |
239 | |
240 | /// The space of this function. The domain variables are considered as the |
241 | /// input variables of the function. The range variables are considered as |
242 | /// the outputs. The symbols paramterize the function. |
243 | PresburgerSpace space; |
244 | |
245 | // The pieces of the PWMAFunction. |
246 | SmallVector<Piece, 4> pieces; |
247 | }; |
248 | |
249 | } // namespace presburger |
250 | } // namespace mlir |
251 | |
252 | #endif // MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H |
253 | |