1//===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpDefinitionsGen uses the description of operations to generate C++
10// definitions for ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "CppGenUtilities.h"
15#include "OpClass.h"
16#include "OpFormatGen.h"
17#include "OpGenHelpers.h"
18#include "mlir/TableGen/Argument.h"
19#include "mlir/TableGen/Attribute.h"
20#include "mlir/TableGen/Class.h"
21#include "mlir/TableGen/CodeGenHelpers.h"
22#include "mlir/TableGen/Format.h"
23#include "mlir/TableGen/GenInfo.h"
24#include "mlir/TableGen/Interfaces.h"
25#include "mlir/TableGen/Operator.h"
26#include "mlir/TableGen/Property.h"
27#include "mlir/TableGen/SideEffects.h"
28#include "mlir/TableGen/Trait.h"
29#include "llvm/ADT/BitVector.h"
30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/Sequence.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/ADT/StringSet.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/ErrorHandling.h"
37#include "llvm/Support/Signals.h"
38#include "llvm/Support/raw_ostream.h"
39#include "llvm/TableGen/Error.h"
40#include "llvm/TableGen/Record.h"
41#include "llvm/TableGen/TableGenBackend.h"
42
43#define DEBUG_TYPE "mlir-tblgen-opdefgen"
44
45using namespace llvm;
46using namespace mlir;
47using namespace mlir::tblgen;
48
49static const char *const tblgenNamePrefix = "tblgen_";
50static const char *const generatedArgName = "odsArg";
51static const char *const odsBuilder = "odsBuilder";
52static const char *const builderOpState = "odsState";
53static const char *const propertyStorage = "propStorage";
54static const char *const propertyValue = "propValue";
55static const char *const propertyAttr = "propAttr";
56static const char *const propertyDiag = "emitError";
57
58/// The names of the implicit attributes that contain variadic operand and
59/// result segment sizes.
60static const char *const operandSegmentAttrName = "operandSegmentSizes";
61static const char *const resultSegmentAttrName = "resultSegmentSizes";
62
63/// Code for an Op to lookup an attribute. Uses cached identifiers and subrange
64/// lookup.
65///
66/// {0}: Code snippet to get the attribute's name or identifier.
67/// {1}: The lower bound on the sorted subrange.
68/// {2}: The upper bound on the sorted subrange.
69/// {3}: Code snippet to get the array of named attributes.
70/// {4}: "Named" to get the named attribute.
71static const char *const subrangeGetAttr =
72 "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - "
73 "{2}, {0})";
74
75/// The logic to calculate the actual value range for a declared operand/result
76/// of an op with variadic operands/results. Note that this logic is not for
77/// general use; it assumes all variadic operands/results must have the same
78/// number of values.
79///
80/// {0}: The list of whether each declared operand/result is variadic.
81/// {1}: The total number of non-variadic operands/results.
82/// {2}: The total number of variadic operands/results.
83/// {3}: The total number of actual values.
84/// {4}: "operand" or "result".
85static const char *const sameVariadicSizeValueRangeCalcCode = R"(
86 bool isVariadic[] = {{{0}};
87 int prevVariadicCount = 0;
88 for (unsigned i = 0; i < index; ++i)
89 if (isVariadic[i]) ++prevVariadicCount;
90
91 // Calculate how many dynamic values a static variadic {4} corresponds to.
92 // This assumes all static variadic {4}s have the same dynamic value count.
93 int variadicSize = ({3} - {1}) / {2};
94 // `index` passed in as the parameter is the static index which counts each
95 // {4} (variadic or not) as size 1. So here for each previous static variadic
96 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
97 // value pack for this static {4} starts.
98 int start = index + (variadicSize - 1) * prevVariadicCount;
99 int size = isVariadic[index] ? variadicSize : 1;
100 return {{start, size};
101)";
102
103/// The logic to calculate the actual value range for a declared operand/result
104/// of an op with variadic operands/results. Note that this logic is assumes
105/// the op has an attribute specifying the size of each operand/result segment
106/// (variadic or not).
107static const char *const attrSizedSegmentValueRangeCalcCode = R"(
108 unsigned start = 0;
109 for (unsigned i = 0; i < index; ++i)
110 start += sizeAttr[i];
111 return {start, sizeAttr[index]};
112)";
113/// The code snippet to initialize the sizes for the value range calculation.
114///
115/// {0}: The code to get the attribute.
116static const char *const adapterSegmentSizeAttrInitCode = R"(
117 assert({0} && "missing segment size attribute for op");
118 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
119)";
120static const char *const adapterSegmentSizeAttrInitCodeProperties = R"(
121 ::llvm::ArrayRef<int32_t> sizeAttr = {0};
122)";
123
124/// The code snippet to initialize the sizes for the value range calculation.
125///
126/// {0}: The code to get the attribute.
127static const char *const opSegmentSizeAttrInitCode = R"(
128 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
129)";
130
131/// The logic to calculate the actual value range for a declared operand
132/// of an op with variadic of variadic operands within the OpAdaptor.
133///
134/// {0}: The name of the segment attribute.
135/// {1}: The index of the main operand.
136/// {2}: The range type of adaptor.
137static const char *const variadicOfVariadicAdaptorCalcCode = R"(
138 auto tblgenTmpOperands = getODSOperands({1});
139 auto sizes = {0}();
140
141 ::llvm::SmallVector<{2}> tblgenTmpOperandGroups;
142 for (int i = 0, e = sizes.size(); i < e; ++i) {{
143 tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(sizes[i]));
144 tblgenTmpOperands = tblgenTmpOperands.drop_front(sizes[i]);
145 }
146 return tblgenTmpOperandGroups;
147)";
148
149/// The logic to build a range of either operand or result values.
150///
151/// {0}: The begin iterator of the actual values.
152/// {1}: The call to generate the start and length of the value range.
153static const char *const valueRangeReturnCode = R"(
154 auto valueRange = {1};
155 return {{std::next({0}, valueRange.first),
156 std::next({0}, valueRange.first + valueRange.second)};
157)";
158
159/// Parse operand/result segment_size property.
160/// {0}: Number of elements in the segment array
161static const char *const parseTextualSegmentSizeFormat = R"(
162 size_t i = 0;
163 auto parseElem = [&]() -> ::mlir::ParseResult {
164 if (i >= {0})
165 return $_parser.emitError($_parser.getCurrentLocation(),
166 "expected `]` after {0} segment sizes");
167 if (failed($_parser.parseInteger($_storage[i])))
168 return ::mlir::failure();
169 i += 1;
170 return ::mlir::success();
171 };
172 if (failed($_parser.parseCommaSeparatedList(
173 ::mlir::AsmParser::Delimeter::Square, parseElem)))
174 return failure();
175 if (i < {0})
176 return $_parser.emitError($_parser.getCurrentLocation(),
177 "expected {0} segment sizes, found only ") << i;
178 return success();
179)";
180
181static const char *const printTextualSegmentSize = R"(
182 [&]() {
183 $_printer << '[';
184 ::llvm::interleaveComma($_storage, $_printer);
185 $_printer << ']';
186 }()
187)";
188
189/// Read operand/result segment_size from bytecode.
190static const char *const readBytecodeSegmentSizeNative = R"(
191 if ($_reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
192 return $_reader.readSparseArray(::llvm::MutableArrayRef($_storage));
193)";
194
195static const char *const readBytecodeSegmentSizeLegacy = R"(
196 if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
197 auto &$_storage = prop.$_propName;
198 ::mlir::DenseI32ArrayAttr attr;
199 if (::mlir::failed($_reader.readAttribute(attr))) return ::mlir::failure();
200 if (attr.size() > static_cast<int64_t>(sizeof($_storage) / sizeof(int32_t))) {
201 $_reader.emitError("size mismatch for operand/result_segment_size");
202 return ::mlir::failure();
203 }
204 ::llvm::copy(::llvm::ArrayRef<int32_t>(attr), $_storage.begin());
205 }
206)";
207
208/// Write operand/result segment_size to bytecode.
209static const char *const writeBytecodeSegmentSizeNative = R"(
210 if ($_writer.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
211 $_writer.writeSparseArray(::llvm::ArrayRef($_storage));
212)";
213
214/// Write operand/result segment_size to bytecode.
215static const char *const writeBytecodeSegmentSizeLegacy = R"(
216if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
217 auto &$_storage = prop.$_propName;
218 $_writer.writeAttribute(::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage));
219}
220)";
221
222/// A header for indicating code sections.
223///
224/// {0}: Some text, or a class name.
225/// {1}: Some text.
226static const char *const opCommentHeader = R"(
227//===----------------------------------------------------------------------===//
228// {0} {1}
229//===----------------------------------------------------------------------===//
230
231)";
232
233//===----------------------------------------------------------------------===//
234// Utility structs and functions
235//===----------------------------------------------------------------------===//
236
237// Replaces all occurrences of `match` in `str` with `substitute`.
238static std::string replaceAllSubstrs(std::string str, const std::string &match,
239 const std::string &substitute) {
240 std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
241 while ((matchLoc = str.find(str: match, pos: scanLoc)) != std::string::npos) {
242 str = str.replace(pos: matchLoc, n: match.size(), str: substitute);
243 scanLoc = matchLoc + substitute.size();
244 }
245 return str;
246}
247
248// Returns whether the record has a value of the given name that can be returned
249// via getValueAsString.
250static inline bool hasStringAttribute(const Record &record,
251 StringRef fieldName) {
252 auto *valueInit = record.getValueInit(FieldName: fieldName);
253 return isa<StringInit>(Val: valueInit);
254}
255
256static std::string getArgumentName(const Operator &op, int index) {
257 const auto &operand = op.getOperand(index);
258 if (!operand.name.empty())
259 return std::string(operand.name);
260 return std::string(formatv(Fmt: "{0}_{1}", Vals: generatedArgName, Vals&: index));
261}
262
263// Returns true if we can use unwrapped value for the given `attr` in builders.
264static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
265 return attr.getReturnType() != attr.getStorageType() &&
266 // We need to wrap the raw value into an attribute in the builder impl
267 // so we need to make sure that the attribute specifies how to do that.
268 !attr.getConstBuilderTemplate().empty();
269}
270
271/// Build an attribute from a parameter value using the constant builder.
272static std::string constBuildAttrFromParam(const tblgen::Attribute &attr,
273 FmtContext &fctx,
274 StringRef paramName) {
275 std::string builderTemplate = attr.getConstBuilderTemplate().str();
276
277 // For StringAttr, its constant builder call will wrap the input in
278 // quotes, which is correct for normal string literals, but incorrect
279 // here given we use function arguments. So we need to strip the
280 // wrapping quotes.
281 if (StringRef(builderTemplate).contains(Other: "\"$0\""))
282 builderTemplate = replaceAllSubstrs(str: builderTemplate, match: "\"$0\"", substitute: "$0");
283
284 return tgfmt(fmt: builderTemplate, ctx: &fctx, vals&: paramName).str();
285}
286
287namespace {
288/// Metadata on a registered attribute. Given that attributes are stored in
289/// sorted order on operations, we can use information from ODS to deduce the
290/// number of required attributes less and and greater than each attribute,
291/// allowing us to search only a subrange of the attributes in ODS-generated
292/// getters.
293struct AttributeMetadata {
294 /// The attribute name.
295 StringRef attrName;
296 /// Whether the attribute is required.
297 bool isRequired;
298 /// The ODS attribute constraint. Not present for implicit attributes.
299 std::optional<Attribute> constraint;
300 /// The number of required attributes less than this attribute.
301 unsigned lowerBound = 0;
302 /// The number of required attributes greater than this attribute.
303 unsigned upperBound = 0;
304};
305
306/// Helper class to select between OpAdaptor and Op code templates.
307class OpOrAdaptorHelper {
308public:
309 OpOrAdaptorHelper(const Operator &op, bool emitForOp)
310 : op(op), emitForOp(emitForOp) {
311 computeAttrMetadata();
312 }
313
314 /// Object that wraps a functor in a stream operator for interop with
315 /// llvm::formatv.
316 class Formatter {
317 public:
318 template <typename Functor>
319 Formatter(Functor &&func) : func(std::forward<Functor>(func)) {}
320
321 std::string str() const {
322 std::string result;
323 llvm::raw_string_ostream os(result);
324 os << *this;
325 return os.str();
326 }
327
328 private:
329 std::function<raw_ostream &(raw_ostream &)> func;
330
331 friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) {
332 return fmt.func(os);
333 }
334 };
335
336 // Generate code for getting an attribute.
337 Formatter getAttr(StringRef attrName, bool isNamed = false) const {
338 assert(attrMetadata.count(attrName) && "expected attribute metadata");
339 return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
340 const AttributeMetadata &attr = attrMetadata.find(Key: attrName)->second;
341 if (hasProperties()) {
342 assert(!isNamed);
343 return os << "getProperties()." << attrName;
344 }
345 return os << formatv(Fmt: subrangeGetAttr, Vals: getAttrName(attrName),
346 Vals: attr.lowerBound, Vals: attr.upperBound, Vals: getAttrRange(),
347 Vals: isNamed ? "Named" : "");
348 };
349 }
350
351 // Generate code for getting the name of an attribute.
352 Formatter getAttrName(StringRef attrName) const {
353 return [this, attrName](raw_ostream &os) -> raw_ostream & {
354 if (emitForOp)
355 return os << op.getGetterName(name: attrName) << "AttrName()";
356 return os << formatv(Fmt: "{0}::{1}AttrName(*odsOpName)", Vals: op.getCppClassName(),
357 Vals: op.getGetterName(name: attrName));
358 };
359 }
360
361 // Get the code snippet for getting the named attribute range.
362 StringRef getAttrRange() const {
363 return emitForOp ? "(*this)->getAttrs()" : "odsAttrs";
364 }
365
366 // Get the prefix code for emitting an error.
367 Formatter emitErrorPrefix() const {
368 return [this](raw_ostream &os) -> raw_ostream & {
369 if (emitForOp)
370 return os << "emitOpError(";
371 return os << formatv(Fmt: "emitError(loc, \"'{0}' op \"",
372 Vals: op.getOperationName());
373 };
374 }
375
376 // Get the call to get an operand or segment of operands.
377 Formatter getOperand(unsigned index) const {
378 return [this, index](raw_ostream &os) -> raw_ostream & {
379 return os << formatv(Fmt: op.getOperand(index).isVariadic()
380 ? "this->getODSOperands({0})"
381 : "(*this->getODSOperands({0}).begin())",
382 Vals: index);
383 };
384 }
385
386 // Get the call to get a result of segment of results.
387 Formatter getResult(unsigned index) const {
388 return [this, index](raw_ostream &os) -> raw_ostream & {
389 if (!emitForOp)
390 return os << "<no results should be generated>";
391 return os << formatv(Fmt: op.getResult(index).isVariadic()
392 ? "this->getODSResults({0})"
393 : "(*this->getODSResults({0}).begin())",
394 Vals: index);
395 };
396 }
397
398 // Return whether an op instance is available.
399 bool isEmittingForOp() const { return emitForOp; }
400
401 // Return the ODS operation wrapper.
402 const Operator &getOp() const { return op; }
403
404 // Get the attribute metadata sorted by name.
405 const llvm::MapVector<StringRef, AttributeMetadata> &getAttrMetadata() const {
406 return attrMetadata;
407 }
408
409 /// Returns whether to emit a `Properties` struct for this operation or not.
410 bool hasProperties() const {
411 if (!op.getProperties().empty())
412 return true;
413 if (!op.getDialect().usePropertiesForAttributes())
414 return false;
415 return true;
416 }
417
418 /// Returns whether the operation will have a non-empty `Properties` struct.
419 bool hasNonEmptyPropertiesStruct() const {
420 if (!op.getProperties().empty())
421 return true;
422 if (!hasProperties())
423 return false;
424 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments") ||
425 op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments"))
426 return true;
427 return llvm::any_of(Range: getAttrMetadata(),
428 P: [](const std::pair<StringRef, AttributeMetadata> &it) {
429 return !it.second.constraint ||
430 !it.second.constraint->isDerivedAttr();
431 });
432 }
433
434 std::optional<NamedProperty> &getOperandSegmentsSize() {
435 return operandSegmentsSize;
436 }
437
438 std::optional<NamedProperty> &getResultSegmentsSize() {
439 return resultSegmentsSize;
440 }
441
442 uint32_t getOperandSegmentSizesLegacyIndex() {
443 return operandSegmentSizesLegacyIndex;
444 }
445
446 uint32_t getResultSegmentSizesLegacyIndex() {
447 return resultSegmentSizesLegacyIndex;
448 }
449
450private:
451 // Compute the attribute metadata.
452 void computeAttrMetadata();
453
454 // The operation ODS wrapper.
455 const Operator &op;
456 // True if code is being generate for an op. False for an adaptor.
457 const bool emitForOp;
458
459 // The attribute metadata, mapped by name.
460 llvm::MapVector<StringRef, AttributeMetadata> attrMetadata;
461
462 // Property
463 std::optional<NamedProperty> operandSegmentsSize;
464 std::string operandSegmentsSizeStorage;
465 std::string operandSegmentsSizeParser;
466 std::optional<NamedProperty> resultSegmentsSize;
467 std::string resultSegmentsSizeStorage;
468 std::string resultSegmentsSizeParser;
469
470 // Indices to store the position in the emission order of the operand/result
471 // segment sizes attribute if emitted as part of the properties for legacy
472 // bytecode encodings, i.e. versions less than 6.
473 uint32_t operandSegmentSizesLegacyIndex = 0;
474 uint32_t resultSegmentSizesLegacyIndex = 0;
475
476 // The number of required attributes.
477 unsigned numRequired;
478};
479
480} // namespace
481
482void OpOrAdaptorHelper::computeAttrMetadata() {
483 // Enumerate the attribute names of this op, ensuring the attribute names are
484 // unique in case implicit attributes are explicitly registered.
485 for (const NamedAttribute &namedAttr : op.getAttributes()) {
486 Attribute attr = namedAttr.attr;
487 bool isOptional =
488 attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr();
489 attrMetadata.insert(
490 KV: {namedAttr.name, AttributeMetadata{.attrName: namedAttr.name, .isRequired: !isOptional, .constraint: attr}});
491 }
492
493 auto makeProperty = [&](StringRef storageType, StringRef parserCall) {
494 return Property(/*maybeDef=*/nullptr,
495 /*summary=*/"",
496 /*description=*/"",
497 /*storageType=*/storageType,
498 /*interfaceType=*/"::llvm::ArrayRef<int32_t>",
499 /*convertFromStorageCall=*/"$_storage",
500 /*assignToStorageCall=*/
501 "::llvm::copy($_value, $_storage.begin())",
502 /*convertToAttributeCall=*/
503 "return ::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage);",
504 /*convertFromAttributeCall=*/
505 "return convertFromAttribute($_storage, $_attr, $_diag);",
506 /*parserCall=*/parserCall,
507 /*optionalParserCall=*/"",
508 /*printerCall=*/printTextualSegmentSize,
509 /*readFromMlirBytecodeCall=*/readBytecodeSegmentSizeNative,
510 /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSizeNative,
511 /*hashPropertyCall=*/
512 "::llvm::hash_combine_range(std::begin($_storage), "
513 "std::end($_storage));",
514 /*StringRef defaultValue=*/"",
515 /*storageTypeValueOverride=*/"");
516 };
517 // Include key attributes from several traits as implicitly registered.
518 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
519 if (op.getDialect().usePropertiesForAttributes()) {
520 operandSegmentsSizeStorage =
521 llvm::formatv(Fmt: "std::array<int32_t, {0}>", Vals: op.getNumOperands());
522 operandSegmentsSizeParser =
523 llvm::formatv(Fmt: parseTextualSegmentSizeFormat, Vals: op.getNumOperands());
524 operandSegmentsSize = {
525 .name: "operandSegmentSizes",
526 .prop: makeProperty(operandSegmentsSizeStorage, operandSegmentsSizeParser)};
527 } else {
528 attrMetadata.insert(
529 KV: {operandSegmentAttrName, AttributeMetadata{.attrName: operandSegmentAttrName,
530 /*isRequired=*/true,
531 /*attr=*/.constraint: std::nullopt}});
532 }
533 }
534 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments")) {
535 if (op.getDialect().usePropertiesForAttributes()) {
536 resultSegmentsSizeStorage =
537 llvm::formatv(Fmt: "std::array<int32_t, {0}>", Vals: op.getNumResults());
538 resultSegmentsSizeParser =
539 llvm::formatv(Fmt: parseTextualSegmentSizeFormat, Vals: op.getNumResults());
540 resultSegmentsSize = {
541 .name: "resultSegmentSizes",
542 .prop: makeProperty(resultSegmentsSizeStorage, resultSegmentsSizeParser)};
543 } else {
544 attrMetadata.insert(
545 KV: {resultSegmentAttrName,
546 AttributeMetadata{.attrName: resultSegmentAttrName, /*isRequired=*/true,
547 /*attr=*/.constraint: std::nullopt}});
548 }
549 }
550
551 // Store the metadata in sorted order.
552 SmallVector<AttributeMetadata> sortedAttrMetadata =
553 llvm::to_vector(Range: llvm::make_second_range(c: attrMetadata.takeVector()));
554 llvm::sort(C&: sortedAttrMetadata,
555 Comp: [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) {
556 return lhs.attrName < rhs.attrName;
557 });
558
559 // Store the position of the legacy operand_segment_sizes /
560 // result_segment_sizes so we can emit a backward compatible property readers
561 // and writers.
562 StringRef legacyOperandSegmentSizeName =
563 StringLiteral("operand_segment_sizes");
564 StringRef legacyResultSegmentSizeName = StringLiteral("result_segment_sizes");
565 operandSegmentSizesLegacyIndex = 0;
566 resultSegmentSizesLegacyIndex = 0;
567 for (auto item : sortedAttrMetadata) {
568 if (item.attrName < legacyOperandSegmentSizeName)
569 ++operandSegmentSizesLegacyIndex;
570 if (item.attrName < legacyResultSegmentSizeName)
571 ++resultSegmentSizesLegacyIndex;
572 }
573
574 // Compute the subrange bounds for each attribute.
575 numRequired = 0;
576 for (AttributeMetadata &attr : sortedAttrMetadata) {
577 attr.lowerBound = numRequired;
578 numRequired += attr.isRequired;
579 };
580 for (AttributeMetadata &attr : sortedAttrMetadata)
581 attr.upperBound = numRequired - attr.lowerBound - attr.isRequired;
582
583 // Store the results back into the map.
584 for (const AttributeMetadata &attr : sortedAttrMetadata)
585 attrMetadata.insert(KV: {attr.attrName, attr});
586}
587
588//===----------------------------------------------------------------------===//
589// Op emitter
590//===----------------------------------------------------------------------===//
591
592namespace {
593// Helper class to emit a record into the given output stream.
594class OpEmitter {
595 using ConstArgument =
596 llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
597
598public:
599 static void
600 emitDecl(const Operator &op, raw_ostream &os,
601 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
602 static void
603 emitDef(const Operator &op, raw_ostream &os,
604 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
605
606private:
607 OpEmitter(const Operator &op,
608 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
609
610 void emitDecl(raw_ostream &os);
611 void emitDef(raw_ostream &os);
612
613 // Generate methods for accessing the attribute names of this operation.
614 void genAttrNameGetters();
615
616 // Generates the OpAsmOpInterface for this operation if possible.
617 void genOpAsmInterface();
618
619 // Generates the `getOperationName` method for this op.
620 void genOpNameGetter();
621
622 // Generates code to manage the properties, if any!
623 void genPropertiesSupport();
624
625 // Generates code to manage the encoding of properties to bytecode.
626 void
627 genPropertiesSupportForBytecode(ArrayRef<ConstArgument> attrOrProperties);
628
629 // Generates getters for the properties.
630 void genPropGetters();
631
632 // Generates seters for the properties.
633 void genPropSetters();
634
635 // Generates getters for the attributes.
636 void genAttrGetters();
637
638 // Generates setter for the attributes.
639 void genAttrSetters();
640
641 // Generates removers for optional attributes.
642 void genOptionalAttrRemovers();
643
644 // Generates getters for named operands.
645 void genNamedOperandGetters();
646
647 // Generates setters for named operands.
648 void genNamedOperandSetters();
649
650 // Generates getters for named results.
651 void genNamedResultGetters();
652
653 // Generates getters for named regions.
654 void genNamedRegionGetters();
655
656 // Generates getters for named successors.
657 void genNamedSuccessorGetters();
658
659 // Generates the method to populate default attributes.
660 void genPopulateDefaultAttributes();
661
662 // Generates builder methods for the operation.
663 void genBuilder();
664
665 // Generates the build() method that takes each operand/attribute
666 // as a stand-alone parameter.
667 void genSeparateArgParamBuilder();
668
669 // Generates the build() method that takes each operand/attribute as a
670 // stand-alone parameter. The generated build() method uses first operand's
671 // type as all results' types.
672 void genUseOperandAsResultTypeSeparateParamBuilder();
673
674 // The kind of collective builder to generate
675 enum class CollectiveBuilderKind {
676 PropStruct, // Inherent attributes/properties are passed by `const
677 // Properties&`
678 AttrDict, // Inherent attributes/properties are passed by attribute
679 // dictionary
680 };
681
682 // Generates the build() method that takes all operands/attributes
683 // collectively as one parameter. The generated build() method uses first
684 // operand's type as all results' types.
685 void
686 genUseOperandAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind kind);
687
688 // Generates the build() method that takes aggregate operands/attributes
689 // parameters. This build() method uses inferred types as result types.
690 // Requires: The type needs to be inferable via InferTypeOpInterface.
691 void genInferredTypeCollectiveParamBuilder(CollectiveBuilderKind kind);
692
693 // Generates the build() method that takesaggregate operands/attributes as
694 // parameters. The generated build() method uses first attribute's
695 // type as all result's types.
696 void genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind kind);
697
698 // Generates the build() method that takes all result types collectively as
699 // one parameter. Similarly for operands and attributes.
700 void genCollectiveParamBuilder(CollectiveBuilderKind kind);
701
702 // The kind of parameter to generate for result types in builders.
703 enum class TypeParamKind {
704 None, // No result type in parameter list.
705 Separate, // A separate parameter for each result type.
706 Collective, // An ArrayRef<Type> for all result types.
707 };
708
709 // The kind of parameter to generate for attributes in builders.
710 enum class AttrParamKind {
711 WrappedAttr, // A wrapped MLIR Attribute instance.
712 UnwrappedValue, // A raw value without MLIR Attribute wrapper.
713 };
714
715 // Builds the parameter list for build() method of this op. This method writes
716 // to `paramList` the comma-separated parameter list and updates
717 // `resultTypeNames` with the names for parameters for specifying result
718 // types. `inferredAttributes` is populated with any attributes that are
719 // elided from the build list. The given `typeParamKind` and `attrParamKind`
720 // controls how result types and attributes are placed in the parameter list.
721 void buildParamList(SmallVectorImpl<MethodParameter> &paramList,
722 llvm::StringSet<> &inferredAttributes,
723 SmallVectorImpl<std::string> &resultTypeNames,
724 TypeParamKind typeParamKind,
725 AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
726
727 // Adds op arguments and regions into operation state for build() methods.
728 void
729 genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
730 llvm::StringSet<> &inferredAttributes,
731 bool isRawValueAttr = false);
732
733 // Generates canonicalizer declaration for the operation.
734 void genCanonicalizerDecls();
735
736 // Generates the folder declaration for the operation.
737 void genFolderDecls();
738
739 // Generates the parser for the operation.
740 void genParser();
741
742 // Generates the printer for the operation.
743 void genPrinter();
744
745 // Generates verify method for the operation.
746 void genVerifier();
747
748 // Generates custom verify methods for the operation.
749 void genCustomVerifier();
750
751 // Generates verify statements for operands and results in the operation.
752 // The generated code will be attached to `body`.
753 void genOperandResultVerifier(MethodBody &body,
754 Operator::const_value_range values,
755 StringRef valueKind);
756
757 // Generates verify statements for regions in the operation.
758 // The generated code will be attached to `body`.
759 void genRegionVerifier(MethodBody &body);
760
761 // Generates verify statements for successors in the operation.
762 // The generated code will be attached to `body`.
763 void genSuccessorVerifier(MethodBody &body);
764
765 // Generates the traits used by the object.
766 void genTraits();
767
768 // Generate the OpInterface methods for all interfaces.
769 void genOpInterfaceMethods();
770
771 // Generate op interface methods for the given interface.
772 void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
773
774 // Generate op interface method for the given interface method. If
775 // 'declaration' is true, generates a declaration, else a definition.
776 Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
777 bool declaration = true);
778
779 // Generate the side effect interface methods.
780 void genSideEffectInterfaceMethods();
781
782 // Generate the type inference interface methods.
783 void genTypeInterfaceMethods();
784
785private:
786 // The TableGen record for this op.
787 // TODO: OpEmitter should not have a Record directly,
788 // it should rather go through the Operator for better abstraction.
789 const Record &def;
790
791 // The wrapper operator class for querying information from this op.
792 const Operator &op;
793
794 // The C++ code builder for this op
795 OpClass opClass;
796
797 // The format context for verification code generation.
798 FmtContext verifyCtx;
799
800 // The emitter containing all of the locally emitted verification functions.
801 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
802
803 // Helper for emitting op code.
804 OpOrAdaptorHelper emitHelper;
805};
806
807} // namespace
808
809// Populate the format context `ctx` with substitutions of attributes, operands
810// and results.
811static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
812 FmtContext &ctx) {
813 // Populate substitutions for attributes.
814 auto &op = emitHelper.getOp();
815 for (const auto &namedAttr : op.getAttributes())
816 ctx.addSubst(placeholder: namedAttr.name,
817 subst: emitHelper.getOp().getGetterName(name: namedAttr.name) + "()");
818
819 // Populate substitutions for named operands.
820 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
821 auto &value = op.getOperand(index: i);
822 if (!value.name.empty())
823 ctx.addSubst(placeholder: value.name, subst: emitHelper.getOperand(index: i).str());
824 }
825
826 // Populate substitutions for results.
827 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
828 auto &value = op.getResult(index: i);
829 if (!value.name.empty())
830 ctx.addSubst(placeholder: value.name, subst: emitHelper.getResult(index: i).str());
831 }
832}
833
834/// Generate verification on native traits requiring attributes.
835static void genNativeTraitAttrVerifier(MethodBody &body,
836 const OpOrAdaptorHelper &emitHelper) {
837 // Check that the variadic segment sizes attribute exists and contains the
838 // expected number of elements.
839 //
840 // {0}: Attribute name.
841 // {1}: Expected number of elements.
842 // {2}: "operand" or "result".
843 // {3}: Emit error prefix.
844 const char *const checkAttrSizedValueSegmentsCode = R"(
845 {
846 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>(tblgen_{0});
847 auto numElements = sizeAttr.asArrayRef().size();
848 if (numElements != {1})
849 return {3}"'{0}' attribute for specifying {2} segments must have {1} "
850 "elements, but got ") << numElements;
851 }
852 )";
853
854 // Verify a few traits first so that we can use getODSOperands() and
855 // getODSResults() in the rest of the verifier.
856 auto &op = emitHelper.getOp();
857 if (!op.getDialect().usePropertiesForAttributes()) {
858 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
859 body << formatv(Fmt: checkAttrSizedValueSegmentsCode, Vals: operandSegmentAttrName,
860 Vals: op.getNumOperands(), Vals: "operand",
861 Vals: emitHelper.emitErrorPrefix());
862 }
863 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments")) {
864 body << formatv(Fmt: checkAttrSizedValueSegmentsCode, Vals: resultSegmentAttrName,
865 Vals: op.getNumResults(), Vals: "result",
866 Vals: emitHelper.emitErrorPrefix());
867 }
868 }
869}
870
871// Return true if a verifier can be emitted for the attribute: it is not a
872// derived attribute, it has a predicate, its condition is not empty, and, for
873// adaptors, the condition does not reference the op.
874static bool canEmitAttrVerifier(Attribute attr, bool isEmittingForOp) {
875 if (attr.isDerivedAttr())
876 return false;
877 Pred pred = attr.getPredicate();
878 if (pred.isNull())
879 return false;
880 std::string condition = pred.getCondition();
881 return !condition.empty() &&
882 (!StringRef(condition).contains(Other: "$_op") || isEmittingForOp);
883}
884
885// Generate attribute verification. If an op instance is not available, then
886// attribute checks that require one will not be emitted.
887//
888// Attribute verification is performed as follows:
889//
890// 1. Verify that all required attributes are present in sorted order. This
891// ensures that we can use subrange lookup even with potentially missing
892// attributes.
893// 2. Verify native trait attributes so that other attributes may call methods
894// that depend on the validity of these attributes, e.g. segment size attributes
895// and operand or result getters.
896// 3. Verify the constraints on all present attributes.
897static void
898genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx,
899 MethodBody &body,
900 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
901 bool useProperties) {
902 if (emitHelper.getAttrMetadata().empty())
903 return;
904
905 // Verify the attribute if it is present. This assumes that default values
906 // are valid. This code snippet pastes the condition inline.
907 //
908 // TODO: verify the default value is valid (perhaps in debug mode only).
909 //
910 // {0}: Attribute variable name.
911 // {1}: Attribute condition code.
912 // {2}: Emit error prefix.
913 // {3}: Attribute name.
914 // {4}: Attribute/constraint description.
915 const char *const verifyAttrInline = R"(
916 if ({0} && !({1}))
917 return {2}"attribute '{3}' failed to satisfy constraint: {4}");
918)";
919 // Verify the attribute using a uniqued constraint. Can only be used within
920 // the context of an op.
921 //
922 // {0}: Unique constraint name.
923 // {1}: Attribute variable name.
924 // {2}: Attribute name.
925 const char *const verifyAttrUnique = R"(
926 if (::mlir::failed({0}(*this, {1}, "{2}")))
927 return ::mlir::failure();
928)";
929
930 // Traverse the array until the required attribute is found. Return an error
931 // if the traversal reached the end.
932 //
933 // {0}: Code to get the name of the attribute.
934 // {1}: The emit error prefix.
935 // {2}: The name of the attribute.
936 const char *const findRequiredAttr = R"(
937while (true) {{
938 if (namedAttrIt == namedAttrRange.end())
939 return {1}"requires attribute '{2}'");
940 if (namedAttrIt->getName() == {0}) {{
941 tblgen_{2} = namedAttrIt->getValue();
942 break;
943 })";
944
945 // Emit a check to see if the iteration has encountered an optional attribute.
946 //
947 // {0}: Code to get the name of the attribute.
948 // {1}: The name of the attribute.
949 const char *const checkOptionalAttr = R"(
950 else if (namedAttrIt->getName() == {0}) {{
951 tblgen_{1} = namedAttrIt->getValue();
952 })";
953
954 // Emit the start of the loop for checking trailing attributes.
955 const char *const checkTrailingAttrs = R"(while (true) {
956 if (namedAttrIt == namedAttrRange.end()) {
957 break;
958 })";
959
960 // Emit the verifier for the attribute.
961 const auto emitVerifier = [&](Attribute attr, StringRef attrName,
962 StringRef varName) {
963 std::string condition = attr.getPredicate().getCondition();
964
965 std::optional<StringRef> constraintFn;
966 if (emitHelper.isEmittingForOp() &&
967 (constraintFn = staticVerifierEmitter.getAttrConstraintFn(constraint: attr))) {
968 body << formatv(Fmt: verifyAttrUnique, Vals&: *constraintFn, Vals&: varName, Vals&: attrName);
969 } else {
970 body << formatv(Fmt: verifyAttrInline, Vals&: varName,
971 Vals: tgfmt(fmt: condition, ctx: &ctx.withSelf(subst: varName)),
972 Vals: emitHelper.emitErrorPrefix(), Vals&: attrName,
973 Vals: escapeString(value: attr.getSummary()));
974 }
975 };
976
977 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
978 const auto getVarName = [&](StringRef attrName) {
979 return (tblgenNamePrefix + attrName).str();
980 };
981
982 body.indent();
983 if (useProperties) {
984 for (const std::pair<StringRef, AttributeMetadata> &it :
985 emitHelper.getAttrMetadata()) {
986 const AttributeMetadata &metadata = it.second;
987 if (metadata.constraint && metadata.constraint->isDerivedAttr())
988 continue;
989 body << formatv(
990 Fmt: "auto tblgen_{0} = getProperties().{0}; (void)tblgen_{0};\n",
991 Vals: it.first);
992 if (metadata.isRequired)
993 body << formatv(
994 Fmt: "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n",
995 Vals: it.first, Vals: emitHelper.emitErrorPrefix());
996 }
997 } else {
998 body << formatv(Fmt: "auto namedAttrRange = {0};\n", Vals: emitHelper.getAttrRange());
999 body << "auto namedAttrIt = namedAttrRange.begin();\n";
1000
1001 // Iterate over the attributes in sorted order. Keep track of the optional
1002 // attributes that may be encountered along the way.
1003 SmallVector<const AttributeMetadata *> optionalAttrs;
1004
1005 for (const std::pair<StringRef, AttributeMetadata> &it :
1006 emitHelper.getAttrMetadata()) {
1007 const AttributeMetadata &metadata = it.second;
1008 if (!metadata.isRequired) {
1009 optionalAttrs.push_back(Elt: &metadata);
1010 continue;
1011 }
1012
1013 body << formatv(Fmt: "::mlir::Attribute {0};\n", Vals: getVarName(it.first));
1014 for (const AttributeMetadata *optional : optionalAttrs) {
1015 body << formatv(Fmt: "::mlir::Attribute {0};\n",
1016 Vals: getVarName(optional->attrName));
1017 }
1018 body << formatv(Fmt: findRequiredAttr, Vals: emitHelper.getAttrName(attrName: it.first),
1019 Vals: emitHelper.emitErrorPrefix(), Vals: it.first);
1020 for (const AttributeMetadata *optional : optionalAttrs) {
1021 body << formatv(Fmt: checkOptionalAttr,
1022 Vals: emitHelper.getAttrName(attrName: optional->attrName),
1023 Vals: optional->attrName);
1024 }
1025 body << "\n ++namedAttrIt;\n}\n";
1026 optionalAttrs.clear();
1027 }
1028 // Get trailing optional attributes.
1029 if (!optionalAttrs.empty()) {
1030 for (const AttributeMetadata *optional : optionalAttrs) {
1031 body << formatv(Fmt: "::mlir::Attribute {0};\n",
1032 Vals: getVarName(optional->attrName));
1033 }
1034 body << checkTrailingAttrs;
1035 for (const AttributeMetadata *optional : optionalAttrs) {
1036 body << formatv(Fmt: checkOptionalAttr,
1037 Vals: emitHelper.getAttrName(attrName: optional->attrName),
1038 Vals: optional->attrName);
1039 }
1040 body << "\n ++namedAttrIt;\n}\n";
1041 }
1042 }
1043 body.unindent();
1044
1045 // Emit the checks for segment attributes first so that the other
1046 // constraints can call operand and result getters.
1047 genNativeTraitAttrVerifier(body, emitHelper);
1048
1049 bool isEmittingForOp = emitHelper.isEmittingForOp();
1050 for (const auto &namedAttr : emitHelper.getOp().getAttributes())
1051 if (canEmitAttrVerifier(attr: namedAttr.attr, isEmittingForOp))
1052 emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
1053}
1054
1055static void genPropertyVerifier(
1056 const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body,
1057 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
1058
1059 // Code to get a reference to a property into a variable to avoid multiple
1060 // evaluations while verifying a property.
1061 // {0}: Property variable name.
1062 // {1}: Property name, with the first letter capitalized, to find the getter.
1063 // {2}: Property interface type.
1064 const char *const fetchProperty = R"(
1065 [[maybe_unused]] {2} {0} = this->get{1}();
1066)";
1067
1068 // Code to verify that the predicate of a property holds. Embeds the
1069 // condition inline.
1070 // {0}: Property condition code, with tgfmt() applied.
1071 // {1}: Emit error prefix.
1072 // {2}: Property name.
1073 // {3}: Property description.
1074 const char *const verifyPropertyInline = R"(
1075 if (!({0}))
1076 return {1}"property '{2}' failed to satisfy constraint: {3}");
1077)";
1078
1079 // Verify the property using a uniqued constraint. Can only be used
1080 // within the context of an op.
1081 //
1082 // {0}: Unique constraint name.
1083 // {1}: Property variable name in interface type.
1084 // {2}: Property name.
1085 const char *const verifyPropertyUniqued = R"(
1086 if (::mlir::failed({0}(*this, {1}, "{2}")))
1087 return ::mlir::failure();
1088)";
1089
1090 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
1091 const auto getVarName = [&](const NamedProperty &prop) {
1092 std::string varName =
1093 convertToCamelFromSnakeCase(input: prop.name, /*capitalizeFirst=*/false);
1094 return (tblgenNamePrefix + Twine(varName)).str();
1095 };
1096
1097 for (const NamedProperty &prop : emitHelper.getOp().getProperties()) {
1098 Pred predicate = prop.prop.getPredicate();
1099 // Null predicate, nothing to verify.
1100 if (predicate == Pred())
1101 continue;
1102
1103 std::string rawCondition = predicate.getCondition();
1104 if (rawCondition == "true")
1105 continue;
1106 bool needsOp = StringRef(rawCondition).contains(Other: "$_op");
1107 if (needsOp && !emitHelper.isEmittingForOp())
1108 continue;
1109
1110 auto scope = body.scope(open: "{\n", close: "}\n", /*indent=*/true);
1111 std::string varName = getVarName(prop);
1112 std::string getterName =
1113 convertToCamelFromSnakeCase(input: prop.name, /*capitalizeFirst=*/true);
1114 body << formatv(Fmt: fetchProperty, Vals&: varName, Vals&: getterName,
1115 Vals: prop.prop.getInterfaceType());
1116 auto uniquedFn = staticVerifierEmitter.getPropConstraintFn(constraint: prop.prop);
1117 if (uniquedFn.has_value())
1118 body << formatv(Fmt: verifyPropertyUniqued, Vals&: *uniquedFn, Vals&: varName, Vals: prop.name);
1119 else
1120 body << formatv(
1121 Fmt: verifyPropertyInline, Vals: tgfmt(fmt: rawCondition, ctx: &ctx.withSelf(subst: varName)),
1122 Vals: emitHelper.emitErrorPrefix(), Vals: prop.name, Vals: prop.prop.getSummary());
1123 }
1124}
1125
1126/// Include declarations specified on NativeTrait
1127static std::string formatExtraDeclarations(const Operator &op) {
1128 SmallVector<StringRef> extraDeclarations;
1129 // Include extra class declarations from NativeTrait
1130 for (const auto &trait : op.getTraits()) {
1131 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
1132 StringRef value = opTrait->getExtraConcreteClassDeclaration();
1133 if (value.empty())
1134 continue;
1135 extraDeclarations.push_back(Elt: value);
1136 }
1137 }
1138 extraDeclarations.push_back(Elt: op.getExtraClassDeclaration());
1139 return llvm::join(R&: extraDeclarations, Separator: "\n");
1140}
1141
1142/// Op extra class definitions have a `$cppClass` substitution that is to be
1143/// replaced by the C++ class name.
1144/// Include declarations specified on NativeTrait
1145static std::string formatExtraDefinitions(const Operator &op) {
1146 SmallVector<StringRef> extraDefinitions;
1147 // Include extra class definitions from NativeTrait
1148 for (const auto &trait : op.getTraits()) {
1149 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
1150 StringRef value = opTrait->getExtraConcreteClassDefinition();
1151 if (value.empty())
1152 continue;
1153 extraDefinitions.push_back(Elt: value);
1154 }
1155 }
1156 extraDefinitions.push_back(Elt: op.getExtraClassDefinition());
1157 FmtContext ctx = FmtContext().addSubst(placeholder: "cppClass", subst: op.getCppClassName());
1158 return tgfmt(fmt: llvm::join(R&: extraDefinitions, Separator: "\n"), ctx: &ctx).str();
1159}
1160
1161OpEmitter::OpEmitter(const Operator &op,
1162 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
1163 : def(op.getDef()), op(op),
1164 opClass(op.getCppClassName(), formatExtraDeclarations(op),
1165 formatExtraDefinitions(op)),
1166 staticVerifierEmitter(staticVerifierEmitter),
1167 emitHelper(op, /*emitForOp=*/true) {
1168 verifyCtx.addSubst(placeholder: "_op", subst: "(*this->getOperation())");
1169 verifyCtx.addSubst(placeholder: "_ctxt", subst: "this->getOperation()->getContext()");
1170
1171 genTraits();
1172
1173 // Generate C++ code for various op methods. The order here determines the
1174 // methods in the generated file.
1175 genAttrNameGetters();
1176 genOpAsmInterface();
1177 genOpNameGetter();
1178 genNamedOperandGetters();
1179 genNamedOperandSetters();
1180 genNamedResultGetters();
1181 genNamedRegionGetters();
1182 genNamedSuccessorGetters();
1183 genPropertiesSupport();
1184 genPropGetters();
1185 genPropSetters();
1186 genAttrGetters();
1187 genAttrSetters();
1188 genOptionalAttrRemovers();
1189 genBuilder();
1190 genPopulateDefaultAttributes();
1191 genParser();
1192 genPrinter();
1193 genVerifier();
1194 genCustomVerifier();
1195 genCanonicalizerDecls();
1196 genFolderDecls();
1197 genTypeInterfaceMethods();
1198 genOpInterfaceMethods();
1199 generateOpFormat(constOp: op, opClass, hasProperties: emitHelper.hasProperties());
1200 genSideEffectInterfaceMethods();
1201}
1202void OpEmitter::emitDecl(
1203 const Operator &op, raw_ostream &os,
1204 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
1205 OpEmitter(op, staticVerifierEmitter).emitDecl(os);
1206}
1207
1208void OpEmitter::emitDef(
1209 const Operator &op, raw_ostream &os,
1210 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
1211 OpEmitter(op, staticVerifierEmitter).emitDef(os);
1212}
1213
1214void OpEmitter::emitDecl(raw_ostream &os) {
1215 opClass.finalize();
1216 opClass.writeDeclTo(rawOs&: os);
1217}
1218
1219void OpEmitter::emitDef(raw_ostream &os) {
1220 opClass.finalize();
1221 opClass.writeDefTo(rawOs&: os);
1222}
1223
1224static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
1225 const Operator &op) {
1226 if (m)
1227 return;
1228 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "Unexpected overlap when generating `" +
1229 methodName + "` for " +
1230 op.getOperationName() + " (from line " +
1231 Twine(line) + ")");
1232}
1233
1234#define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
1235
1236void OpEmitter::genAttrNameGetters() {
1237 const llvm::MapVector<StringRef, AttributeMetadata> &attributes =
1238 emitHelper.getAttrMetadata();
1239 bool hasOperandSegmentsSize =
1240 op.getDialect().usePropertiesForAttributes() &&
1241 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments");
1242 // Emit the getAttributeNames method.
1243 {
1244 auto *method = opClass.addStaticInlineMethod(
1245 retType: "::llvm::ArrayRef<::llvm::StringRef>", name: "getAttributeNames");
1246 ERROR_IF_PRUNED(method, "getAttributeNames", op);
1247 auto &body = method->body();
1248 if (!hasOperandSegmentsSize && attributes.empty()) {
1249 body << " return {};";
1250 // Nothing else to do if there are no registered attributes. Exit early.
1251 return;
1252 }
1253 body << " static ::llvm::StringRef attrNames[] = {";
1254 llvm::interleaveComma(c: llvm::make_first_range(c: attributes), os&: body,
1255 each_fn: [&](StringRef attrName) {
1256 body << "::llvm::StringRef(\"" << attrName << "\")";
1257 });
1258 if (hasOperandSegmentsSize) {
1259 if (!attributes.empty())
1260 body << ", ";
1261 body << "::llvm::StringRef(\"" << operandSegmentAttrName << "\")";
1262 }
1263 body << "};\n return ::llvm::ArrayRef(attrNames);";
1264 }
1265
1266 // Emit the getAttributeNameForIndex methods.
1267 {
1268 auto *method = opClass.addInlineMethod<Method::Private>(
1269 retType: "::mlir::StringAttr", name: "getAttributeNameForIndex",
1270 args: MethodParameter("unsigned", "index"));
1271 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
1272 method->body()
1273 << " return getAttributeNameForIndex((*this)->getName(), index);";
1274 }
1275 {
1276 auto *method = opClass.addStaticInlineMethod<Method::Private>(
1277 retType: "::mlir::StringAttr", name: "getAttributeNameForIndex",
1278 args: MethodParameter("::mlir::OperationName", "name"),
1279 args: MethodParameter("unsigned", "index"));
1280 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
1281
1282 if (attributes.empty()) {
1283 method->body() << " return {};";
1284 } else {
1285 const char *const getAttrName = R"(
1286 assert(index < {0} && "invalid attribute index");
1287 assert(name.getStringRef() == getOperationName() && "invalid operation name");
1288 assert(name.isRegistered() && "Operation isn't registered, missing a "
1289 "dependent dialect loading?");
1290 return name.getAttributeNames()[index];
1291)";
1292 method->body() << formatv(Fmt: getAttrName, Vals: attributes.size());
1293 }
1294 }
1295
1296 // Generate the <attr>AttrName methods, that expose the attribute names to
1297 // users.
1298 const char *attrNameMethodBody = " return getAttributeNameForIndex({0});";
1299 for (auto [index, attr] :
1300 llvm::enumerate(First: llvm::make_first_range(c: attributes))) {
1301 std::string name = op.getGetterName(name: attr);
1302 std::string methodName = name + "AttrName";
1303
1304 // Generate the non-static variant.
1305 {
1306 auto *method = opClass.addInlineMethod(retType: "::mlir::StringAttr", name&: methodName);
1307 ERROR_IF_PRUNED(method, methodName, op);
1308 method->body() << llvm::formatv(Fmt: attrNameMethodBody, Vals&: index);
1309 }
1310
1311 // Generate the static variant.
1312 {
1313 auto *method = opClass.addStaticInlineMethod(
1314 retType: "::mlir::StringAttr", name&: methodName,
1315 args: MethodParameter("::mlir::OperationName", "name"));
1316 ERROR_IF_PRUNED(method, methodName, op);
1317 method->body() << llvm::formatv(Fmt: attrNameMethodBody,
1318 Vals: "name, " + Twine(index));
1319 }
1320 }
1321 if (hasOperandSegmentsSize) {
1322 std::string name = op.getGetterName(name: operandSegmentAttrName);
1323 std::string methodName = name + "AttrName";
1324 // Generate the non-static variant.
1325 {
1326 auto *method = opClass.addInlineMethod(retType: "::mlir::StringAttr", name&: methodName);
1327 ERROR_IF_PRUNED(method, methodName, op);
1328 method->body()
1329 << " return (*this)->getName().getAttributeNames().back();";
1330 }
1331
1332 // Generate the static variant.
1333 {
1334 auto *method = opClass.addStaticInlineMethod(
1335 retType: "::mlir::StringAttr", name&: methodName,
1336 args: MethodParameter("::mlir::OperationName", "name"));
1337 ERROR_IF_PRUNED(method, methodName, op);
1338 method->body() << " return name.getAttributeNames().back();";
1339 }
1340 }
1341}
1342
1343// Emit the getter for a named property.
1344// It is templated to be shared between the Op and the adaptor class.
1345template <typename OpClassOrAdaptor>
1346static void emitPropGetter(OpClassOrAdaptor &opClass, const Operator &op,
1347 StringRef name, const Property &prop) {
1348 auto *method = opClass.addInlineMethod(prop.getInterfaceType(), name);
1349 ERROR_IF_PRUNED(method, name, op);
1350 method->body() << formatv(Fmt: " return getProperties().{0}();", Vals&: name);
1351}
1352
1353// Emit the getter for an attribute with the return type specified.
1354// It is templated to be shared between the Op and the adaptor class.
1355template <typename OpClassOrAdaptor>
1356static void emitAttrGetterWithReturnType(FmtContext &fctx,
1357 OpClassOrAdaptor &opClass,
1358 const Operator &op, StringRef name,
1359 Attribute attr) {
1360 auto *method = opClass.addMethod(attr.getReturnType(), name);
1361 ERROR_IF_PRUNED(method, name, op);
1362 auto &body = method->body();
1363 body << " auto attr = " << name << "Attr();\n";
1364 if (attr.hasDefaultValue() && attr.isOptional()) {
1365 // Returns the default value if not set.
1366 // TODO: this is inefficient, we are recreating the attribute for every
1367 // call. This should be set instead.
1368 if (!attr.isConstBuildable()) {
1369 PrintFatalError(Msg: "DefaultValuedAttr of type " + attr.getAttrDefName() +
1370 " must have a constBuilder");
1371 }
1372 std::string defaultValue =
1373 std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx,
1374 vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx)));
1375 body << " if (!attr)\n return "
1376 << tgfmt(fmt: attr.getConvertFromStorageCall(),
1377 ctx: &fctx.withSelf(subst: defaultValue))
1378 << ";\n";
1379 }
1380 body << " return "
1381 << tgfmt(fmt: attr.getConvertFromStorageCall(), ctx: &fctx.withSelf(subst: "attr"))
1382 << ";\n";
1383}
1384
1385void OpEmitter::genPropertiesSupport() {
1386 if (!emitHelper.hasProperties())
1387 return;
1388
1389 SmallVector<ConstArgument> attrOrProperties;
1390 for (const std::pair<StringRef, AttributeMetadata> &it :
1391 emitHelper.getAttrMetadata()) {
1392 if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
1393 attrOrProperties.push_back(Elt: &it.second);
1394 }
1395 for (const NamedProperty &prop : op.getProperties())
1396 attrOrProperties.push_back(Elt: &prop);
1397 if (emitHelper.getOperandSegmentsSize())
1398 attrOrProperties.push_back(Elt: &emitHelper.getOperandSegmentsSize().value());
1399 if (emitHelper.getResultSegmentsSize())
1400 attrOrProperties.push_back(Elt: &emitHelper.getResultSegmentsSize().value());
1401 auto &setPropMethod =
1402 opClass
1403 .addStaticMethod(
1404 retType: "::llvm::LogicalResult", name: "setPropertiesFromAttr",
1405 args: MethodParameter("Properties &", "prop"),
1406 args: MethodParameter("::mlir::Attribute", "attr"),
1407 args: MethodParameter(
1408 "::llvm::function_ref<::mlir::InFlightDiagnostic()>",
1409 "emitError"))
1410 ->body();
1411 auto &getPropMethod =
1412 opClass
1413 .addStaticMethod(retType: "::mlir::Attribute", name: "getPropertiesAsAttr",
1414 args: MethodParameter("::mlir::MLIRContext *", "ctx"),
1415 args: MethodParameter("const Properties &", "prop"))
1416 ->body();
1417 auto &hashMethod =
1418 opClass
1419 .addStaticMethod(retType: "llvm::hash_code", name: "computePropertiesHash",
1420 args: MethodParameter("const Properties &", "prop"))
1421 ->body();
1422 auto &getInherentAttrMethod =
1423 opClass
1424 .addStaticMethod(retType: "std::optional<mlir::Attribute>", name: "getInherentAttr",
1425 args: MethodParameter("::mlir::MLIRContext *", "ctx"),
1426 args: MethodParameter("const Properties &", "prop"),
1427 args: MethodParameter("llvm::StringRef", "name"))
1428 ->body();
1429 auto &setInherentAttrMethod =
1430 opClass
1431 .addStaticMethod(retType: "void", name: "setInherentAttr",
1432 args: MethodParameter("Properties &", "prop"),
1433 args: MethodParameter("llvm::StringRef", "name"),
1434 args: MethodParameter("mlir::Attribute", "value"))
1435 ->body();
1436 auto &populateInherentAttrsMethod =
1437 opClass
1438 .addStaticMethod(retType: "void", name: "populateInherentAttrs",
1439 args: MethodParameter("::mlir::MLIRContext *", "ctx"),
1440 args: MethodParameter("const Properties &", "prop"),
1441 args: MethodParameter("::mlir::NamedAttrList &", "attrs"))
1442 ->body();
1443 auto &verifyInherentAttrsMethod =
1444 opClass
1445 .addStaticMethod(
1446 retType: "::llvm::LogicalResult", name: "verifyInherentAttrs",
1447 args: MethodParameter("::mlir::OperationName", "opName"),
1448 args: MethodParameter("::mlir::NamedAttrList &", "attrs"),
1449 args: MethodParameter(
1450 "llvm::function_ref<::mlir::InFlightDiagnostic()>",
1451 "emitError"))
1452 ->body();
1453
1454 opClass.declare<UsingDeclaration>(args: "Properties", args: "FoldAdaptor::Properties");
1455
1456 // Convert the property to the attribute form.
1457
1458 setPropMethod << R"decl(
1459 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
1460 if (!dict) {
1461 emitError() << "expected DictionaryAttr to set properties";
1462 return ::mlir::failure();
1463 }
1464 )decl";
1465 const char *propFromAttrFmt = R"decl(
1466 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
1467 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{
1468 {0}
1469 };
1470 {1};
1471)decl";
1472 const char *attrGetNoDefaultFmt = R"decl(;
1473 if (attr && ::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
1474 return ::mlir::failure();
1475)decl";
1476 const char *attrGetDefaultFmt = R"decl(;
1477 if (attr) {{
1478 if (::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
1479 return ::mlir::failure();
1480 } else {{
1481 prop.{0} = {1};
1482 }
1483)decl";
1484
1485 for (const auto &attrOrProp : attrOrProperties) {
1486 if (const auto *namedProperty =
1487 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
1488 StringRef name = namedProperty->name;
1489 auto &prop = namedProperty->prop;
1490 FmtContext fctx;
1491
1492 std::string getAttr;
1493 llvm::raw_string_ostream os(getAttr);
1494 os << " auto attr = dict.get(\"" << name << "\");";
1495 if (name == operandSegmentAttrName) {
1496 // Backward compat for now, TODO: Remove at some point.
1497 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1498 }
1499 if (name == resultSegmentAttrName) {
1500 // Backward compat for now, TODO: Remove at some point.
1501 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1502 }
1503
1504 fctx.withBuilder(subst: odsBuilder);
1505 setPropMethod << "{\n"
1506 << formatv(Fmt: propFromAttrFmt,
1507 Vals: tgfmt(fmt: prop.getConvertFromAttributeCall(),
1508 ctx: &fctx.addSubst(placeholder: "_attr", subst: propertyAttr)
1509 .addSubst(placeholder: "_storage", subst: propertyStorage)
1510 .addSubst(placeholder: "_diag", subst: propertyDiag)),
1511 Vals&: getAttr);
1512 if (prop.hasStorageTypeValueOverride()) {
1513 setPropMethod << formatv(Fmt: attrGetDefaultFmt, Vals&: name,
1514 Vals: prop.getStorageTypeValueOverride());
1515 } else if (prop.hasDefaultValue()) {
1516 setPropMethod << formatv(Fmt: attrGetDefaultFmt, Vals&: name,
1517 Vals: tgfmt(fmt: prop.getDefaultValue(), ctx: &fctx));
1518 } else {
1519 setPropMethod << formatv(Fmt: attrGetNoDefaultFmt, Vals&: name);
1520 }
1521 setPropMethod << " }\n";
1522 } else {
1523 const auto *namedAttr =
1524 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
1525 StringRef name = namedAttr->attrName;
1526 std::string getAttr;
1527 llvm::raw_string_ostream os(getAttr);
1528 os << " auto attr = dict.get(\"" << name << "\");";
1529 if (name == operandSegmentAttrName) {
1530 // Backward compat for now
1531 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1532 }
1533 if (name == resultSegmentAttrName) {
1534 // Backward compat for now
1535 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1536 }
1537
1538 setPropMethod << formatv(Fmt: R"decl(
1539 {{
1540 auto &propStorage = prop.{0};
1541 {1}
1542 if (attr) {{
1543 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
1544 if (convertedAttr) {{
1545 propStorage = convertedAttr;
1546 } else {{
1547 emitError() << "Invalid attribute `{0}` in property conversion: " << attr;
1548 return ::mlir::failure();
1549 }
1550 }
1551 }
1552)decl",
1553 Vals&: name, Vals&: getAttr);
1554 }
1555 }
1556 setPropMethod << " return ::mlir::success();\n";
1557
1558 // Convert the attribute form to the property.
1559
1560 getPropMethod << " ::mlir::SmallVector<::mlir::NamedAttribute> attrs;\n"
1561 << " ::mlir::Builder odsBuilder{ctx};\n";
1562 const char *propToAttrFmt = R"decl(
1563 {
1564 const auto &propStorage = prop.{0};
1565 auto attr = [&]() -> ::mlir::Attribute {{
1566 {1}
1567 }();
1568 attrs.push_back(odsBuilder.getNamedAttr("{0}", attr));
1569 }
1570)decl";
1571 for (const auto &attrOrProp : attrOrProperties) {
1572 if (const auto *namedProperty =
1573 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
1574 StringRef name = namedProperty->name;
1575 auto &prop = namedProperty->prop;
1576 FmtContext fctx;
1577 getPropMethod << formatv(
1578 Fmt: propToAttrFmt, Vals&: name,
1579 Vals: tgfmt(fmt: prop.getConvertToAttributeCall(),
1580 ctx: &fctx.addSubst(placeholder: "_ctxt", subst: "ctx")
1581 .addSubst(placeholder: "_storage", subst: propertyStorage)));
1582 continue;
1583 }
1584 const auto *namedAttr =
1585 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
1586 StringRef name = namedAttr->attrName;
1587 getPropMethod << formatv(Fmt: R"decl(
1588 {{
1589 const auto &propStorage = prop.{0};
1590 if (propStorage)
1591 attrs.push_back(odsBuilder.getNamedAttr("{0}",
1592 propStorage));
1593 }
1594)decl",
1595 Vals&: name);
1596 }
1597 getPropMethod << R"decl(
1598 if (!attrs.empty())
1599 return odsBuilder.getDictionaryAttr(attrs);
1600 return {};
1601)decl";
1602
1603 // Hashing for the property
1604
1605 const char *propHashFmt = R"decl(
1606 auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code {
1607 using ::llvm::hash_value;
1608 return {1};
1609 };
1610)decl";
1611 for (const auto &attrOrProp : attrOrProperties) {
1612 if (const auto *namedProperty =
1613 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
1614 StringRef name = namedProperty->name;
1615 auto &prop = namedProperty->prop;
1616 FmtContext fctx;
1617 if (!prop.getHashPropertyCall().empty()) {
1618 hashMethod << formatv(
1619 Fmt: propHashFmt, Vals&: name,
1620 Vals: tgfmt(fmt: prop.getHashPropertyCall(),
1621 ctx: &fctx.addSubst(placeholder: "_storage", subst: propertyStorage)));
1622 }
1623 }
1624 }
1625 hashMethod << " using llvm::hash_value;\n";
1626 hashMethod << " return llvm::hash_combine(";
1627 llvm::interleaveComma(
1628 c: attrOrProperties, os&: hashMethod, each_fn: [&](const ConstArgument &attrOrProp) {
1629 if (const auto *namedProperty =
1630 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
1631 if (!namedProperty->prop.getHashPropertyCall().empty()) {
1632 hashMethod << "\n hash_" << namedProperty->name << "(prop."
1633 << namedProperty->name << ")";
1634 } else {
1635 hashMethod << "\n hash_value(prop." << namedProperty->name
1636 << ")";
1637 }
1638 return;
1639 }
1640 const auto *namedAttr =
1641 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
1642 StringRef name = namedAttr->attrName;
1643 hashMethod << "\n llvm::hash_value(prop." << name
1644 << ".getAsOpaquePointer())";
1645 });
1646 hashMethod << ");\n";
1647
1648 const char *getInherentAttrMethodFmt = R"decl(
1649 if (name == "{0}")
1650 return prop.{0};
1651)decl";
1652 const char *setInherentAttrMethodFmt = R"decl(
1653 if (name == "{0}") {{
1654 prop.{0} = ::llvm::dyn_cast_or_null<std::remove_reference_t<decltype(prop.{0})>>(value);
1655 return;
1656 }
1657)decl";
1658 const char *populateInherentAttrsMethodFmt = R"decl(
1659 if (prop.{0}) attrs.append("{0}", prop.{0});
1660)decl";
1661 for (const auto &attrOrProp : attrOrProperties) {
1662 if (const auto *namedAttr =
1663 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp)) {
1664 StringRef name = namedAttr->attrName;
1665 getInherentAttrMethod << formatv(Fmt: getInherentAttrMethodFmt, Vals&: name);
1666 setInherentAttrMethod << formatv(Fmt: setInherentAttrMethodFmt, Vals&: name);
1667 populateInherentAttrsMethod
1668 << formatv(Fmt: populateInherentAttrsMethodFmt, Vals&: name);
1669 continue;
1670 }
1671 // The ODS segment size property is "special": we expose it as an attribute
1672 // even though it is a native property.
1673 const auto *namedProperty = cast<const NamedProperty *>(Val: attrOrProp);
1674 StringRef name = namedProperty->name;
1675 if (name != operandSegmentAttrName && name != resultSegmentAttrName)
1676 continue;
1677 auto &prop = namedProperty->prop;
1678 FmtContext fctx;
1679 fctx.addSubst(placeholder: "_ctxt", subst: "ctx");
1680 fctx.addSubst(placeholder: "_storage", subst: Twine("prop.") + name);
1681 if (name == operandSegmentAttrName) {
1682 getInherentAttrMethod
1683 << formatv(Fmt: " if (name == \"operand_segment_sizes\" || name == "
1684 "\"{0}\") return ",
1685 Vals: operandSegmentAttrName);
1686 } else {
1687 getInherentAttrMethod
1688 << formatv(Fmt: " if (name == \"result_segment_sizes\" || name == "
1689 "\"{0}\") return ",
1690 Vals: resultSegmentAttrName);
1691 }
1692 getInherentAttrMethod << "[&]() -> ::mlir::Attribute { "
1693 << tgfmt(fmt: prop.getConvertToAttributeCall(), ctx: &fctx)
1694 << " }();\n";
1695
1696 if (name == operandSegmentAttrName) {
1697 setInherentAttrMethod
1698 << formatv(Fmt: " if (name == \"operand_segment_sizes\" || name == "
1699 "\"{0}\") {{",
1700 Vals: operandSegmentAttrName);
1701 } else {
1702 setInherentAttrMethod
1703 << formatv(Fmt: " if (name == \"result_segment_sizes\" || name == "
1704 "\"{0}\") {{",
1705 Vals: resultSegmentAttrName);
1706 }
1707 setInherentAttrMethod << formatv(Fmt: R"decl(
1708 auto arrAttr = ::llvm::dyn_cast_or_null<::mlir::DenseI32ArrayAttr>(value);
1709 if (!arrAttr) return;
1710 if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t))
1711 return;
1712 llvm::copy(arrAttr.asArrayRef(), prop.{0}.begin());
1713 return;
1714 }
1715)decl",
1716 Vals&: name);
1717 if (name == operandSegmentAttrName) {
1718 populateInherentAttrsMethod << formatv(
1719 Fmt: " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n",
1720 Vals: operandSegmentAttrName,
1721 Vals: tgfmt(fmt: prop.getConvertToAttributeCall(), ctx: &fctx));
1722 } else {
1723 populateInherentAttrsMethod << formatv(
1724 Fmt: " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n",
1725 Vals: resultSegmentAttrName,
1726 Vals: tgfmt(fmt: prop.getConvertToAttributeCall(), ctx: &fctx));
1727 }
1728 }
1729 getInherentAttrMethod << " return std::nullopt;\n";
1730
1731 // Emit the verifiers method for backward compatibility with the generic
1732 // syntax. This method verifies the constraint on the properties attributes
1733 // before they are set, since dyn_cast<> will silently omit failures.
1734 for (const auto &attrOrProp : attrOrProperties) {
1735 const auto *namedAttr =
1736 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
1737 if (!namedAttr || !namedAttr->constraint)
1738 continue;
1739 Attribute attr = *namedAttr->constraint;
1740 std::optional<StringRef> constraintFn =
1741 staticVerifierEmitter.getAttrConstraintFn(constraint: attr);
1742 if (!constraintFn)
1743 continue;
1744 if (canEmitAttrVerifier(attr,
1745 /*isEmittingForOp=*/false)) {
1746 std::string name = op.getGetterName(name: namedAttr->attrName);
1747 verifyInherentAttrsMethod
1748 << formatv(Fmt: R"(
1749 {{
1750 ::mlir::Attribute attr = attrs.get({0}AttrName(opName));
1751 if (attr && ::mlir::failed({1}(attr, "{2}", emitError)))
1752 return ::mlir::failure();
1753 }
1754)",
1755 Vals&: name, Vals&: constraintFn, Vals: namedAttr->attrName);
1756 }
1757 }
1758 verifyInherentAttrsMethod << " return ::mlir::success();";
1759
1760 // Generate methods to interact with bytecode.
1761 genPropertiesSupportForBytecode(attrOrProperties);
1762}
1763
1764void OpEmitter::genPropertiesSupportForBytecode(
1765 ArrayRef<ConstArgument> attrOrProperties) {
1766 if (attrOrProperties.empty())
1767 return;
1768
1769 if (op.useCustomPropertiesEncoding()) {
1770 opClass.declareStaticMethod(
1771 retType: "::llvm::LogicalResult", name: "readProperties",
1772 args: MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1773 args: MethodParameter("::mlir::OperationState &", "state"));
1774 opClass.declareMethod(
1775 retType: "void", name: "writeProperties",
1776 args: MethodParameter("::mlir::DialectBytecodeWriter &", "writer"));
1777 return;
1778 }
1779
1780 auto &readPropertiesMethod =
1781 opClass
1782 .addStaticMethod(
1783 retType: "::llvm::LogicalResult", name: "readProperties",
1784 args: MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1785 args: MethodParameter("::mlir::OperationState &", "state"))
1786 ->body();
1787
1788 auto &writePropertiesMethod =
1789 opClass
1790 .addMethod(
1791 retType: "void", name: "writeProperties",
1792 args: MethodParameter("::mlir::DialectBytecodeWriter &", "writer"))
1793 ->body();
1794
1795 // Populate bytecode serialization logic.
1796 readPropertiesMethod
1797 << " auto &prop = state.getOrAddProperties<Properties>(); (void)prop;";
1798 writePropertiesMethod << " auto &prop = getProperties(); (void)prop;\n";
1799 for (const auto &item : llvm::enumerate(First&: attrOrProperties)) {
1800 auto &attrOrProp = item.value();
1801 FmtContext fctx;
1802 fctx.addSubst(placeholder: "_reader", subst: "reader")
1803 .addSubst(placeholder: "_writer", subst: "writer")
1804 .addSubst(placeholder: "_storage", subst: propertyStorage)
1805 .addSubst(placeholder: "_ctxt", subst: "this->getContext()");
1806 // If the op emits operand/result segment sizes as a property, emit the
1807 // legacy reader/writer in the appropriate order to allow backward
1808 // compatibility and back deployment.
1809 if (emitHelper.getOperandSegmentsSize().has_value() &&
1810 item.index() == emitHelper.getOperandSegmentSizesLegacyIndex()) {
1811 FmtContext fmtCtxt(fctx);
1812 fmtCtxt.addSubst(placeholder: "_propName", subst: operandSegmentAttrName);
1813 readPropertiesMethod << tgfmt(fmt: readBytecodeSegmentSizeLegacy, ctx: &fmtCtxt);
1814 writePropertiesMethod << tgfmt(fmt: writeBytecodeSegmentSizeLegacy, ctx: &fmtCtxt);
1815 }
1816 if (emitHelper.getResultSegmentsSize().has_value() &&
1817 item.index() == emitHelper.getResultSegmentSizesLegacyIndex()) {
1818 FmtContext fmtCtxt(fctx);
1819 fmtCtxt.addSubst(placeholder: "_propName", subst: resultSegmentAttrName);
1820 readPropertiesMethod << tgfmt(fmt: readBytecodeSegmentSizeLegacy, ctx: &fmtCtxt);
1821 writePropertiesMethod << tgfmt(fmt: writeBytecodeSegmentSizeLegacy, ctx: &fmtCtxt);
1822 }
1823 if (const auto *namedProperty =
1824 dyn_cast<const NamedProperty *>(Val: attrOrProp)) {
1825 StringRef name = namedProperty->name;
1826 readPropertiesMethod << formatv(
1827 Fmt: R"(
1828 {{
1829 auto &propStorage = prop.{0};
1830 auto readProp = [&]() {
1831 {1};
1832 return ::mlir::success();
1833 };
1834 if (::mlir::failed(readProp()))
1835 return ::mlir::failure();
1836 }
1837)",
1838 Vals&: name,
1839 Vals: tgfmt(fmt: namedProperty->prop.getReadFromMlirBytecodeCall(), ctx: &fctx));
1840 writePropertiesMethod << formatv(
1841 Fmt: R"(
1842 {{
1843 auto &propStorage = prop.{0};
1844 {1};
1845 }
1846)",
1847 Vals&: name, Vals: tgfmt(fmt: namedProperty->prop.getWriteToMlirBytecodeCall(), ctx: &fctx));
1848 continue;
1849 }
1850 const auto *namedAttr = dyn_cast<const AttributeMetadata *>(Val: attrOrProp);
1851 StringRef name = namedAttr->attrName;
1852 if (namedAttr->isRequired) {
1853 readPropertiesMethod << formatv(Fmt: R"(
1854 if (::mlir::failed(reader.readAttribute(prop.{0})))
1855 return ::mlir::failure();
1856)",
1857 Vals&: name);
1858 writePropertiesMethod
1859 << formatv(Fmt: " writer.writeAttribute(prop.{0});\n", Vals&: name);
1860 } else {
1861 readPropertiesMethod << formatv(Fmt: R"(
1862 if (::mlir::failed(reader.readOptionalAttribute(prop.{0})))
1863 return ::mlir::failure();
1864)",
1865 Vals&: name);
1866 writePropertiesMethod << formatv(Fmt: R"(
1867 writer.writeOptionalAttribute(prop.{0});
1868)",
1869 Vals&: name);
1870 }
1871 }
1872 readPropertiesMethod << " return ::mlir::success();";
1873}
1874
1875void OpEmitter::genPropGetters() {
1876 for (const NamedProperty &prop : op.getProperties()) {
1877 std::string name = op.getGetterName(name: prop.name);
1878 emitPropGetter(opClass, op, name, prop: prop.prop);
1879 }
1880}
1881
1882void OpEmitter::genPropSetters() {
1883 for (const NamedProperty &prop : op.getProperties()) {
1884 std::string name = op.getSetterName(name: prop.name);
1885 std::string argName = "new" + convertToCamelFromSnakeCase(
1886 input: prop.name, /*capitalizeFirst=*/true);
1887 auto *method = opClass.addInlineMethod(
1888 retType: "void", name, args: MethodParameter(prop.prop.getInterfaceType(), argName));
1889 if (!method)
1890 return;
1891 method->body() << formatv(Fmt: " getProperties().{0}({1});", Vals&: name, Vals&: argName);
1892 }
1893}
1894
1895void OpEmitter::genAttrGetters() {
1896 FmtContext fctx;
1897 fctx.withBuilder(subst: "::mlir::Builder((*this)->getContext())");
1898
1899 // Emit the derived attribute body.
1900 auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
1901 if (auto *method = opClass.addMethod(retType: attr.getReturnType(), name))
1902 method->body() << " " << attr.getDerivedCodeBody() << "\n";
1903 };
1904
1905 // Generate named accessor with Attribute return type. This is a wrapper
1906 // class that allows referring to the attributes via accessors instead of
1907 // having to use the string interface for better compile time verification.
1908 auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName,
1909 Attribute attr) {
1910 // The method body for this getter is trivial. Emit it inline.
1911 auto *method =
1912 opClass.addInlineMethod(retType: attr.getStorageType(), name: name + "Attr");
1913 if (!method)
1914 return;
1915 method->body() << formatv(
1916 Fmt: " return ::llvm::{1}<{2}>({0});", Vals: emitHelper.getAttr(attrName),
1917 Vals: attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null"
1918 : "cast",
1919 Vals: attr.getStorageType());
1920 };
1921
1922 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1923 std::string name = op.getGetterName(name: namedAttr.name);
1924 if (namedAttr.attr.isDerivedAttr()) {
1925 emitDerivedAttr(name, namedAttr.attr);
1926 } else {
1927 emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr);
1928 emitAttrGetterWithReturnType(fctx, opClass, op, name, attr: namedAttr.attr);
1929 }
1930 }
1931
1932 auto derivedAttrs = make_filter_range(Range: op.getAttributes(),
1933 Pred: [](const NamedAttribute &namedAttr) {
1934 return namedAttr.attr.isDerivedAttr();
1935 });
1936 if (derivedAttrs.empty())
1937 return;
1938
1939 opClass.addTrait(trait: "::mlir::DerivedAttributeOpInterface::Trait");
1940 // Generate helper method to query whether a named attribute is a derived
1941 // attribute. This enables, for example, avoiding adding an attribute that
1942 // overlaps with a derived attribute.
1943 {
1944 auto *method =
1945 opClass.addStaticMethod(retType: "bool", name: "isDerivedAttribute",
1946 args: MethodParameter("::llvm::StringRef", "name"));
1947 ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
1948 auto &body = method->body();
1949 for (auto namedAttr : derivedAttrs)
1950 body << " if (name == \"" << namedAttr.name << "\") return true;\n";
1951 body << " return false;";
1952 }
1953 // Generate method to materialize derived attributes as a DictionaryAttr.
1954 {
1955 auto *method = opClass.addMethod(retType: "::mlir::DictionaryAttr",
1956 name: "materializeDerivedAttributes");
1957 ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
1958 auto &body = method->body();
1959
1960 auto nonMaterializable =
1961 make_filter_range(Range&: derivedAttrs, Pred: [](const NamedAttribute &namedAttr) {
1962 return namedAttr.attr.getConvertFromStorageCall().empty();
1963 });
1964 if (!nonMaterializable.empty()) {
1965 std::string attrs;
1966 llvm::raw_string_ostream os(attrs);
1967 interleaveComma(c: nonMaterializable, os, each_fn: [&](const NamedAttribute &attr) {
1968 os << op.getGetterName(name: attr.name);
1969 });
1970 PrintWarning(
1971 WarningLoc: op.getLoc(),
1972 Msg: formatv(
1973 Fmt: "op has non-materializable derived attributes '{0}', skipping",
1974 Vals&: os.str()));
1975 body << formatv(Fmt: " emitOpError(\"op has non-materializable derived "
1976 "attributes '{0}'\");\n",
1977 Vals&: attrs);
1978 body << " return nullptr;";
1979 return;
1980 }
1981
1982 body << " ::mlir::MLIRContext* ctx = getContext();\n";
1983 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
1984 body << " return ::mlir::DictionaryAttr::get(";
1985 body << " ctx, {\n";
1986 interleave(
1987 c: derivedAttrs, os&: body,
1988 each_fn: [&](const NamedAttribute &namedAttr) {
1989 auto tmpl = namedAttr.attr.getConvertFromStorageCall();
1990 std::string name = op.getGetterName(name: namedAttr.name);
1991 body << " {" << name << "AttrName(),\n"
1992 << tgfmt(fmt: tmpl, ctx: &fctx.withSelf(subst: name + "()")
1993 .withBuilder(subst: "odsBuilder")
1994 .addSubst(placeholder: "_ctxt", subst: "ctx")
1995 .addSubst(placeholder: "_storage", subst: "ctx"))
1996 << "}";
1997 },
1998 separator: ",\n");
1999 body << "});";
2000 }
2001}
2002
2003void OpEmitter::genAttrSetters() {
2004 bool useProperties = op.getDialect().usePropertiesForAttributes();
2005
2006 // Generate the code to set an attribute.
2007 auto emitSetAttr = [&](Method *method, StringRef getterName,
2008 StringRef attrName, StringRef attrVar) {
2009 if (useProperties) {
2010 method->body() << formatv(Fmt: " getProperties().{0} = {1};", Vals&: attrName,
2011 Vals&: attrVar);
2012 } else {
2013 method->body() << formatv(Fmt: " (*this)->setAttr({0}AttrName(), {1});",
2014 Vals&: getterName, Vals&: attrVar);
2015 }
2016 };
2017
2018 // Generate raw named setter type. This is a wrapper class that allows setting
2019 // to the attributes via setters instead of having to use the string interface
2020 // for better compile time verification.
2021 auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
2022 StringRef attrName, Attribute attr) {
2023 // This method body is trivial, so emit it inline.
2024 auto *method =
2025 opClass.addInlineMethod(retType: "void", name: setterName + "Attr",
2026 args: MethodParameter(attr.getStorageType(), "attr"));
2027 if (method)
2028 emitSetAttr(method, getterName, attrName, "attr");
2029 };
2030
2031 // Generate a setter that accepts the underlying C++ type as opposed to the
2032 // attribute type.
2033 auto emitAttrWithReturnType = [&](StringRef setterName, StringRef getterName,
2034 StringRef attrName, Attribute attr) {
2035 Attribute baseAttr = attr.getBaseAttr();
2036 if (!canUseUnwrappedRawValue(attr: baseAttr))
2037 return;
2038 FmtContext fctx;
2039 fctx.withBuilder(subst: "::mlir::Builder((*this)->getContext())");
2040 bool isUnitAttr = attr.getAttrDefName() == "UnitAttr";
2041 bool isOptional = attr.isOptional();
2042
2043 auto createMethod = [&](const Twine &paramType) {
2044 return opClass.addMethod(retType: "void", name&: setterName,
2045 args: MethodParameter(paramType.str(), "attrValue"));
2046 };
2047
2048 // Build the method using the correct parameter type depending on
2049 // optionality.
2050 Method *method = nullptr;
2051 if (isUnitAttr)
2052 method = createMethod("bool");
2053 else if (isOptional)
2054 method =
2055 createMethod("::std::optional<" + baseAttr.getReturnType() + ">");
2056 else
2057 method = createMethod(attr.getReturnType());
2058 if (!method)
2059 return;
2060
2061 // If the value isn't optional, just set it directly.
2062 if (!isOptional) {
2063 emitSetAttr(method, getterName, attrName,
2064 constBuildAttrFromParam(attr, fctx, paramName: "attrValue"));
2065 return;
2066 }
2067
2068 // Otherwise, we only set if the provided value is valid. If it isn't, we
2069 // remove the attribute.
2070
2071 // TODO: Handle unit attr parameters specially, given that it is treated as
2072 // optional but not in the same way as the others (i.e. it uses bool over
2073 // std::optional<>).
2074 StringRef paramStr = isUnitAttr ? "attrValue" : "*attrValue";
2075 if (!useProperties) {
2076 const char *optionalCodeBody = R"(
2077 if (attrValue)
2078 return (*this)->setAttr({0}AttrName(), {1});
2079 (*this)->removeAttr({0}AttrName());)";
2080 method->body() << formatv(
2081 Fmt: optionalCodeBody, Vals&: getterName,
2082 Vals: constBuildAttrFromParam(attr: baseAttr, fctx, paramName: paramStr));
2083 } else {
2084 const char *optionalCodeBody = R"(
2085 auto &odsProp = getProperties().{0};
2086 if (attrValue)
2087 odsProp = {1};
2088 else
2089 odsProp = nullptr;)";
2090 method->body() << formatv(
2091 Fmt: optionalCodeBody, Vals&: attrName,
2092 Vals: constBuildAttrFromParam(attr: baseAttr, fctx, paramName: paramStr));
2093 }
2094 };
2095
2096 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2097 if (namedAttr.attr.isDerivedAttr())
2098 continue;
2099 std::string setterName = op.getSetterName(name: namedAttr.name);
2100 std::string getterName = op.getGetterName(name: namedAttr.name);
2101 emitAttrWithStorageType(setterName, getterName, namedAttr.name,
2102 namedAttr.attr);
2103 emitAttrWithReturnType(setterName, getterName, namedAttr.name,
2104 namedAttr.attr);
2105 }
2106}
2107
2108void OpEmitter::genOptionalAttrRemovers() {
2109 // Generate methods for removing optional attributes, instead of having to
2110 // use the string interface. Enables better compile time verification.
2111 auto emitRemoveAttr = [&](StringRef name, bool useProperties) {
2112 auto *method = opClass.addInlineMethod(retType: "::mlir::Attribute",
2113 name: op.getRemoverName(name) + "Attr");
2114 if (!method)
2115 return;
2116 if (useProperties) {
2117 method->body() << formatv(Fmt: R"(
2118 auto attr = getProperties().{0};
2119 getProperties().{0} = {{};
2120 return attr;
2121)",
2122 Vals&: name);
2123 return;
2124 }
2125 method->body() << formatv(Fmt: "return (*this)->removeAttr({0}AttrName());",
2126 Vals: op.getGetterName(name));
2127 };
2128
2129 for (const NamedAttribute &namedAttr : op.getAttributes())
2130 if (namedAttr.attr.isOptional())
2131 emitRemoveAttr(namedAttr.name,
2132 op.getDialect().usePropertiesForAttributes());
2133}
2134
2135// Generates the code to compute the start and end index of an operand or result
2136// range.
2137template <typename RangeT>
2138static void generateValueRangeStartAndEnd(
2139 Class &opClass, bool isGenericAdaptorBase, StringRef methodName,
2140 int numVariadic, int numNonVariadic, StringRef rangeSizeCall,
2141 bool hasAttrSegmentSize, StringRef sizeAttrInit, RangeT &&odsValues) {
2142
2143 SmallVector<MethodParameter> parameters{MethodParameter("unsigned", "index")};
2144 if (isGenericAdaptorBase) {
2145 parameters.emplace_back(Args: "unsigned", Args: "odsOperandsSize");
2146 // The range size is passed per parameter for generic adaptor bases as
2147 // using the rangeSizeCall would require the operands, which are not
2148 // accessible in the base class.
2149 rangeSizeCall = "odsOperandsSize";
2150 }
2151
2152 // The method is trivial if the operation does not have any variadic operands.
2153 // In that case, make sure to generate it in-line.
2154 auto *method = opClass.addMethod(retType: "std::pair<unsigned, unsigned>", name&: methodName,
2155 properties: numVariadic == 0 ? Method::Properties::Inline
2156 : Method::Properties::None,
2157 args&: parameters);
2158 if (!method)
2159 return;
2160 auto &body = method->body();
2161 if (numVariadic == 0) {
2162 body << " return {index, 1};\n";
2163 } else if (hasAttrSegmentSize) {
2164 body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
2165 } else {
2166 // Because the op can have arbitrarily interleaved variadic and non-variadic
2167 // operands, we need to embed a list in the "sink" getter method for
2168 // calculation at run-time.
2169 SmallVector<StringRef, 4> isVariadic;
2170 isVariadic.reserve(N: llvm::size(odsValues));
2171 for (auto &it : odsValues)
2172 isVariadic.push_back(Elt: it.isVariableLength() ? "true" : "false");
2173 std::string isVariadicList = llvm::join(R&: isVariadic, Separator: ", ");
2174 body << formatv(Fmt: sameVariadicSizeValueRangeCalcCode, Vals&: isVariadicList,
2175 Vals&: numNonVariadic, Vals&: numVariadic, Vals&: rangeSizeCall, Vals: "operand");
2176 }
2177}
2178
2179static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
2180 return llvm::formatv(Fmt: "::mlir::TypedValue<{0}>", Vals: value.constraint.getCppType())
2181 .str();
2182}
2183
2184// Generates the named operand getter methods for the given Operator `op` and
2185// puts them in `opClass`. Uses `rangeType` as the return type of getters that
2186// return a range of operands (individual operands are `Value ` and each
2187// element in the range must also be `Value `); use `rangeBeginCall` to get
2188// an iterator to the beginning of the operand range; use `rangeSizeCall` to
2189// obtain the number of operands. `getOperandCallPattern` contains the code
2190// necessary to obtain a single operand whose position will be substituted
2191// instead of
2192// "{0}" marker in the pattern. Note that the pattern should work for any kind
2193// of ops, in particular for one-operand ops that may not have the
2194// `getOperand(unsigned)` method.
2195static void
2196generateNamedOperandGetters(const Operator &op, Class &opClass,
2197 Class *genericAdaptorBase, StringRef sizeAttrInit,
2198 StringRef rangeType, StringRef rangeElementType,
2199 StringRef rangeBeginCall, StringRef rangeSizeCall,
2200 StringRef getOperandCallPattern) {
2201 const int numOperands = op.getNumOperands();
2202 const int numVariadicOperands = op.getNumVariableLengthOperands();
2203 const int numNormalOperands = numOperands - numVariadicOperands;
2204
2205 const auto *sameVariadicSize =
2206 op.getTrait(trait: "::mlir::OpTrait::SameVariadicOperandSize");
2207 const auto *attrSizedOperands =
2208 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments");
2209
2210 if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
2211 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "op has multiple variadic operands but no "
2212 "specification over their sizes");
2213 }
2214
2215 if (numVariadicOperands < 2 && attrSizedOperands) {
2216 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "op must have at least two variadic operands "
2217 "to use 'AttrSizedOperandSegments' trait");
2218 }
2219
2220 if (attrSizedOperands && sameVariadicSize) {
2221 PrintFatalError(ErrorLoc: op.getLoc(),
2222 Msg: "op cannot have both 'AttrSizedOperandSegments' and "
2223 "'SameVariadicOperandSize' traits");
2224 }
2225
2226 // First emit a few "sink" getter methods upon which we layer all nicer named
2227 // getter methods.
2228 // If generating for an adaptor, the method is put into the non-templated
2229 // generic base class, to not require being defined in the header.
2230 // Since the operand size can't be determined from the base class however,
2231 // it has to be passed as an additional argument. The trampoline below
2232 // generates the function with the same signature as the Op in the generic
2233 // adaptor.
2234 bool isGenericAdaptorBase = genericAdaptorBase != nullptr;
2235 generateValueRangeStartAndEnd(
2236 /*opClass=*/isGenericAdaptorBase ? *genericAdaptorBase : opClass,
2237 isGenericAdaptorBase,
2238 /*methodName=*/"getODSOperandIndexAndLength", numVariadic: numVariadicOperands,
2239 numNonVariadic: numNormalOperands, rangeSizeCall, hasAttrSegmentSize: attrSizedOperands, sizeAttrInit,
2240 odsValues: const_cast<Operator &>(op).getOperands());
2241 if (isGenericAdaptorBase) {
2242 // Generate trampoline for calling 'getODSOperandIndexAndLength' with just
2243 // the index. This just calls the implementation in the base class but
2244 // passes the operand size as parameter.
2245 Method *method = opClass.addInlineMethod(
2246 retType: "std::pair<unsigned, unsigned>", name: "getODSOperandIndexAndLength",
2247 args: MethodParameter("unsigned", "index"));
2248 ERROR_IF_PRUNED(method, "getODSOperandIndexAndLength", op);
2249 MethodBody &body = method->body();
2250 body.indent() << formatv(
2251 Fmt: "return Base::getODSOperandIndexAndLength(index, {0});", Vals&: rangeSizeCall);
2252 }
2253
2254 // The implementation of this method is trivial and it is very load-bearing.
2255 // Generate it inline.
2256 auto *m = opClass.addInlineMethod(retType&: rangeType, name: "getODSOperands",
2257 args: MethodParameter("unsigned", "index"));
2258 ERROR_IF_PRUNED(m, "getODSOperands", op);
2259 auto &body = m->body();
2260 body << formatv(Fmt: valueRangeReturnCode, Vals&: rangeBeginCall,
2261 Vals: "getODSOperandIndexAndLength(index)");
2262
2263 // Then we emit nicer named getter methods by redirecting to the "sink" getter
2264 // method.
2265 for (int i = 0; i != numOperands; ++i) {
2266 const auto &operand = op.getOperand(index: i);
2267 if (operand.name.empty())
2268 continue;
2269 std::string name = op.getGetterName(name: operand.name);
2270 if (operand.isOptional()) {
2271 m = opClass.addInlineMethod(retType: isGenericAdaptorBase
2272 ? rangeElementType
2273 : generateTypeForGetter(value: operand),
2274 name);
2275 ERROR_IF_PRUNED(m, name, op);
2276 m->body().indent() << formatv(Fmt: "auto operands = getODSOperands({0});\n"
2277 "return operands.empty() ? {1}{{} : ",
2278 Vals&: i, Vals: m->getReturnType());
2279 if (!isGenericAdaptorBase)
2280 m->body() << llvm::formatv(Fmt: "::llvm::cast<{0}>", Vals: m->getReturnType());
2281 m->body() << "(*operands.begin());";
2282 } else if (operand.isVariadicOfVariadic()) {
2283 std::string segmentAttr = op.getGetterName(
2284 name: operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
2285 if (genericAdaptorBase) {
2286 m = opClass.addMethod(retType: "::llvm::SmallVector<" + rangeType + ">", name);
2287 ERROR_IF_PRUNED(m, name, op);
2288 m->body() << llvm::formatv(Fmt: variadicOfVariadicAdaptorCalcCode,
2289 Vals&: segmentAttr, Vals&: i, Vals&: rangeType);
2290 continue;
2291 }
2292
2293 m = opClass.addInlineMethod(retType: "::mlir::OperandRangeRange", name);
2294 ERROR_IF_PRUNED(m, name, op);
2295 m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr
2296 << "Attr());";
2297 } else if (operand.isVariadic()) {
2298 m = opClass.addInlineMethod(retType&: rangeType, name);
2299 ERROR_IF_PRUNED(m, name, op);
2300 m->body() << " return getODSOperands(" << i << ");";
2301 } else {
2302 m = opClass.addInlineMethod(retType: isGenericAdaptorBase
2303 ? rangeElementType
2304 : generateTypeForGetter(value: operand),
2305 name);
2306 ERROR_IF_PRUNED(m, name, op);
2307 m->body().indent() << "return ";
2308 if (!isGenericAdaptorBase)
2309 m->body() << llvm::formatv(Fmt: "::llvm::cast<{0}>", Vals: m->getReturnType());
2310 m->body() << llvm::formatv(Fmt: "(*getODSOperands({0}).begin());", Vals&: i);
2311 }
2312 }
2313}
2314
2315void OpEmitter::genNamedOperandGetters() {
2316 // Build the code snippet used for initializing the operand_segment_size)s
2317 // array.
2318 std::string attrSizeInitCode;
2319 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
2320 if (op.getDialect().usePropertiesForAttributes())
2321 attrSizeInitCode = formatv(Fmt: adapterSegmentSizeAttrInitCodeProperties,
2322 Vals: "getProperties().operandSegmentSizes");
2323
2324 else
2325 attrSizeInitCode = formatv(Fmt: opSegmentSizeAttrInitCode,
2326 Vals: emitHelper.getAttr(attrName: operandSegmentAttrName));
2327 }
2328
2329 generateNamedOperandGetters(
2330 op, opClass,
2331 /*genericAdaptorBase=*/nullptr,
2332 /*sizeAttrInit=*/attrSizeInitCode,
2333 /*rangeType=*/"::mlir::Operation::operand_range",
2334 /*rangeElementType=*/"::mlir::Value",
2335 /*rangeBeginCall=*/"getOperation()->operand_begin()",
2336 /*rangeSizeCall=*/"getOperation()->getNumOperands()",
2337 /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
2338}
2339
2340void OpEmitter::genNamedOperandSetters() {
2341 auto *attrSizedOperands =
2342 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments");
2343 for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
2344 const auto &operand = op.getOperand(index: i);
2345 if (operand.name.empty())
2346 continue;
2347 std::string name = op.getGetterName(name: operand.name);
2348
2349 StringRef returnType;
2350 if (operand.isVariadicOfVariadic()) {
2351 returnType = "::mlir::MutableOperandRangeRange";
2352 } else if (operand.isVariableLength()) {
2353 returnType = "::mlir::MutableOperandRange";
2354 } else {
2355 returnType = "::mlir::OpOperand &";
2356 }
2357 bool isVariadicOperand =
2358 operand.isVariadicOfVariadic() || operand.isVariableLength();
2359 auto *m = opClass.addMethod(retType&: returnType, name: name + "Mutable",
2360 properties: isVariadicOperand ? Method::Properties::None
2361 : Method::Properties::Inline);
2362 ERROR_IF_PRUNED(m, name, op);
2363 auto &body = m->body();
2364 body << " auto range = getODSOperandIndexAndLength(" << i << ");\n";
2365
2366 if (!isVariadicOperand) {
2367 // In case of a single operand, return a single OpOperand.
2368 body << " return getOperation()->getOpOperand(range.first);\n";
2369 continue;
2370 }
2371
2372 body << " auto mutableRange = "
2373 "::mlir::MutableOperandRange(getOperation(), "
2374 "range.first, range.second";
2375 if (attrSizedOperands) {
2376 if (emitHelper.hasProperties())
2377 body << formatv(Fmt: ", ::mlir::MutableOperandRange::OperandSegment({0}u, "
2378 "{{getOperandSegmentSizesAttrName(), "
2379 "::mlir::DenseI32ArrayAttr::get(getContext(), "
2380 "getProperties().operandSegmentSizes)})",
2381 Vals&: i);
2382 else
2383 body << formatv(
2384 Fmt: ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", Vals&: i,
2385 Vals: emitHelper.getAttr(attrName: operandSegmentAttrName, /*isNamed=*/true));
2386 }
2387 body << ");\n";
2388
2389 // If this operand is a nested variadic, we split the range into a
2390 // MutableOperandRangeRange that provides a range over all of the
2391 // sub-ranges.
2392 if (operand.isVariadicOfVariadic()) {
2393 body << " return "
2394 "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
2395 << op.getGetterName(
2396 name: operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
2397 << "AttrName()));\n";
2398 } else {
2399 // Otherwise, we use the full range directly.
2400 body << " return mutableRange;\n";
2401 }
2402 }
2403}
2404
2405void OpEmitter::genNamedResultGetters() {
2406 const int numResults = op.getNumResults();
2407 const int numVariadicResults = op.getNumVariableLengthResults();
2408 const int numNormalResults = numResults - numVariadicResults;
2409
2410 // If we have more than one variadic results, we need more complicated logic
2411 // to calculate the value range for each result.
2412
2413 const auto *sameVariadicSize =
2414 op.getTrait(trait: "::mlir::OpTrait::SameVariadicResultSize");
2415 const auto *attrSizedResults =
2416 op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments");
2417
2418 if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
2419 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "op has multiple variadic results but no "
2420 "specification over their sizes");
2421 }
2422
2423 if (numVariadicResults < 2 && attrSizedResults) {
2424 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "op must have at least two variadic results "
2425 "to use 'AttrSizedResultSegments' trait");
2426 }
2427
2428 if (attrSizedResults && sameVariadicSize) {
2429 PrintFatalError(ErrorLoc: op.getLoc(),
2430 Msg: "op cannot have both 'AttrSizedResultSegments' and "
2431 "'SameVariadicResultSize' traits");
2432 }
2433
2434 // Build the initializer string for the result segment size attribute.
2435 std::string attrSizeInitCode;
2436 if (attrSizedResults) {
2437 if (op.getDialect().usePropertiesForAttributes())
2438 attrSizeInitCode = formatv(Fmt: adapterSegmentSizeAttrInitCodeProperties,
2439 Vals: "getProperties().resultSegmentSizes");
2440
2441 else
2442 attrSizeInitCode = formatv(Fmt: opSegmentSizeAttrInitCode,
2443 Vals: emitHelper.getAttr(attrName: resultSegmentAttrName));
2444 }
2445
2446 generateValueRangeStartAndEnd(
2447 opClass, /*isGenericAdaptorBase=*/false, methodName: "getODSResultIndexAndLength",
2448 numVariadic: numVariadicResults, numNonVariadic: numNormalResults, rangeSizeCall: "getOperation()->getNumResults()",
2449 hasAttrSegmentSize: attrSizedResults, sizeAttrInit: attrSizeInitCode, odsValues: op.getResults());
2450
2451 // The implementation of this method is trivial and it is very load-bearing.
2452 // Generate it inline.
2453 auto *m = opClass.addInlineMethod(retType: "::mlir::Operation::result_range",
2454 name: "getODSResults",
2455 args: MethodParameter("unsigned", "index"));
2456 ERROR_IF_PRUNED(m, "getODSResults", op);
2457 m->body() << formatv(Fmt: valueRangeReturnCode, Vals: "getOperation()->result_begin()",
2458 Vals: "getODSResultIndexAndLength(index)");
2459
2460 for (int i = 0; i != numResults; ++i) {
2461 const auto &result = op.getResult(index: i);
2462 if (result.name.empty())
2463 continue;
2464 std::string name = op.getGetterName(name: result.name);
2465 if (result.isOptional()) {
2466 m = opClass.addInlineMethod(retType: generateTypeForGetter(value: result), name);
2467 ERROR_IF_PRUNED(m, name, op);
2468 m->body() << " auto results = getODSResults(" << i << ");\n"
2469 << llvm::formatv(Fmt: " return results.empty()"
2470 " ? {0}()"
2471 " : ::llvm::cast<{0}>(*results.begin());",
2472 Vals: m->getReturnType());
2473 } else if (result.isVariadic()) {
2474 m = opClass.addInlineMethod(retType: "::mlir::Operation::result_range", name);
2475 ERROR_IF_PRUNED(m, name, op);
2476 m->body() << " return getODSResults(" << i << ");";
2477 } else {
2478 m = opClass.addInlineMethod(retType: generateTypeForGetter(value: result), name);
2479 ERROR_IF_PRUNED(m, name, op);
2480 m->body() << llvm::formatv(
2481 Fmt: " return ::llvm::cast<{0}>(*getODSResults({1}).begin());",
2482 Vals: m->getReturnType(), Vals&: i);
2483 }
2484 }
2485}
2486
2487void OpEmitter::genNamedRegionGetters() {
2488 unsigned numRegions = op.getNumRegions();
2489 for (unsigned i = 0; i < numRegions; ++i) {
2490 const auto &region = op.getRegion(index: i);
2491 if (region.name.empty())
2492 continue;
2493 std::string name = op.getGetterName(name: region.name);
2494
2495 // Generate the accessors for a variadic region.
2496 if (region.isVariadic()) {
2497 auto *m = opClass.addInlineMethod(
2498 retType: "::mlir::MutableArrayRef<::mlir::Region>", name);
2499 ERROR_IF_PRUNED(m, name, op);
2500 m->body() << formatv(Fmt: " return (*this)->getRegions().drop_front({0});",
2501 Vals&: i);
2502 continue;
2503 }
2504
2505 auto *m = opClass.addInlineMethod(retType: "::mlir::Region &", name);
2506 ERROR_IF_PRUNED(m, name, op);
2507 m->body() << formatv(Fmt: " return (*this)->getRegion({0});", Vals&: i);
2508 }
2509}
2510
2511void OpEmitter::genNamedSuccessorGetters() {
2512 unsigned numSuccessors = op.getNumSuccessors();
2513 for (unsigned i = 0; i < numSuccessors; ++i) {
2514 const NamedSuccessor &successor = op.getSuccessor(index: i);
2515 if (successor.name.empty())
2516 continue;
2517 std::string name = op.getGetterName(name: successor.name);
2518 // Generate the accessors for a variadic successor list.
2519 if (successor.isVariadic()) {
2520 auto *m = opClass.addInlineMethod(retType: "::mlir::SuccessorRange", name);
2521 ERROR_IF_PRUNED(m, name, op);
2522 m->body() << formatv(
2523 Fmt: " return {std::next((*this)->successor_begin(), {0}), "
2524 "(*this)->successor_end()};",
2525 Vals&: i);
2526 continue;
2527 }
2528
2529 auto *m = opClass.addInlineMethod(retType: "::mlir::Block *", name);
2530 ERROR_IF_PRUNED(m, name, op);
2531 m->body() << formatv(Fmt: " return (*this)->getSuccessor({0});", Vals&: i);
2532 }
2533}
2534
2535static bool canGenerateUnwrappedBuilder(const Operator &op) {
2536 // If this op does not have native attributes at all, return directly to avoid
2537 // redefining builders.
2538 if (op.getNumNativeAttributes() == 0)
2539 return false;
2540
2541 bool canGenerate = false;
2542 // We are generating builders that take raw values for attributes. We need to
2543 // make sure the native attributes have a meaningful "unwrapped" value type
2544 // different from the wrapped mlir::Attribute type to avoid redefining
2545 // builders. This checks for the op has at least one such native attribute.
2546 for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
2547 const NamedAttribute &namedAttr = op.getAttribute(index: i);
2548 if (canUseUnwrappedRawValue(attr: namedAttr.attr)) {
2549 canGenerate = true;
2550 break;
2551 }
2552 }
2553 return canGenerate;
2554}
2555
2556static bool canInferType(const Operator &op) {
2557 return op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait");
2558}
2559
2560void OpEmitter::genSeparateArgParamBuilder() {
2561 SmallVector<AttrParamKind, 2> attrBuilderType;
2562 attrBuilderType.push_back(Elt: AttrParamKind::WrappedAttr);
2563 if (canGenerateUnwrappedBuilder(op))
2564 attrBuilderType.push_back(Elt: AttrParamKind::UnwrappedValue);
2565
2566 // Emit with separate builders with or without unwrapped attributes and/or
2567 // inferring result type.
2568 auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
2569 bool inferType) {
2570 SmallVector<MethodParameter> paramList;
2571 SmallVector<std::string, 4> resultNames;
2572 llvm::StringSet<> inferredAttributes;
2573 buildParamList(paramList, inferredAttributes, resultTypeNames&: resultNames, typeParamKind: paramKind,
2574 attrParamKind: attrType);
2575
2576 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args: std::move(paramList));
2577 // If the builder is redundant, skip generating the method.
2578 if (!m)
2579 return;
2580 auto &body = m->body();
2581 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
2582 /*isRawValueAttr=*/attrType ==
2583 AttrParamKind::UnwrappedValue);
2584
2585 // Push all result types to the operation state
2586
2587 if (inferType) {
2588 // Generate builder that infers type too.
2589 // TODO: Subsume this with general checking if type can be
2590 // inferred automatically.
2591 body << formatv(Fmt: R"(
2592 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2593 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2594 {1}.location, {1}.operands,
2595 {1}.attributes.getDictionary({1}.getContext()),
2596 {1}.getRawProperties(),
2597 {1}.regions, inferredReturnTypes)))
2598 {1}.addTypes(inferredReturnTypes);
2599 else
2600 ::mlir::detail::reportFatalInferReturnTypesError({1});
2601 )",
2602 Vals: opClass.getClassName(), Vals: builderOpState);
2603 return;
2604 }
2605
2606 switch (paramKind) {
2607 case TypeParamKind::None:
2608 return;
2609 case TypeParamKind::Separate:
2610 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
2611 if (op.getResult(index: i).isOptional())
2612 body << " if (" << resultNames[i] << ")\n ";
2613 body << " " << builderOpState << ".addTypes(" << resultNames[i]
2614 << ");\n";
2615 }
2616
2617 // Automatically create the 'resultSegmentSizes' attribute using
2618 // the length of the type ranges.
2619 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments")) {
2620 if (op.getDialect().usePropertiesForAttributes()) {
2621 body << " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
2622 } else {
2623 std::string getterName = op.getGetterName(name: resultSegmentAttrName);
2624 body << " " << builderOpState << ".addAttribute(" << getterName
2625 << "AttrName(" << builderOpState << ".name), "
2626 << "odsBuilder.getDenseI32ArrayAttr({";
2627 }
2628 interleaveComma(
2629 c: llvm::seq<int>(Begin: 0, End: op.getNumResults()), os&: body, each_fn: [&](int i) {
2630 const NamedTypeConstraint &result = op.getResult(index: i);
2631 if (!result.isVariableLength()) {
2632 body << "1";
2633 } else if (result.isOptional()) {
2634 body << "(" << resultNames[i] << " ? 1 : 0)";
2635 } else {
2636 // VariadicOfVariadic of results are currently unsupported in
2637 // MLIR, hence it can only be a simple variadic.
2638 // TODO: Add implementation for VariadicOfVariadic results here
2639 // once supported.
2640 assert(result.isVariadic());
2641 body << "static_cast<int32_t>(" << resultNames[i] << ".size())";
2642 }
2643 });
2644 if (op.getDialect().usePropertiesForAttributes()) {
2645 body << "}), " << builderOpState
2646 << ".getOrAddProperties<Properties>()."
2647 "resultSegmentSizes.begin());\n";
2648 } else {
2649 body << "}));\n";
2650 }
2651 }
2652
2653 return;
2654 case TypeParamKind::Collective: {
2655 int numResults = op.getNumResults();
2656 int numVariadicResults = op.getNumVariableLengthResults();
2657 int numNonVariadicResults = numResults - numVariadicResults;
2658 bool hasVariadicResult = numVariadicResults != 0;
2659
2660 // Avoid emitting "resultTypes.size() >= 0u" which is always true.
2661 if (!hasVariadicResult || numNonVariadicResults != 0)
2662 body << " " << "assert(resultTypes.size() "
2663 << (hasVariadicResult ? ">=" : "==") << " "
2664 << numNonVariadicResults
2665 << "u && \"mismatched number of results\");\n";
2666 body << " " << builderOpState << ".addTypes(resultTypes);\n";
2667 }
2668 return;
2669 }
2670 llvm_unreachable("unhandled TypeParamKind");
2671 };
2672
2673 // Some of the build methods generated here may be ambiguous, but TableGen's
2674 // ambiguous function detection will elide those ones.
2675 for (auto attrType : attrBuilderType) {
2676 emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
2677 if (canInferType(op))
2678 emit(attrType, TypeParamKind::None, /*inferType=*/true);
2679 emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
2680 }
2681}
2682
2683void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder(
2684 CollectiveBuilderKind kind) {
2685 int numResults = op.getNumResults();
2686
2687 // Signature
2688 SmallVector<MethodParameter> paramList;
2689 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: "odsBuilder");
2690 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
2691 paramList.emplace_back(Args: "::mlir::ValueRange", Args: "operands");
2692 if (kind == CollectiveBuilderKind::PropStruct)
2693 paramList.emplace_back(Args: "const Properties &", Args: "properties");
2694 // Provide default value for `attributes` when its the last parameter
2695 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2696 StringRef attributesName = kind == CollectiveBuilderKind::PropStruct
2697 ? "discardableAttributes"
2698 : "attributes";
2699 paramList.emplace_back(Args: "::llvm::ArrayRef<::mlir::NamedAttribute>",
2700 Args&: attributesName, Args&: attributesDefaultValue);
2701 if (op.getNumVariadicRegions())
2702 paramList.emplace_back(Args: "unsigned", Args: "numRegions");
2703
2704 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args: std::move(paramList));
2705 // If the builder is redundant, skip generating the method
2706 if (!m)
2707 return;
2708 auto &body = m->body();
2709
2710 // Operands
2711 body << " " << builderOpState << ".addOperands(operands);\n";
2712
2713 if (kind == CollectiveBuilderKind::PropStruct)
2714 body << " " << builderOpState
2715 << ".useProperties(const_cast<Properties&>(properties));\n";
2716 // Attributes
2717 body << " " << builderOpState << ".addAttributes(" << attributesName
2718 << ");\n";
2719
2720 // Create the correct number of regions
2721 if (int numRegions = op.getNumRegions()) {
2722 body << llvm::formatv(
2723 Fmt: " for (unsigned i = 0; i != {0}; ++i)\n",
2724 Vals: (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2725 body << " (void)" << builderOpState << ".addRegion();\n";
2726 }
2727
2728 // Result types
2729 SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
2730 body << " " << builderOpState << ".addTypes({"
2731 << llvm::join(R&: resultTypes, Separator: ", ") << "});\n\n";
2732}
2733
2734void OpEmitter::genPopulateDefaultAttributes() {
2735 // All done if no attributes, except optional ones, have default values.
2736 if (llvm::all_of(Range: op.getAttributes(), P: [](const NamedAttribute &named) {
2737 return !named.attr.hasDefaultValue() || named.attr.isOptional();
2738 }))
2739 return;
2740
2741 if (emitHelper.hasProperties()) {
2742 SmallVector<MethodParameter> paramList;
2743 paramList.emplace_back(Args: "::mlir::OperationName", Args: "opName");
2744 paramList.emplace_back(Args: "Properties &", Args: "properties");
2745 auto *m =
2746 opClass.addStaticMethod(retType: "void", name: "populateDefaultProperties", args&: paramList);
2747 ERROR_IF_PRUNED(m, "populateDefaultProperties", op);
2748 auto &body = m->body();
2749 body.indent();
2750 body << "::mlir::Builder " << odsBuilder << "(opName.getContext());\n";
2751 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2752 auto &attr = namedAttr.attr;
2753 if (!attr.hasDefaultValue() || attr.isOptional())
2754 continue;
2755 StringRef name = namedAttr.name;
2756 FmtContext fctx;
2757 fctx.withBuilder(subst: odsBuilder);
2758 body << "if (!properties." << name << ")\n"
2759 << " properties." << name << " = "
2760 << std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx,
2761 vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx)))
2762 << ";\n";
2763 }
2764 return;
2765 }
2766
2767 SmallVector<MethodParameter> paramList;
2768 paramList.emplace_back(Args: "const ::mlir::OperationName &", Args: "opName");
2769 paramList.emplace_back(Args: "::mlir::NamedAttrList &", Args: "attributes");
2770 auto *m = opClass.addStaticMethod(retType: "void", name: "populateDefaultAttrs", args&: paramList);
2771 ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
2772 auto &body = m->body();
2773 body.indent();
2774
2775 // Set default attributes that are unset.
2776 body << "auto attrNames = opName.getAttributeNames();\n";
2777 body << "::mlir::Builder " << odsBuilder
2778 << "(attrNames.front().getContext());\n";
2779 StringMap<int> attrIndex;
2780 for (const auto &it : llvm::enumerate(First: emitHelper.getAttrMetadata())) {
2781 attrIndex[it.value().first] = it.index();
2782 }
2783 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2784 auto &attr = namedAttr.attr;
2785 if (!attr.hasDefaultValue() || attr.isOptional())
2786 continue;
2787 auto index = attrIndex[namedAttr.name];
2788 body << "if (!attributes.get(attrNames[" << index << "])) {\n";
2789 FmtContext fctx;
2790 fctx.withBuilder(subst: odsBuilder);
2791
2792 std::string defaultValue =
2793 std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx,
2794 vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx)));
2795 body.indent() << formatv(Fmt: "attributes.append(attrNames[{0}], {1});\n", Vals&: index,
2796 Vals&: defaultValue);
2797 body.unindent() << "}\n";
2798 }
2799}
2800
2801void OpEmitter::genInferredTypeCollectiveParamBuilder(
2802 CollectiveBuilderKind kind) {
2803 SmallVector<MethodParameter> paramList;
2804 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: "odsBuilder");
2805 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
2806 paramList.emplace_back(Args: "::mlir::ValueRange", Args: "operands");
2807 if (kind == CollectiveBuilderKind::PropStruct)
2808 paramList.emplace_back(Args: "const Properties &", Args: "properties");
2809 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2810 StringRef attributesName = kind == CollectiveBuilderKind::PropStruct
2811 ? "discardableAttributes"
2812 : "attributes";
2813 paramList.emplace_back(Args: "::llvm::ArrayRef<::mlir::NamedAttribute>",
2814 Args&: attributesName, Args&: attributesDefaultValue);
2815 if (op.getNumVariadicRegions())
2816 paramList.emplace_back(Args: "unsigned", Args: "numRegions");
2817
2818 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args: std::move(paramList));
2819 // If the builder is redundant, skip generating the method
2820 if (!m)
2821 return;
2822 auto &body = m->body();
2823
2824 int numResults = op.getNumResults();
2825 int numVariadicResults = op.getNumVariableLengthResults();
2826 int numNonVariadicResults = numResults - numVariadicResults;
2827
2828 int numOperands = op.getNumOperands();
2829 int numVariadicOperands = op.getNumVariableLengthOperands();
2830 int numNonVariadicOperands = numOperands - numVariadicOperands;
2831
2832 // Operands
2833 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
2834 body << " assert(operands.size()"
2835 << (numVariadicOperands != 0 ? " >= " : " == ")
2836 << numNonVariadicOperands
2837 << "u && \"mismatched number of parameters\");\n";
2838 body << " " << builderOpState << ".addOperands(operands);\n";
2839 if (kind == CollectiveBuilderKind::PropStruct)
2840 body << " " << builderOpState
2841 << ".useProperties(const_cast<Properties &>(properties));\n";
2842 body << " " << builderOpState << ".addAttributes(" << attributesName
2843 << ");\n";
2844
2845 // Create the correct number of regions
2846 if (int numRegions = op.getNumRegions()) {
2847 body << llvm::formatv(
2848 Fmt: " for (unsigned i = 0; i != {0}; ++i)\n",
2849 Vals: (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2850 body << " (void)" << builderOpState << ".addRegion();\n";
2851 }
2852
2853 // Result types
2854 if (emitHelper.hasNonEmptyPropertiesStruct() &&
2855 kind == CollectiveBuilderKind::AttrDict) {
2856 // Initialize the properties from Attributes before invoking the infer
2857 // function.
2858 body << formatv(Fmt: R"(
2859 if (!attributes.empty()) {
2860 ::mlir::OpaqueProperties properties =
2861 &{1}.getOrAddProperties<{0}::Properties>();
2862 std::optional<::mlir::RegisteredOperationName> info =
2863 {1}.name.getRegisteredInfo();
2864 if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
2865 {1}.attributes.getDictionary({1}.getContext()), nullptr)))
2866 ::llvm::report_fatal_error("Property conversion failed.");
2867 })",
2868 Vals: opClass.getClassName(), Vals: builderOpState);
2869 }
2870 body << formatv(Fmt: R"(
2871 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2872 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2873 {1}.location, operands,
2874 {1}.attributes.getDictionary({1}.getContext()),
2875 {1}.getRawProperties(),
2876 {1}.regions, inferredReturnTypes))) {{)",
2877 Vals: opClass.getClassName(), Vals: builderOpState);
2878 if (numVariadicResults == 0 || numNonVariadicResults != 0)
2879 body << "\n assert(inferredReturnTypes.size()"
2880 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
2881 << "u && \"mismatched number of return types\");";
2882 body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);";
2883
2884 body << R"(
2885 } else {
2886 ::llvm::report_fatal_error("Failed to infer result type(s).");
2887 })";
2888}
2889
2890void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
2891 auto emit = [&](AttrParamKind attrType) {
2892 SmallVector<MethodParameter> paramList;
2893 SmallVector<std::string, 4> resultNames;
2894 llvm::StringSet<> inferredAttributes;
2895 buildParamList(paramList, inferredAttributes, resultTypeNames&: resultNames,
2896 typeParamKind: TypeParamKind::None, attrParamKind: attrType);
2897
2898 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args: std::move(paramList));
2899 // If the builder is redundant, skip generating the method
2900 if (!m)
2901 return;
2902 auto &body = m->body();
2903 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
2904 /*isRawValueAttr=*/attrType ==
2905 AttrParamKind::UnwrappedValue);
2906
2907 auto numResults = op.getNumResults();
2908 if (numResults == 0)
2909 return;
2910
2911 // Push all result types to the operation state
2912 const char *index = op.getOperand(index: 0).isVariadic() ? ".front()" : "";
2913 std::string resultType =
2914 formatv(Fmt: "{0}{1}.getType()", Vals: getArgumentName(op, index: 0), Vals&: index).str();
2915 body << " " << builderOpState << ".addTypes({" << resultType;
2916 for (int i = 1; i != numResults; ++i)
2917 body << ", " << resultType;
2918 body << "});\n\n";
2919 };
2920
2921 emit(AttrParamKind::WrappedAttr);
2922 // Generate additional builder(s) if any attributes can be "unwrapped"
2923 if (canGenerateUnwrappedBuilder(op))
2924 emit(AttrParamKind::UnwrappedValue);
2925}
2926
2927void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(
2928 CollectiveBuilderKind kind) {
2929 SmallVector<MethodParameter> paramList;
2930 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: "odsBuilder");
2931 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
2932 paramList.emplace_back(Args: "::mlir::ValueRange", Args: "operands");
2933 if (kind == CollectiveBuilderKind::PropStruct)
2934 paramList.emplace_back(Args: "const Properties &", Args: "properties");
2935 StringRef attributesName = kind == CollectiveBuilderKind::PropStruct
2936 ? "discardableAttributes"
2937 : "attributes";
2938 paramList.emplace_back(Args: "::llvm::ArrayRef<::mlir::NamedAttribute>",
2939 Args&: attributesName, Args: "{}");
2940 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args: std::move(paramList));
2941 // If the builder is redundant, skip generating the method
2942 if (!m)
2943 return;
2944
2945 auto &body = m->body();
2946
2947 // Push all result types to the operation state
2948 std::string resultType;
2949 const auto &namedAttr = op.getAttribute(index: 0);
2950
2951 if (namedAttr.attr.isTypeAttr()) {
2952 resultType = "::llvm::cast<::mlir::TypeAttr>(typeAttr).getValue()";
2953 } else {
2954 resultType = "::llvm::cast<::mlir::TypedAttr>(typeAttr).getType()";
2955 }
2956
2957 if (kind == CollectiveBuilderKind::PropStruct) {
2958 body << " ::mlir::Attribute typeAttr = properties."
2959 << op.getGetterName(name: namedAttr.name) << "();\n";
2960 } else {
2961 body << " ::mlir::Attribute typeAttr;\n"
2962 << " auto attrName = " << op.getGetterName(name: namedAttr.name)
2963 << "AttrName(" << builderOpState
2964 << ".name);\n"
2965 " for (auto attr : attributes) {\n"
2966 " if (attr.getName() == attrName) {\n"
2967 " typeAttr = attr.getValue();\n"
2968 " break;\n"
2969 " }\n"
2970 " }\n";
2971 }
2972
2973 // Operands
2974 body << " " << builderOpState << ".addOperands(operands);\n";
2975
2976 // Properties
2977 if (kind == CollectiveBuilderKind::PropStruct)
2978 body << " " << builderOpState
2979 << ".useProperties(const_cast<Properties&>(properties));\n";
2980
2981 // Attributes
2982 body << " " << builderOpState << ".addAttributes(" << attributesName
2983 << ");\n";
2984
2985 // Result types
2986 SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
2987 body << " " << builderOpState << ".addTypes({"
2988 << llvm::join(R&: resultTypes, Separator: ", ") << "});\n";
2989}
2990
2991/// Returns a signature of the builder. Updates the context `fctx` to enable
2992/// replacement of $_builder and $_state in the body.
2993static SmallVector<MethodParameter>
2994getBuilderSignature(const Builder &builder) {
2995 ArrayRef<Builder::Parameter> params(builder.getParameters());
2996
2997 // Inject builder and state arguments.
2998 SmallVector<MethodParameter> arguments;
2999 arguments.reserve(N: params.size() + 2);
3000 arguments.emplace_back(Args: "::mlir::OpBuilder &", Args: odsBuilder);
3001 arguments.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
3002
3003 FmtContext fctx;
3004 fctx.withBuilder(subst: odsBuilder);
3005
3006 for (unsigned i = 0, e = params.size(); i < e; ++i) {
3007 // If no name is provided, generate one.
3008 std::optional<StringRef> paramName = params[i].getName();
3009 std::string name =
3010 paramName ? paramName->str() : "odsArg" + std::to_string(val: i);
3011
3012 StringRef defaultValue;
3013 if (std::optional<StringRef> defaultParamValue =
3014 params[i].getDefaultValue())
3015 defaultValue = *defaultParamValue;
3016
3017 arguments.emplace_back(Args: params[i].getCppType(), Args: std::move(name),
3018 Args: tgfmt(fmt: defaultValue, ctx: &fctx));
3019 }
3020
3021 return arguments;
3022}
3023
3024void OpEmitter::genBuilder() {
3025 // Handle custom builders if provided.
3026 for (const Builder &builder : op.getBuilders()) {
3027 SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
3028
3029 std::optional<StringRef> body = builder.getBody();
3030 auto properties = body ? Method::Static : Method::StaticDeclaration;
3031 auto *method =
3032 opClass.addMethod(retType: "void", name: "build", properties, args: std::move(arguments));
3033 if (body)
3034 ERROR_IF_PRUNED(method, "build", op);
3035
3036 if (method)
3037 method->setDeprecated(builder.getDeprecatedMessage());
3038
3039 FmtContext fctx;
3040 fctx.withBuilder(subst: odsBuilder);
3041 fctx.addSubst(placeholder: "_state", subst: builderOpState);
3042 if (body)
3043 method->body() << tgfmt(fmt: *body, ctx: &fctx);
3044 }
3045
3046 // Generate default builders that requires all result type, operands, and
3047 // attributes as parameters.
3048 if (op.skipDefaultBuilders())
3049 return;
3050
3051 // We generate three classes of builders here:
3052 // 1. one having a stand-alone parameter for each operand / attribute, and
3053 genSeparateArgParamBuilder();
3054 // 2. one having an aggregated parameter for all result types / operands /
3055 // [properties / discardable] attributes, and
3056 genCollectiveParamBuilder(kind: CollectiveBuilderKind::AttrDict);
3057 if (emitHelper.hasProperties())
3058 genCollectiveParamBuilder(kind: CollectiveBuilderKind::PropStruct);
3059 // 3. one having a stand-alone parameter for each operand and attribute,
3060 // use the first operand or attribute's type as all result types
3061 // to facilitate different call patterns.
3062 if (op.getNumVariableLengthResults() == 0) {
3063 if (op.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType")) {
3064 genUseOperandAsResultTypeSeparateParamBuilder();
3065 genUseOperandAsResultTypeCollectiveParamBuilder(
3066 kind: CollectiveBuilderKind::AttrDict);
3067 if (emitHelper.hasProperties())
3068 genUseOperandAsResultTypeCollectiveParamBuilder(
3069 kind: CollectiveBuilderKind::PropStruct);
3070 }
3071 if (op.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType")) {
3072 genUseAttrAsResultTypeCollectiveParamBuilder(
3073 kind: CollectiveBuilderKind::AttrDict);
3074 genUseAttrAsResultTypeCollectiveParamBuilder(
3075 kind: CollectiveBuilderKind::PropStruct);
3076 }
3077 }
3078}
3079
3080void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) {
3081 int numResults = op.getNumResults();
3082 int numVariadicResults = op.getNumVariableLengthResults();
3083 int numNonVariadicResults = numResults - numVariadicResults;
3084
3085 int numOperands = op.getNumOperands();
3086 int numVariadicOperands = op.getNumVariableLengthOperands();
3087 int numNonVariadicOperands = numOperands - numVariadicOperands;
3088
3089 SmallVector<MethodParameter> paramList;
3090 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: "");
3091 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
3092 paramList.emplace_back(Args: "::mlir::TypeRange", Args: "resultTypes");
3093 paramList.emplace_back(Args: "::mlir::ValueRange", Args: "operands");
3094 if (kind == CollectiveBuilderKind::PropStruct)
3095 paramList.emplace_back(Args: "const Properties &", Args: "properties");
3096 // Provide default value for `attributes` when its the last parameter
3097 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
3098 StringRef attributesName = kind == CollectiveBuilderKind::PropStruct
3099 ? "discardableAttributes"
3100 : "attributes";
3101 paramList.emplace_back(Args: "::llvm::ArrayRef<::mlir::NamedAttribute>",
3102 Args&: attributesName, Args&: attributesDefaultValue);
3103 if (op.getNumVariadicRegions())
3104 paramList.emplace_back(Args: "unsigned", Args: "numRegions");
3105
3106 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args: std::move(paramList));
3107 // If the builder is redundant, skip generating the method
3108 if (!m)
3109 return;
3110 auto &body = m->body();
3111
3112 // Operands
3113 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
3114 body << " assert(operands.size()"
3115 << (numVariadicOperands != 0 ? " >= " : " == ")
3116 << numNonVariadicOperands
3117 << "u && \"mismatched number of parameters\");\n";
3118 body << " " << builderOpState << ".addOperands(operands);\n";
3119
3120 // Properties
3121 if (kind == CollectiveBuilderKind::PropStruct)
3122 body << " " << builderOpState
3123 << ".useProperties(const_cast<Properties&>(properties));\n";
3124
3125 // Attributes
3126 body << " " << builderOpState << ".addAttributes(" << attributesName
3127 << ");\n";
3128
3129 // Create the correct number of regions
3130 if (int numRegions = op.getNumRegions()) {
3131 body << llvm::formatv(
3132 Fmt: " for (unsigned i = 0; i != {0}; ++i)\n",
3133 Vals: (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
3134 body << " (void)" << builderOpState << ".addRegion();\n";
3135 }
3136
3137 // Result types
3138 if (numVariadicResults == 0 || numNonVariadicResults != 0)
3139 body << " assert(resultTypes.size()"
3140 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
3141 << "u && \"mismatched number of return types\");\n";
3142 body << " " << builderOpState << ".addTypes(resultTypes);\n";
3143
3144 if (emitHelper.hasNonEmptyPropertiesStruct() &&
3145 kind == CollectiveBuilderKind::AttrDict) {
3146 // Initialize the properties from Attributes before invoking the infer
3147 // function.
3148 body << formatv(Fmt: R"(
3149 if (!attributes.empty()) {
3150 ::mlir::OpaqueProperties properties =
3151 &{1}.getOrAddProperties<{0}::Properties>();
3152 std::optional<::mlir::RegisteredOperationName> info =
3153 {1}.name.getRegisteredInfo();
3154 if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
3155 {1}.attributes.getDictionary({1}.getContext()), nullptr)))
3156 ::llvm::report_fatal_error("Property conversion failed.");
3157 })",
3158 Vals: opClass.getClassName(), Vals: builderOpState);
3159 }
3160
3161 // Generate builder that infers type too.
3162 // TODO: Expand to handle successors.
3163 if (canInferType(op) && op.getNumSuccessors() == 0)
3164 genInferredTypeCollectiveParamBuilder(kind);
3165}
3166
3167void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
3168 llvm::StringSet<> &inferredAttributes,
3169 SmallVectorImpl<std::string> &resultTypeNames,
3170 TypeParamKind typeParamKind,
3171 AttrParamKind attrParamKind) {
3172 resultTypeNames.clear();
3173 auto numResults = op.getNumResults();
3174 resultTypeNames.reserve(N: numResults);
3175
3176 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: odsBuilder);
3177 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
3178
3179 switch (typeParamKind) {
3180 case TypeParamKind::None:
3181 break;
3182 case TypeParamKind::Separate: {
3183 // Add parameters for all return types
3184 for (int i = 0; i < numResults; ++i) {
3185 const auto &result = op.getResult(index: i);
3186 std::string resultName = std::string(result.name);
3187 if (resultName.empty())
3188 resultName = std::string(formatv(Fmt: "resultType{0}", Vals&: i));
3189
3190 StringRef type =
3191 result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
3192
3193 paramList.emplace_back(Args&: type, Args&: resultName, Args: result.isOptional());
3194 resultTypeNames.emplace_back(Args: std::move(resultName));
3195 }
3196 } break;
3197 case TypeParamKind::Collective: {
3198 paramList.emplace_back(Args: "::mlir::TypeRange", Args: "resultTypes");
3199 resultTypeNames.push_back(Elt: "resultTypes");
3200 } break;
3201 }
3202
3203 // Add parameters for all arguments (operands and attributes).
3204 // Track "attr-like" (property and attribute) optional values separate from
3205 // attributes themselves so that the disambiguation code can look at the first
3206 // attribute specifically when determining where to trim the optional-value
3207 // list to avoid ambiguity while preserving the ability of all-property ops to
3208 // use default parameters.
3209 int defaultValuedAttrLikeStartIndex = op.getNumArgs();
3210 int defaultValuedAttrStartIndex = op.getNumArgs();
3211 // Successors and variadic regions go at the end of the parameter list, so no
3212 // default arguments are possible.
3213 bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions();
3214 if (!hasTrailingParams) {
3215 // Calculate the start index from which we can attach default values in the
3216 // builder declaration.
3217 for (int i = op.getNumArgs() - 1; i >= 0; --i) {
3218 auto *namedAttr =
3219 llvm::dyn_cast_if_present<tblgen::NamedAttribute *>(Val: op.getArg(index: i));
3220 auto *namedProperty =
3221 llvm::dyn_cast_if_present<tblgen::NamedProperty *>(Val: op.getArg(index: i));
3222 if (namedProperty) {
3223 Property prop = namedProperty->prop;
3224 if (!prop.hasDefaultValue())
3225 break;
3226 defaultValuedAttrLikeStartIndex = i;
3227 continue;
3228 }
3229 if (!namedAttr)
3230 break;
3231
3232 Attribute attr = namedAttr->attr;
3233 // TODO: Currently we can't differentiate between optional meaning do not
3234 // verify/not always error if missing or optional meaning need not be
3235 // specified in builder. Expand isOptional once we can differentiate.
3236 if (!attr.hasDefaultValue() && !attr.isDerivedAttr())
3237 break;
3238
3239 // Creating an APInt requires us to provide bitwidth, value, and
3240 // signedness, which is complicated compared to others. Similarly
3241 // for APFloat.
3242 // TODO: Adjust the 'returnType' field of such attributes
3243 // to support them.
3244 StringRef retType = namedAttr->attr.getReturnType();
3245 if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
3246 break;
3247
3248 defaultValuedAttrLikeStartIndex = i;
3249 defaultValuedAttrStartIndex = i;
3250 }
3251 }
3252
3253 // Check if parameters besides default valued one are enough to distinguish
3254 // between builders with wrapped and unwrapped arguments.
3255 bool hasBuilderAmbiguity = true;
3256 for (const auto &arg : op.getArgs()) {
3257 auto *namedAttr = dyn_cast<NamedAttribute *>(Val: arg);
3258 if (!namedAttr)
3259 continue;
3260 Attribute attr = namedAttr->attr;
3261 if (attr.hasDefaultValue() || attr.isDerivedAttr())
3262 continue;
3263
3264 if (attrParamKind != AttrParamKind::WrappedAttr ||
3265 !canUseUnwrappedRawValue(attr))
3266 continue;
3267
3268 hasBuilderAmbiguity = false;
3269 break;
3270 }
3271
3272 // Avoid generating build methods that are ambiguous due to default values by
3273 // requiring at least one attribute.
3274 if (defaultValuedAttrStartIndex < op.getNumArgs()) {
3275 // TODO: This should have been possible as a cast<NamedAttribute> but
3276 // required template instantiations is not yet defined for the tblgen helper
3277 // classes.
3278 auto *namedAttr =
3279 cast<NamedAttribute *>(Val: op.getArg(index: defaultValuedAttrStartIndex));
3280 Attribute attr = namedAttr->attr;
3281 if ((attrParamKind == AttrParamKind::WrappedAttr &&
3282 canUseUnwrappedRawValue(attr) && hasBuilderAmbiguity) ||
3283 (attrParamKind == AttrParamKind::UnwrappedValue &&
3284 !canUseUnwrappedRawValue(attr) && hasBuilderAmbiguity)) {
3285 ++defaultValuedAttrStartIndex;
3286 defaultValuedAttrLikeStartIndex = defaultValuedAttrStartIndex;
3287 }
3288 }
3289
3290 /// Collect any inferred attributes.
3291 for (const NamedTypeConstraint &operand : op.getOperands()) {
3292 if (operand.isVariadicOfVariadic()) {
3293 inferredAttributes.insert(
3294 key: operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
3295 }
3296 }
3297
3298 FmtContext fctx;
3299 fctx.withBuilder(subst: odsBuilder);
3300
3301 for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
3302 Argument arg = op.getArg(index: i);
3303 if (const auto *operand =
3304 llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: arg)) {
3305 StringRef type;
3306 if (operand->isVariadicOfVariadic())
3307 type = "::llvm::ArrayRef<::mlir::ValueRange>";
3308 else if (operand->isVariadic())
3309 type = "::mlir::ValueRange";
3310 else
3311 type = "::mlir::Value";
3312
3313 paramList.emplace_back(Args&: type, Args: getArgumentName(op, index: numOperands++),
3314 Args: operand->isOptional());
3315 continue;
3316 }
3317 if (auto *propArg = llvm::dyn_cast_if_present<NamedProperty *>(Val&: arg)) {
3318 const Property &prop = propArg->prop;
3319 StringRef type = prop.getInterfaceType();
3320 std::string defaultValue;
3321 if (prop.hasDefaultValue() && i >= defaultValuedAttrLikeStartIndex) {
3322 defaultValue = tgfmt(fmt: prop.getDefaultValue(), ctx: &fctx);
3323 }
3324 bool isOptional = prop.hasDefaultValue();
3325 paramList.emplace_back(Args&: type, Args&: propArg->name, Args: StringRef(defaultValue),
3326 Args&: isOptional);
3327 continue;
3328 }
3329 const NamedAttribute &namedAttr = *cast<NamedAttribute *>(Val&: arg);
3330 const Attribute &attr = namedAttr.attr;
3331
3332 // Inferred attributes don't need to be added to the param list.
3333 if (inferredAttributes.contains(key: namedAttr.name))
3334 continue;
3335
3336 StringRef type;
3337 switch (attrParamKind) {
3338 case AttrParamKind::WrappedAttr:
3339 type = attr.getStorageType();
3340 break;
3341 case AttrParamKind::UnwrappedValue:
3342 if (canUseUnwrappedRawValue(attr))
3343 type = attr.getReturnType();
3344 else
3345 type = attr.getStorageType();
3346 break;
3347 }
3348
3349 // Attach default value if requested and possible.
3350 std::string defaultValue;
3351 if (i >= defaultValuedAttrStartIndex) {
3352 if (attrParamKind == AttrParamKind::UnwrappedValue &&
3353 canUseUnwrappedRawValue(attr))
3354 defaultValue += tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx);
3355 else
3356 defaultValue += "nullptr";
3357 }
3358 paramList.emplace_back(Args&: type, Args: namedAttr.name, Args: StringRef(defaultValue),
3359 Args: attr.isOptional());
3360 }
3361
3362 /// Insert parameters for each successor.
3363 for (const NamedSuccessor &succ : op.getSuccessors()) {
3364 StringRef type =
3365 succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
3366 paramList.emplace_back(Args&: type, Args: succ.name);
3367 }
3368
3369 /// Insert parameters for variadic regions.
3370 for (const NamedRegion &region : op.getRegions())
3371 if (region.isVariadic())
3372 paramList.emplace_back(Args: "unsigned",
3373 Args: llvm::formatv(Fmt: "{0}Count", Vals: region.name).str());
3374}
3375
3376void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
3377 MethodBody &body, llvm::StringSet<> &inferredAttributes,
3378 bool isRawValueAttr) {
3379 // Push all operands to the result.
3380 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
3381 std::string argName = getArgumentName(op, index: i);
3382 const NamedTypeConstraint &operand = op.getOperand(index: i);
3383 if (operand.constraint.isVariadicOfVariadic()) {
3384 body << " for (::mlir::ValueRange range : " << argName << ")\n "
3385 << builderOpState << ".addOperands(range);\n";
3386
3387 // Add the segment attribute.
3388 body << " {\n"
3389 << " ::llvm::SmallVector<int32_t> rangeSegments;\n"
3390 << " for (::mlir::ValueRange range : " << argName << ")\n"
3391 << " rangeSegments.push_back(range.size());\n"
3392 << " auto rangeAttr = " << odsBuilder
3393 << ".getDenseI32ArrayAttr(rangeSegments);\n";
3394 if (op.getDialect().usePropertiesForAttributes()) {
3395 body << " " << builderOpState << ".getOrAddProperties<Properties>()."
3396 << operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
3397 << " = rangeAttr;";
3398 } else {
3399 body << " " << builderOpState << ".addAttribute("
3400 << op.getGetterName(
3401 name: operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
3402 << "AttrName(" << builderOpState << ".name), rangeAttr);";
3403 }
3404 body << " }\n";
3405 continue;
3406 }
3407
3408 if (operand.isOptional())
3409 body << " if (" << argName << ")\n ";
3410 body << " " << builderOpState << ".addOperands(" << argName << ");\n";
3411 }
3412
3413 // If the operation has the operand segment size attribute, add it here.
3414 auto emitSegment = [&]() {
3415 interleaveComma(c: llvm::seq<int>(Begin: 0, End: op.getNumOperands()), os&: body, each_fn: [&](int i) {
3416 const NamedTypeConstraint &operand = op.getOperand(index: i);
3417 if (!operand.isVariableLength()) {
3418 body << "1";
3419 return;
3420 }
3421
3422 std::string operandName = getArgumentName(op, index: i);
3423 if (operand.isOptional()) {
3424 body << "(" << operandName << " ? 1 : 0)";
3425 } else if (operand.isVariadicOfVariadic()) {
3426 body << llvm::formatv(
3427 Fmt: "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
3428 "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
3429 "static_cast<int32_t>(range.size()); }))",
3430 Vals&: operandName);
3431 } else {
3432 body << "static_cast<int32_t>(" << getArgumentName(op, index: i) << ".size())";
3433 }
3434 });
3435 };
3436 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
3437 std::string sizes = op.getGetterName(name: operandSegmentAttrName);
3438 if (op.getDialect().usePropertiesForAttributes()) {
3439 body << " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
3440 emitSegment();
3441 body << "}), " << builderOpState
3442 << ".getOrAddProperties<Properties>()."
3443 "operandSegmentSizes.begin());\n";
3444 } else {
3445 body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
3446 << builderOpState << ".name), "
3447 << "odsBuilder.getDenseI32ArrayAttr({";
3448 emitSegment();
3449 body << "}));\n";
3450 }
3451 }
3452
3453 // Push all properties to the result.
3454 for (const auto &namedProp : op.getProperties()) {
3455 // Use the setter from the Properties struct since the conversion from the
3456 // interface type (used in the builder argument) to the storage type (used
3457 // in the state) is not necessarily trivial.
3458 std::string setterName = op.getSetterName(name: namedProp.name);
3459 body << formatv(Fmt: " {0}.getOrAddProperties<Properties>().{1}({2});\n",
3460 Vals: builderOpState, Vals&: setterName, Vals: namedProp.name);
3461 }
3462 // Push all attributes to the result.
3463 for (const auto &namedAttr : op.getAttributes()) {
3464 auto &attr = namedAttr.attr;
3465 if (attr.isDerivedAttr() || inferredAttributes.contains(key: namedAttr.name))
3466 continue;
3467
3468 // TODO: The wrapping of optional is different for default or not, so don't
3469 // unwrap for default ones that would fail below.
3470 bool emitNotNullCheck =
3471 (attr.isOptional() && !attr.hasDefaultValue()) ||
3472 (attr.hasDefaultValue() && !isRawValueAttr) ||
3473 // TODO: UnitAttr is optional, not wrapped, but needs to be guarded as
3474 // the constant materialization is only for true case.
3475 (isRawValueAttr && attr.getAttrDefName() == "UnitAttr");
3476 if (emitNotNullCheck)
3477 body.indent() << formatv(Fmt: "if ({0}) ", Vals: namedAttr.name) << "{\n";
3478
3479 if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
3480 // If this is a raw value, then we need to wrap it in an Attribute
3481 // instance.
3482 FmtContext fctx;
3483 fctx.withBuilder(subst: "odsBuilder");
3484 if (op.getDialect().usePropertiesForAttributes()) {
3485 body << formatv(Fmt: " {0}.getOrAddProperties<Properties>().{1} = {2};\n",
3486 Vals: builderOpState, Vals: namedAttr.name,
3487 Vals: constBuildAttrFromParam(attr, fctx, paramName: namedAttr.name));
3488 } else {
3489 body << formatv(Fmt: " {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3490 Vals: builderOpState, Vals: op.getGetterName(name: namedAttr.name),
3491 Vals: constBuildAttrFromParam(attr, fctx, paramName: namedAttr.name));
3492 }
3493 } else {
3494 if (op.getDialect().usePropertiesForAttributes()) {
3495 body << formatv(Fmt: " {0}.getOrAddProperties<Properties>().{1} = {1};\n",
3496 Vals: builderOpState, Vals: namedAttr.name);
3497 } else {
3498 body << formatv(Fmt: " {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3499 Vals: builderOpState, Vals: op.getGetterName(name: namedAttr.name),
3500 Vals: namedAttr.name);
3501 }
3502 }
3503 if (emitNotNullCheck)
3504 body.unindent() << " }\n";
3505 }
3506
3507 // Create the correct number of regions.
3508 for (const NamedRegion &region : op.getRegions()) {
3509 if (region.isVariadic())
3510 body << formatv(Fmt: " for (unsigned i = 0; i < {0}Count; ++i)\n ",
3511 Vals: region.name);
3512
3513 body << " (void)" << builderOpState << ".addRegion();\n";
3514 }
3515
3516 // Push all successors to the result.
3517 for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
3518 body << formatv(Fmt: " {0}.addSuccessors({1});\n", Vals: builderOpState,
3519 Vals: namedSuccessor.name);
3520 }
3521}
3522
3523void OpEmitter::genCanonicalizerDecls() {
3524 bool hasCanonicalizeMethod = def.getValueAsBit(FieldName: "hasCanonicalizeMethod");
3525 if (hasCanonicalizeMethod) {
3526 // static LogicResult FooOp::
3527 // canonicalize(FooOp op, PatternRewriter &rewriter);
3528 SmallVector<MethodParameter> paramList;
3529 paramList.emplace_back(Args: op.getCppClassName(), Args: "op");
3530 paramList.emplace_back(Args: "::mlir::PatternRewriter &", Args: "rewriter");
3531 auto *m = opClass.declareStaticMethod(retType: "::llvm::LogicalResult",
3532 name: "canonicalize", args: std::move(paramList));
3533 ERROR_IF_PRUNED(m, "canonicalize", op);
3534 }
3535
3536 // We get a prototype for 'getCanonicalizationPatterns' if requested directly
3537 // or if using a 'canonicalize' method.
3538 bool hasCanonicalizer = def.getValueAsBit(FieldName: "hasCanonicalizer");
3539 if (!hasCanonicalizeMethod && !hasCanonicalizer)
3540 return;
3541
3542 // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
3543 // method, but not implementing 'getCanonicalizationPatterns' manually.
3544 bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
3545
3546 // Add a signature for getCanonicalizationPatterns if implemented by the
3547 // dialect or if synthesized to call 'canonicalize'.
3548 SmallVector<MethodParameter> paramList;
3549 paramList.emplace_back(Args: "::mlir::RewritePatternSet &", Args: "results");
3550 paramList.emplace_back(Args: "::mlir::MLIRContext *", Args: "context");
3551 auto kind = hasBody ? Method::Static : Method::StaticDeclaration;
3552 auto *method = opClass.addMethod(retType: "void", name: "getCanonicalizationPatterns", properties: kind,
3553 args: std::move(paramList));
3554
3555 // If synthesizing the method, fill it.
3556 if (hasBody) {
3557 ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op);
3558 method->body() << " results.add(canonicalize);\n";
3559 }
3560}
3561
3562void OpEmitter::genFolderDecls() {
3563 if (!op.hasFolder())
3564 return;
3565
3566 SmallVector<MethodParameter> paramList;
3567 paramList.emplace_back(Args: "FoldAdaptor", Args: "adaptor");
3568
3569 StringRef retType;
3570 bool hasSingleResult =
3571 op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
3572 if (hasSingleResult) {
3573 retType = "::mlir::OpFoldResult";
3574 } else {
3575 paramList.emplace_back(Args: "::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
3576 Args: "results");
3577 retType = "::llvm::LogicalResult";
3578 }
3579
3580 auto *m = opClass.declareMethod(retType, name: "fold", args: std::move(paramList));
3581 ERROR_IF_PRUNED(m, "fold", op);
3582}
3583
3584void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
3585 Interface interface = opTrait->getInterface();
3586
3587 // Get the set of methods that should always be declared.
3588 auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
3589 llvm::StringSet<> alwaysDeclaredMethods;
3590 alwaysDeclaredMethods.insert_range(R&: alwaysDeclaredMethodsVec);
3591
3592 for (const InterfaceMethod &method : interface.getMethods()) {
3593 // Don't declare if the method has a body.
3594 if (method.getBody())
3595 continue;
3596 // Don't declare if the method has a default implementation and the op
3597 // didn't request that it always be declared.
3598 if (method.getDefaultImplementation() &&
3599 !alwaysDeclaredMethods.count(Key: method.getName()))
3600 continue;
3601 // Interface methods are allowed to overlap with existing methods, so don't
3602 // check if pruned.
3603 (void)genOpInterfaceMethod(method);
3604 }
3605}
3606
3607Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
3608 bool declaration) {
3609 SmallVector<MethodParameter> paramList;
3610 for (const InterfaceMethod::Argument &arg : method.getArguments())
3611 paramList.emplace_back(Args: arg.type, Args: arg.name);
3612
3613 auto props = (method.isStatic() ? Method::Static : Method::None) |
3614 (declaration ? Method::Declaration : Method::None);
3615 return opClass.addMethod(retType: method.getReturnType(), name: method.getName(), properties: props,
3616 args: std::move(paramList));
3617}
3618
3619void OpEmitter::genOpInterfaceMethods() {
3620 for (const auto &trait : op.getTraits()) {
3621 if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(Val: &trait))
3622 if (opTrait->shouldDeclareMethods())
3623 genOpInterfaceMethods(opTrait);
3624 }
3625}
3626
3627void OpEmitter::genSideEffectInterfaceMethods() {
3628 enum EffectKind { Operand, Result, Symbol, Static };
3629 struct EffectLocation {
3630 /// The effect applied.
3631 SideEffect effect;
3632
3633 /// The index if the kind is not static.
3634 unsigned index;
3635
3636 /// The kind of the location.
3637 unsigned kind;
3638 };
3639
3640 StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
3641 auto resolveDecorators = [&](Operator::var_decorator_range decorators,
3642 unsigned index, unsigned kind) {
3643 for (auto decorator : decorators)
3644 if (SideEffect *effect = dyn_cast<SideEffect>(Val: &decorator)) {
3645 opClass.addTrait(trait: effect->getInterfaceTrait());
3646 interfaceEffects[effect->getBaseEffectName()].push_back(
3647 Elt: EffectLocation{.effect: *effect, .index: index, .kind: kind});
3648 }
3649 };
3650
3651 // Collect effects that were specified via:
3652 /// Traits.
3653 for (const auto &trait : op.getTraits()) {
3654 const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(Val: &trait);
3655 if (!opTrait)
3656 continue;
3657 auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
3658 for (auto decorator : opTrait->getEffects())
3659 effects.push_back(Elt: EffectLocation{.effect: cast<SideEffect>(Val&: decorator),
3660 /*index=*/0, .kind: EffectKind::Static});
3661 }
3662 /// Attributes and Operands.
3663 for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
3664 Argument arg = op.getArg(index: i);
3665 if (isa<NamedTypeConstraint *>(Val: arg)) {
3666 resolveDecorators(op.getArgDecorators(index: i), operandIt, EffectKind::Operand);
3667 ++operandIt;
3668 continue;
3669 }
3670 if (isa<NamedProperty *>(Val: arg))
3671 continue;
3672 const NamedAttribute *attr = cast<NamedAttribute *>(Val&: arg);
3673 if (attr->attr.getBaseAttr().isSymbolRefAttr())
3674 resolveDecorators(op.getArgDecorators(index: i), i, EffectKind::Symbol);
3675 }
3676 /// Results.
3677 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
3678 resolveDecorators(op.getResultDecorators(index: i), i, EffectKind::Result);
3679
3680 // The code used to add an effect instance.
3681 // {0}: The effect class.
3682 // {1}: Optional value or symbol reference.
3683 // {2}: The side effect stage.
3684 // {3}: Does this side effect act on every single value of resource.
3685 // {4}: The resource class.
3686 const char *addEffectCode =
3687 " effects.emplace_back({0}::get(), {1}{2}, {3}, {4}::get());\n";
3688
3689 for (auto &it : interfaceEffects) {
3690 // Generate the 'getEffects' method.
3691 std::string type = llvm::formatv(Fmt: "::llvm::SmallVectorImpl<::mlir::"
3692 "SideEffects::EffectInstance<{0}>> &",
3693 Vals: it.first())
3694 .str();
3695 auto *getEffects = opClass.addMethod(retType: "void", name: "getEffects",
3696 args: MethodParameter(type, "effects"));
3697 ERROR_IF_PRUNED(getEffects, "getEffects", op);
3698 auto &body = getEffects->body();
3699
3700 // Add effect instances for each of the locations marked on the operation.
3701 for (auto &location : it.second) {
3702 StringRef effect = location.effect.getName();
3703 StringRef resource = location.effect.getResource();
3704 int stage = (int)location.effect.getStage();
3705 bool effectOnFullRegion = (int)location.effect.getEffectOnfullRegion();
3706 if (location.kind == EffectKind::Static) {
3707 // A static instance has no attached value.
3708 body << llvm::formatv(Fmt: addEffectCode, Vals&: effect, Vals: "", Vals&: stage,
3709 Vals&: effectOnFullRegion, Vals&: resource)
3710 .str();
3711 } else if (location.kind == EffectKind::Symbol) {
3712 // A symbol reference requires adding the proper attribute.
3713 const auto *attr = cast<NamedAttribute *>(Val: op.getArg(index: location.index));
3714 std::string argName = op.getGetterName(name: attr->name);
3715 if (attr->attr.isOptional()) {
3716 body << " if (auto symbolRef = " << argName << "Attr())\n "
3717 << llvm::formatv(Fmt: addEffectCode, Vals&: effect, Vals: "symbolRef, ", Vals&: stage,
3718 Vals&: effectOnFullRegion, Vals&: resource)
3719 .str();
3720 } else {
3721 body << llvm::formatv(Fmt: addEffectCode, Vals&: effect, Vals: argName + "Attr(), ",
3722 Vals&: stage, Vals&: effectOnFullRegion, Vals&: resource)
3723 .str();
3724 }
3725 } else {
3726 // Otherwise this is an operand/result, so we need to attach the Value.
3727 body << " {\n auto valueRange = getODS"
3728 << (location.kind == EffectKind::Operand ? "Operand" : "Result")
3729 << "IndexAndLength(" << location.index << ");\n"
3730 << " for (unsigned idx = valueRange.first; idx < "
3731 "valueRange.first"
3732 << " + valueRange.second; idx++) {\n "
3733 << llvm::formatv(Fmt: addEffectCode, Vals&: effect,
3734 Vals: (location.kind == EffectKind::Operand
3735 ? "&getOperation()->getOpOperand(idx), "
3736 : "getOperation()->getOpResult(idx), "),
3737 Vals&: stage, Vals&: effectOnFullRegion, Vals&: resource)
3738 << " }\n }\n";
3739 }
3740 }
3741 }
3742}
3743
3744void OpEmitter::genTypeInterfaceMethods() {
3745 if (!op.allResultTypesKnown())
3746 return;
3747 // Generate 'inferReturnTypes' method declaration using the interface method
3748 // declared in 'InferTypeOpInterface' op interface.
3749 const auto *trait =
3750 cast<InterfaceTrait>(Val: op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait"));
3751 Interface interface = trait->getInterface();
3752 Method *method = [&]() -> Method * {
3753 for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
3754 if (interfaceMethod.getName() == "inferReturnTypes") {
3755 return genOpInterfaceMethod(method: interfaceMethod, /*declaration=*/false);
3756 }
3757 }
3758 assert(0 && "unable to find inferReturnTypes interface method");
3759 return nullptr;
3760 }();
3761 ERROR_IF_PRUNED(method, "inferReturnTypes", op);
3762 auto &body = method->body();
3763 body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
3764
3765 FmtContext fctx;
3766 fctx.withBuilder(subst: "odsBuilder");
3767 fctx.addSubst(placeholder: "_ctxt", subst: "context");
3768 body << " ::mlir::Builder odsBuilder(context);\n";
3769
3770 // Preprocessing stage to verify all accesses to operands are valid.
3771 int maxAccessedIndex = -1;
3772 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
3773 const InferredResultType &infer = op.getInferredResultType(index: i);
3774 if (!infer.isArg())
3775 continue;
3776 Operator::OperandOrAttribute arg =
3777 op.getArgToOperandOrAttribute(index: infer.getIndex());
3778 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3779 maxAccessedIndex =
3780 std::max(a: maxAccessedIndex, b: arg.operandOrAttributeIndex());
3781 }
3782 }
3783 if (maxAccessedIndex != -1) {
3784 body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";
3785 body << " return ::mlir::failure();\n";
3786 }
3787
3788 // Process the type inference graph in topological order, starting from types
3789 // that are always fully-inferred: operands and results with constructible
3790 // types. The type inference graph here will always be a DAG, so this gives
3791 // us the correct order for generating the types. -1 is a placeholder to
3792 // indicate the type for a result has not been generated.
3793 SmallVector<int> constructedIndices(op.getNumResults(), -1);
3794 int inferredTypeIdx = 0;
3795 for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) {
3796 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
3797 if (constructedIndices[i] >= 0)
3798 continue;
3799 const InferredResultType &infer = op.getInferredResultType(index: i);
3800 std::string typeStr;
3801 if (infer.isArg()) {
3802 // If this is an operand, just index into operand list to access the
3803 // type.
3804 Operator::OperandOrAttribute arg =
3805 op.getArgToOperandOrAttribute(index: infer.getIndex());
3806 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3807 typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
3808 "].getType()")
3809 .str();
3810
3811 // If this is an attribute, index into the attribute dictionary.
3812 } else {
3813 auto *attr =
3814 cast<NamedAttribute *>(Val: op.getArg(index: arg.operandOrAttributeIndex()));
3815 body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
3816 << " = ";
3817 if (op.getDialect().usePropertiesForAttributes()) {
3818 body << "(properties ? properties.as<Properties *>()->"
3819 << attr->name
3820 << " : "
3821 "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3822 "get(\"" +
3823 attr->name + "\")));\n";
3824 } else {
3825 body << "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3826 "get(\"" +
3827 attr->name + "\"));\n";
3828 }
3829 body << " if (!odsInferredTypeAttr" << inferredTypeIdx
3830 << ") return ::mlir::failure();\n";
3831 typeStr =
3832 ("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()")
3833 .str();
3834 }
3835 } else if (std::optional<StringRef> builder =
3836 op.getResult(index: infer.getResultIndex())
3837 .constraint.getBuilderCall()) {
3838 typeStr = tgfmt(fmt: *builder, ctx: &fctx).str();
3839 } else if (int index = constructedIndices[infer.getResultIndex()];
3840 index >= 0) {
3841 typeStr = ("odsInferredType" + Twine(index)).str();
3842 } else {
3843 continue;
3844 }
3845 body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
3846 << tgfmt(fmt: infer.getTransformer(), ctx: &fctx.withSelf(subst: typeStr)) << ";\n";
3847 constructedIndices[i] = inferredTypeIdx - 1;
3848 }
3849 }
3850 for (auto [i, index] : llvm::enumerate(First&: constructedIndices))
3851 body << " inferredReturnTypes[" << i << "] = odsInferredType" << index
3852 << ";\n";
3853 body << " return ::mlir::success();";
3854}
3855
3856void OpEmitter::genParser() {
3857 if (hasStringAttribute(record: def, fieldName: "assemblyFormat"))
3858 return;
3859
3860 if (!def.getValueAsBit(FieldName: "hasCustomAssemblyFormat"))
3861 return;
3862
3863 SmallVector<MethodParameter> paramList;
3864 paramList.emplace_back(Args: "::mlir::OpAsmParser &", Args: "parser");
3865 paramList.emplace_back(Args: "::mlir::OperationState &", Args: "result");
3866
3867 auto *method = opClass.declareStaticMethod(retType: "::mlir::ParseResult", name: "parse",
3868 args: std::move(paramList));
3869 ERROR_IF_PRUNED(method, "parse", op);
3870}
3871
3872void OpEmitter::genPrinter() {
3873 if (hasStringAttribute(record: def, fieldName: "assemblyFormat"))
3874 return;
3875
3876 // Check to see if this op uses a c++ format.
3877 if (!def.getValueAsBit(FieldName: "hasCustomAssemblyFormat"))
3878 return;
3879 auto *method = opClass.declareMethod(
3880 retType: "void", name: "print", args: MethodParameter("::mlir::OpAsmPrinter &", "p"));
3881 ERROR_IF_PRUNED(method, "print", op);
3882}
3883
3884void OpEmitter::genVerifier() {
3885 auto *implMethod =
3886 opClass.addMethod(retType: "::llvm::LogicalResult", name: "verifyInvariantsImpl");
3887 ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
3888 auto &implBody = implMethod->body();
3889 bool useProperties = emitHelper.hasProperties();
3890
3891 populateSubstitutions(emitHelper, ctx&: verifyCtx);
3892 genPropertyVerifier(emitHelper, ctx&: verifyCtx, body&: implBody, staticVerifierEmitter);
3893 genAttributeVerifier(emitHelper, ctx&: verifyCtx, body&: implBody, staticVerifierEmitter,
3894 useProperties);
3895 genOperandResultVerifier(body&: implBody, values: op.getOperands(), valueKind: "operand");
3896 genOperandResultVerifier(body&: implBody, values: op.getResults(), valueKind: "result");
3897
3898 for (auto &trait : op.getTraits()) {
3899 if (auto *t = dyn_cast<tblgen::PredTrait>(Val: &trait)) {
3900 implBody << tgfmt(fmt: " if (!($0))\n "
3901 "return emitOpError(\"failed to verify that $1\");\n",
3902 ctx: &verifyCtx, vals: tgfmt(fmt: t->getPredTemplate(), ctx: &verifyCtx),
3903 vals: t->getSummary());
3904 }
3905 }
3906
3907 genRegionVerifier(body&: implBody);
3908 genSuccessorVerifier(body&: implBody);
3909
3910 implBody << " return ::mlir::success();\n";
3911
3912 // TODO: Some places use the `verifyInvariants` to do operation verification.
3913 // This may not act as their expectation because this doesn't call any
3914 // verifiers of native/interface traits. Needs to review those use cases and
3915 // see if we should use the mlir::verify() instead.
3916 auto *method = opClass.addMethod(retType: "::llvm::LogicalResult", name: "verifyInvariants");
3917 ERROR_IF_PRUNED(method, "verifyInvariants", op);
3918 auto &body = method->body();
3919 if (def.getValueAsBit(FieldName: "hasVerifier")) {
3920 body << " if(::mlir::succeeded(verifyInvariantsImpl()) && "
3921 "::mlir::succeeded(verify()))\n";
3922 body << " return ::mlir::success();\n";
3923 body << " return ::mlir::failure();";
3924 } else {
3925 body << " return verifyInvariantsImpl();";
3926 }
3927}
3928
3929void OpEmitter::genCustomVerifier() {
3930 if (def.getValueAsBit(FieldName: "hasVerifier")) {
3931 auto *method = opClass.declareMethod(retType: "::llvm::LogicalResult", name: "verify");
3932 ERROR_IF_PRUNED(method, "verify", op);
3933 }
3934
3935 if (def.getValueAsBit(FieldName: "hasRegionVerifier")) {
3936 auto *method =
3937 opClass.declareMethod(retType: "::llvm::LogicalResult", name: "verifyRegions");
3938 ERROR_IF_PRUNED(method, "verifyRegions", op);
3939 }
3940}
3941
3942void OpEmitter::genOperandResultVerifier(MethodBody &body,
3943 Operator::const_value_range values,
3944 StringRef valueKind) {
3945 // Check that an optional value is at most 1 element.
3946 //
3947 // {0}: Value index.
3948 // {1}: "operand" or "result"
3949 const char *const verifyOptional = R"(
3950 if (valueGroup{0}.size() > 1) {
3951 return emitOpError("{1} group starting at #") << index
3952 << " requires 0 or 1 element, but found " << valueGroup{0}.size();
3953 }
3954)";
3955 // Check the types of a range of values.
3956 //
3957 // {0}: Value index.
3958 // {1}: Type constraint function.
3959 // {2}: "operand" or "result"
3960 const char *const verifyValues = R"(
3961 for (auto v : valueGroup{0}) {
3962 if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
3963 return ::mlir::failure();
3964 }
3965)";
3966
3967 const auto canSkip = [](const NamedTypeConstraint &value) {
3968 return !value.hasPredicate() && !value.isOptional() &&
3969 !value.isVariadicOfVariadic();
3970 };
3971 if (values.empty() || llvm::all_of(Range&: values, P: canSkip))
3972 return;
3973
3974 FmtContext fctx;
3975
3976 body << " {\n unsigned index = 0; (void)index;\n";
3977
3978 for (const auto &staticValue : llvm::enumerate(First&: values)) {
3979 const NamedTypeConstraint &value = staticValue.value();
3980
3981 bool hasPredicate = value.hasPredicate();
3982 bool isOptional = value.isOptional();
3983 bool isVariadicOfVariadic = value.isVariadicOfVariadic();
3984 if (!hasPredicate && !isOptional && !isVariadicOfVariadic)
3985 continue;
3986 body << formatv(Fmt: " auto valueGroup{2} = getODS{0}{1}s({2});\n",
3987 // Capitalize the first letter to match the function name
3988 Vals: valueKind.substr(Start: 0, N: 1).upper(), Vals: valueKind.substr(Start: 1),
3989 Vals: staticValue.index());
3990
3991 // If the constraint is optional check that the value group has at most 1
3992 // value.
3993 if (isOptional) {
3994 body << formatv(Fmt: verifyOptional, Vals: staticValue.index(), Vals&: valueKind);
3995 } else if (isVariadicOfVariadic) {
3996 body << formatv(
3997 Fmt: " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
3998 "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
3999 " return ::mlir::failure();\n",
4000 Vals: value.constraint.getVariadicOfVariadicSegmentSizeAttr(), Vals: value.name,
4001 Vals: staticValue.index());
4002 }
4003
4004 // Otherwise, if there is no predicate there is nothing left to do.
4005 if (!hasPredicate)
4006 continue;
4007 // Emit a loop to check all the dynamic values in the pack.
4008 StringRef constraintFn =
4009 staticVerifierEmitter.getTypeConstraintFn(constraint: value.constraint);
4010 body << formatv(Fmt: verifyValues, Vals: staticValue.index(), Vals&: constraintFn, Vals&: valueKind);
4011 }
4012
4013 body << " }\n";
4014}
4015
4016void OpEmitter::genRegionVerifier(MethodBody &body) {
4017 /// Code to verify a region.
4018 ///
4019 /// {0}: Getter for the regions.
4020 /// {1}: The region constraint.
4021 /// {2}: The region's name.
4022 /// {3}: The region description.
4023 const char *const verifyRegion = R"(
4024 for (auto &region : {0})
4025 if (::mlir::failed({1}(*this, region, "{2}", index++)))
4026 return ::mlir::failure();
4027)";
4028 /// Get a single region.
4029 ///
4030 /// {0}: The region's index.
4031 const char *const getSingleRegion =
4032 "::llvm::MutableArrayRef((*this)->getRegion({0}))";
4033
4034 // If we have no regions, there is nothing more to do.
4035 const auto canSkip = [](const NamedRegion &region) {
4036 return region.constraint.getPredicate().isNull();
4037 };
4038 auto regions = op.getRegions();
4039 if (regions.empty() && llvm::all_of(Range&: regions, P: canSkip))
4040 return;
4041
4042 body << " {\n unsigned index = 0; (void)index;\n";
4043 for (const auto &it : llvm::enumerate(First&: regions)) {
4044 const auto &region = it.value();
4045 if (canSkip(region))
4046 continue;
4047
4048 auto getRegion = region.isVariadic()
4049 ? formatv(Fmt: "{0}()", Vals: op.getGetterName(name: region.name)).str()
4050 : formatv(Fmt: getSingleRegion, Vals: it.index()).str();
4051 auto constraintFn =
4052 staticVerifierEmitter.getRegionConstraintFn(constraint: region.constraint);
4053 body << formatv(Fmt: verifyRegion, Vals&: getRegion, Vals&: constraintFn, Vals: region.name);
4054 }
4055 body << " }\n";
4056}
4057
4058void OpEmitter::genSuccessorVerifier(MethodBody &body) {
4059 const char *const verifySuccessor = R"(
4060 for (auto *successor : {0})
4061 if (::mlir::failed({1}(*this, successor, "{2}", index++)))
4062 return ::mlir::failure();
4063)";
4064 /// Get a single successor.
4065 ///
4066 /// {0}: The successor's name.
4067 const char *const getSingleSuccessor = "::llvm::MutableArrayRef({0}())";
4068
4069 // If we have no successors, there is nothing more to do.
4070 const auto canSkip = [](const NamedSuccessor &successor) {
4071 return successor.constraint.getPredicate().isNull();
4072 };
4073 auto successors = op.getSuccessors();
4074 if (successors.empty() && llvm::all_of(Range&: successors, P: canSkip))
4075 return;
4076
4077 body << " {\n unsigned index = 0; (void)index;\n";
4078
4079 for (auto it : llvm::enumerate(First&: successors)) {
4080 const auto &successor = it.value();
4081 if (canSkip(successor))
4082 continue;
4083
4084 auto getSuccessor =
4085 formatv(Fmt: successor.isVariadic() ? "{0}()" : getSingleSuccessor,
4086 Vals: successor.name)
4087 .str();
4088 auto constraintFn =
4089 staticVerifierEmitter.getSuccessorConstraintFn(constraint: successor.constraint);
4090 body << formatv(Fmt: verifySuccessor, Vals&: getSuccessor, Vals&: constraintFn,
4091 Vals: successor.name);
4092 }
4093 body << " }\n";
4094}
4095
4096/// Add a size count trait to the given operation class.
4097static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
4098 int numTotal, int numVariadic) {
4099 if (numVariadic != 0) {
4100 if (numTotal == numVariadic)
4101 opClass.addTrait(trait: "::mlir::OpTrait::Variadic" + traitKind + "s");
4102 else
4103 opClass.addTrait(trait: "::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
4104 Twine(numTotal - numVariadic) + ">::Impl");
4105 return;
4106 }
4107 switch (numTotal) {
4108 case 0:
4109 opClass.addTrait(trait: "::mlir::OpTrait::Zero" + traitKind + "s");
4110 break;
4111 case 1:
4112 opClass.addTrait(trait: "::mlir::OpTrait::One" + traitKind);
4113 break;
4114 default:
4115 opClass.addTrait(trait: "::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
4116 ">::Impl");
4117 break;
4118 }
4119}
4120
4121void OpEmitter::genTraits() {
4122 // Add region size trait.
4123 unsigned numRegions = op.getNumRegions();
4124 unsigned numVariadicRegions = op.getNumVariadicRegions();
4125 addSizeCountTrait(opClass, traitKind: "Region", numTotal: numRegions, numVariadic: numVariadicRegions);
4126
4127 // Add result size traits.
4128 int numResults = op.getNumResults();
4129 int numVariadicResults = op.getNumVariableLengthResults();
4130 addSizeCountTrait(opClass, traitKind: "Result", numTotal: numResults, numVariadic: numVariadicResults);
4131
4132 // For single result ops with a known specific type, generate a OneTypedResult
4133 // trait.
4134 if (numResults == 1 && numVariadicResults == 0) {
4135 auto cppName = op.getResults().begin()->constraint.getCppType();
4136 opClass.addTrait(trait: "::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
4137 }
4138
4139 // Add successor size trait.
4140 unsigned numSuccessors = op.getNumSuccessors();
4141 unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
4142 addSizeCountTrait(opClass, traitKind: "Successor", numTotal: numSuccessors, numVariadic: numVariadicSuccessors);
4143
4144 // Add variadic size trait and normal op traits.
4145 int numOperands = op.getNumOperands();
4146 int numVariadicOperands = op.getNumVariableLengthOperands();
4147
4148 // Add operand size trait.
4149 addSizeCountTrait(opClass, traitKind: "Operand", numTotal: numOperands, numVariadic: numVariadicOperands);
4150
4151 // The op traits defined internal are ensured that they can be verified
4152 // earlier.
4153 for (const auto &trait : op.getTraits()) {
4154 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
4155 if (opTrait->isStructuralOpTrait())
4156 opClass.addTrait(trait: opTrait->getFullyQualifiedTraitName());
4157 }
4158 }
4159
4160 // OpInvariants wrapps the verifyInvariants which needs to be run before
4161 // native/interface traits and after all the traits with `StructuralOpTrait`.
4162 opClass.addTrait(trait: "::mlir::OpTrait::OpInvariants");
4163
4164 if (emitHelper.hasNonEmptyPropertiesStruct())
4165 opClass.addTrait(trait: "::mlir::BytecodeOpInterface::Trait");
4166
4167 // Add the native and interface traits.
4168 for (const auto &trait : op.getTraits()) {
4169 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
4170 if (!opTrait->isStructuralOpTrait())
4171 opClass.addTrait(trait: opTrait->getFullyQualifiedTraitName());
4172 } else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(Val: &trait)) {
4173 opClass.addTrait(trait: opTrait->getFullyQualifiedTraitName());
4174 }
4175 }
4176}
4177
4178void OpEmitter::genOpNameGetter() {
4179 auto *method = opClass.addStaticMethod<Method::Constexpr>(
4180 retType: "::llvm::StringLiteral", name: "getOperationName");
4181 ERROR_IF_PRUNED(method, "getOperationName", op);
4182 method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName()
4183 << "\");";
4184}
4185
4186void OpEmitter::genOpAsmInterface() {
4187 // If the user only has one results or specifically added the Asm trait,
4188 // then don't generate it for them. We specifically only handle multi result
4189 // operations, because the name of a single result in the common case is not
4190 // interesting(generally 'result'/'output'/etc.).
4191 // TODO: We could also add a flag to allow operations to opt in to this
4192 // generation, even if they only have a single operation.
4193 int numResults = op.getNumResults();
4194 if (numResults <= 1 || op.getTrait(trait: "::mlir::OpAsmOpInterface::Trait"))
4195 return;
4196
4197 SmallVector<StringRef, 4> resultNames(numResults);
4198 for (int i = 0; i != numResults; ++i)
4199 resultNames[i] = op.getResultName(index: i);
4200
4201 // Don't add the trait if none of the results have a valid name.
4202 if (llvm::all_of(Range&: resultNames, P: [](StringRef name) { return name.empty(); }))
4203 return;
4204 opClass.addTrait(trait: "::mlir::OpAsmOpInterface::Trait");
4205
4206 // Generate the right accessor for the number of results.
4207 auto *method = opClass.addMethod(
4208 retType: "void", name: "getAsmResultNames",
4209 args: MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
4210 ERROR_IF_PRUNED(method, "getAsmResultNames", op);
4211 auto &body = method->body();
4212 for (int i = 0; i != numResults; ++i) {
4213 body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n"
4214 << " if (!resultGroup" << i << ".empty())\n"
4215 << " setNameFn(*resultGroup" << i << ".begin(), \""
4216 << resultNames[i] << "\");\n";
4217 }
4218}
4219
4220//===----------------------------------------------------------------------===//
4221// OpOperandAdaptor emitter
4222//===----------------------------------------------------------------------===//
4223
4224namespace {
4225// Helper class to emit Op operand adaptors to an output stream. Operand
4226// adaptors are wrappers around random access ranges that provide named operand
4227// getters identical to those defined in the Op.
4228// This currently generates 3 classes per Op:
4229// * A Base class within the 'detail' namespace, which contains all logic and
4230// members independent of the random access range that is indexed into.
4231// In other words, it contains all the attribute and region getters.
4232// * A templated class named '{OpName}GenericAdaptor' with a template parameter
4233// 'RangeT' that is indexed into by the getters to access the operands.
4234// It contains all getters to access operands and inherits from the previous
4235// class.
4236// * A class named '{OpName}Adaptor', which inherits from the 'GenericAdaptor'
4237// with 'mlir::ValueRange' as template parameter. It adds a constructor from
4238// an instance of the op type and a verify function.
4239class OpOperandAdaptorEmitter {
4240public:
4241 static void
4242 emitDecl(const Operator &op,
4243 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4244 raw_ostream &os);
4245 static void
4246 emitDef(const Operator &op,
4247 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4248 raw_ostream &os);
4249
4250private:
4251 explicit OpOperandAdaptorEmitter(
4252 const Operator &op,
4253 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
4254
4255 // Add verification function. This generates a verify method for the adaptor
4256 // which verifies all the op-independent attribute constraints.
4257 void addVerification();
4258
4259 // The operation for which to emit an adaptor.
4260 const Operator &op;
4261
4262 // The generated adaptor classes.
4263 Class genericAdaptorBase;
4264 Class genericAdaptor;
4265 Class adaptor;
4266
4267 // The emitter containing all of the locally emitted verification functions.
4268 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
4269
4270 // Helper for emitting adaptor code.
4271 OpOrAdaptorHelper emitHelper;
4272};
4273} // namespace
4274
4275OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
4276 const Operator &op,
4277 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
4278 : op(op), genericAdaptorBase(op.getGenericAdaptorName() + "Base"),
4279 genericAdaptor(op.getGenericAdaptorName()), adaptor(op.getAdaptorName()),
4280 staticVerifierEmitter(staticVerifierEmitter),
4281 emitHelper(op, /*emitForOp=*/false) {
4282
4283 FmtContext fctx;
4284 fctx.withBuilder(subst: odsBuilder);
4285
4286 genericAdaptorBase.declare<VisibilityDeclaration>(args: Visibility::Public);
4287 bool useProperties = emitHelper.hasProperties();
4288 if (useProperties) {
4289 // Define the properties struct with multiple members.
4290 using ConstArgument =
4291 llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
4292 SmallVector<ConstArgument> attrOrProperties;
4293 for (const std::pair<StringRef, AttributeMetadata> &it :
4294 emitHelper.getAttrMetadata()) {
4295 if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
4296 attrOrProperties.push_back(Elt: &it.second);
4297 }
4298 for (const NamedProperty &prop : op.getProperties())
4299 attrOrProperties.push_back(Elt: &prop);
4300 if (emitHelper.getOperandSegmentsSize())
4301 attrOrProperties.push_back(Elt: &emitHelper.getOperandSegmentsSize().value());
4302 if (emitHelper.getResultSegmentsSize())
4303 attrOrProperties.push_back(Elt: &emitHelper.getResultSegmentsSize().value());
4304 std::string declarations = " struct Properties {\n";
4305 llvm::raw_string_ostream os(declarations);
4306 std::string comparator =
4307 " bool operator==(const Properties &rhs) const {\n"
4308 " return \n";
4309 llvm::raw_string_ostream comparatorOs(comparator);
4310 for (const auto &attrOrProp : attrOrProperties) {
4311 if (const auto *namedProperty =
4312 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
4313 StringRef name = namedProperty->name;
4314 if (name.empty())
4315 report_fatal_error(reason: "missing name for property");
4316 std::string camelName =
4317 convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
4318 auto &prop = namedProperty->prop;
4319 // Generate the data member using the storage type.
4320 os << " using " << name << "Ty = " << prop.getStorageType() << ";\n"
4321 << " " << name << "Ty " << name;
4322 if (prop.hasStorageTypeValueOverride())
4323 os << " = " << prop.getStorageTypeValueOverride();
4324 else if (prop.hasDefaultValue())
4325 os << " = " << tgfmt(fmt: prop.getDefaultValue(), ctx: &fctx);
4326 comparatorOs << " rhs." << name << " == this->" << name
4327 << " &&\n";
4328 // Emit accessors using the interface type.
4329 const char *accessorFmt = R"decl(;
4330 {0} get{1}() const {
4331 auto &propStorage = this->{2};
4332 return {3};
4333 }
4334 void set{1}({0} propValue) {
4335 auto &propStorage = this->{2};
4336 {4};
4337 }
4338)decl";
4339 FmtContext fctx;
4340 os << formatv(Fmt: accessorFmt, Vals: prop.getInterfaceType(), Vals&: camelName, Vals&: name,
4341 Vals: tgfmt(fmt: prop.getConvertFromStorageCall(),
4342 ctx: &fctx.addSubst(placeholder: "_storage", subst: propertyStorage)),
4343 Vals: tgfmt(fmt: prop.getAssignToStorageCall(),
4344 ctx: &fctx.addSubst(placeholder: "_value", subst: propertyValue)
4345 .addSubst(placeholder: "_storage", subst: propertyStorage)));
4346 continue;
4347 }
4348 const auto *namedAttr =
4349 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
4350 const Attribute *attr = nullptr;
4351 if (namedAttr->constraint)
4352 attr = &*namedAttr->constraint;
4353 StringRef name = namedAttr->attrName;
4354 if (name.empty())
4355 report_fatal_error(reason: "missing name for property attr");
4356 std::string camelName =
4357 convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
4358 // Generate the data member using the storage type.
4359 StringRef storageType;
4360 if (attr) {
4361 storageType = attr->getStorageType();
4362 } else {
4363 if (name != operandSegmentAttrName && name != resultSegmentAttrName) {
4364 report_fatal_error(reason: "unexpected AttributeMetadata");
4365 }
4366 // TODO: update to use native integers.
4367 storageType = "::mlir::DenseI32ArrayAttr";
4368 }
4369 os << " using " << name << "Ty = " << storageType << ";\n"
4370 << " " << name << "Ty " << name << ";\n";
4371 comparatorOs << " rhs." << name << " == this->" << name << " &&\n";
4372
4373 // Emit accessors using the interface type.
4374 if (attr) {
4375 const char *accessorFmt = R"decl(
4376 auto get{0}() const {
4377 auto &propStorage = this->{1};
4378 return ::llvm::{2}<{3}>(propStorage);
4379 }
4380 void set{0}(const {3} &propValue) {
4381 this->{1} = propValue;
4382 }
4383)decl";
4384 os << formatv(Fmt: accessorFmt, Vals&: camelName, Vals&: name,
4385 Vals: attr->isOptional() || attr->hasDefaultValue()
4386 ? "dyn_cast_or_null"
4387 : "cast",
4388 Vals&: storageType);
4389 }
4390 }
4391 comparatorOs << " true;\n }\n"
4392 " bool operator!=(const Properties &rhs) const {\n"
4393 " return !(*this == rhs);\n"
4394 " }\n";
4395 os << comparator;
4396 os << " };\n";
4397
4398 if (attrOrProperties.empty())
4399 genericAdaptorBase.declare<UsingDeclaration>(args: "Properties",
4400 args: "::mlir::EmptyProperties");
4401 else
4402 genericAdaptorBase.declare<ExtraClassDeclaration>(
4403 args: std::move(declarations));
4404 }
4405 genericAdaptorBase.declare<VisibilityDeclaration>(args: Visibility::Protected);
4406 genericAdaptorBase.declare<Field>(args: "::mlir::DictionaryAttr", args: "odsAttrs");
4407 genericAdaptorBase.declare<Field>(args: "::std::optional<::mlir::OperationName>",
4408 args: "odsOpName");
4409 if (useProperties)
4410 genericAdaptorBase.declare<Field>(args: "Properties", args: "properties");
4411 genericAdaptorBase.declare<Field>(args: "::mlir::RegionRange", args: "odsRegions");
4412
4413 genericAdaptor.addTemplateParam(param: "RangeT");
4414 genericAdaptor.addField(type: "RangeT", name: "odsOperands");
4415 genericAdaptor.addParent(
4416 parent: ParentClass("detail::" + genericAdaptorBase.getClassName()));
4417 genericAdaptor.declare<UsingDeclaration>(
4418 args: "ValueT", args: "::llvm::detail::ValueOfRange<RangeT>");
4419 genericAdaptor.declare<UsingDeclaration>(
4420 args: "Base", args: "detail::" + genericAdaptorBase.getClassName());
4421
4422 const auto *attrSizedOperands =
4423 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments");
4424 {
4425 SmallVector<MethodParameter> paramList;
4426 if (useProperties) {
4427 // Properties can't be given a default constructor here due to Properties
4428 // struct being defined in the enclosing class which isn't complete by
4429 // here.
4430 paramList.emplace_back(Args: "::mlir::DictionaryAttr", Args: "attrs");
4431 paramList.emplace_back(Args: "const Properties &", Args: "properties");
4432 } else {
4433 paramList.emplace_back(Args: "::mlir::DictionaryAttr", Args: "attrs", Args: "{}");
4434 paramList.emplace_back(Args: "const ::mlir::EmptyProperties &", Args: "properties",
4435 Args: "{}");
4436 }
4437 paramList.emplace_back(Args: "::mlir::RegionRange", Args: "regions", Args: "{}");
4438 auto *baseConstructor =
4439 genericAdaptorBase.addConstructor<Method::Inline>(args&: paramList);
4440 baseConstructor->addMemberInitializer(name: "odsAttrs", value: "attrs");
4441 if (useProperties)
4442 baseConstructor->addMemberInitializer(name: "properties", value: "properties");
4443 baseConstructor->addMemberInitializer(name: "odsRegions", value: "regions");
4444
4445 MethodBody &body = baseConstructor->body();
4446 body.indent() << "if (odsAttrs)\n";
4447 body.indent() << formatv(
4448 Fmt: "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
4449 Vals: op.getOperationName());
4450
4451 paramList.insert(I: paramList.begin(), Elt: MethodParameter("RangeT", "values"));
4452 auto *constructor = genericAdaptor.addConstructor(args&: paramList);
4453 constructor->addMemberInitializer(name: "Base", value: "attrs, properties, regions");
4454 constructor->addMemberInitializer(name: "odsOperands", value: "values");
4455
4456 // Add a forwarding constructor to the previous one that accepts
4457 // OpaqueProperties instead and check for null and perform the cast to the
4458 // actual properties type.
4459 paramList[1] = MethodParameter("::mlir::DictionaryAttr", "attrs");
4460 paramList[2] = MethodParameter("::mlir::OpaqueProperties", "properties");
4461 auto *opaquePropertiesConstructor =
4462 genericAdaptor.addConstructor(args: std::move(paramList));
4463 if (useProperties) {
4464 opaquePropertiesConstructor->addMemberInitializer(
4465 name: genericAdaptor.getClassName(),
4466 value: "values, "
4467 "attrs, "
4468 "(properties ? *properties.as<Properties *>() : Properties{}), "
4469 "regions");
4470 } else {
4471 opaquePropertiesConstructor->addMemberInitializer(
4472 name: genericAdaptor.getClassName(),
4473 value: "values, "
4474 "attrs, "
4475 "(properties ? *properties.as<::mlir::EmptyProperties *>() : "
4476 "::mlir::EmptyProperties{}), "
4477 "regions");
4478 }
4479
4480 // Add forwarding constructor that constructs Properties.
4481 if (useProperties) {
4482 SmallVector<MethodParameter> paramList;
4483 paramList.emplace_back(Args: "RangeT", Args: "values");
4484 paramList.emplace_back(Args: "::mlir::DictionaryAttr", Args: "attrs",
4485 Args: attrSizedOperands ? "" : "nullptr");
4486 auto *noPropertiesConstructor =
4487 genericAdaptor.addConstructor(args: std::move(paramList));
4488 noPropertiesConstructor->addMemberInitializer(
4489 name: genericAdaptor.getClassName(), value: "values, "
4490 "attrs, "
4491 "Properties{}, "
4492 "{}");
4493 }
4494 }
4495
4496 // Create a constructor that creates a new generic adaptor by copying
4497 // everything from another adaptor, except for the values.
4498 {
4499 SmallVector<MethodParameter> paramList;
4500 paramList.emplace_back(Args: "RangeT", Args: "values");
4501 paramList.emplace_back(Args: "const " + op.getGenericAdaptorName() + "Base &",
4502 Args: "base");
4503 auto *constructor =
4504 genericAdaptor.addConstructor<Method::Inline>(args&: paramList);
4505 constructor->addMemberInitializer(name: "Base", value: "base");
4506 constructor->addMemberInitializer(name: "odsOperands", value: "values");
4507 }
4508
4509 // Create constructors constructing the adaptor from an instance of the op.
4510 // This takes the attributes, properties and regions from the op instance
4511 // and the value range from the parameter.
4512 {
4513 // Base class is in the cpp file and can simply access the members of the op
4514 // class to initialize the template independent fields. If the op doesn't
4515 // have properties, we can emit a generic constructor inline. Otherwise,
4516 // emit it out-of-line because we need the op to be defined.
4517 Constructor *constructor;
4518 if (useProperties) {
4519 constructor = genericAdaptorBase.addConstructor(
4520 args: MethodParameter(op.getCppClassName(), "op"));
4521 } else {
4522 constructor = genericAdaptorBase.addConstructor<Method::Inline>(
4523 args: MethodParameter("::mlir::Operation *", "op"));
4524 }
4525 constructor->addMemberInitializer(name: "odsAttrs",
4526 value: "op->getRawDictionaryAttrs()");
4527 // Retrieve the operation name from the op directly.
4528 constructor->addMemberInitializer(name: "odsOpName", value: "op->getName()");
4529 if (useProperties)
4530 constructor->addMemberInitializer(name: "properties", value: "op.getProperties()");
4531 constructor->addMemberInitializer(name: "odsRegions", value: "op->getRegions()");
4532
4533 // Generic adaptor is templated and therefore defined inline in the header.
4534 // We cannot use the Op class here as it is an incomplete type (we have a
4535 // circular reference between the two).
4536 // Use a template trick to make the constructor be instantiated at call site
4537 // when the op class is complete.
4538 constructor = genericAdaptor.addConstructor(
4539 args: MethodParameter("RangeT", "values"), args: MethodParameter("LateInst", "op"));
4540 constructor->addTemplateParam(param: "LateInst = " + op.getCppClassName());
4541 constructor->addTemplateParam(
4542 param: "= std::enable_if_t<std::is_same_v<LateInst, " + op.getCppClassName() +
4543 ">>");
4544 constructor->addMemberInitializer(name: "Base", value: "op");
4545 constructor->addMemberInitializer(name: "odsOperands", value: "values");
4546 }
4547
4548 std::string sizeAttrInit;
4549 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
4550 if (op.getDialect().usePropertiesForAttributes())
4551 sizeAttrInit =
4552 formatv(Fmt: adapterSegmentSizeAttrInitCodeProperties,
4553 Vals: llvm::formatv(Fmt: "getProperties().operandSegmentSizes"));
4554 else
4555 sizeAttrInit = formatv(Fmt: adapterSegmentSizeAttrInitCode,
4556 Vals: emitHelper.getAttr(attrName: operandSegmentAttrName));
4557 }
4558 generateNamedOperandGetters(op, opClass&: genericAdaptor,
4559 /*genericAdaptorBase=*/&genericAdaptorBase,
4560 /*sizeAttrInit=*/sizeAttrInit,
4561 /*rangeType=*/"RangeT",
4562 /*rangeElementType=*/"ValueT",
4563 /*rangeBeginCall=*/"odsOperands.begin()",
4564 /*rangeSizeCall=*/"odsOperands.size()",
4565 /*getOperandCallPattern=*/"odsOperands[{0}]");
4566
4567 // Any invalid overlap for `getOperands` will have been diagnosed before
4568 // here already.
4569 if (auto *m = genericAdaptor.addMethod(retType: "RangeT", name: "getOperands"))
4570 m->body() << " return odsOperands;";
4571
4572 fctx.withBuilder(subst: "::mlir::Builder(odsAttrs.getContext())");
4573
4574 // Generate named accessor with Attribute return type.
4575 auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName,
4576 Attribute attr) {
4577 // The method body is trivial if the attribute does not have a default
4578 // value, in which case the default value may be arbitrary code.
4579 auto *method = genericAdaptorBase.addMethod(
4580 retType: attr.getStorageType(), name: emitName + "Attr",
4581 properties: attr.hasDefaultValue() || !useProperties ? Method::Properties::None
4582 : Method::Properties::Inline);
4583 ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
4584 auto &body = method->body().indent();
4585 if (!useProperties)
4586 body << "assert(odsAttrs && \"no attributes when constructing "
4587 "adapter\");\n";
4588 body << formatv(
4589 Fmt: "auto attr = ::llvm::{1}<{2}>({0});\n", Vals: emitHelper.getAttr(attrName: name),
4590 Vals: attr.hasDefaultValue() || attr.isOptional() ? "dyn_cast_or_null"
4591 : "cast",
4592 Vals: attr.getStorageType());
4593
4594 if (attr.hasDefaultValue() && attr.isOptional()) {
4595 // Use the default value if attribute is not set.
4596 // TODO: this is inefficient, we are recreating the attribute for every
4597 // call. This should be set instead.
4598 std::string defaultValue =
4599 std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx,
4600 vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx)));
4601 body << "if (!attr)\n attr = " << defaultValue << ";\n";
4602 }
4603 body << "return attr;\n";
4604 };
4605
4606 if (useProperties) {
4607 auto *m = genericAdaptorBase.addInlineMethod(retType: "const Properties &",
4608 name: "getProperties");
4609 ERROR_IF_PRUNED(m, "Adaptor::getProperties", op);
4610 m->body() << " return properties;";
4611 }
4612 {
4613 auto *m = genericAdaptorBase.addInlineMethod(retType: "::mlir::DictionaryAttr",
4614 name: "getAttributes");
4615 ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
4616 m->body() << " return odsAttrs;";
4617 }
4618 for (auto &namedProp : op.getProperties()) {
4619 std::string name = op.getGetterName(name: namedProp.name);
4620 emitPropGetter(opClass&: genericAdaptorBase, op, name, prop: namedProp.prop);
4621 }
4622
4623 for (auto &namedAttr : op.getAttributes()) {
4624 const auto &name = namedAttr.name;
4625 const auto &attr = namedAttr.attr;
4626 if (attr.isDerivedAttr())
4627 continue;
4628 std::string emitName = op.getGetterName(name);
4629 emitAttrWithStorageType(name, emitName, attr);
4630 emitAttrGetterWithReturnType(fctx, opClass&: genericAdaptorBase, op, name: emitName, attr);
4631 }
4632
4633 unsigned numRegions = op.getNumRegions();
4634 for (unsigned i = 0; i < numRegions; ++i) {
4635 const auto &region = op.getRegion(index: i);
4636 if (region.name.empty())
4637 continue;
4638
4639 // Generate the accessors for a variadic region.
4640 std::string name = op.getGetterName(name: region.name);
4641 if (region.isVariadic()) {
4642 auto *m = genericAdaptorBase.addInlineMethod(retType: "::mlir::RegionRange", name);
4643 ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
4644 m->body() << formatv(Fmt: " return odsRegions.drop_front({0});", Vals&: i);
4645 continue;
4646 }
4647
4648 auto *m = genericAdaptorBase.addInlineMethod(retType: "::mlir::Region &", name);
4649 ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
4650 m->body() << formatv(Fmt: " return *odsRegions[{0}];", Vals&: i);
4651 }
4652 if (numRegions > 0) {
4653 // Any invalid overlap for `getRegions` will have been diagnosed before
4654 // here already.
4655 if (auto *m = genericAdaptorBase.addInlineMethod(retType: "::mlir::RegionRange",
4656 name: "getRegions"))
4657 m->body() << " return odsRegions;";
4658 }
4659
4660 StringRef genericAdaptorClassName = genericAdaptor.getClassName();
4661 adaptor.addParent(parent: ParentClass(genericAdaptorClassName))
4662 .addTemplateParam(param: "::mlir::ValueRange");
4663 adaptor.declare<VisibilityDeclaration>(args: Visibility::Public);
4664 adaptor.declare<UsingDeclaration>(args: genericAdaptorClassName +
4665 "::" + genericAdaptorClassName);
4666 {
4667 // Constructor taking the Op as single parameter.
4668 auto *constructor =
4669 adaptor.addConstructor(args: MethodParameter(op.getCppClassName(), "op"));
4670 constructor->addMemberInitializer(name&: genericAdaptorClassName,
4671 value: "op->getOperands(), op");
4672 }
4673
4674 // Add verification function.
4675 addVerification();
4676
4677 genericAdaptorBase.finalize();
4678 genericAdaptor.finalize();
4679 adaptor.finalize();
4680}
4681
4682void OpOperandAdaptorEmitter::addVerification() {
4683 auto *method = adaptor.addMethod(retType: "::llvm::LogicalResult", name: "verify",
4684 args: MethodParameter("::mlir::Location", "loc"));
4685 ERROR_IF_PRUNED(method, "verify", op);
4686 auto &body = method->body();
4687 bool useProperties = emitHelper.hasProperties();
4688
4689 FmtContext verifyCtx;
4690 populateSubstitutions(emitHelper, ctx&: verifyCtx);
4691 genAttributeVerifier(emitHelper, ctx&: verifyCtx, body, staticVerifierEmitter,
4692 useProperties);
4693
4694 body << " return ::mlir::success();";
4695}
4696
4697void OpOperandAdaptorEmitter::emitDecl(
4698 const Operator &op,
4699 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4700 raw_ostream &os) {
4701 OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
4702 {
4703 NamespaceEmitter ns(os, "detail");
4704 emitter.genericAdaptorBase.writeDeclTo(rawOs&: os);
4705 }
4706 emitter.genericAdaptor.writeDeclTo(rawOs&: os);
4707 emitter.adaptor.writeDeclTo(rawOs&: os);
4708}
4709
4710void OpOperandAdaptorEmitter::emitDef(
4711 const Operator &op,
4712 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4713 raw_ostream &os) {
4714 OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
4715 {
4716 NamespaceEmitter ns(os, "detail");
4717 emitter.genericAdaptorBase.writeDefTo(rawOs&: os);
4718 }
4719 emitter.genericAdaptor.writeDefTo(rawOs&: os);
4720 emitter.adaptor.writeDefTo(rawOs&: os);
4721}
4722
4723/// Emit the class declarations or definitions for the given op defs.
4724static void
4725emitOpClasses(const RecordKeeper &records,
4726 const std::vector<const Record *> &defs, raw_ostream &os,
4727 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4728 bool emitDecl) {
4729 if (defs.empty())
4730 return;
4731
4732 for (auto *def : defs) {
4733 Operator op(*def);
4734 if (emitDecl) {
4735 {
4736 NamespaceEmitter emitter(os, op.getCppNamespace());
4737 os << formatv(Fmt: opCommentHeader, Vals: op.getQualCppClassName(),
4738 Vals: "declarations");
4739 OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os);
4740 OpEmitter::emitDecl(op, os, staticVerifierEmitter);
4741 }
4742 // Emit the TypeID explicit specialization to have a single definition.
4743 if (!op.getCppNamespace().empty())
4744 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
4745 << "::" << op.getCppClassName() << ")\n\n";
4746 } else {
4747 {
4748 NamespaceEmitter emitter(os, op.getCppNamespace());
4749 os << formatv(Fmt: opCommentHeader, Vals: op.getQualCppClassName(), Vals: "definitions");
4750 OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os);
4751 OpEmitter::emitDef(op, os, staticVerifierEmitter);
4752 }
4753 // Emit the TypeID explicit specialization to have a single definition.
4754 if (!op.getCppNamespace().empty())
4755 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
4756 << "::" << op.getCppClassName() << ")\n\n";
4757 }
4758 }
4759}
4760
4761/// Emit the declarations for the provided op classes.
4762static void emitOpClassDecls(const RecordKeeper &records,
4763 const std::vector<const Record *> &defs,
4764 raw_ostream &os) {
4765 // First emit forward declaration for each class, this allows them to refer
4766 // to each others in traits for example.
4767 for (auto *def : defs) {
4768 Operator op(*def);
4769 NamespaceEmitter emitter(os, op.getCppNamespace());
4770 std::string comments = tblgen::emitSummaryAndDescComments(
4771 summary: op.getSummary(), description: op.getDescription());
4772 if (!comments.empty()) {
4773 os << comments << "\n";
4774 }
4775 os << "class " << op.getCppClassName() << ";\n";
4776 }
4777
4778 // Emit the op class declarations.
4779 IfDefScope scope("GET_OP_CLASSES", os);
4780 if (defs.empty())
4781 return;
4782 StaticVerifierFunctionEmitter staticVerifierEmitter(os, records);
4783 staticVerifierEmitter.collectOpConstraints(opDefs: defs);
4784 emitOpClasses(records, defs, os, staticVerifierEmitter,
4785 /*emitDecl=*/true);
4786}
4787
4788/// Emit the definitions for the provided op classes.
4789static void emitOpClassDefs(const RecordKeeper &records,
4790 ArrayRef<const Record *> defs, raw_ostream &os,
4791 StringRef constraintPrefix = "") {
4792 if (defs.empty())
4793 return;
4794
4795 // Generate all of the locally instantiated methods first.
4796 StaticVerifierFunctionEmitter staticVerifierEmitter(os, records,
4797 constraintPrefix);
4798 os << formatv(Fmt: opCommentHeader, Vals: "Local Utility Method", Vals: "Definitions");
4799 staticVerifierEmitter.collectOpConstraints(opDefs: defs);
4800 staticVerifierEmitter.emitOpConstraints(opDefs: defs);
4801
4802 // Emit the classes.
4803 emitOpClasses(records, defs, os, staticVerifierEmitter,
4804 /*emitDecl=*/false);
4805}
4806
4807/// Emit op declarations for all op records.
4808static bool emitOpDecls(const RecordKeeper &records, raw_ostream &os) {
4809 emitSourceFileHeader(Desc: "Op Declarations", OS&: os, Record: records);
4810
4811 std::vector<const Record *> defs = getRequestedOpDefinitions(records);
4812 emitOpClassDecls(records, defs, os);
4813
4814 // If we are generating sharded op definitions, emit the sharded op
4815 // registration hooks.
4816 SmallVector<ArrayRef<const Record *>, 4> shardedDefs;
4817 shardOpDefinitions(defs, shardedDefs);
4818 if (defs.empty() || shardedDefs.size() <= 1)
4819 return false;
4820
4821 Dialect dialect = Operator(defs.front()).getDialect();
4822 NamespaceEmitter ns(os, dialect);
4823
4824 const char *const opRegistrationHook =
4825 "void register{0}Operations{1}({2}::{0} *dialect);\n";
4826 os << formatv(Fmt: opRegistrationHook, Vals: dialect.getCppClassName(), Vals: "",
4827 Vals: dialect.getCppNamespace());
4828 for (unsigned i = 0; i < shardedDefs.size(); ++i) {
4829 os << formatv(Fmt: opRegistrationHook, Vals: dialect.getCppClassName(), Vals&: i,
4830 Vals: dialect.getCppNamespace());
4831 }
4832
4833 return false;
4834}
4835
4836/// Generate the dialect op registration hook and the op class definitions for a
4837/// shard of ops.
4838static void emitOpDefShard(const RecordKeeper &records,
4839 ArrayRef<const Record *> defs,
4840 const Dialect &dialect, unsigned shardIndex,
4841 unsigned shardCount, raw_ostream &os) {
4842 std::string shardGuard = "GET_OP_DEFS_";
4843 std::string indexStr = std::to_string(val: shardIndex);
4844 shardGuard += indexStr;
4845 IfDefScope scope(shardGuard, os);
4846
4847 // Emit the op registration hook in the first shard.
4848 const char *const opRegistrationHook =
4849 "void {0}::register{1}Operations{2}({0}::{1} *dialect) {{\n";
4850 if (shardIndex == 0) {
4851 os << formatv(Fmt: opRegistrationHook, Vals: dialect.getCppNamespace(),
4852 Vals: dialect.getCppClassName(), Vals: "");
4853 for (unsigned i = 0; i < shardCount; ++i) {
4854 os << formatv(Fmt: " {0}::register{1}Operations{2}(dialect);\n",
4855 Vals: dialect.getCppNamespace(), Vals: dialect.getCppClassName(), Vals&: i);
4856 }
4857 os << "}\n";
4858 }
4859
4860 // Generate the per-shard op registration hook.
4861 os << formatv(Fmt: opCommentHeader, Vals: dialect.getCppClassName(),
4862 Vals: "Op Registration Hook")
4863 << formatv(Fmt: opRegistrationHook, Vals: dialect.getCppNamespace(),
4864 Vals: dialect.getCppClassName(), Vals&: shardIndex);
4865 for (const Record *def : defs) {
4866 os << formatv(Fmt: " ::mlir::RegisteredOperationName::insert<{0}>(*dialect);\n",
4867 Vals: Operator(def).getQualCppClassName());
4868 }
4869 os << "}\n";
4870
4871 // Generate the per-shard op definitions.
4872 emitOpClassDefs(records, defs, os, constraintPrefix: indexStr);
4873}
4874
4875/// Emit op definitions for all op records.
4876static bool emitOpDefs(const RecordKeeper &records, raw_ostream &os) {
4877 emitSourceFileHeader(Desc: "Op Definitions", OS&: os, Record: records);
4878
4879 std::vector<const Record *> defs = getRequestedOpDefinitions(records);
4880 SmallVector<ArrayRef<const Record *>, 4> shardedDefs;
4881 shardOpDefinitions(defs, shardedDefs);
4882
4883 // If no shard was requested, emit the regular op list and class definitions.
4884 if (shardedDefs.size() == 1) {
4885 {
4886 IfDefScope scope("GET_OP_LIST", os);
4887 interleave(
4888 c: defs, os,
4889 each_fn: [&](const Record *def) { os << Operator(def).getQualCppClassName(); },
4890 separator: ",\n");
4891 }
4892 {
4893 IfDefScope scope("GET_OP_CLASSES", os);
4894 emitOpClassDefs(records, defs, os);
4895 }
4896 return false;
4897 }
4898
4899 if (defs.empty())
4900 return false;
4901 Dialect dialect = Operator(defs.front()).getDialect();
4902 for (auto [idx, value] : llvm::enumerate(First&: shardedDefs)) {
4903 emitOpDefShard(records, defs: value, dialect, shardIndex: idx, shardCount: shardedDefs.size(), os);
4904 }
4905 return false;
4906}
4907
4908static mlir::GenRegistration
4909 genOpDecls("gen-op-decls", "Generate op declarations",
4910 [](const RecordKeeper &records, raw_ostream &os) {
4911 return emitOpDecls(records, os);
4912 });
4913
4914static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
4915 [](const RecordKeeper &records,
4916 raw_ostream &os) {
4917 return emitOpDefs(records, os);
4918 });
4919

source code of mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp