1 | //===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" |
9 | |
10 | #include "mlir/Dialect/Polynomial/IR/Polynomial.h" |
11 | #include "mlir/Support/LLVM.h" |
12 | #include "mlir/Support/LogicalResult.h" |
13 | #include "llvm/ADT/StringExtras.h" |
14 | #include "llvm/ADT/StringRef.h" |
15 | #include "llvm/ADT/StringSet.h" |
16 | |
17 | namespace mlir { |
18 | namespace polynomial { |
19 | |
20 | void PolynomialAttr::print(AsmPrinter &p) const { |
21 | p << '<'; |
22 | p << getPolynomial(); |
23 | p << '>'; |
24 | } |
25 | |
26 | /// Try to parse a monomial. If successful, populate the fields of the outparam |
27 | /// `monomial` with the results, and the `variable` outparam with the parsed |
28 | /// variable name. Sets shouldParseMore to true if the monomial is followed by |
29 | /// a '+'. |
30 | ParseResult parseMonomial(AsmParser &parser, Monomial &monomial, |
31 | llvm::StringRef &variable, bool &isConstantTerm, |
32 | bool &shouldParseMore) { |
33 | APInt parsedCoeff(apintBitWidth, 1); |
34 | auto parsedCoeffResult = parser.parseOptionalInteger(result&: parsedCoeff); |
35 | monomial.coefficient = parsedCoeff; |
36 | |
37 | isConstantTerm = false; |
38 | shouldParseMore = false; |
39 | |
40 | // A + indicates it's a constant term with more to go, as in `1 + x`. |
41 | if (succeeded(result: parser.parseOptionalPlus())) { |
42 | // If no coefficient was parsed, and there's a +, then it's effectively |
43 | // parsing an empty string. |
44 | if (!parsedCoeffResult.has_value()) { |
45 | return failure(); |
46 | } |
47 | monomial.exponent = APInt(apintBitWidth, 0); |
48 | isConstantTerm = true; |
49 | shouldParseMore = true; |
50 | return success(); |
51 | } |
52 | |
53 | // A monomial can be a trailing constant term, as in `x + 1`. |
54 | if (failed(result: parser.parseOptionalKeyword(keyword: &variable))) { |
55 | // If neither a coefficient nor a variable was found, then it's effectively |
56 | // parsing an empty string. |
57 | if (!parsedCoeffResult.has_value()) { |
58 | return failure(); |
59 | } |
60 | |
61 | monomial.exponent = APInt(apintBitWidth, 0); |
62 | isConstantTerm = true; |
63 | return success(); |
64 | } |
65 | |
66 | // Parse exponentiation symbol as `**`. We can't use caret because it's |
67 | // reserved for basic block identifiers If no star is present, it's treated |
68 | // as a polynomial with exponent 1. |
69 | if (succeeded(result: parser.parseOptionalStar())) { |
70 | // If there's one * there must be two. |
71 | if (failed(result: parser.parseStar())) { |
72 | return failure(); |
73 | } |
74 | |
75 | // If there's a **, then the integer exponent is required. |
76 | APInt parsedExponent(apintBitWidth, 0); |
77 | if (failed(result: parser.parseInteger(result&: parsedExponent))) { |
78 | parser.emitError(loc: parser.getCurrentLocation(), |
79 | message: "found invalid integer exponent" ); |
80 | return failure(); |
81 | } |
82 | |
83 | monomial.exponent = parsedExponent; |
84 | } else { |
85 | monomial.exponent = APInt(apintBitWidth, 1); |
86 | } |
87 | |
88 | if (succeeded(result: parser.parseOptionalPlus())) { |
89 | shouldParseMore = true; |
90 | } |
91 | return success(); |
92 | } |
93 | |
94 | Attribute PolynomialAttr::parse(AsmParser &parser, Type type) { |
95 | if (failed(result: parser.parseLess())) |
96 | return {}; |
97 | |
98 | llvm::SmallVector<Monomial> monomials; |
99 | llvm::StringSet<> variables; |
100 | |
101 | while (true) { |
102 | Monomial parsedMonomial; |
103 | llvm::StringRef parsedVariableRef; |
104 | bool isConstantTerm; |
105 | bool shouldParseMore; |
106 | if (failed(result: parseMonomial(parser, monomial&: parsedMonomial, variable&: parsedVariableRef, |
107 | isConstantTerm, shouldParseMore))) { |
108 | parser.emitError(loc: parser.getCurrentLocation(), message: "expected a monomial" ); |
109 | return {}; |
110 | } |
111 | |
112 | if (!isConstantTerm) { |
113 | std::string parsedVariable = parsedVariableRef.str(); |
114 | variables.insert(key: parsedVariable); |
115 | } |
116 | monomials.push_back(Elt: parsedMonomial); |
117 | |
118 | if (shouldParseMore) |
119 | continue; |
120 | |
121 | if (succeeded(result: parser.parseOptionalGreater())) { |
122 | break; |
123 | } |
124 | parser.emitError( |
125 | loc: parser.getCurrentLocation(), |
126 | message: "expected + and more monomials, or > to end polynomial attribute" ); |
127 | return {}; |
128 | } |
129 | |
130 | if (variables.size() > 1) { |
131 | std::string vars = llvm::join(R: variables.keys(), Separator: ", " ); |
132 | parser.emitError( |
133 | loc: parser.getCurrentLocation(), |
134 | message: "polynomials must have one indeterminate, but there were multiple: " + |
135 | vars); |
136 | } |
137 | |
138 | auto result = Polynomial::fromMonomials(monomials); |
139 | if (failed(result)) { |
140 | parser.emitError(loc: parser.getCurrentLocation()) |
141 | << "parsed polynomial must have unique exponents among monomials" ; |
142 | return {}; |
143 | } |
144 | return PolynomialAttr::get(parser.getContext(), result.value()); |
145 | } |
146 | |
147 | void RingAttr::print(AsmPrinter &p) const { |
148 | p << "#polynomial.ring<coefficientType=" << getCoefficientType() |
149 | << ", coefficientModulus=" << getCoefficientModulus() |
150 | << ", polynomialModulus=" << getPolynomialModulus() << '>'; |
151 | } |
152 | |
153 | Attribute RingAttr::parse(AsmParser &parser, Type type) { |
154 | if (failed(parser.parseLess())) |
155 | return {}; |
156 | |
157 | if (failed(parser.parseKeyword("coefficientType" ))) |
158 | return {}; |
159 | |
160 | if (failed(parser.parseEqual())) |
161 | return {}; |
162 | |
163 | Type ty; |
164 | if (failed(parser.parseType(ty))) |
165 | return {}; |
166 | |
167 | if (failed(parser.parseComma())) |
168 | return {}; |
169 | |
170 | IntegerAttr coefficientModulusAttr = nullptr; |
171 | if (succeeded(parser.parseKeyword("coefficientModulus" ))) { |
172 | if (failed(parser.parseEqual())) |
173 | return {}; |
174 | |
175 | IntegerType iType = ty.dyn_cast<IntegerType>(); |
176 | if (!iType) { |
177 | parser.emitError(parser.getCurrentLocation(), |
178 | "coefficientType must specify an integer type" ); |
179 | return {}; |
180 | } |
181 | APInt coefficientModulus(iType.getWidth(), 0); |
182 | auto result = parser.parseInteger(coefficientModulus); |
183 | if (failed(result)) { |
184 | parser.emitError(parser.getCurrentLocation(), |
185 | "invalid coefficient modulus" ); |
186 | return {}; |
187 | } |
188 | coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus); |
189 | |
190 | if (failed(parser.parseComma())) |
191 | return {}; |
192 | } |
193 | |
194 | PolynomialAttr polyAttr = nullptr; |
195 | if (succeeded(parser.parseKeyword("polynomialModulus" ))) { |
196 | if (failed(parser.parseEqual())) |
197 | return {}; |
198 | |
199 | PolynomialAttr attr; |
200 | if (failed(parser.parseAttribute<PolynomialAttr>(attr))) |
201 | return {}; |
202 | polyAttr = attr; |
203 | } |
204 | |
205 | if (failed(parser.parseGreater())) |
206 | return {}; |
207 | |
208 | return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr, |
209 | polyAttr); |
210 | } |
211 | |
212 | } // namespace polynomial |
213 | } // namespace mlir |
214 | |