1 | //===- Polynomial.cpp - MLIR storage type for static Polynomial -*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Polynomial/IR/Polynomial.h" |
10 | |
11 | #include "mlir/Support/LogicalResult.h" |
12 | #include "llvm/ADT/APInt.h" |
13 | #include "llvm/ADT/SmallString.h" |
14 | #include "llvm/ADT/SmallVector.h" |
15 | #include "llvm/ADT/Twine.h" |
16 | #include "llvm/Support/raw_ostream.h" |
17 | |
18 | namespace mlir { |
19 | namespace polynomial { |
20 | |
21 | FailureOr<Polynomial> Polynomial::fromMonomials(ArrayRef<Monomial> monomials) { |
22 | // A polynomial's terms are canonically stored in order of increasing degree. |
23 | auto monomialsCopy = llvm::SmallVector<Monomial>(monomials); |
24 | std::sort(first: monomialsCopy.begin(), last: monomialsCopy.end()); |
25 | |
26 | // Ensure non-unique exponents are not present. Since we sorted the list by |
27 | // exponent, a linear scan of adjancent monomials suffices. |
28 | if (std::adjacent_find(first: monomialsCopy.begin(), last: monomialsCopy.end(), |
29 | binary_pred: [](const Monomial &lhs, const Monomial &rhs) { |
30 | return lhs.exponent == rhs.exponent; |
31 | }) != monomialsCopy.end()) { |
32 | return failure(); |
33 | } |
34 | |
35 | return Polynomial(monomialsCopy); |
36 | } |
37 | |
38 | Polynomial Polynomial::fromCoefficients(ArrayRef<int64_t> coeffs) { |
39 | llvm::SmallVector<Monomial> monomials; |
40 | auto size = coeffs.size(); |
41 | monomials.reserve(N: size); |
42 | for (size_t i = 0; i < size; i++) { |
43 | monomials.emplace_back(Args: coeffs[i], Args&: i); |
44 | } |
45 | auto result = Polynomial::fromMonomials(monomials); |
46 | // Construction guarantees unique exponents, so the failure mode of |
47 | // fromMonomials can be bypassed. |
48 | assert(succeeded(result)); |
49 | return result.value(); |
50 | } |
51 | |
52 | void Polynomial::print(raw_ostream &os, ::llvm::StringRef separator, |
53 | ::llvm::StringRef exponentiation) const { |
54 | bool first = true; |
55 | for (const Monomial &term : terms) { |
56 | if (first) { |
57 | first = false; |
58 | } else { |
59 | os << separator; |
60 | } |
61 | std::string coeffToPrint; |
62 | if (term.coefficient == 1 && term.exponent.uge(RHS: 1)) { |
63 | coeffToPrint = "" ; |
64 | } else { |
65 | llvm::SmallString<16> coeffString; |
66 | term.coefficient.toStringSigned(Str&: coeffString); |
67 | coeffToPrint = coeffString.str(); |
68 | } |
69 | |
70 | if (term.exponent == 0) { |
71 | os << coeffToPrint; |
72 | } else if (term.exponent == 1) { |
73 | os << coeffToPrint << "x" ; |
74 | } else { |
75 | llvm::SmallString<16> expString; |
76 | term.exponent.toStringSigned(Str&: expString); |
77 | os << coeffToPrint << "x" << exponentiation << expString; |
78 | } |
79 | } |
80 | } |
81 | |
82 | void Polynomial::print(raw_ostream &os) const { print(os, separator: " + " , exponentiation: "**" ); } |
83 | |
84 | std::string Polynomial::toIdentifier() const { |
85 | std::string result; |
86 | llvm::raw_string_ostream os(result); |
87 | print(os, separator: "_" , exponentiation: "" ); |
88 | return os.str(); |
89 | } |
90 | |
91 | unsigned Polynomial::getDegree() const { |
92 | return terms.back().exponent.getZExtValue(); |
93 | } |
94 | |
95 | } // namespace polynomial |
96 | } // namespace mlir |
97 | |