1//===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
9
10#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
11#include "mlir/Support/LLVM.h"
12#include "mlir/Support/LogicalResult.h"
13#include "llvm/ADT/StringExtras.h"
14#include "llvm/ADT/StringRef.h"
15#include "llvm/ADT/StringSet.h"
16
17namespace mlir {
18namespace polynomial {
19
20void PolynomialAttr::print(AsmPrinter &p) const {
21 p << '<';
22 p << getPolynomial();
23 p << '>';
24}
25
26/// Try to parse a monomial. If successful, populate the fields of the outparam
27/// `monomial` with the results, and the `variable` outparam with the parsed
28/// variable name. Sets shouldParseMore to true if the monomial is followed by
29/// a '+'.
30ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
31 llvm::StringRef &variable, bool &isConstantTerm,
32 bool &shouldParseMore) {
33 APInt parsedCoeff(apintBitWidth, 1);
34 auto parsedCoeffResult = parser.parseOptionalInteger(result&: parsedCoeff);
35 monomial.coefficient = parsedCoeff;
36
37 isConstantTerm = false;
38 shouldParseMore = false;
39
40 // A + indicates it's a constant term with more to go, as in `1 + x`.
41 if (succeeded(result: parser.parseOptionalPlus())) {
42 // If no coefficient was parsed, and there's a +, then it's effectively
43 // parsing an empty string.
44 if (!parsedCoeffResult.has_value()) {
45 return failure();
46 }
47 monomial.exponent = APInt(apintBitWidth, 0);
48 isConstantTerm = true;
49 shouldParseMore = true;
50 return success();
51 }
52
53 // A monomial can be a trailing constant term, as in `x + 1`.
54 if (failed(result: parser.parseOptionalKeyword(keyword: &variable))) {
55 // If neither a coefficient nor a variable was found, then it's effectively
56 // parsing an empty string.
57 if (!parsedCoeffResult.has_value()) {
58 return failure();
59 }
60
61 monomial.exponent = APInt(apintBitWidth, 0);
62 isConstantTerm = true;
63 return success();
64 }
65
66 // Parse exponentiation symbol as `**`. We can't use caret because it's
67 // reserved for basic block identifiers If no star is present, it's treated
68 // as a polynomial with exponent 1.
69 if (succeeded(result: parser.parseOptionalStar())) {
70 // If there's one * there must be two.
71 if (failed(result: parser.parseStar())) {
72 return failure();
73 }
74
75 // If there's a **, then the integer exponent is required.
76 APInt parsedExponent(apintBitWidth, 0);
77 if (failed(result: parser.parseInteger(result&: parsedExponent))) {
78 parser.emitError(loc: parser.getCurrentLocation(),
79 message: "found invalid integer exponent");
80 return failure();
81 }
82
83 monomial.exponent = parsedExponent;
84 } else {
85 monomial.exponent = APInt(apintBitWidth, 1);
86 }
87
88 if (succeeded(result: parser.parseOptionalPlus())) {
89 shouldParseMore = true;
90 }
91 return success();
92}
93
94Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
95 if (failed(result: parser.parseLess()))
96 return {};
97
98 llvm::SmallVector<Monomial> monomials;
99 llvm::StringSet<> variables;
100
101 while (true) {
102 Monomial parsedMonomial;
103 llvm::StringRef parsedVariableRef;
104 bool isConstantTerm;
105 bool shouldParseMore;
106 if (failed(result: parseMonomial(parser, monomial&: parsedMonomial, variable&: parsedVariableRef,
107 isConstantTerm, shouldParseMore))) {
108 parser.emitError(loc: parser.getCurrentLocation(), message: "expected a monomial");
109 return {};
110 }
111
112 if (!isConstantTerm) {
113 std::string parsedVariable = parsedVariableRef.str();
114 variables.insert(key: parsedVariable);
115 }
116 monomials.push_back(Elt: parsedMonomial);
117
118 if (shouldParseMore)
119 continue;
120
121 if (succeeded(result: parser.parseOptionalGreater())) {
122 break;
123 }
124 parser.emitError(
125 loc: parser.getCurrentLocation(),
126 message: "expected + and more monomials, or > to end polynomial attribute");
127 return {};
128 }
129
130 if (variables.size() > 1) {
131 std::string vars = llvm::join(R: variables.keys(), Separator: ", ");
132 parser.emitError(
133 loc: parser.getCurrentLocation(),
134 message: "polynomials must have one indeterminate, but there were multiple: " +
135 vars);
136 }
137
138 auto result = Polynomial::fromMonomials(monomials);
139 if (failed(result)) {
140 parser.emitError(loc: parser.getCurrentLocation())
141 << "parsed polynomial must have unique exponents among monomials";
142 return {};
143 }
144 return PolynomialAttr::get(parser.getContext(), result.value());
145}
146
147void RingAttr::print(AsmPrinter &p) const {
148 p << "#polynomial.ring<coefficientType=" << getCoefficientType()
149 << ", coefficientModulus=" << getCoefficientModulus()
150 << ", polynomialModulus=" << getPolynomialModulus() << '>';
151}
152
153Attribute RingAttr::parse(AsmParser &parser, Type type) {
154 if (failed(parser.parseLess()))
155 return {};
156
157 if (failed(parser.parseKeyword("coefficientType")))
158 return {};
159
160 if (failed(parser.parseEqual()))
161 return {};
162
163 Type ty;
164 if (failed(parser.parseType(ty)))
165 return {};
166
167 if (failed(parser.parseComma()))
168 return {};
169
170 IntegerAttr coefficientModulusAttr = nullptr;
171 if (succeeded(parser.parseKeyword("coefficientModulus"))) {
172 if (failed(parser.parseEqual()))
173 return {};
174
175 IntegerType iType = ty.dyn_cast<IntegerType>();
176 if (!iType) {
177 parser.emitError(parser.getCurrentLocation(),
178 "coefficientType must specify an integer type");
179 return {};
180 }
181 APInt coefficientModulus(iType.getWidth(), 0);
182 auto result = parser.parseInteger(coefficientModulus);
183 if (failed(result)) {
184 parser.emitError(parser.getCurrentLocation(),
185 "invalid coefficient modulus");
186 return {};
187 }
188 coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus);
189
190 if (failed(parser.parseComma()))
191 return {};
192 }
193
194 PolynomialAttr polyAttr = nullptr;
195 if (succeeded(parser.parseKeyword("polynomialModulus"))) {
196 if (failed(parser.parseEqual()))
197 return {};
198
199 PolynomialAttr attr;
200 if (failed(parser.parseAttribute<PolynomialAttr>(attr)))
201 return {};
202 polyAttr = attr;
203 }
204
205 if (failed(parser.parseGreater()))
206 return {};
207
208 return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
209 polyAttr);
210}
211
212} // namespace polynomial
213} // namespace mlir
214

source code of mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp