1 | //===- Matrix.h - MLIR Matrix Class -----------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This is a simple 2D matrix class that supports reading, writing, resizing, |
10 | // swapping rows, and swapping columns. It can hold integers (MPInt) or rational |
11 | // numbers (Fraction). |
12 | // |
13 | //===----------------------------------------------------------------------===// |
14 | |
15 | #ifndef MLIR_ANALYSIS_PRESBURGER_MATRIX_H |
16 | #define MLIR_ANALYSIS_PRESBURGER_MATRIX_H |
17 | |
18 | #include "mlir/Analysis/Presburger/Fraction.h" |
19 | #include "mlir/Support/LLVM.h" |
20 | #include "llvm/ADT/ArrayRef.h" |
21 | #include "llvm/Support/raw_ostream.h" |
22 | |
23 | #include <bitset> |
24 | #include <cassert> |
25 | |
26 | namespace mlir { |
27 | namespace presburger { |
28 | |
29 | /// This is a class to represent a resizable matrix. |
30 | /// |
31 | /// More columns and rows can be reserved than are currently used. The data is |
32 | /// stored as a single 1D array, viewed as a 2D matrix with nRows rows and |
33 | /// nReservedColumns columns, stored in row major form. Thus the element at |
34 | /// (i, j) is stored at data[i*nReservedColumns + j]. The reserved but unused |
35 | /// columns always have all zero values. The reserved rows are just reserved |
36 | /// space in the underlying SmallVector's capacity. |
37 | /// This class only works for the types MPInt and Fraction, since the method |
38 | /// implementations are in the Matrix.cpp file. Only these two types have |
39 | /// been explicitly instantiated there. |
40 | template <typename T> |
41 | class Matrix { |
42 | static_assert(std::is_same_v<T, MPInt> || std::is_same_v<T, Fraction>, |
43 | "T must be MPInt or Fraction." ); |
44 | |
45 | public: |
46 | Matrix() = delete; |
47 | |
48 | /// Construct a matrix with the specified number of rows and columns. |
49 | /// The number of reserved rows and columns will be at least the number |
50 | /// specified, and will always be sufficient to accomodate the number of rows |
51 | /// and columns specified. |
52 | /// |
53 | /// Initially, the entries are initialized to ero. |
54 | Matrix(unsigned rows, unsigned columns, unsigned reservedRows = 0, |
55 | unsigned reservedColumns = 0); |
56 | |
57 | /// Return the identity matrix of the specified dimension. |
58 | static Matrix identity(unsigned dimension); |
59 | |
60 | /// Access the element at the specified row and column. |
61 | T &at(unsigned row, unsigned column) { |
62 | assert(row < nRows && "Row outside of range" ); |
63 | assert(column < nColumns && "Column outside of range" ); |
64 | return data[row * nReservedColumns + column]; |
65 | } |
66 | |
67 | T at(unsigned row, unsigned column) const { |
68 | assert(row < nRows && "Row outside of range" ); |
69 | assert(column < nColumns && "Column outside of range" ); |
70 | return data[row * nReservedColumns + column]; |
71 | } |
72 | |
73 | T &operator()(unsigned row, unsigned column) { return at(row, column); } |
74 | |
75 | T operator()(unsigned row, unsigned column) const { return at(row, column); } |
76 | |
77 | bool operator==(const Matrix<T> &m) const; |
78 | |
79 | /// Swap the given columns. |
80 | void swapColumns(unsigned column, unsigned otherColumn); |
81 | |
82 | /// Swap the given rows. |
83 | void swapRows(unsigned row, unsigned otherRow); |
84 | |
85 | unsigned getNumRows() const { return nRows; } |
86 | |
87 | unsigned getNumColumns() const { return nColumns; } |
88 | |
89 | /// Return the maximum number of rows/columns that can be added without |
90 | /// incurring a reallocation. |
91 | unsigned getNumReservedRows() const; |
92 | unsigned getNumReservedColumns() const { return nReservedColumns; } |
93 | |
94 | /// Reserve enough space to resize to the specified number of rows without |
95 | /// reallocations. |
96 | void reserveRows(unsigned rows); |
97 | |
98 | /// Get a [Mutable]ArrayRef corresponding to the specified row. |
99 | MutableArrayRef<T> getRow(unsigned row); |
100 | ArrayRef<T> getRow(unsigned row) const; |
101 | |
102 | /// Set the specified row to `elems`. |
103 | void setRow(unsigned row, ArrayRef<T> elems); |
104 | |
105 | /// Insert columns having positions pos, pos + 1, ... pos + count - 1. |
106 | /// Columns that were at positions 0 to pos - 1 will stay where they are; |
107 | /// columns that were at positions pos to nColumns - 1 will be pushed to the |
108 | /// right. pos should be at most nColumns. |
109 | void insertColumns(unsigned pos, unsigned count); |
110 | void insertColumn(unsigned pos); |
111 | |
112 | /// Insert rows having positions pos, pos + 1, ... pos + count - 1. |
113 | /// Rows that were at positions 0 to pos - 1 will stay where they are; |
114 | /// rows that were at positions pos to nColumns - 1 will be pushed to the |
115 | /// right. pos should be at most nRows. |
116 | void insertRows(unsigned pos, unsigned count); |
117 | void insertRow(unsigned pos); |
118 | |
119 | /// Remove the columns having positions pos, pos + 1, ... pos + count - 1. |
120 | /// Rows that were at positions 0 to pos - 1 will stay where they are; |
121 | /// columns that were at positions pos + count - 1 or later will be pushed to |
122 | /// the right. The columns to be deleted must be valid rows: pos + count - 1 |
123 | /// must be at most nColumns - 1. |
124 | void removeColumns(unsigned pos, unsigned count); |
125 | void removeColumn(unsigned pos); |
126 | |
127 | /// Remove the rows having positions pos, pos + 1, ... pos + count - 1. |
128 | /// Rows that were at positions 0 to pos - 1 will stay where they are; |
129 | /// rows that were at positions pos + count - 1 or later will be pushed to the |
130 | /// right. The rows to be deleted must be valid rows: pos + count - 1 must be |
131 | /// at most nRows - 1. |
132 | void removeRows(unsigned pos, unsigned count); |
133 | void removeRow(unsigned pos); |
134 | |
135 | void copyRow(unsigned sourceRow, unsigned targetRow); |
136 | |
137 | void fillRow(unsigned row, const T &value); |
138 | void fillRow(unsigned row, int64_t value) { fillRow(row, T(value)); } |
139 | |
140 | /// Add `scale` multiples of the source row to the target row. |
141 | void addToRow(unsigned sourceRow, unsigned targetRow, const T &scale); |
142 | void addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale) { |
143 | addToRow(sourceRow, targetRow, T(scale)); |
144 | } |
145 | /// Add `scale` multiples of the rowVec row to the specified row. |
146 | void addToRow(unsigned row, ArrayRef<T> rowVec, const T &scale); |
147 | |
148 | /// Multiply the specified row by a factor of `scale`. |
149 | void scaleRow(unsigned row, const T &scale); |
150 | |
151 | /// Add `scale` multiples of the source column to the target column. |
152 | void addToColumn(unsigned sourceColumn, unsigned targetColumn, |
153 | const T &scale); |
154 | void addToColumn(unsigned sourceColumn, unsigned targetColumn, |
155 | int64_t scale) { |
156 | addToColumn(sourceColumn, targetColumn, T(scale)); |
157 | } |
158 | |
159 | /// Negate the specified column. |
160 | void negateColumn(unsigned column); |
161 | |
162 | /// Negate the specified row. |
163 | void negateRow(unsigned row); |
164 | |
165 | /// Negate the entire matrix. |
166 | void negateMatrix(); |
167 | |
168 | /// The given vector is interpreted as a row vector v. Post-multiply v with |
169 | /// this matrix, say M, and return vM. |
170 | SmallVector<T, 8> preMultiplyWithRow(ArrayRef<T> rowVec) const; |
171 | |
172 | /// The given vector is interpreted as a column vector v. Pre-multiply v with |
173 | /// this matrix, say M, and return Mv. |
174 | SmallVector<T, 8> postMultiplyWithColumn(ArrayRef<T> colVec) const; |
175 | |
176 | /// Resize the matrix to the specified dimensions. If a dimension is smaller, |
177 | /// the values are truncated; if it is bigger, the new values are initialized |
178 | /// to zero. |
179 | /// |
180 | /// Due to the representation of the matrix, resizing vertically (adding rows) |
181 | /// is less expensive than increasing the number of columns beyond |
182 | /// nReservedColumns. |
183 | void resize(unsigned newNRows, unsigned newNColumns); |
184 | void resizeHorizontally(unsigned newNColumns); |
185 | void resizeVertically(unsigned newNRows); |
186 | |
187 | /// Add an extra row at the bottom of the matrix and return its position. |
188 | unsigned (); |
189 | /// Same as above, but copy the given elements into the row. The length of |
190 | /// `elems` must be equal to the number of columns. |
191 | unsigned (ArrayRef<T> elems); |
192 | |
193 | // Transpose the matrix without modifying it. |
194 | Matrix<T> transpose() const; |
195 | |
196 | // Copy the cells in the intersection of |
197 | // the rows between `fromRows` and `toRows` and |
198 | // the columns between `fromColumns` and `toColumns`, both inclusive. |
199 | Matrix<T> getSubMatrix(unsigned fromRow, unsigned toRow, unsigned fromColumn, |
200 | unsigned toColumn) const; |
201 | |
202 | /// Split the rows of a matrix into two matrices according to which bits are |
203 | /// 1 and which are 0 in a given bitset. |
204 | /// |
205 | /// The first matrix returned has the rows corresponding to 1 and the second |
206 | /// corresponding to 2. |
207 | std::pair<Matrix<T>, Matrix<T>> splitByBitset(ArrayRef<int> indicator); |
208 | |
209 | /// Print the matrix. |
210 | void print(raw_ostream &os) const; |
211 | void dump() const; |
212 | |
213 | /// Return whether the Matrix is in a consistent state with all its |
214 | /// invariants satisfied. |
215 | bool hasConsistentState() const; |
216 | |
217 | /// Move the columns in the source range [srcPos, srcPos + num) to the |
218 | /// specified destination [dstPos, dstPos + num), while moving the columns |
219 | /// adjacent to the source range to the left/right of the shifted columns. |
220 | /// |
221 | /// When moving the source columns right (i.e. dstPos > srcPos), columns that |
222 | /// were at positions [0, srcPos) and [dstPos + num, nCols) will stay where |
223 | /// they are; columns that were at positions [srcPos, srcPos + num) will be |
224 | /// moved to [dstPos, dstPos + num); and columns that were at positions |
225 | /// [srcPos + num, dstPos + num) will be moved to [srcPos, dstPos). |
226 | /// Equivalently, the columns [srcPos + num, dstPos + num) are interchanged |
227 | /// with [srcPos, srcPos + num). |
228 | /// For example, if m = |0 1 2 3 4 5| then: |
229 | /// m.moveColumns(1, 3, 2) will result in m = |0 4 1 2 3 5|; or |
230 | /// m.moveColumns(1, 2, 4) will result in m = |0 3 4 5 1 2|. |
231 | /// |
232 | /// The left shift operation (i.e. dstPos < srcPos) works in a similar way. |
233 | void moveColumns(unsigned srcPos, unsigned num, unsigned dstPos); |
234 | |
235 | protected: |
236 | /// The current number of rows, columns, and reserved columns. The underlying |
237 | /// data vector is viewed as an nRows x nReservedColumns matrix, of which the |
238 | /// first nColumns columns are currently in use, and the remaining are |
239 | /// reserved columns filled with zeros. |
240 | unsigned nRows, nColumns, nReservedColumns; |
241 | |
242 | /// Stores the data. data.size() is equal to nRows * nReservedColumns. |
243 | /// data.capacity() / nReservedColumns is the number of reserved rows. |
244 | SmallVector<T, 16> data; |
245 | }; |
246 | |
247 | extern template class Matrix<MPInt>; |
248 | extern template class Matrix<Fraction>; |
249 | |
250 | // An inherited class for integer matrices, with no new data attributes. |
251 | // This is only used for the matrix-related methods which apply only |
252 | // to integers (hermite normal form computation and row normalisation). |
253 | class IntMatrix : public Matrix<MPInt> { |
254 | public: |
255 | IntMatrix(unsigned rows, unsigned columns, unsigned reservedRows = 0, |
256 | unsigned reservedColumns = 0) |
257 | : Matrix<MPInt>(rows, columns, reservedRows, reservedColumns){}; |
258 | |
259 | IntMatrix(Matrix<MPInt> m) : Matrix<MPInt>(std::move(m)){}; |
260 | |
261 | /// Return the identity matrix of the specified dimension. |
262 | static IntMatrix identity(unsigned dimension); |
263 | |
264 | /// Given the current matrix M, returns the matrices H, U such that H is the |
265 | /// column hermite normal form of M, i.e. H = M * U, where U is unimodular and |
266 | /// the matrix H has the following restrictions: |
267 | /// - H is lower triangular. |
268 | /// - The leading coefficient (the first non-zero entry from the top, called |
269 | /// the pivot) of a non-zero column is always strictly below of the leading |
270 | /// coefficient of the column before it; moreover, it is positive. |
271 | /// - The elements to the right of the pivots are zero and the elements to |
272 | /// the left of the pivots are nonnegative and strictly smaller than the |
273 | /// pivot. |
274 | std::pair<IntMatrix, IntMatrix> computeHermiteNormalForm() const; |
275 | |
276 | /// Divide the first `nCols` of the specified row by their GCD. |
277 | /// Returns the GCD of the first `nCols` of the specified row. |
278 | MPInt normalizeRow(unsigned row, unsigned nCols); |
279 | /// Divide the columns of the specified row by their GCD. |
280 | /// Returns the GCD of the columns of the specified row. |
281 | MPInt normalizeRow(unsigned row); |
282 | |
283 | // Compute the determinant of the matrix (cubic time). |
284 | // Stores the integer inverse of the matrix in the pointer |
285 | // passed (if any). The pointer is unchanged if the inverse |
286 | // does not exist, which happens iff det = 0. |
287 | // For a matrix M, the integer inverse is the matrix M' such that |
288 | // M x M' = M' M = det(M) x I. |
289 | // Assert-fails if the matrix is not square. |
290 | MPInt determinant(IntMatrix *inverse = nullptr) const; |
291 | }; |
292 | |
293 | // An inherited class for rational matrices, with no new data attributes. |
294 | // This class is for functionality that only applies to matrices of fractions. |
295 | class FracMatrix : public Matrix<Fraction> { |
296 | public: |
297 | FracMatrix(unsigned rows, unsigned columns, unsigned reservedRows = 0, |
298 | unsigned reservedColumns = 0) |
299 | : Matrix<Fraction>(rows, columns, reservedRows, reservedColumns){}; |
300 | |
301 | FracMatrix(Matrix<Fraction> m) : Matrix<Fraction>(std::move(m)){}; |
302 | |
303 | explicit FracMatrix(IntMatrix m); |
304 | |
305 | /// Return the identity matrix of the specified dimension. |
306 | static FracMatrix identity(unsigned dimension); |
307 | |
308 | // Compute the determinant of the matrix (cubic time). |
309 | // Stores the inverse of the matrix in the pointer |
310 | // passed (if any). The pointer is unchanged if the inverse |
311 | // does not exist, which happens iff det = 0. |
312 | // Assert-fails if the matrix is not square. |
313 | Fraction determinant(FracMatrix *inverse = nullptr) const; |
314 | |
315 | // Computes the Gram-Schmidt orthogonalisation |
316 | // of the rows of matrix (cubic time). |
317 | // The rows of the matrix must be linearly independent. |
318 | FracMatrix gramSchmidt() const; |
319 | |
320 | // Run LLL basis reduction on the matrix, modifying it in-place. |
321 | // The parameter is what [the original |
322 | // paper](https://www.cs.cmu.edu/~avrim/451f11/lectures/lect1129_LLL.pdf) |
323 | // calls `y`, usually 3/4. |
324 | void LLL(Fraction delta); |
325 | |
326 | // Multiply each row of the matrix by the LCM of the denominators, thereby |
327 | // converting it to an integer matrix. |
328 | IntMatrix normalizeRows() const; |
329 | }; |
330 | |
331 | } // namespace presburger |
332 | } // namespace mlir |
333 | |
334 | #endif // MLIR_ANALYSIS_PRESBURGER_MATRIX_H |
335 | |