1//===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpDefinitionsGen uses the description of operations to generate C++
10// definitions for ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "CppGenUtilities.h"
15#include "OpClass.h"
16#include "OpFormatGen.h"
17#include "OpGenHelpers.h"
18#include "mlir/TableGen/Argument.h"
19#include "mlir/TableGen/Attribute.h"
20#include "mlir/TableGen/Class.h"
21#include "mlir/TableGen/CodeGenHelpers.h"
22#include "mlir/TableGen/Format.h"
23#include "mlir/TableGen/GenInfo.h"
24#include "mlir/TableGen/Interfaces.h"
25#include "mlir/TableGen/Operator.h"
26#include "mlir/TableGen/Property.h"
27#include "mlir/TableGen/SideEffects.h"
28#include "mlir/TableGen/Trait.h"
29#include "llvm/ADT/BitVector.h"
30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/Sequence.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/ADT/StringSet.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/ErrorHandling.h"
37#include "llvm/Support/Signals.h"
38#include "llvm/Support/raw_ostream.h"
39#include "llvm/TableGen/Error.h"
40#include "llvm/TableGen/Record.h"
41#include "llvm/TableGen/TableGenBackend.h"
42
43#define DEBUG_TYPE "mlir-tblgen-opdefgen"
44
45using namespace llvm;
46using namespace mlir;
47using namespace mlir::tblgen;
48
49static const char *const tblgenNamePrefix = "tblgen_";
50static const char *const generatedArgName = "odsArg";
51static const char *const odsBuilder = "odsBuilder";
52static const char *const builderOpState = "odsState";
53static const char *const propertyStorage = "propStorage";
54static const char *const propertyValue = "propValue";
55static const char *const propertyAttr = "propAttr";
56static const char *const propertyDiag = "emitError";
57
58/// The names of the implicit attributes that contain variadic operand and
59/// result segment sizes.
60static const char *const operandSegmentAttrName = "operandSegmentSizes";
61static const char *const resultSegmentAttrName = "resultSegmentSizes";
62
63/// Code for an Op to lookup an attribute. Uses cached identifiers and subrange
64/// lookup.
65///
66/// {0}: Code snippet to get the attribute's name or identifier.
67/// {1}: The lower bound on the sorted subrange.
68/// {2}: The upper bound on the sorted subrange.
69/// {3}: Code snippet to get the array of named attributes.
70/// {4}: "Named" to get the named attribute.
71static const char *const subrangeGetAttr =
72 "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - "
73 "{2}, {0})";
74
75/// The logic to calculate the actual value range for a declared operand/result
76/// of an op with variadic operands/results. Note that this logic is not for
77/// general use; it assumes all variadic operands/results must have the same
78/// number of values.
79///
80/// {0}: The list of whether each declared operand/result is variadic.
81/// {1}: The total number of non-variadic operands/results.
82/// {2}: The total number of variadic operands/results.
83/// {3}: The total number of actual values.
84/// {4}: "operand" or "result".
85static const char *const sameVariadicSizeValueRangeCalcCode = R"(
86 bool isVariadic[] = {{{0}};
87 int prevVariadicCount = 0;
88 for (unsigned i = 0; i < index; ++i)
89 if (isVariadic[i]) ++prevVariadicCount;
90
91 // Calculate how many dynamic values a static variadic {4} corresponds to.
92 // This assumes all static variadic {4}s have the same dynamic value count.
93 int variadicSize = ({3} - {1}) / {2};
94 // `index` passed in as the parameter is the static index which counts each
95 // {4} (variadic or not) as size 1. So here for each previous static variadic
96 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
97 // value pack for this static {4} starts.
98 int start = index + (variadicSize - 1) * prevVariadicCount;
99 int size = isVariadic[index] ? variadicSize : 1;
100 return {{start, size};
101)";
102
103/// The logic to calculate the actual value range for a declared operand/result
104/// of an op with variadic operands/results. Note that this logic is assumes
105/// the op has an attribute specifying the size of each operand/result segment
106/// (variadic or not).
107static const char *const attrSizedSegmentValueRangeCalcCode = R"(
108 unsigned start = 0;
109 for (unsigned i = 0; i < index; ++i)
110 start += sizeAttr[i];
111 return {start, sizeAttr[index]};
112)";
113/// The code snippet to initialize the sizes for the value range calculation.
114///
115/// {0}: The code to get the attribute.
116static const char *const adapterSegmentSizeAttrInitCode = R"(
117 assert({0} && "missing segment size attribute for op");
118 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
119)";
120static const char *const adapterSegmentSizeAttrInitCodeProperties = R"(
121 ::llvm::ArrayRef<int32_t> sizeAttr = {0};
122)";
123
124/// The code snippet to initialize the sizes for the value range calculation.
125///
126/// {0}: The code to get the attribute.
127static const char *const opSegmentSizeAttrInitCode = R"(
128 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
129)";
130
131/// The logic to calculate the actual value range for a declared operand
132/// of an op with variadic of variadic operands within the OpAdaptor.
133///
134/// {0}: The name of the segment attribute.
135/// {1}: The index of the main operand.
136/// {2}: The range type of adaptor.
137static const char *const variadicOfVariadicAdaptorCalcCode = R"(
138 auto tblgenTmpOperands = getODSOperands({1});
139 auto sizes = {0}();
140
141 ::llvm::SmallVector<{2}> tblgenTmpOperandGroups;
142 for (int i = 0, e = sizes.size(); i < e; ++i) {{
143 tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(sizes[i]));
144 tblgenTmpOperands = tblgenTmpOperands.drop_front(sizes[i]);
145 }
146 return tblgenTmpOperandGroups;
147)";
148
149/// The logic to build a range of either operand or result values.
150///
151/// {0}: The begin iterator of the actual values.
152/// {1}: The call to generate the start and length of the value range.
153static const char *const valueRangeReturnCode = R"(
154 auto valueRange = {1};
155 return {{std::next({0}, valueRange.first),
156 std::next({0}, valueRange.first + valueRange.second)};
157)";
158
159/// Parse operand/result segment_size property.
160/// {0}: Number of elements in the segment array
161static const char *const parseTextualSegmentSizeFormat = R"(
162 size_t i = 0;
163 auto parseElem = [&]() -> ::mlir::ParseResult {
164 if (i >= {0})
165 return $_parser.emitError($_parser.getCurrentLocation(),
166 "expected `]` after {0} segment sizes");
167 if (failed($_parser.parseInteger($_storage[i])))
168 return ::mlir::failure();
169 i += 1;
170 return ::mlir::success();
171 };
172 if (failed($_parser.parseCommaSeparatedList(
173 ::mlir::AsmParser::Delimeter::Square, parseElem)))
174 return failure();
175 if (i < {0})
176 return $_parser.emitError($_parser.getCurrentLocation(),
177 "expected {0} segment sizes, found only ") << i;
178 return success();
179)";
180
181static const char *const printTextualSegmentSize = R"(
182 [&]() {
183 $_printer << '[';
184 ::llvm::interleaveComma($_storage, $_printer);
185 $_printer << ']';
186 }()
187)";
188
189/// Read operand/result segment_size from bytecode.
190static const char *const readBytecodeSegmentSizeNative = R"(
191 if ($_reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
192 return $_reader.readSparseArray(::llvm::MutableArrayRef($_storage));
193)";
194
195static const char *const readBytecodeSegmentSizeLegacy = R"(
196 if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
197 auto &$_storage = prop.$_propName;
198 ::mlir::DenseI32ArrayAttr attr;
199 if (::mlir::failed($_reader.readAttribute(attr))) return ::mlir::failure();
200 if (attr.size() > static_cast<int64_t>(sizeof($_storage) / sizeof(int32_t))) {
201 $_reader.emitError("size mismatch for operand/result_segment_size");
202 return ::mlir::failure();
203 }
204 ::llvm::copy(::llvm::ArrayRef<int32_t>(attr), $_storage.begin());
205 }
206)";
207
208/// Write operand/result segment_size to bytecode.
209static const char *const writeBytecodeSegmentSizeNative = R"(
210 if ($_writer.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
211 $_writer.writeSparseArray(::llvm::ArrayRef($_storage));
212)";
213
214/// Write operand/result segment_size to bytecode.
215static const char *const writeBytecodeSegmentSizeLegacy = R"(
216if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
217 auto &$_storage = prop.$_propName;
218 $_writer.writeAttribute(::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage));
219}
220)";
221
222/// A header for indicating code sections.
223///
224/// {0}: Some text, or a class name.
225/// {1}: Some text.
226static const char *const opCommentHeader = R"(
227//===----------------------------------------------------------------------===//
228// {0} {1}
229//===----------------------------------------------------------------------===//
230
231)";
232
233static const char *const inlineCreateBody = R"(
234 ::mlir::OperationState __state__({0}, getOperationName());
235 build(builder, __state__{1});
236 auto __res__ = ::llvm::dyn_cast<{2}>(builder.create(__state__));
237 assert(__res__ && "builder didn't return the right type");
238 return __res__;
239)";
240
241static const char *const inlineCreateBodyImplicitLoc = R"(
242 return create(builder, builder.getLoc(){0});
243)";
244
245//===----------------------------------------------------------------------===//
246// Utility structs and functions
247//===----------------------------------------------------------------------===//
248
249// Replaces all occurrences of `match` in `str` with `substitute`.
250static std::string replaceAllSubstrs(std::string str, const std::string &match,
251 const std::string &substitute) {
252 std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
253 while ((matchLoc = str.find(str: match, pos: scanLoc)) != std::string::npos) {
254 str = str.replace(pos: matchLoc, n: match.size(), str: substitute);
255 scanLoc = matchLoc + substitute.size();
256 }
257 return str;
258}
259
260// Returns whether the record has a value of the given name that can be returned
261// via getValueAsString.
262static inline bool hasStringAttribute(const Record &record,
263 StringRef fieldName) {
264 auto *valueInit = record.getValueInit(FieldName: fieldName);
265 return isa<StringInit>(Val: valueInit);
266}
267
268static std::string getArgumentName(const Operator &op, int index) {
269 const auto &operand = op.getOperand(index);
270 if (!operand.name.empty())
271 return std::string(operand.name);
272 return std::string(formatv(Fmt: "{0}_{1}", Vals: generatedArgName, Vals&: index));
273}
274
275// Returns true if we can use unwrapped value for the given `attr` in builders.
276static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
277 return attr.getReturnType() != attr.getStorageType() &&
278 // We need to wrap the raw value into an attribute in the builder impl
279 // so we need to make sure that the attribute specifies how to do that.
280 !attr.getConstBuilderTemplate().empty();
281}
282
283/// Build an attribute from a parameter value using the constant builder.
284static std::string constBuildAttrFromParam(const tblgen::Attribute &attr,
285 FmtContext &fctx,
286 StringRef paramName) {
287 std::string builderTemplate = attr.getConstBuilderTemplate().str();
288
289 // For StringAttr, its constant builder call will wrap the input in
290 // quotes, which is correct for normal string literals, but incorrect
291 // here given we use function arguments. So we need to strip the
292 // wrapping quotes.
293 if (StringRef(builderTemplate).contains(Other: "\"$0\""))
294 builderTemplate = replaceAllSubstrs(str: builderTemplate, match: "\"$0\"", substitute: "$0");
295
296 return tgfmt(fmt: builderTemplate, ctx: &fctx, vals&: paramName).str();
297}
298
299namespace {
300/// Metadata on a registered attribute. Given that attributes are stored in
301/// sorted order on operations, we can use information from ODS to deduce the
302/// number of required attributes less and and greater than each attribute,
303/// allowing us to search only a subrange of the attributes in ODS-generated
304/// getters.
305struct AttributeMetadata {
306 /// The attribute name.
307 StringRef attrName;
308 /// Whether the attribute is required.
309 bool isRequired;
310 /// The ODS attribute constraint. Not present for implicit attributes.
311 std::optional<Attribute> constraint;
312 /// The number of required attributes less than this attribute.
313 unsigned lowerBound = 0;
314 /// The number of required attributes greater than this attribute.
315 unsigned upperBound = 0;
316};
317
318/// Helper class to select between OpAdaptor and Op code templates.
319class OpOrAdaptorHelper {
320public:
321 OpOrAdaptorHelper(const Operator &op, bool emitForOp)
322 : op(op), emitForOp(emitForOp) {
323 computeAttrMetadata();
324 }
325
326 /// Object that wraps a functor in a stream operator for interop with
327 /// llvm::formatv.
328 class Formatter {
329 public:
330 template <typename Functor>
331 Formatter(Functor &&func) : func(std::forward<Functor>(func)) {}
332
333 std::string str() const {
334 std::string result;
335 llvm::raw_string_ostream os(result);
336 os << *this;
337 return os.str();
338 }
339
340 private:
341 std::function<raw_ostream &(raw_ostream &)> func;
342
343 friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) {
344 return fmt.func(os);
345 }
346 };
347
348 // Generate code for getting an attribute.
349 Formatter getAttr(StringRef attrName, bool isNamed = false) const {
350 assert(attrMetadata.count(attrName) && "expected attribute metadata");
351 return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
352 const AttributeMetadata &attr = attrMetadata.find(Key: attrName)->second;
353 if (hasProperties()) {
354 assert(!isNamed);
355 return os << "getProperties()." << attrName;
356 }
357 return os << formatv(Fmt: subrangeGetAttr, Vals: getAttrName(attrName),
358 Vals: attr.lowerBound, Vals: attr.upperBound, Vals: getAttrRange(),
359 Vals: isNamed ? "Named" : "");
360 };
361 }
362
363 // Generate code for getting the name of an attribute.
364 Formatter getAttrName(StringRef attrName) const {
365 return [this, attrName](raw_ostream &os) -> raw_ostream & {
366 if (emitForOp)
367 return os << op.getGetterName(name: attrName) << "AttrName()";
368 return os << formatv(Fmt: "{0}::{1}AttrName(*odsOpName)", Vals: op.getCppClassName(),
369 Vals: op.getGetterName(name: attrName));
370 };
371 }
372
373 // Get the code snippet for getting the named attribute range.
374 StringRef getAttrRange() const {
375 return emitForOp ? "(*this)->getAttrs()" : "odsAttrs";
376 }
377
378 // Get the prefix code for emitting an error.
379 Formatter emitErrorPrefix() const {
380 return [this](raw_ostream &os) -> raw_ostream & {
381 if (emitForOp)
382 return os << "emitOpError(";
383 return os << formatv(Fmt: "emitError(loc, \"'{0}' op \"",
384 Vals: op.getOperationName());
385 };
386 }
387
388 // Get the call to get an operand or segment of operands.
389 Formatter getOperand(unsigned index) const {
390 return [this, index](raw_ostream &os) -> raw_ostream & {
391 return os << formatv(Fmt: op.getOperand(index).isVariadic()
392 ? "this->getODSOperands({0})"
393 : "(*this->getODSOperands({0}).begin())",
394 Vals: index);
395 };
396 }
397
398 // Get the call to get a result of segment of results.
399 Formatter getResult(unsigned index) const {
400 return [this, index](raw_ostream &os) -> raw_ostream & {
401 if (!emitForOp)
402 return os << "<no results should be generated>";
403 return os << formatv(Fmt: op.getResult(index).isVariadic()
404 ? "this->getODSResults({0})"
405 : "(*this->getODSResults({0}).begin())",
406 Vals: index);
407 };
408 }
409
410 // Return whether an op instance is available.
411 bool isEmittingForOp() const { return emitForOp; }
412
413 // Return the ODS operation wrapper.
414 const Operator &getOp() const { return op; }
415
416 // Get the attribute metadata sorted by name.
417 const llvm::MapVector<StringRef, AttributeMetadata> &getAttrMetadata() const {
418 return attrMetadata;
419 }
420
421 /// Returns whether to emit a `Properties` struct for this operation or not.
422 bool hasProperties() const {
423 if (!op.getProperties().empty())
424 return true;
425 if (!op.getDialect().usePropertiesForAttributes())
426 return false;
427 return true;
428 }
429
430 /// Returns whether the operation will have a non-empty `Properties` struct.
431 bool hasNonEmptyPropertiesStruct() const {
432 if (!op.getProperties().empty())
433 return true;
434 if (!hasProperties())
435 return false;
436 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments") ||
437 op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments"))
438 return true;
439 return llvm::any_of(Range: getAttrMetadata(),
440 P: [](const std::pair<StringRef, AttributeMetadata> &it) {
441 return !it.second.constraint ||
442 !it.second.constraint->isDerivedAttr();
443 });
444 }
445
446 std::optional<NamedProperty> &getOperandSegmentsSize() {
447 return operandSegmentsSize;
448 }
449
450 std::optional<NamedProperty> &getResultSegmentsSize() {
451 return resultSegmentsSize;
452 }
453
454 uint32_t getOperandSegmentSizesLegacyIndex() {
455 return operandSegmentSizesLegacyIndex;
456 }
457
458 uint32_t getResultSegmentSizesLegacyIndex() {
459 return resultSegmentSizesLegacyIndex;
460 }
461
462private:
463 // Compute the attribute metadata.
464 void computeAttrMetadata();
465
466 // The operation ODS wrapper.
467 const Operator &op;
468 // True if code is being generate for an op. False for an adaptor.
469 const bool emitForOp;
470
471 // The attribute metadata, mapped by name.
472 llvm::MapVector<StringRef, AttributeMetadata> attrMetadata;
473
474 // Property
475 std::optional<NamedProperty> operandSegmentsSize;
476 std::string operandSegmentsSizeStorage;
477 std::string operandSegmentsSizeParser;
478 std::optional<NamedProperty> resultSegmentsSize;
479 std::string resultSegmentsSizeStorage;
480 std::string resultSegmentsSizeParser;
481
482 // Indices to store the position in the emission order of the operand/result
483 // segment sizes attribute if emitted as part of the properties for legacy
484 // bytecode encodings, i.e. versions less than 6.
485 uint32_t operandSegmentSizesLegacyIndex = 0;
486 uint32_t resultSegmentSizesLegacyIndex = 0;
487
488 // The number of required attributes.
489 unsigned numRequired;
490};
491
492} // namespace
493
494void OpOrAdaptorHelper::computeAttrMetadata() {
495 // Enumerate the attribute names of this op, ensuring the attribute names are
496 // unique in case implicit attributes are explicitly registered.
497 for (const NamedAttribute &namedAttr : op.getAttributes()) {
498 Attribute attr = namedAttr.attr;
499 bool isOptional =
500 attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr();
501 attrMetadata.insert(
502 KV: {namedAttr.name, AttributeMetadata{.attrName: namedAttr.name, .isRequired: !isOptional, .constraint: attr}});
503 }
504
505 auto makeProperty = [&](StringRef storageType, StringRef parserCall) {
506 return Property(/*maybeDef=*/nullptr,
507 /*summary=*/"",
508 /*description=*/"",
509 /*storageType=*/storageType,
510 /*interfaceType=*/"::llvm::ArrayRef<int32_t>",
511 /*convertFromStorageCall=*/"$_storage",
512 /*assignToStorageCall=*/
513 "::llvm::copy($_value, $_storage.begin())",
514 /*convertToAttributeCall=*/
515 "return ::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage);",
516 /*convertFromAttributeCall=*/
517 "return convertFromAttribute($_storage, $_attr, $_diag);",
518 /*parserCall=*/parserCall,
519 /*optionalParserCall=*/"",
520 /*printerCall=*/printTextualSegmentSize,
521 /*readFromMlirBytecodeCall=*/readBytecodeSegmentSizeNative,
522 /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSizeNative,
523 /*hashPropertyCall=*/
524 "::llvm::hash_combine_range(std::begin($_storage), "
525 "std::end($_storage));",
526 /*StringRef defaultValue=*/"",
527 /*storageTypeValueOverride=*/"");
528 };
529 // Include key attributes from several traits as implicitly registered.
530 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
531 if (op.getDialect().usePropertiesForAttributes()) {
532 operandSegmentsSizeStorage =
533 llvm::formatv(Fmt: "std::array<int32_t, {0}>", Vals: op.getNumOperands());
534 operandSegmentsSizeParser =
535 llvm::formatv(Fmt: parseTextualSegmentSizeFormat, Vals: op.getNumOperands());
536 operandSegmentsSize = {
537 .name: "operandSegmentSizes",
538 .prop: makeProperty(operandSegmentsSizeStorage, operandSegmentsSizeParser)};
539 } else {
540 attrMetadata.insert(
541 KV: {operandSegmentAttrName, AttributeMetadata{.attrName: operandSegmentAttrName,
542 /*isRequired=*/true,
543 /*attr=*/.constraint: std::nullopt}});
544 }
545 }
546 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments")) {
547 if (op.getDialect().usePropertiesForAttributes()) {
548 resultSegmentsSizeStorage =
549 llvm::formatv(Fmt: "std::array<int32_t, {0}>", Vals: op.getNumResults());
550 resultSegmentsSizeParser =
551 llvm::formatv(Fmt: parseTextualSegmentSizeFormat, Vals: op.getNumResults());
552 resultSegmentsSize = {
553 .name: "resultSegmentSizes",
554 .prop: makeProperty(resultSegmentsSizeStorage, resultSegmentsSizeParser)};
555 } else {
556 attrMetadata.insert(
557 KV: {resultSegmentAttrName,
558 AttributeMetadata{.attrName: resultSegmentAttrName, /*isRequired=*/true,
559 /*attr=*/.constraint: std::nullopt}});
560 }
561 }
562
563 // Store the metadata in sorted order.
564 SmallVector<AttributeMetadata> sortedAttrMetadata =
565 llvm::to_vector(Range: llvm::make_second_range(c: attrMetadata.takeVector()));
566 llvm::sort(C&: sortedAttrMetadata,
567 Comp: [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) {
568 return lhs.attrName < rhs.attrName;
569 });
570
571 // Store the position of the legacy operand_segment_sizes /
572 // result_segment_sizes so we can emit a backward compatible property readers
573 // and writers.
574 StringRef legacyOperandSegmentSizeName =
575 StringLiteral("operand_segment_sizes");
576 StringRef legacyResultSegmentSizeName = StringLiteral("result_segment_sizes");
577 operandSegmentSizesLegacyIndex = 0;
578 resultSegmentSizesLegacyIndex = 0;
579 for (auto item : sortedAttrMetadata) {
580 if (item.attrName < legacyOperandSegmentSizeName)
581 ++operandSegmentSizesLegacyIndex;
582 if (item.attrName < legacyResultSegmentSizeName)
583 ++resultSegmentSizesLegacyIndex;
584 }
585
586 // Compute the subrange bounds for each attribute.
587 numRequired = 0;
588 for (AttributeMetadata &attr : sortedAttrMetadata) {
589 attr.lowerBound = numRequired;
590 numRequired += attr.isRequired;
591 };
592 for (AttributeMetadata &attr : sortedAttrMetadata)
593 attr.upperBound = numRequired - attr.lowerBound - attr.isRequired;
594
595 // Store the results back into the map.
596 for (const AttributeMetadata &attr : sortedAttrMetadata)
597 attrMetadata.insert(KV: {attr.attrName, attr});
598}
599
600//===----------------------------------------------------------------------===//
601// Op emitter
602//===----------------------------------------------------------------------===//
603
604namespace {
605// Helper class to emit a record into the given output stream.
606class OpEmitter {
607 using ConstArgument =
608 llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
609
610public:
611 static void
612 emitDecl(const Operator &op, raw_ostream &os,
613 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
614 static void
615 emitDef(const Operator &op, raw_ostream &os,
616 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
617
618private:
619 OpEmitter(const Operator &op,
620 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
621
622 void emitDecl(raw_ostream &os);
623 void emitDef(raw_ostream &os);
624
625 // Generate methods for accessing the attribute names of this operation.
626 void genAttrNameGetters();
627
628 // Generates the OpAsmOpInterface for this operation if possible.
629 void genOpAsmInterface();
630
631 // Generates the `getOperationName` method for this op.
632 void genOpNameGetter();
633
634 // Generates code to manage the properties, if any!
635 void genPropertiesSupport();
636
637 // Generates code to manage the encoding of properties to bytecode.
638 void
639 genPropertiesSupportForBytecode(ArrayRef<ConstArgument> attrOrProperties);
640
641 // Generates getters for the properties.
642 void genPropGetters();
643
644 // Generates seters for the properties.
645 void genPropSetters();
646
647 // Generates getters for the attributes.
648 void genAttrGetters();
649
650 // Generates setter for the attributes.
651 void genAttrSetters();
652
653 // Generates removers for optional attributes.
654 void genOptionalAttrRemovers();
655
656 // Generates getters for named operands.
657 void genNamedOperandGetters();
658
659 // Generates setters for named operands.
660 void genNamedOperandSetters();
661
662 // Generates getters for named results.
663 void genNamedResultGetters();
664
665 // Generates getters for named regions.
666 void genNamedRegionGetters();
667
668 // Generates getters for named successors.
669 void genNamedSuccessorGetters();
670
671 // Generates the method to populate default attributes.
672 void genPopulateDefaultAttributes();
673
674 // Generates builder methods for the operation.
675 void genBuilder();
676
677 // Generates the build() method that takes each operand/attribute
678 // as a stand-alone parameter.
679 void genSeparateArgParamBuilder();
680 void genInlineCreateBody(const SmallVector<MethodParameter> &paramList);
681
682 // Generates the build() method that takes each operand/attribute as a
683 // stand-alone parameter. The generated build() method uses first operand's
684 // type as all results' types.
685 void genUseOperandAsResultTypeSeparateParamBuilder();
686
687 // The kind of collective builder to generate
688 enum class CollectiveBuilderKind {
689 PropStruct, // Inherent attributes/properties are passed by `const
690 // Properties&`
691 AttrDict, // Inherent attributes/properties are passed by attribute
692 // dictionary
693 };
694
695 // Generates the build() method that takes all operands/attributes
696 // collectively as one parameter. The generated build() method uses first
697 // operand's type as all results' types.
698 void
699 genUseOperandAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind kind);
700
701 // Generates the build() method that takes aggregate operands/attributes
702 // parameters. This build() method uses inferred types as result types.
703 // Requires: The type needs to be inferable via InferTypeOpInterface.
704 void genInferredTypeCollectiveParamBuilder(CollectiveBuilderKind kind);
705
706 // Generates the build() method that takesaggregate operands/attributes as
707 // parameters. The generated build() method uses first attribute's
708 // type as all result's types.
709 void genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind kind);
710
711 // Generates the build() method that takes all result types collectively as
712 // one parameter. Similarly for operands and attributes.
713 void genCollectiveParamBuilder(CollectiveBuilderKind kind);
714
715 // The kind of parameter to generate for result types in builders.
716 enum class TypeParamKind {
717 None, // No result type in parameter list.
718 Separate, // A separate parameter for each result type.
719 Collective, // An ArrayRef<Type> for all result types.
720 };
721
722 // The kind of parameter to generate for attributes in builders.
723 enum class AttrParamKind {
724 WrappedAttr, // A wrapped MLIR Attribute instance.
725 UnwrappedValue, // A raw value without MLIR Attribute wrapper.
726 };
727
728 // Builds the parameter list for build() method of this op. This method writes
729 // to `paramList` the comma-separated parameter list and updates
730 // `resultTypeNames` with the names for parameters for specifying result
731 // types. `inferredAttributes` is populated with any attributes that are
732 // elided from the build list. The given `typeParamKind` and `attrParamKind`
733 // controls how result types and attributes are placed in the parameter list.
734 void buildParamList(SmallVectorImpl<MethodParameter> &paramList,
735 llvm::StringSet<> &inferredAttributes,
736 SmallVectorImpl<std::string> &resultTypeNames,
737 TypeParamKind typeParamKind,
738 AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
739
740 // Adds op arguments and regions into operation state for build() methods.
741 void
742 genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
743 llvm::StringSet<> &inferredAttributes,
744 bool isRawValueAttr = false);
745
746 // Generates canonicalizer declaration for the operation.
747 void genCanonicalizerDecls();
748
749 // Generates the folder declaration for the operation.
750 void genFolderDecls();
751
752 // Generates the parser for the operation.
753 void genParser();
754
755 // Generates the printer for the operation.
756 void genPrinter();
757
758 // Generates verify method for the operation.
759 void genVerifier();
760
761 // Generates custom verify methods for the operation.
762 void genCustomVerifier();
763
764 // Generates verify statements for operands and results in the operation.
765 // The generated code will be attached to `body`.
766 void genOperandResultVerifier(MethodBody &body,
767 Operator::const_value_range values,
768 StringRef valueKind);
769
770 // Generates verify statements for regions in the operation.
771 // The generated code will be attached to `body`.
772 void genRegionVerifier(MethodBody &body);
773
774 // Generates verify statements for successors in the operation.
775 // The generated code will be attached to `body`.
776 void genSuccessorVerifier(MethodBody &body);
777
778 // Generates the traits used by the object.
779 void genTraits();
780
781 // Generate the OpInterface methods for all interfaces.
782 void genOpInterfaceMethods();
783
784 // Generate op interface methods for the given interface.
785 void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
786
787 // Generate op interface method for the given interface method. If
788 // 'declaration' is true, generates a declaration, else a definition.
789 Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
790 bool declaration = true);
791
792 // Generate the side effect interface methods.
793 void genSideEffectInterfaceMethods();
794
795 // Generate the type inference interface methods.
796 void genTypeInterfaceMethods();
797
798private:
799 // The TableGen record for this op.
800 // TODO: OpEmitter should not have a Record directly,
801 // it should rather go through the Operator for better abstraction.
802 const Record &def;
803
804 // The wrapper operator class for querying information from this op.
805 const Operator &op;
806
807 // The C++ code builder for this op
808 OpClass opClass;
809
810 // The format context for verification code generation.
811 FmtContext verifyCtx;
812
813 // The emitter containing all of the locally emitted verification functions.
814 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
815
816 // Helper for emitting op code.
817 OpOrAdaptorHelper emitHelper;
818};
819
820} // namespace
821
822// Populate the format context `ctx` with substitutions of attributes, operands
823// and results.
824static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
825 FmtContext &ctx) {
826 // Populate substitutions for attributes.
827 auto &op = emitHelper.getOp();
828 for (const auto &namedAttr : op.getAttributes())
829 ctx.addSubst(placeholder: namedAttr.name,
830 subst: emitHelper.getOp().getGetterName(name: namedAttr.name) + "()");
831
832 // Populate substitutions for named operands.
833 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
834 auto &value = op.getOperand(index: i);
835 if (!value.name.empty())
836 ctx.addSubst(placeholder: value.name, subst: emitHelper.getOperand(index: i).str());
837 }
838
839 // Populate substitutions for results.
840 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
841 auto &value = op.getResult(index: i);
842 if (!value.name.empty())
843 ctx.addSubst(placeholder: value.name, subst: emitHelper.getResult(index: i).str());
844 }
845}
846
847/// Generate verification on native traits requiring attributes.
848static void genNativeTraitAttrVerifier(MethodBody &body,
849 const OpOrAdaptorHelper &emitHelper) {
850 // Check that the variadic segment sizes attribute exists and contains the
851 // expected number of elements.
852 //
853 // {0}: Attribute name.
854 // {1}: Expected number of elements.
855 // {2}: "operand" or "result".
856 // {3}: Emit error prefix.
857 const char *const checkAttrSizedValueSegmentsCode = R"(
858 {
859 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>(tblgen_{0});
860 auto numElements = sizeAttr.asArrayRef().size();
861 if (numElements != {1})
862 return {3}"'{0}' attribute for specifying {2} segments must have {1} "
863 "elements, but got ") << numElements;
864 }
865 )";
866
867 // Verify a few traits first so that we can use getODSOperands() and
868 // getODSResults() in the rest of the verifier.
869 auto &op = emitHelper.getOp();
870 if (!op.getDialect().usePropertiesForAttributes()) {
871 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
872 body << formatv(Fmt: checkAttrSizedValueSegmentsCode, Vals: operandSegmentAttrName,
873 Vals: op.getNumOperands(), Vals: "operand",
874 Vals: emitHelper.emitErrorPrefix());
875 }
876 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments")) {
877 body << formatv(Fmt: checkAttrSizedValueSegmentsCode, Vals: resultSegmentAttrName,
878 Vals: op.getNumResults(), Vals: "result",
879 Vals: emitHelper.emitErrorPrefix());
880 }
881 }
882}
883
884// Return true if a verifier can be emitted for the attribute: it is not a
885// derived attribute, it has a predicate, its condition is not empty, and, for
886// adaptors, the condition does not reference the op.
887static bool canEmitAttrVerifier(Attribute attr, bool isEmittingForOp) {
888 if (attr.isDerivedAttr())
889 return false;
890 Pred pred = attr.getPredicate();
891 if (pred.isNull())
892 return false;
893 std::string condition = pred.getCondition();
894 return !condition.empty() &&
895 (!StringRef(condition).contains(Other: "$_op") || isEmittingForOp);
896}
897
898// Generate attribute verification. If an op instance is not available, then
899// attribute checks that require one will not be emitted.
900//
901// Attribute verification is performed as follows:
902//
903// 1. Verify that all required attributes are present in sorted order. This
904// ensures that we can use subrange lookup even with potentially missing
905// attributes.
906// 2. Verify native trait attributes so that other attributes may call methods
907// that depend on the validity of these attributes, e.g. segment size attributes
908// and operand or result getters.
909// 3. Verify the constraints on all present attributes.
910static void
911genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx,
912 MethodBody &body,
913 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
914 bool useProperties) {
915 if (emitHelper.getAttrMetadata().empty())
916 return;
917
918 // Verify the attribute if it is present. This assumes that default values
919 // are valid. This code snippet pastes the condition inline.
920 //
921 // TODO: verify the default value is valid (perhaps in debug mode only).
922 //
923 // {0}: Attribute variable name.
924 // {1}: Attribute condition code.
925 // {2}: Emit error prefix.
926 // {3}: Attribute name.
927 // {4}: Attribute/constraint description.
928 const char *const verifyAttrInline = R"(
929 if ({0} && !({1}))
930 return {2}"attribute '{3}' failed to satisfy constraint: {4}");
931)";
932 // Verify the attribute using a uniqued constraint. Can only be used within
933 // the context of an op.
934 //
935 // {0}: Unique constraint name.
936 // {1}: Attribute variable name.
937 // {2}: Attribute name.
938 const char *const verifyAttrUnique = R"(
939 if (::mlir::failed({0}(*this, {1}, "{2}")))
940 return ::mlir::failure();
941)";
942
943 // Traverse the array until the required attribute is found. Return an error
944 // if the traversal reached the end.
945 //
946 // {0}: Code to get the name of the attribute.
947 // {1}: The emit error prefix.
948 // {2}: The name of the attribute.
949 const char *const findRequiredAttr = R"(
950while (true) {{
951 if (namedAttrIt == namedAttrRange.end())
952 return {1}"requires attribute '{2}'");
953 if (namedAttrIt->getName() == {0}) {{
954 tblgen_{2} = namedAttrIt->getValue();
955 break;
956 })";
957
958 // Emit a check to see if the iteration has encountered an optional attribute.
959 //
960 // {0}: Code to get the name of the attribute.
961 // {1}: The name of the attribute.
962 const char *const checkOptionalAttr = R"(
963 else if (namedAttrIt->getName() == {0}) {{
964 tblgen_{1} = namedAttrIt->getValue();
965 })";
966
967 // Emit the start of the loop for checking trailing attributes.
968 const char *const checkTrailingAttrs = R"(while (true) {
969 if (namedAttrIt == namedAttrRange.end()) {
970 break;
971 })";
972
973 // Emit the verifier for the attribute.
974 const auto emitVerifier = [&](Attribute attr, StringRef attrName,
975 StringRef varName) {
976 std::string condition = attr.getPredicate().getCondition();
977
978 std::optional<StringRef> constraintFn;
979 if (emitHelper.isEmittingForOp() &&
980 (constraintFn = staticVerifierEmitter.getAttrConstraintFn(constraint: attr))) {
981 body << formatv(Fmt: verifyAttrUnique, Vals&: *constraintFn, Vals&: varName, Vals&: attrName);
982 } else {
983 body << formatv(Fmt: verifyAttrInline, Vals&: varName,
984 Vals: tgfmt(fmt: condition, ctx: &ctx.withSelf(subst: varName)),
985 Vals: emitHelper.emitErrorPrefix(), Vals&: attrName,
986 Vals: escapeString(value: attr.getSummary()));
987 }
988 };
989
990 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
991 const auto getVarName = [&](StringRef attrName) {
992 return (tblgenNamePrefix + attrName).str();
993 };
994
995 body.indent();
996 if (useProperties) {
997 for (const std::pair<StringRef, AttributeMetadata> &it :
998 emitHelper.getAttrMetadata()) {
999 const AttributeMetadata &metadata = it.second;
1000 if (metadata.constraint && metadata.constraint->isDerivedAttr())
1001 continue;
1002 body << formatv(
1003 Fmt: "auto tblgen_{0} = getProperties().{0}; (void)tblgen_{0};\n",
1004 Vals: it.first);
1005 if (metadata.isRequired)
1006 body << formatv(
1007 Fmt: "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n",
1008 Vals: it.first, Vals: emitHelper.emitErrorPrefix());
1009 }
1010 } else {
1011 body << formatv(Fmt: "auto namedAttrRange = {0};\n", Vals: emitHelper.getAttrRange());
1012 body << "auto namedAttrIt = namedAttrRange.begin();\n";
1013
1014 // Iterate over the attributes in sorted order. Keep track of the optional
1015 // attributes that may be encountered along the way.
1016 SmallVector<const AttributeMetadata *> optionalAttrs;
1017
1018 for (const std::pair<StringRef, AttributeMetadata> &it :
1019 emitHelper.getAttrMetadata()) {
1020 const AttributeMetadata &metadata = it.second;
1021 if (!metadata.isRequired) {
1022 optionalAttrs.push_back(Elt: &metadata);
1023 continue;
1024 }
1025
1026 body << formatv(Fmt: "::mlir::Attribute {0};\n", Vals: getVarName(it.first));
1027 for (const AttributeMetadata *optional : optionalAttrs) {
1028 body << formatv(Fmt: "::mlir::Attribute {0};\n",
1029 Vals: getVarName(optional->attrName));
1030 }
1031 body << formatv(Fmt: findRequiredAttr, Vals: emitHelper.getAttrName(attrName: it.first),
1032 Vals: emitHelper.emitErrorPrefix(), Vals: it.first);
1033 for (const AttributeMetadata *optional : optionalAttrs) {
1034 body << formatv(Fmt: checkOptionalAttr,
1035 Vals: emitHelper.getAttrName(attrName: optional->attrName),
1036 Vals: optional->attrName);
1037 }
1038 body << "\n ++namedAttrIt;\n}\n";
1039 optionalAttrs.clear();
1040 }
1041 // Get trailing optional attributes.
1042 if (!optionalAttrs.empty()) {
1043 for (const AttributeMetadata *optional : optionalAttrs) {
1044 body << formatv(Fmt: "::mlir::Attribute {0};\n",
1045 Vals: getVarName(optional->attrName));
1046 }
1047 body << checkTrailingAttrs;
1048 for (const AttributeMetadata *optional : optionalAttrs) {
1049 body << formatv(Fmt: checkOptionalAttr,
1050 Vals: emitHelper.getAttrName(attrName: optional->attrName),
1051 Vals: optional->attrName);
1052 }
1053 body << "\n ++namedAttrIt;\n}\n";
1054 }
1055 }
1056 body.unindent();
1057
1058 // Emit the checks for segment attributes first so that the other
1059 // constraints can call operand and result getters.
1060 genNativeTraitAttrVerifier(body, emitHelper);
1061
1062 bool isEmittingForOp = emitHelper.isEmittingForOp();
1063 for (const auto &namedAttr : emitHelper.getOp().getAttributes())
1064 if (canEmitAttrVerifier(attr: namedAttr.attr, isEmittingForOp))
1065 emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
1066}
1067
1068static void genPropertyVerifier(
1069 const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body,
1070 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
1071
1072 // Code to get a reference to a property into a variable to avoid multiple
1073 // evaluations while verifying a property.
1074 // {0}: Property variable name.
1075 // {1}: Property name, with the first letter capitalized, to find the getter.
1076 // {2}: Property interface type.
1077 const char *const fetchProperty = R"(
1078 [[maybe_unused]] {2} {0} = this->get{1}();
1079)";
1080
1081 // Code to verify that the predicate of a property holds. Embeds the
1082 // condition inline.
1083 // {0}: Property condition code, with tgfmt() applied.
1084 // {1}: Emit error prefix.
1085 // {2}: Property name.
1086 // {3}: Property description.
1087 const char *const verifyPropertyInline = R"(
1088 if (!({0}))
1089 return {1}"property '{2}' failed to satisfy constraint: {3}");
1090)";
1091
1092 // Verify the property using a uniqued constraint. Can only be used
1093 // within the context of an op.
1094 //
1095 // {0}: Unique constraint name.
1096 // {1}: Property variable name in interface type.
1097 // {2}: Property name.
1098 const char *const verifyPropertyUniqued = R"(
1099 if (::mlir::failed({0}(*this, {1}, "{2}")))
1100 return ::mlir::failure();
1101)";
1102
1103 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
1104 const auto getVarName = [&](const NamedProperty &prop) {
1105 std::string varName =
1106 convertToCamelFromSnakeCase(input: prop.name, /*capitalizeFirst=*/false);
1107 return (tblgenNamePrefix + Twine(varName)).str();
1108 };
1109
1110 for (const NamedProperty &prop : emitHelper.getOp().getProperties()) {
1111 Pred predicate = prop.prop.getPredicate();
1112 // Null predicate, nothing to verify.
1113 if (predicate == Pred())
1114 continue;
1115
1116 std::string rawCondition = predicate.getCondition();
1117 if (rawCondition == "true")
1118 continue;
1119 bool needsOp = StringRef(rawCondition).contains(Other: "$_op");
1120 if (needsOp && !emitHelper.isEmittingForOp())
1121 continue;
1122
1123 auto scope = body.scope(open: "{\n", close: "}\n", /*indent=*/true);
1124 std::string varName = getVarName(prop);
1125 std::string getterName =
1126 convertToCamelFromSnakeCase(input: prop.name, /*capitalizeFirst=*/true);
1127 body << formatv(Fmt: fetchProperty, Vals&: varName, Vals&: getterName,
1128 Vals: prop.prop.getInterfaceType());
1129 auto uniquedFn = staticVerifierEmitter.getPropConstraintFn(constraint: prop.prop);
1130 if (uniquedFn.has_value())
1131 body << formatv(Fmt: verifyPropertyUniqued, Vals&: *uniquedFn, Vals&: varName, Vals: prop.name);
1132 else
1133 body << formatv(
1134 Fmt: verifyPropertyInline, Vals: tgfmt(fmt: rawCondition, ctx: &ctx.withSelf(subst: varName)),
1135 Vals: emitHelper.emitErrorPrefix(), Vals: prop.name, Vals: prop.prop.getSummary());
1136 }
1137}
1138
1139/// Include declarations specified on NativeTrait
1140static std::string formatExtraDeclarations(const Operator &op) {
1141 SmallVector<StringRef> extraDeclarations;
1142 // Include extra class declarations from NativeTrait
1143 for (const auto &trait : op.getTraits()) {
1144 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
1145 StringRef value = opTrait->getExtraConcreteClassDeclaration();
1146 if (value.empty())
1147 continue;
1148 extraDeclarations.push_back(Elt: value);
1149 }
1150 }
1151 extraDeclarations.push_back(Elt: op.getExtraClassDeclaration());
1152 return llvm::join(R&: extraDeclarations, Separator: "\n");
1153}
1154
1155/// Op extra class definitions have a `$cppClass` substitution that is to be
1156/// replaced by the C++ class name.
1157/// Include declarations specified on NativeTrait
1158static std::string formatExtraDefinitions(const Operator &op) {
1159 SmallVector<StringRef> extraDefinitions;
1160 // Include extra class definitions from NativeTrait
1161 for (const auto &trait : op.getTraits()) {
1162 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
1163 StringRef value = opTrait->getExtraConcreteClassDefinition();
1164 if (value.empty())
1165 continue;
1166 extraDefinitions.push_back(Elt: value);
1167 }
1168 }
1169 extraDefinitions.push_back(Elt: op.getExtraClassDefinition());
1170 FmtContext ctx = FmtContext().addSubst(placeholder: "cppClass", subst: op.getCppClassName());
1171 return tgfmt(fmt: llvm::join(R&: extraDefinitions, Separator: "\n"), ctx: &ctx).str();
1172}
1173
1174OpEmitter::OpEmitter(const Operator &op,
1175 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
1176 : def(op.getDef()), op(op),
1177 opClass(op.getCppClassName(), formatExtraDeclarations(op),
1178 formatExtraDefinitions(op)),
1179 staticVerifierEmitter(staticVerifierEmitter),
1180 emitHelper(op, /*emitForOp=*/true) {
1181 verifyCtx.addSubst(placeholder: "_op", subst: "(*this->getOperation())");
1182 verifyCtx.addSubst(placeholder: "_ctxt", subst: "this->getOperation()->getContext()");
1183
1184 genTraits();
1185
1186 // Generate C++ code for various op methods. The order here determines the
1187 // methods in the generated file.
1188 genAttrNameGetters();
1189 genOpAsmInterface();
1190 genOpNameGetter();
1191 genNamedOperandGetters();
1192 genNamedOperandSetters();
1193 genNamedResultGetters();
1194 genNamedRegionGetters();
1195 genNamedSuccessorGetters();
1196 genPropertiesSupport();
1197 genPropGetters();
1198 genPropSetters();
1199 genAttrGetters();
1200 genAttrSetters();
1201 genOptionalAttrRemovers();
1202 genBuilder();
1203 genPopulateDefaultAttributes();
1204 genParser();
1205 genPrinter();
1206 genVerifier();
1207 genCustomVerifier();
1208 genCanonicalizerDecls();
1209 genFolderDecls();
1210 genTypeInterfaceMethods();
1211 genOpInterfaceMethods();
1212 generateOpFormat(constOp: op, opClass, hasProperties: emitHelper.hasProperties());
1213 genSideEffectInterfaceMethods();
1214}
1215void OpEmitter::emitDecl(
1216 const Operator &op, raw_ostream &os,
1217 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
1218 OpEmitter(op, staticVerifierEmitter).emitDecl(os);
1219}
1220
1221void OpEmitter::emitDef(
1222 const Operator &op, raw_ostream &os,
1223 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
1224 OpEmitter(op, staticVerifierEmitter).emitDef(os);
1225}
1226
1227void OpEmitter::emitDecl(raw_ostream &os) {
1228 opClass.finalize();
1229 opClass.writeDeclTo(rawOs&: os);
1230}
1231
1232void OpEmitter::emitDef(raw_ostream &os) {
1233 opClass.finalize();
1234 opClass.writeDefTo(rawOs&: os);
1235}
1236
1237static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
1238 const Operator &op) {
1239 if (m)
1240 return;
1241 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "Unexpected overlap when generating `" +
1242 methodName + "` for " +
1243 op.getOperationName() + " (from line " +
1244 Twine(line) + ")");
1245}
1246
1247#define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
1248
1249void OpEmitter::genAttrNameGetters() {
1250 const llvm::MapVector<StringRef, AttributeMetadata> &attributes =
1251 emitHelper.getAttrMetadata();
1252 bool hasOperandSegmentsSize =
1253 op.getDialect().usePropertiesForAttributes() &&
1254 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments");
1255 // Emit the getAttributeNames method.
1256 {
1257 auto *method = opClass.addStaticInlineMethod(
1258 retType: "::llvm::ArrayRef<::llvm::StringRef>", name: "getAttributeNames");
1259 ERROR_IF_PRUNED(method, "getAttributeNames", op);
1260 auto &body = method->body();
1261 if (!hasOperandSegmentsSize && attributes.empty()) {
1262 body << " return {};";
1263 // Nothing else to do if there are no registered attributes. Exit early.
1264 return;
1265 }
1266 body << " static ::llvm::StringRef attrNames[] = {";
1267 llvm::interleaveComma(c: llvm::make_first_range(c: attributes), os&: body,
1268 each_fn: [&](StringRef attrName) {
1269 body << "::llvm::StringRef(\"" << attrName << "\")";
1270 });
1271 if (hasOperandSegmentsSize) {
1272 if (!attributes.empty())
1273 body << ", ";
1274 body << "::llvm::StringRef(\"" << operandSegmentAttrName << "\")";
1275 }
1276 body << "};\n return ::llvm::ArrayRef(attrNames);";
1277 }
1278
1279 // Emit the getAttributeNameForIndex methods.
1280 {
1281 auto *method = opClass.addInlineMethod<Method::Private>(
1282 retType: "::mlir::StringAttr", name: "getAttributeNameForIndex",
1283 args: MethodParameter("unsigned", "index"));
1284 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
1285 method->body()
1286 << " return getAttributeNameForIndex((*this)->getName(), index);";
1287 }
1288 {
1289 auto *method = opClass.addStaticInlineMethod<Method::Private>(
1290 retType: "::mlir::StringAttr", name: "getAttributeNameForIndex",
1291 args: MethodParameter("::mlir::OperationName", "name"),
1292 args: MethodParameter("unsigned", "index"));
1293 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
1294
1295 if (attributes.empty()) {
1296 method->body() << " return {};";
1297 } else {
1298 const char *const getAttrName = R"(
1299 assert(index < {0} && "invalid attribute index");
1300 assert(name.getStringRef() == getOperationName() && "invalid operation name");
1301 assert(name.isRegistered() && "Operation isn't registered, missing a "
1302 "dependent dialect loading?");
1303 return name.getAttributeNames()[index];
1304)";
1305 method->body() << formatv(Fmt: getAttrName, Vals: attributes.size());
1306 }
1307 }
1308
1309 // Generate the <attr>AttrName methods, that expose the attribute names to
1310 // users.
1311 const char *attrNameMethodBody = " return getAttributeNameForIndex({0});";
1312 for (auto [index, attr] :
1313 llvm::enumerate(First: llvm::make_first_range(c: attributes))) {
1314 std::string name = op.getGetterName(name: attr);
1315 std::string methodName = name + "AttrName";
1316
1317 // Generate the non-static variant.
1318 {
1319 auto *method = opClass.addInlineMethod(retType: "::mlir::StringAttr", name&: methodName);
1320 ERROR_IF_PRUNED(method, methodName, op);
1321 method->body() << llvm::formatv(Fmt: attrNameMethodBody, Vals&: index);
1322 }
1323
1324 // Generate the static variant.
1325 {
1326 auto *method = opClass.addStaticInlineMethod(
1327 retType: "::mlir::StringAttr", name&: methodName,
1328 args: MethodParameter("::mlir::OperationName", "name"));
1329 ERROR_IF_PRUNED(method, methodName, op);
1330 method->body() << llvm::formatv(Fmt: attrNameMethodBody,
1331 Vals: "name, " + Twine(index));
1332 }
1333 }
1334 if (hasOperandSegmentsSize) {
1335 std::string name = op.getGetterName(name: operandSegmentAttrName);
1336 std::string methodName = name + "AttrName";
1337 // Generate the non-static variant.
1338 {
1339 auto *method = opClass.addInlineMethod(retType: "::mlir::StringAttr", name&: methodName);
1340 ERROR_IF_PRUNED(method, methodName, op);
1341 method->body()
1342 << " return (*this)->getName().getAttributeNames().back();";
1343 }
1344
1345 // Generate the static variant.
1346 {
1347 auto *method = opClass.addStaticInlineMethod(
1348 retType: "::mlir::StringAttr", name&: methodName,
1349 args: MethodParameter("::mlir::OperationName", "name"));
1350 ERROR_IF_PRUNED(method, methodName, op);
1351 method->body() << " return name.getAttributeNames().back();";
1352 }
1353 }
1354}
1355
1356// Emit the getter for a named property.
1357// It is templated to be shared between the Op and the adaptor class.
1358template <typename OpClassOrAdaptor>
1359static void emitPropGetter(OpClassOrAdaptor &opClass, const Operator &op,
1360 StringRef name, const Property &prop) {
1361 auto *method = opClass.addInlineMethod(prop.getInterfaceType(), name);
1362 ERROR_IF_PRUNED(method, name, op);
1363 method->body() << formatv(Fmt: " return getProperties().{0}();", Vals&: name);
1364}
1365
1366// Emit the getter for an attribute with the return type specified.
1367// It is templated to be shared between the Op and the adaptor class.
1368template <typename OpClassOrAdaptor>
1369static void emitAttrGetterWithReturnType(FmtContext &fctx,
1370 OpClassOrAdaptor &opClass,
1371 const Operator &op, StringRef name,
1372 Attribute attr) {
1373 auto *method = opClass.addMethod(attr.getReturnType(), name);
1374 ERROR_IF_PRUNED(method, name, op);
1375 auto &body = method->body();
1376 body << " auto attr = " << name << "Attr();\n";
1377 if (attr.hasDefaultValue() && attr.isOptional()) {
1378 // Returns the default value if not set.
1379 // TODO: this is inefficient, we are recreating the attribute for every
1380 // call. This should be set instead.
1381 if (!attr.isConstBuildable()) {
1382 PrintFatalError(Msg: "DefaultValuedAttr of type " + attr.getAttrDefName() +
1383 " must have a constBuilder");
1384 }
1385 std::string defaultValue =
1386 std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx,
1387 vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx)));
1388 body << " if (!attr)\n return "
1389 << tgfmt(fmt: attr.getConvertFromStorageCall(),
1390 ctx: &fctx.withSelf(subst: defaultValue))
1391 << ";\n";
1392 }
1393 body << " return "
1394 << tgfmt(fmt: attr.getConvertFromStorageCall(), ctx: &fctx.withSelf(subst: "attr"))
1395 << ";\n";
1396}
1397
1398void OpEmitter::genPropertiesSupport() {
1399 if (!emitHelper.hasProperties())
1400 return;
1401
1402 SmallVector<ConstArgument> attrOrProperties;
1403 for (const std::pair<StringRef, AttributeMetadata> &it :
1404 emitHelper.getAttrMetadata()) {
1405 if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
1406 attrOrProperties.push_back(Elt: &it.second);
1407 }
1408 for (const NamedProperty &prop : op.getProperties())
1409 attrOrProperties.push_back(Elt: &prop);
1410 if (emitHelper.getOperandSegmentsSize())
1411 attrOrProperties.push_back(Elt: &emitHelper.getOperandSegmentsSize().value());
1412 if (emitHelper.getResultSegmentsSize())
1413 attrOrProperties.push_back(Elt: &emitHelper.getResultSegmentsSize().value());
1414 auto &setPropMethod =
1415 opClass
1416 .addStaticMethod(
1417 retType: "::llvm::LogicalResult", name: "setPropertiesFromAttr",
1418 args: MethodParameter("Properties &", "prop"),
1419 args: MethodParameter("::mlir::Attribute", "attr"),
1420 args: MethodParameter(
1421 "::llvm::function_ref<::mlir::InFlightDiagnostic()>",
1422 "emitError"))
1423 ->body();
1424 auto &getPropMethod =
1425 opClass
1426 .addStaticMethod(retType: "::mlir::Attribute", name: "getPropertiesAsAttr",
1427 args: MethodParameter("::mlir::MLIRContext *", "ctx"),
1428 args: MethodParameter("const Properties &", "prop"))
1429 ->body();
1430 auto &hashMethod =
1431 opClass
1432 .addStaticMethod(retType: "llvm::hash_code", name: "computePropertiesHash",
1433 args: MethodParameter("const Properties &", "prop"))
1434 ->body();
1435 auto &getInherentAttrMethod =
1436 opClass
1437 .addStaticMethod(retType: "std::optional<mlir::Attribute>", name: "getInherentAttr",
1438 args: MethodParameter("::mlir::MLIRContext *", "ctx"),
1439 args: MethodParameter("const Properties &", "prop"),
1440 args: MethodParameter("llvm::StringRef", "name"))
1441 ->body();
1442 auto &setInherentAttrMethod =
1443 opClass
1444 .addStaticMethod(retType: "void", name: "setInherentAttr",
1445 args: MethodParameter("Properties &", "prop"),
1446 args: MethodParameter("llvm::StringRef", "name"),
1447 args: MethodParameter("mlir::Attribute", "value"))
1448 ->body();
1449 auto &populateInherentAttrsMethod =
1450 opClass
1451 .addStaticMethod(retType: "void", name: "populateInherentAttrs",
1452 args: MethodParameter("::mlir::MLIRContext *", "ctx"),
1453 args: MethodParameter("const Properties &", "prop"),
1454 args: MethodParameter("::mlir::NamedAttrList &", "attrs"))
1455 ->body();
1456 auto &verifyInherentAttrsMethod =
1457 opClass
1458 .addStaticMethod(
1459 retType: "::llvm::LogicalResult", name: "verifyInherentAttrs",
1460 args: MethodParameter("::mlir::OperationName", "opName"),
1461 args: MethodParameter("::mlir::NamedAttrList &", "attrs"),
1462 args: MethodParameter(
1463 "llvm::function_ref<::mlir::InFlightDiagnostic()>",
1464 "emitError"))
1465 ->body();
1466
1467 opClass.declare<UsingDeclaration>(args: "Properties", args: "FoldAdaptor::Properties");
1468
1469 // Convert the property to the attribute form.
1470
1471 setPropMethod << R"decl(
1472 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
1473 if (!dict) {
1474 emitError() << "expected DictionaryAttr to set properties";
1475 return ::mlir::failure();
1476 }
1477 )decl";
1478 const char *propFromAttrFmt = R"decl(
1479 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
1480 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{
1481 {0}
1482 };
1483 {1};
1484)decl";
1485 const char *attrGetNoDefaultFmt = R"decl(;
1486 if (attr && ::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
1487 return ::mlir::failure();
1488)decl";
1489 const char *attrGetDefaultFmt = R"decl(;
1490 if (attr) {{
1491 if (::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
1492 return ::mlir::failure();
1493 } else {{
1494 prop.{0} = {1};
1495 }
1496)decl";
1497
1498 for (const auto &attrOrProp : attrOrProperties) {
1499 if (const auto *namedProperty =
1500 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
1501 StringRef name = namedProperty->name;
1502 auto &prop = namedProperty->prop;
1503 FmtContext fctx;
1504
1505 std::string getAttr;
1506 llvm::raw_string_ostream os(getAttr);
1507 os << " auto attr = dict.get(\"" << name << "\");";
1508 if (name == operandSegmentAttrName) {
1509 // Backward compat for now, TODO: Remove at some point.
1510 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1511 }
1512 if (name == resultSegmentAttrName) {
1513 // Backward compat for now, TODO: Remove at some point.
1514 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1515 }
1516
1517 fctx.withBuilder(subst: odsBuilder);
1518 setPropMethod << "{\n"
1519 << formatv(Fmt: propFromAttrFmt,
1520 Vals: tgfmt(fmt: prop.getConvertFromAttributeCall(),
1521 ctx: &fctx.addSubst(placeholder: "_attr", subst: propertyAttr)
1522 .addSubst(placeholder: "_storage", subst: propertyStorage)
1523 .addSubst(placeholder: "_diag", subst: propertyDiag)),
1524 Vals&: getAttr);
1525 if (prop.hasStorageTypeValueOverride()) {
1526 setPropMethod << formatv(Fmt: attrGetDefaultFmt, Vals&: name,
1527 Vals: prop.getStorageTypeValueOverride());
1528 } else if (prop.hasDefaultValue()) {
1529 setPropMethod << formatv(Fmt: attrGetDefaultFmt, Vals&: name,
1530 Vals: tgfmt(fmt: prop.getDefaultValue(), ctx: &fctx));
1531 } else {
1532 setPropMethod << formatv(Fmt: attrGetNoDefaultFmt, Vals&: name);
1533 }
1534 setPropMethod << " }\n";
1535 } else {
1536 const auto *namedAttr =
1537 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
1538 StringRef name = namedAttr->attrName;
1539 std::string getAttr;
1540 llvm::raw_string_ostream os(getAttr);
1541 os << " auto attr = dict.get(\"" << name << "\");";
1542 if (name == operandSegmentAttrName) {
1543 // Backward compat for now
1544 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1545 }
1546 if (name == resultSegmentAttrName) {
1547 // Backward compat for now
1548 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1549 }
1550
1551 setPropMethod << formatv(Fmt: R"decl(
1552 {{
1553 auto &propStorage = prop.{0};
1554 {1}
1555 if (attr) {{
1556 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
1557 if (convertedAttr) {{
1558 propStorage = convertedAttr;
1559 } else {{
1560 emitError() << "Invalid attribute `{0}` in property conversion: " << attr;
1561 return ::mlir::failure();
1562 }
1563 }
1564 }
1565)decl",
1566 Vals&: name, Vals&: getAttr);
1567 }
1568 }
1569 setPropMethod << " return ::mlir::success();\n";
1570
1571 // Convert the attribute form to the property.
1572
1573 getPropMethod << " ::mlir::SmallVector<::mlir::NamedAttribute> attrs;\n"
1574 << " ::mlir::Builder odsBuilder{ctx};\n";
1575 const char *propToAttrFmt = R"decl(
1576 {
1577 const auto &propStorage = prop.{0};
1578 auto attr = [&]() -> ::mlir::Attribute {{
1579 {1}
1580 }();
1581 attrs.push_back(odsBuilder.getNamedAttr("{0}", attr));
1582 }
1583)decl";
1584 for (const auto &attrOrProp : attrOrProperties) {
1585 if (const auto *namedProperty =
1586 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
1587 StringRef name = namedProperty->name;
1588 auto &prop = namedProperty->prop;
1589 FmtContext fctx;
1590 getPropMethod << formatv(
1591 Fmt: propToAttrFmt, Vals&: name,
1592 Vals: tgfmt(fmt: prop.getConvertToAttributeCall(),
1593 ctx: &fctx.addSubst(placeholder: "_ctxt", subst: "ctx")
1594 .addSubst(placeholder: "_storage", subst: propertyStorage)));
1595 continue;
1596 }
1597 const auto *namedAttr =
1598 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
1599 StringRef name = namedAttr->attrName;
1600 getPropMethod << formatv(Fmt: R"decl(
1601 {{
1602 const auto &propStorage = prop.{0};
1603 if (propStorage)
1604 attrs.push_back(odsBuilder.getNamedAttr("{0}",
1605 propStorage));
1606 }
1607)decl",
1608 Vals&: name);
1609 }
1610 getPropMethod << R"decl(
1611 if (!attrs.empty())
1612 return odsBuilder.getDictionaryAttr(attrs);
1613 return {};
1614)decl";
1615
1616 // Hashing for the property
1617
1618 const char *propHashFmt = R"decl(
1619 auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code {
1620 using ::llvm::hash_value;
1621 return {1};
1622 };
1623)decl";
1624 for (const auto &attrOrProp : attrOrProperties) {
1625 if (const auto *namedProperty =
1626 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
1627 StringRef name = namedProperty->name;
1628 auto &prop = namedProperty->prop;
1629 FmtContext fctx;
1630 if (!prop.getHashPropertyCall().empty()) {
1631 hashMethod << formatv(
1632 Fmt: propHashFmt, Vals&: name,
1633 Vals: tgfmt(fmt: prop.getHashPropertyCall(),
1634 ctx: &fctx.addSubst(placeholder: "_storage", subst: propertyStorage)));
1635 }
1636 }
1637 }
1638 hashMethod << " using llvm::hash_value;\n";
1639 hashMethod << " return llvm::hash_combine(";
1640 llvm::interleaveComma(
1641 c: attrOrProperties, os&: hashMethod, each_fn: [&](const ConstArgument &attrOrProp) {
1642 if (const auto *namedProperty =
1643 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
1644 if (!namedProperty->prop.getHashPropertyCall().empty()) {
1645 hashMethod << "\n hash_" << namedProperty->name << "(prop."
1646 << namedProperty->name << ")";
1647 } else {
1648 hashMethod << "\n hash_value(prop." << namedProperty->name
1649 << ")";
1650 }
1651 return;
1652 }
1653 const auto *namedAttr =
1654 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
1655 StringRef name = namedAttr->attrName;
1656 hashMethod << "\n llvm::hash_value(prop." << name
1657 << ".getAsOpaquePointer())";
1658 });
1659 hashMethod << ");\n";
1660
1661 const char *getInherentAttrMethodFmt = R"decl(
1662 if (name == "{0}")
1663 return prop.{0};
1664)decl";
1665 const char *setInherentAttrMethodFmt = R"decl(
1666 if (name == "{0}") {{
1667 prop.{0} = ::llvm::dyn_cast_or_null<std::remove_reference_t<decltype(prop.{0})>>(value);
1668 return;
1669 }
1670)decl";
1671 const char *populateInherentAttrsMethodFmt = R"decl(
1672 if (prop.{0}) attrs.append("{0}", prop.{0});
1673)decl";
1674 for (const auto &attrOrProp : attrOrProperties) {
1675 if (const auto *namedAttr =
1676 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp)) {
1677 StringRef name = namedAttr->attrName;
1678 getInherentAttrMethod << formatv(Fmt: getInherentAttrMethodFmt, Vals&: name);
1679 setInherentAttrMethod << formatv(Fmt: setInherentAttrMethodFmt, Vals&: name);
1680 populateInherentAttrsMethod
1681 << formatv(Fmt: populateInherentAttrsMethodFmt, Vals&: name);
1682 continue;
1683 }
1684 // The ODS segment size property is "special": we expose it as an attribute
1685 // even though it is a native property.
1686 const auto *namedProperty = cast<const NamedProperty *>(Val: attrOrProp);
1687 StringRef name = namedProperty->name;
1688 if (name != operandSegmentAttrName && name != resultSegmentAttrName)
1689 continue;
1690 auto &prop = namedProperty->prop;
1691 FmtContext fctx;
1692 fctx.addSubst(placeholder: "_ctxt", subst: "ctx");
1693 fctx.addSubst(placeholder: "_storage", subst: Twine("prop.") + name);
1694 if (name == operandSegmentAttrName) {
1695 getInherentAttrMethod
1696 << formatv(Fmt: " if (name == \"operand_segment_sizes\" || name == "
1697 "\"{0}\") return ",
1698 Vals: operandSegmentAttrName);
1699 } else {
1700 getInherentAttrMethod
1701 << formatv(Fmt: " if (name == \"result_segment_sizes\" || name == "
1702 "\"{0}\") return ",
1703 Vals: resultSegmentAttrName);
1704 }
1705 getInherentAttrMethod << "[&]() -> ::mlir::Attribute { "
1706 << tgfmt(fmt: prop.getConvertToAttributeCall(), ctx: &fctx)
1707 << " }();\n";
1708
1709 if (name == operandSegmentAttrName) {
1710 setInherentAttrMethod
1711 << formatv(Fmt: " if (name == \"operand_segment_sizes\" || name == "
1712 "\"{0}\") {{",
1713 Vals: operandSegmentAttrName);
1714 } else {
1715 setInherentAttrMethod
1716 << formatv(Fmt: " if (name == \"result_segment_sizes\" || name == "
1717 "\"{0}\") {{",
1718 Vals: resultSegmentAttrName);
1719 }
1720 setInherentAttrMethod << formatv(Fmt: R"decl(
1721 auto arrAttr = ::llvm::dyn_cast_or_null<::mlir::DenseI32ArrayAttr>(value);
1722 if (!arrAttr) return;
1723 if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t))
1724 return;
1725 llvm::copy(arrAttr.asArrayRef(), prop.{0}.begin());
1726 return;
1727 }
1728)decl",
1729 Vals&: name);
1730 if (name == operandSegmentAttrName) {
1731 populateInherentAttrsMethod << formatv(
1732 Fmt: " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n",
1733 Vals: operandSegmentAttrName,
1734 Vals: tgfmt(fmt: prop.getConvertToAttributeCall(), ctx: &fctx));
1735 } else {
1736 populateInherentAttrsMethod << formatv(
1737 Fmt: " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n",
1738 Vals: resultSegmentAttrName,
1739 Vals: tgfmt(fmt: prop.getConvertToAttributeCall(), ctx: &fctx));
1740 }
1741 }
1742 getInherentAttrMethod << " return std::nullopt;\n";
1743
1744 // Emit the verifiers method for backward compatibility with the generic
1745 // syntax. This method verifies the constraint on the properties attributes
1746 // before they are set, since dyn_cast<> will silently omit failures.
1747 for (const auto &attrOrProp : attrOrProperties) {
1748 const auto *namedAttr =
1749 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
1750 if (!namedAttr || !namedAttr->constraint)
1751 continue;
1752 Attribute attr = *namedAttr->constraint;
1753 std::optional<StringRef> constraintFn =
1754 staticVerifierEmitter.getAttrConstraintFn(constraint: attr);
1755 if (!constraintFn)
1756 continue;
1757 if (canEmitAttrVerifier(attr,
1758 /*isEmittingForOp=*/false)) {
1759 std::string name = op.getGetterName(name: namedAttr->attrName);
1760 verifyInherentAttrsMethod
1761 << formatv(Fmt: R"(
1762 {{
1763 ::mlir::Attribute attr = attrs.get({0}AttrName(opName));
1764 if (attr && ::mlir::failed({1}(attr, "{2}", emitError)))
1765 return ::mlir::failure();
1766 }
1767)",
1768 Vals&: name, Vals&: constraintFn, Vals: namedAttr->attrName);
1769 }
1770 }
1771 verifyInherentAttrsMethod << " return ::mlir::success();";
1772
1773 // Generate methods to interact with bytecode.
1774 genPropertiesSupportForBytecode(attrOrProperties);
1775}
1776
1777void OpEmitter::genPropertiesSupportForBytecode(
1778 ArrayRef<ConstArgument> attrOrProperties) {
1779 if (attrOrProperties.empty())
1780 return;
1781
1782 if (op.useCustomPropertiesEncoding()) {
1783 opClass.declareStaticMethod(
1784 retType: "::llvm::LogicalResult", name: "readProperties",
1785 args: MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1786 args: MethodParameter("::mlir::OperationState &", "state"));
1787 opClass.declareMethod(
1788 retType: "void", name: "writeProperties",
1789 args: MethodParameter("::mlir::DialectBytecodeWriter &", "writer"));
1790 return;
1791 }
1792
1793 auto &readPropertiesMethod =
1794 opClass
1795 .addStaticMethod(
1796 retType: "::llvm::LogicalResult", name: "readProperties",
1797 args: MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1798 args: MethodParameter("::mlir::OperationState &", "state"))
1799 ->body();
1800
1801 auto &writePropertiesMethod =
1802 opClass
1803 .addMethod(
1804 retType: "void", name: "writeProperties",
1805 args: MethodParameter("::mlir::DialectBytecodeWriter &", "writer"))
1806 ->body();
1807
1808 // Populate bytecode serialization logic.
1809 readPropertiesMethod
1810 << " auto &prop = state.getOrAddProperties<Properties>(); (void)prop;";
1811 writePropertiesMethod << " auto &prop = getProperties(); (void)prop;\n";
1812 for (const auto &item : llvm::enumerate(First&: attrOrProperties)) {
1813 auto &attrOrProp = item.value();
1814 FmtContext fctx;
1815 fctx.addSubst(placeholder: "_reader", subst: "reader")
1816 .addSubst(placeholder: "_writer", subst: "writer")
1817 .addSubst(placeholder: "_storage", subst: propertyStorage)
1818 .addSubst(placeholder: "_ctxt", subst: "this->getContext()");
1819 // If the op emits operand/result segment sizes as a property, emit the
1820 // legacy reader/writer in the appropriate order to allow backward
1821 // compatibility and back deployment.
1822 if (emitHelper.getOperandSegmentsSize().has_value() &&
1823 item.index() == emitHelper.getOperandSegmentSizesLegacyIndex()) {
1824 FmtContext fmtCtxt(fctx);
1825 fmtCtxt.addSubst(placeholder: "_propName", subst: operandSegmentAttrName);
1826 readPropertiesMethod << tgfmt(fmt: readBytecodeSegmentSizeLegacy, ctx: &fmtCtxt);
1827 writePropertiesMethod << tgfmt(fmt: writeBytecodeSegmentSizeLegacy, ctx: &fmtCtxt);
1828 }
1829 if (emitHelper.getResultSegmentsSize().has_value() &&
1830 item.index() == emitHelper.getResultSegmentSizesLegacyIndex()) {
1831 FmtContext fmtCtxt(fctx);
1832 fmtCtxt.addSubst(placeholder: "_propName", subst: resultSegmentAttrName);
1833 readPropertiesMethod << tgfmt(fmt: readBytecodeSegmentSizeLegacy, ctx: &fmtCtxt);
1834 writePropertiesMethod << tgfmt(fmt: writeBytecodeSegmentSizeLegacy, ctx: &fmtCtxt);
1835 }
1836 if (const auto *namedProperty =
1837 dyn_cast<const NamedProperty *>(Val: attrOrProp)) {
1838 StringRef name = namedProperty->name;
1839 readPropertiesMethod << formatv(
1840 Fmt: R"(
1841 {{
1842 auto &propStorage = prop.{0};
1843 auto readProp = [&]() {
1844 {1};
1845 return ::mlir::success();
1846 };
1847 if (::mlir::failed(readProp()))
1848 return ::mlir::failure();
1849 }
1850)",
1851 Vals&: name,
1852 Vals: tgfmt(fmt: namedProperty->prop.getReadFromMlirBytecodeCall(), ctx: &fctx));
1853 writePropertiesMethod << formatv(
1854 Fmt: R"(
1855 {{
1856 auto &propStorage = prop.{0};
1857 {1};
1858 }
1859)",
1860 Vals&: name, Vals: tgfmt(fmt: namedProperty->prop.getWriteToMlirBytecodeCall(), ctx: &fctx));
1861 continue;
1862 }
1863 const auto *namedAttr = dyn_cast<const AttributeMetadata *>(Val: attrOrProp);
1864 StringRef name = namedAttr->attrName;
1865 if (namedAttr->isRequired) {
1866 readPropertiesMethod << formatv(Fmt: R"(
1867 if (::mlir::failed(reader.readAttribute(prop.{0})))
1868 return ::mlir::failure();
1869)",
1870 Vals&: name);
1871 writePropertiesMethod
1872 << formatv(Fmt: " writer.writeAttribute(prop.{0});\n", Vals&: name);
1873 } else {
1874 readPropertiesMethod << formatv(Fmt: R"(
1875 if (::mlir::failed(reader.readOptionalAttribute(prop.{0})))
1876 return ::mlir::failure();
1877)",
1878 Vals&: name);
1879 writePropertiesMethod << formatv(Fmt: R"(
1880 writer.writeOptionalAttribute(prop.{0});
1881)",
1882 Vals&: name);
1883 }
1884 }
1885 readPropertiesMethod << " return ::mlir::success();";
1886}
1887
1888void OpEmitter::genPropGetters() {
1889 for (const NamedProperty &prop : op.getProperties()) {
1890 std::string name = op.getGetterName(name: prop.name);
1891 emitPropGetter(opClass, op, name, prop: prop.prop);
1892 }
1893}
1894
1895void OpEmitter::genPropSetters() {
1896 for (const NamedProperty &prop : op.getProperties()) {
1897 std::string name = op.getSetterName(name: prop.name);
1898 std::string argName = "new" + convertToCamelFromSnakeCase(
1899 input: prop.name, /*capitalizeFirst=*/true);
1900 auto *method = opClass.addInlineMethod(
1901 retType: "void", name, args: MethodParameter(prop.prop.getInterfaceType(), argName));
1902 if (!method)
1903 return;
1904 method->body() << formatv(Fmt: " getProperties().{0}({1});", Vals&: name, Vals&: argName);
1905 }
1906}
1907
1908void OpEmitter::genAttrGetters() {
1909 FmtContext fctx;
1910 fctx.withBuilder(subst: "::mlir::Builder((*this)->getContext())");
1911
1912 // Emit the derived attribute body.
1913 auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
1914 if (auto *method = opClass.addMethod(retType: attr.getReturnType(), name))
1915 method->body() << " " << attr.getDerivedCodeBody() << "\n";
1916 };
1917
1918 // Generate named accessor with Attribute return type. This is a wrapper
1919 // class that allows referring to the attributes via accessors instead of
1920 // having to use the string interface for better compile time verification.
1921 auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName,
1922 Attribute attr) {
1923 // The method body for this getter is trivial. Emit it inline.
1924 auto *method =
1925 opClass.addInlineMethod(retType: attr.getStorageType(), name: name + "Attr");
1926 if (!method)
1927 return;
1928 method->body() << formatv(
1929 Fmt: " return ::llvm::{1}<{2}>({0});", Vals: emitHelper.getAttr(attrName),
1930 Vals: attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null"
1931 : "cast",
1932 Vals: attr.getStorageType());
1933 };
1934
1935 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1936 std::string name = op.getGetterName(name: namedAttr.name);
1937 if (namedAttr.attr.isDerivedAttr()) {
1938 emitDerivedAttr(name, namedAttr.attr);
1939 } else {
1940 emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr);
1941 emitAttrGetterWithReturnType(fctx, opClass, op, name, attr: namedAttr.attr);
1942 }
1943 }
1944
1945 auto derivedAttrs = make_filter_range(Range: op.getAttributes(),
1946 Pred: [](const NamedAttribute &namedAttr) {
1947 return namedAttr.attr.isDerivedAttr();
1948 });
1949 if (derivedAttrs.empty())
1950 return;
1951
1952 opClass.addTrait(trait: "::mlir::DerivedAttributeOpInterface::Trait");
1953 // Generate helper method to query whether a named attribute is a derived
1954 // attribute. This enables, for example, avoiding adding an attribute that
1955 // overlaps with a derived attribute.
1956 {
1957 auto *method =
1958 opClass.addStaticMethod(retType: "bool", name: "isDerivedAttribute",
1959 args: MethodParameter("::llvm::StringRef", "name"));
1960 ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
1961 auto &body = method->body();
1962 for (auto namedAttr : derivedAttrs)
1963 body << " if (name == \"" << namedAttr.name << "\") return true;\n";
1964 body << " return false;";
1965 }
1966 // Generate method to materialize derived attributes as a DictionaryAttr.
1967 {
1968 auto *method = opClass.addMethod(retType: "::mlir::DictionaryAttr",
1969 name: "materializeDerivedAttributes");
1970 ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
1971 auto &body = method->body();
1972
1973 auto nonMaterializable =
1974 make_filter_range(Range&: derivedAttrs, Pred: [](const NamedAttribute &namedAttr) {
1975 return namedAttr.attr.getConvertFromStorageCall().empty();
1976 });
1977 if (!nonMaterializable.empty()) {
1978 std::string attrs;
1979 llvm::raw_string_ostream os(attrs);
1980 interleaveComma(c: nonMaterializable, os, each_fn: [&](const NamedAttribute &attr) {
1981 os << op.getGetterName(name: attr.name);
1982 });
1983 PrintWarning(
1984 WarningLoc: op.getLoc(),
1985 Msg: formatv(
1986 Fmt: "op has non-materializable derived attributes '{0}', skipping",
1987 Vals&: os.str()));
1988 body << formatv(Fmt: " emitOpError(\"op has non-materializable derived "
1989 "attributes '{0}'\");\n",
1990 Vals&: attrs);
1991 body << " return nullptr;";
1992 return;
1993 }
1994
1995 body << " ::mlir::MLIRContext* ctx = getContext();\n";
1996 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
1997 body << " return ::mlir::DictionaryAttr::get(";
1998 body << " ctx, {\n";
1999 interleave(
2000 c: derivedAttrs, os&: body,
2001 each_fn: [&](const NamedAttribute &namedAttr) {
2002 auto tmpl = namedAttr.attr.getConvertFromStorageCall();
2003 std::string name = op.getGetterName(name: namedAttr.name);
2004 body << " {" << name << "AttrName(),\n"
2005 << tgfmt(fmt: tmpl, ctx: &fctx.withSelf(subst: name + "()")
2006 .withBuilder(subst: "odsBuilder")
2007 .addSubst(placeholder: "_ctxt", subst: "ctx")
2008 .addSubst(placeholder: "_storage", subst: "ctx"))
2009 << "}";
2010 },
2011 separator: ",\n");
2012 body << "});";
2013 }
2014}
2015
2016void OpEmitter::genAttrSetters() {
2017 bool useProperties = op.getDialect().usePropertiesForAttributes();
2018
2019 // Generate the code to set an attribute.
2020 auto emitSetAttr = [&](Method *method, StringRef getterName,
2021 StringRef attrName, StringRef attrVar) {
2022 if (useProperties) {
2023 method->body() << formatv(Fmt: " getProperties().{0} = {1};", Vals&: attrName,
2024 Vals&: attrVar);
2025 } else {
2026 method->body() << formatv(Fmt: " (*this)->setAttr({0}AttrName(), {1});",
2027 Vals&: getterName, Vals&: attrVar);
2028 }
2029 };
2030
2031 // Generate raw named setter type. This is a wrapper class that allows setting
2032 // to the attributes via setters instead of having to use the string interface
2033 // for better compile time verification.
2034 auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
2035 StringRef attrName, Attribute attr) {
2036 // This method body is trivial, so emit it inline.
2037 auto *method =
2038 opClass.addInlineMethod(retType: "void", name: setterName + "Attr",
2039 args: MethodParameter(attr.getStorageType(), "attr"));
2040 if (method)
2041 emitSetAttr(method, getterName, attrName, "attr");
2042 };
2043
2044 // Generate a setter that accepts the underlying C++ type as opposed to the
2045 // attribute type.
2046 auto emitAttrWithReturnType = [&](StringRef setterName, StringRef getterName,
2047 StringRef attrName, Attribute attr) {
2048 Attribute baseAttr = attr.getBaseAttr();
2049 if (!canUseUnwrappedRawValue(attr: baseAttr))
2050 return;
2051 FmtContext fctx;
2052 fctx.withBuilder(subst: "::mlir::Builder((*this)->getContext())");
2053 bool isUnitAttr = attr.getAttrDefName() == "UnitAttr";
2054 bool isOptional = attr.isOptional();
2055
2056 auto createMethod = [&](const Twine &paramType) {
2057 return opClass.addMethod(retType: "void", name&: setterName,
2058 args: MethodParameter(paramType.str(), "attrValue"));
2059 };
2060
2061 // Build the method using the correct parameter type depending on
2062 // optionality.
2063 Method *method = nullptr;
2064 if (isUnitAttr)
2065 method = createMethod("bool");
2066 else if (isOptional)
2067 method =
2068 createMethod("::std::optional<" + baseAttr.getReturnType() + ">");
2069 else
2070 method = createMethod(attr.getReturnType());
2071 if (!method)
2072 return;
2073
2074 // If the value isn't optional, just set it directly.
2075 if (!isOptional) {
2076 emitSetAttr(method, getterName, attrName,
2077 constBuildAttrFromParam(attr, fctx, paramName: "attrValue"));
2078 return;
2079 }
2080
2081 // Otherwise, we only set if the provided value is valid. If it isn't, we
2082 // remove the attribute.
2083
2084 // TODO: Handle unit attr parameters specially, given that it is treated as
2085 // optional but not in the same way as the others (i.e. it uses bool over
2086 // std::optional<>).
2087 StringRef paramStr = isUnitAttr ? "attrValue" : "*attrValue";
2088 if (!useProperties) {
2089 const char *optionalCodeBody = R"(
2090 if (attrValue)
2091 return (*this)->setAttr({0}AttrName(), {1});
2092 (*this)->removeAttr({0}AttrName());)";
2093 method->body() << formatv(
2094 Fmt: optionalCodeBody, Vals&: getterName,
2095 Vals: constBuildAttrFromParam(attr: baseAttr, fctx, paramName: paramStr));
2096 } else {
2097 const char *optionalCodeBody = R"(
2098 auto &odsProp = getProperties().{0};
2099 if (attrValue)
2100 odsProp = {1};
2101 else
2102 odsProp = nullptr;)";
2103 method->body() << formatv(
2104 Fmt: optionalCodeBody, Vals&: attrName,
2105 Vals: constBuildAttrFromParam(attr: baseAttr, fctx, paramName: paramStr));
2106 }
2107 };
2108
2109 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2110 if (namedAttr.attr.isDerivedAttr())
2111 continue;
2112 std::string setterName = op.getSetterName(name: namedAttr.name);
2113 std::string getterName = op.getGetterName(name: namedAttr.name);
2114 emitAttrWithStorageType(setterName, getterName, namedAttr.name,
2115 namedAttr.attr);
2116 emitAttrWithReturnType(setterName, getterName, namedAttr.name,
2117 namedAttr.attr);
2118 }
2119}
2120
2121void OpEmitter::genOptionalAttrRemovers() {
2122 // Generate methods for removing optional attributes, instead of having to
2123 // use the string interface. Enables better compile time verification.
2124 auto emitRemoveAttr = [&](StringRef name, bool useProperties) {
2125 auto *method = opClass.addInlineMethod(retType: "::mlir::Attribute",
2126 name: op.getRemoverName(name) + "Attr");
2127 if (!method)
2128 return;
2129 if (useProperties) {
2130 method->body() << formatv(Fmt: R"(
2131 auto attr = getProperties().{0};
2132 getProperties().{0} = {{};
2133 return attr;
2134)",
2135 Vals&: name);
2136 return;
2137 }
2138 method->body() << formatv(Fmt: "return (*this)->removeAttr({0}AttrName());",
2139 Vals: op.getGetterName(name));
2140 };
2141
2142 for (const NamedAttribute &namedAttr : op.getAttributes())
2143 if (namedAttr.attr.isOptional())
2144 emitRemoveAttr(namedAttr.name,
2145 op.getDialect().usePropertiesForAttributes());
2146}
2147
2148// Generates the code to compute the start and end index of an operand or result
2149// range.
2150template <typename RangeT>
2151static void generateValueRangeStartAndEnd(
2152 Class &opClass, bool isGenericAdaptorBase, StringRef methodName,
2153 int numVariadic, int numNonVariadic, StringRef rangeSizeCall,
2154 bool hasAttrSegmentSize, StringRef sizeAttrInit, RangeT &&odsValues) {
2155
2156 SmallVector<MethodParameter> parameters{MethodParameter("unsigned", "index")};
2157 if (isGenericAdaptorBase) {
2158 parameters.emplace_back(Args: "unsigned", Args: "odsOperandsSize");
2159 // The range size is passed per parameter for generic adaptor bases as
2160 // using the rangeSizeCall would require the operands, which are not
2161 // accessible in the base class.
2162 rangeSizeCall = "odsOperandsSize";
2163 }
2164
2165 // The method is trivial if the operation does not have any variadic operands.
2166 // In that case, make sure to generate it in-line.
2167 auto *method = opClass.addMethod(retType: "std::pair<unsigned, unsigned>", name&: methodName,
2168 properties: numVariadic == 0 ? Method::Properties::Inline
2169 : Method::Properties::None,
2170 args&: parameters);
2171 if (!method)
2172 return;
2173 auto &body = method->body();
2174 if (numVariadic == 0) {
2175 body << " return {index, 1};\n";
2176 } else if (hasAttrSegmentSize) {
2177 body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
2178 } else {
2179 // Because the op can have arbitrarily interleaved variadic and non-variadic
2180 // operands, we need to embed a list in the "sink" getter method for
2181 // calculation at run-time.
2182 SmallVector<StringRef, 4> isVariadic;
2183 isVariadic.reserve(N: llvm::size(odsValues));
2184 for (auto &it : odsValues)
2185 isVariadic.push_back(Elt: it.isVariableLength() ? "true" : "false");
2186 std::string isVariadicList = llvm::join(R&: isVariadic, Separator: ", ");
2187 body << formatv(Fmt: sameVariadicSizeValueRangeCalcCode, Vals&: isVariadicList,
2188 Vals&: numNonVariadic, Vals&: numVariadic, Vals&: rangeSizeCall, Vals: "operand");
2189 }
2190}
2191
2192static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
2193 return llvm::formatv(Fmt: "::mlir::TypedValue<{0}>", Vals: value.constraint.getCppType())
2194 .str();
2195}
2196
2197// Generates the named operand getter methods for the given Operator `op` and
2198// puts them in `opClass`. Uses `rangeType` as the return type of getters that
2199// return a range of operands (individual operands are `Value ` and each
2200// element in the range must also be `Value `); use `rangeBeginCall` to get
2201// an iterator to the beginning of the operand range; use `rangeSizeCall` to
2202// obtain the number of operands. `getOperandCallPattern` contains the code
2203// necessary to obtain a single operand whose position will be substituted
2204// instead of
2205// "{0}" marker in the pattern. Note that the pattern should work for any kind
2206// of ops, in particular for one-operand ops that may not have the
2207// `getOperand(unsigned)` method.
2208static void
2209generateNamedOperandGetters(const Operator &op, Class &opClass,
2210 Class *genericAdaptorBase, StringRef sizeAttrInit,
2211 StringRef rangeType, StringRef rangeElementType,
2212 StringRef rangeBeginCall, StringRef rangeSizeCall,
2213 StringRef getOperandCallPattern) {
2214 const int numOperands = op.getNumOperands();
2215 const int numVariadicOperands = op.getNumVariableLengthOperands();
2216 const int numNormalOperands = numOperands - numVariadicOperands;
2217
2218 const auto *sameVariadicSize =
2219 op.getTrait(trait: "::mlir::OpTrait::SameVariadicOperandSize");
2220 const auto *attrSizedOperands =
2221 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments");
2222
2223 if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
2224 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "op has multiple variadic operands but no "
2225 "specification over their sizes");
2226 }
2227
2228 if (numVariadicOperands < 2 && attrSizedOperands) {
2229 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "op must have at least two variadic operands "
2230 "to use 'AttrSizedOperandSegments' trait");
2231 }
2232
2233 if (attrSizedOperands && sameVariadicSize) {
2234 PrintFatalError(ErrorLoc: op.getLoc(),
2235 Msg: "op cannot have both 'AttrSizedOperandSegments' and "
2236 "'SameVariadicOperandSize' traits");
2237 }
2238
2239 // Print the ods names so they don't need to be hardcoded in the source.
2240 for (int i = 0; i != numOperands; ++i) {
2241 const auto &operand = op.getOperand(index: i);
2242 if (operand.name.empty())
2243 continue;
2244
2245 opClass.declare<Field>(args: "static constexpr int", args: Twine("odsIndex_") +
2246 operand.name + " = " +
2247 Twine(i));
2248 }
2249
2250 // First emit a few "sink" getter methods upon which we layer all nicer named
2251 // getter methods.
2252 // If generating for an adaptor, the method is put into the non-templated
2253 // generic base class, to not require being defined in the header.
2254 // Since the operand size can't be determined from the base class however,
2255 // it has to be passed as an additional argument. The trampoline below
2256 // generates the function with the same signature as the Op in the generic
2257 // adaptor.
2258 bool isGenericAdaptorBase = genericAdaptorBase != nullptr;
2259 generateValueRangeStartAndEnd(
2260 /*opClass=*/isGenericAdaptorBase ? *genericAdaptorBase : opClass,
2261 isGenericAdaptorBase,
2262 /*methodName=*/"getODSOperandIndexAndLength", numVariadic: numVariadicOperands,
2263 numNonVariadic: numNormalOperands, rangeSizeCall, hasAttrSegmentSize: attrSizedOperands, sizeAttrInit,
2264 odsValues: const_cast<Operator &>(op).getOperands());
2265 if (isGenericAdaptorBase) {
2266 // Generate trampoline for calling 'getODSOperandIndexAndLength' with just
2267 // the index. This just calls the implementation in the base class but
2268 // passes the operand size as parameter.
2269 Method *method = opClass.addInlineMethod(
2270 retType: "std::pair<unsigned, unsigned>", name: "getODSOperandIndexAndLength",
2271 args: MethodParameter("unsigned", "index"));
2272 ERROR_IF_PRUNED(method, "getODSOperandIndexAndLength", op);
2273 MethodBody &body = method->body();
2274 body.indent() << formatv(
2275 Fmt: "return Base::getODSOperandIndexAndLength(index, {0});", Vals&: rangeSizeCall);
2276 }
2277
2278 // The implementation of this method is trivial and it is very load-bearing.
2279 // Generate it inline.
2280 auto *m = opClass.addInlineMethod(retType&: rangeType, name: "getODSOperands",
2281 args: MethodParameter("unsigned", "index"));
2282 ERROR_IF_PRUNED(m, "getODSOperands", op);
2283 auto &body = m->body();
2284 body << formatv(Fmt: valueRangeReturnCode, Vals&: rangeBeginCall,
2285 Vals: "getODSOperandIndexAndLength(index)");
2286
2287 // Then we emit nicer named getter methods by redirecting to the "sink" getter
2288 // method.
2289 for (int i = 0; i != numOperands; ++i) {
2290 const auto &operand = op.getOperand(index: i);
2291 if (operand.name.empty())
2292 continue;
2293 std::string name = op.getGetterName(name: operand.name);
2294 if (operand.isOptional()) {
2295 m = opClass.addInlineMethod(retType: isGenericAdaptorBase
2296 ? rangeElementType
2297 : generateTypeForGetter(value: operand),
2298 name);
2299 ERROR_IF_PRUNED(m, name, op);
2300 m->body().indent() << formatv(Fmt: "auto operands = getODSOperands({0});\n"
2301 "return operands.empty() ? {1}{{} : ",
2302 Vals&: i, Vals: m->getReturnType());
2303 if (!isGenericAdaptorBase)
2304 m->body() << llvm::formatv(Fmt: "::llvm::cast<{0}>", Vals: m->getReturnType());
2305 m->body() << "(*operands.begin());";
2306 } else if (operand.isVariadicOfVariadic()) {
2307 std::string segmentAttr = op.getGetterName(
2308 name: operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
2309 if (genericAdaptorBase) {
2310 m = opClass.addMethod(retType: "::llvm::SmallVector<" + rangeType + ">", name);
2311 ERROR_IF_PRUNED(m, name, op);
2312 m->body() << llvm::formatv(Fmt: variadicOfVariadicAdaptorCalcCode,
2313 Vals&: segmentAttr, Vals&: i, Vals&: rangeType);
2314 continue;
2315 }
2316
2317 m = opClass.addInlineMethod(retType: "::mlir::OperandRangeRange", name);
2318 ERROR_IF_PRUNED(m, name, op);
2319 m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr
2320 << "Attr());";
2321 } else if (operand.isVariadic()) {
2322 m = opClass.addInlineMethod(retType&: rangeType, name);
2323 ERROR_IF_PRUNED(m, name, op);
2324 m->body() << " return getODSOperands(" << i << ");";
2325 } else {
2326 m = opClass.addInlineMethod(retType: isGenericAdaptorBase
2327 ? rangeElementType
2328 : generateTypeForGetter(value: operand),
2329 name);
2330 ERROR_IF_PRUNED(m, name, op);
2331 m->body().indent() << "return ";
2332 if (!isGenericAdaptorBase)
2333 m->body() << llvm::formatv(Fmt: "::llvm::cast<{0}>", Vals: m->getReturnType());
2334 m->body() << llvm::formatv(Fmt: "(*getODSOperands({0}).begin());", Vals&: i);
2335 }
2336 }
2337}
2338
2339void OpEmitter::genNamedOperandGetters() {
2340 // Build the code snippet used for initializing the operand_segment_size)s
2341 // array.
2342 std::string attrSizeInitCode;
2343 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
2344 if (op.getDialect().usePropertiesForAttributes())
2345 attrSizeInitCode = formatv(Fmt: adapterSegmentSizeAttrInitCodeProperties,
2346 Vals: "getProperties().operandSegmentSizes");
2347
2348 else
2349 attrSizeInitCode = formatv(Fmt: opSegmentSizeAttrInitCode,
2350 Vals: emitHelper.getAttr(attrName: operandSegmentAttrName));
2351 }
2352
2353 generateNamedOperandGetters(
2354 op, opClass,
2355 /*genericAdaptorBase=*/nullptr,
2356 /*sizeAttrInit=*/attrSizeInitCode,
2357 /*rangeType=*/"::mlir::Operation::operand_range",
2358 /*rangeElementType=*/"::mlir::Value",
2359 /*rangeBeginCall=*/"getOperation()->operand_begin()",
2360 /*rangeSizeCall=*/"getOperation()->getNumOperands()",
2361 /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
2362}
2363
2364void OpEmitter::genNamedOperandSetters() {
2365 auto *attrSizedOperands =
2366 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments");
2367 for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
2368 const auto &operand = op.getOperand(index: i);
2369 if (operand.name.empty())
2370 continue;
2371 std::string name = op.getGetterName(name: operand.name);
2372
2373 StringRef returnType;
2374 if (operand.isVariadicOfVariadic()) {
2375 returnType = "::mlir::MutableOperandRangeRange";
2376 } else if (operand.isVariableLength()) {
2377 returnType = "::mlir::MutableOperandRange";
2378 } else {
2379 returnType = "::mlir::OpOperand &";
2380 }
2381 bool isVariadicOperand =
2382 operand.isVariadicOfVariadic() || operand.isVariableLength();
2383 auto *m = opClass.addMethod(retType&: returnType, name: name + "Mutable",
2384 properties: isVariadicOperand ? Method::Properties::None
2385 : Method::Properties::Inline);
2386 ERROR_IF_PRUNED(m, name, op);
2387 auto &body = m->body();
2388 body << " auto range = getODSOperandIndexAndLength(" << i << ");\n";
2389
2390 if (!isVariadicOperand) {
2391 // In case of a single operand, return a single OpOperand.
2392 body << " return getOperation()->getOpOperand(range.first);\n";
2393 continue;
2394 }
2395
2396 body << " auto mutableRange = "
2397 "::mlir::MutableOperandRange(getOperation(), "
2398 "range.first, range.second";
2399 if (attrSizedOperands) {
2400 if (emitHelper.hasProperties())
2401 body << formatv(Fmt: ", ::mlir::MutableOperandRange::OperandSegment({0}u, "
2402 "{{getOperandSegmentSizesAttrName(), "
2403 "::mlir::DenseI32ArrayAttr::get(getContext(), "
2404 "getProperties().operandSegmentSizes)})",
2405 Vals&: i);
2406 else
2407 body << formatv(
2408 Fmt: ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", Vals&: i,
2409 Vals: emitHelper.getAttr(attrName: operandSegmentAttrName, /*isNamed=*/true));
2410 }
2411 body << ");\n";
2412
2413 // If this operand is a nested variadic, we split the range into a
2414 // MutableOperandRangeRange that provides a range over all of the
2415 // sub-ranges.
2416 if (operand.isVariadicOfVariadic()) {
2417 body << " return "
2418 "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
2419 << op.getGetterName(
2420 name: operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
2421 << "AttrName()));\n";
2422 } else {
2423 // Otherwise, we use the full range directly.
2424 body << " return mutableRange;\n";
2425 }
2426 }
2427}
2428
2429void OpEmitter::genNamedResultGetters() {
2430 const int numResults = op.getNumResults();
2431 const int numVariadicResults = op.getNumVariableLengthResults();
2432 const int numNormalResults = numResults - numVariadicResults;
2433
2434 // If we have more than one variadic results, we need more complicated logic
2435 // to calculate the value range for each result.
2436
2437 const auto *sameVariadicSize =
2438 op.getTrait(trait: "::mlir::OpTrait::SameVariadicResultSize");
2439 const auto *attrSizedResults =
2440 op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments");
2441
2442 if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
2443 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "op has multiple variadic results but no "
2444 "specification over their sizes");
2445 }
2446
2447 if (numVariadicResults < 2 && attrSizedResults) {
2448 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "op must have at least two variadic results "
2449 "to use 'AttrSizedResultSegments' trait");
2450 }
2451
2452 if (attrSizedResults && sameVariadicSize) {
2453 PrintFatalError(ErrorLoc: op.getLoc(),
2454 Msg: "op cannot have both 'AttrSizedResultSegments' and "
2455 "'SameVariadicResultSize' traits");
2456 }
2457
2458 // Build the initializer string for the result segment size attribute.
2459 std::string attrSizeInitCode;
2460 if (attrSizedResults) {
2461 if (op.getDialect().usePropertiesForAttributes())
2462 attrSizeInitCode = formatv(Fmt: adapterSegmentSizeAttrInitCodeProperties,
2463 Vals: "getProperties().resultSegmentSizes");
2464
2465 else
2466 attrSizeInitCode = formatv(Fmt: opSegmentSizeAttrInitCode,
2467 Vals: emitHelper.getAttr(attrName: resultSegmentAttrName));
2468 }
2469
2470 generateValueRangeStartAndEnd(
2471 opClass, /*isGenericAdaptorBase=*/false, methodName: "getODSResultIndexAndLength",
2472 numVariadic: numVariadicResults, numNonVariadic: numNormalResults, rangeSizeCall: "getOperation()->getNumResults()",
2473 hasAttrSegmentSize: attrSizedResults, sizeAttrInit: attrSizeInitCode, odsValues: op.getResults());
2474
2475 // The implementation of this method is trivial and it is very load-bearing.
2476 // Generate it inline.
2477 auto *m = opClass.addInlineMethod(retType: "::mlir::Operation::result_range",
2478 name: "getODSResults",
2479 args: MethodParameter("unsigned", "index"));
2480 ERROR_IF_PRUNED(m, "getODSResults", op);
2481 m->body() << formatv(Fmt: valueRangeReturnCode, Vals: "getOperation()->result_begin()",
2482 Vals: "getODSResultIndexAndLength(index)");
2483
2484 for (int i = 0; i != numResults; ++i) {
2485 const auto &result = op.getResult(index: i);
2486 if (result.name.empty())
2487 continue;
2488 std::string name = op.getGetterName(name: result.name);
2489 if (result.isOptional()) {
2490 m = opClass.addInlineMethod(retType: generateTypeForGetter(value: result), name);
2491 ERROR_IF_PRUNED(m, name, op);
2492 m->body() << " auto results = getODSResults(" << i << ");\n"
2493 << llvm::formatv(Fmt: " return results.empty()"
2494 " ? {0}()"
2495 " : ::llvm::cast<{0}>(*results.begin());",
2496 Vals: m->getReturnType());
2497 } else if (result.isVariadic()) {
2498 m = opClass.addInlineMethod(retType: "::mlir::Operation::result_range", name);
2499 ERROR_IF_PRUNED(m, name, op);
2500 m->body() << " return getODSResults(" << i << ");";
2501 } else {
2502 m = opClass.addInlineMethod(retType: generateTypeForGetter(value: result), name);
2503 ERROR_IF_PRUNED(m, name, op);
2504 m->body() << llvm::formatv(
2505 Fmt: " return ::llvm::cast<{0}>(*getODSResults({1}).begin());",
2506 Vals: m->getReturnType(), Vals&: i);
2507 }
2508 }
2509}
2510
2511void OpEmitter::genNamedRegionGetters() {
2512 unsigned numRegions = op.getNumRegions();
2513 for (unsigned i = 0; i < numRegions; ++i) {
2514 const auto &region = op.getRegion(index: i);
2515 if (region.name.empty())
2516 continue;
2517 std::string name = op.getGetterName(name: region.name);
2518
2519 // Generate the accessors for a variadic region.
2520 if (region.isVariadic()) {
2521 auto *m = opClass.addInlineMethod(
2522 retType: "::mlir::MutableArrayRef<::mlir::Region>", name);
2523 ERROR_IF_PRUNED(m, name, op);
2524 m->body() << formatv(Fmt: " return (*this)->getRegions().drop_front({0});",
2525 Vals&: i);
2526 continue;
2527 }
2528
2529 auto *m = opClass.addInlineMethod(retType: "::mlir::Region &", name);
2530 ERROR_IF_PRUNED(m, name, op);
2531 m->body() << formatv(Fmt: " return (*this)->getRegion({0});", Vals&: i);
2532 }
2533}
2534
2535void OpEmitter::genNamedSuccessorGetters() {
2536 unsigned numSuccessors = op.getNumSuccessors();
2537 for (unsigned i = 0; i < numSuccessors; ++i) {
2538 const NamedSuccessor &successor = op.getSuccessor(index: i);
2539 if (successor.name.empty())
2540 continue;
2541 std::string name = op.getGetterName(name: successor.name);
2542 // Generate the accessors for a variadic successor list.
2543 if (successor.isVariadic()) {
2544 auto *m = opClass.addInlineMethod(retType: "::mlir::SuccessorRange", name);
2545 ERROR_IF_PRUNED(m, name, op);
2546 m->body() << formatv(
2547 Fmt: " return {std::next((*this)->successor_begin(), {0}), "
2548 "(*this)->successor_end()};",
2549 Vals&: i);
2550 continue;
2551 }
2552
2553 auto *m = opClass.addInlineMethod(retType: "::mlir::Block *", name);
2554 ERROR_IF_PRUNED(m, name, op);
2555 m->body() << formatv(Fmt: " return (*this)->getSuccessor({0});", Vals&: i);
2556 }
2557}
2558
2559static bool canGenerateUnwrappedBuilder(const Operator &op) {
2560 // If this op does not have native attributes at all, return directly to avoid
2561 // redefining builders.
2562 if (op.getNumNativeAttributes() == 0)
2563 return false;
2564
2565 bool canGenerate = false;
2566 // We are generating builders that take raw values for attributes. We need to
2567 // make sure the native attributes have a meaningful "unwrapped" value type
2568 // different from the wrapped mlir::Attribute type to avoid redefining
2569 // builders. This checks for the op has at least one such native attribute.
2570 for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
2571 const NamedAttribute &namedAttr = op.getAttribute(index: i);
2572 if (canUseUnwrappedRawValue(attr: namedAttr.attr)) {
2573 canGenerate = true;
2574 break;
2575 }
2576 }
2577 return canGenerate;
2578}
2579
2580static bool canInferType(const Operator &op) {
2581 return op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait");
2582}
2583
2584void OpEmitter::genInlineCreateBody(
2585 const SmallVector<MethodParameter> &paramList) {
2586 SmallVector<MethodParameter> createParamListOpBuilder;
2587 SmallVector<MethodParameter> createParamListImplicitLocOpBuilder;
2588 SmallVector<llvm::StringRef, 4> nonBuilderStateArgsList;
2589 createParamListOpBuilder.emplace_back(Args: "::mlir::OpBuilder &", Args: "builder");
2590 createParamListImplicitLocOpBuilder.emplace_back(
2591 Args: "::mlir::ImplicitLocOpBuilder &", Args: "builder");
2592 std::string locParamName = "location";
2593 while (llvm::find_if(Range: paramList, P: [&locParamName](const MethodParameter &p) {
2594 return p.getName() == locParamName;
2595 }) != paramList.end()) {
2596 locParamName += "_";
2597 }
2598 createParamListOpBuilder.emplace_back(Args: "::mlir::Location", Args&: locParamName);
2599
2600 for (auto &param : paramList) {
2601 if (param.getType() == "::mlir::OpBuilder &" ||
2602 param.getType() == "::mlir::OperationState &")
2603 continue;
2604 createParamListOpBuilder.emplace_back(Args: param.getType(), Args: param.getName(),
2605 Args: param.getDefaultValue(),
2606 Args: param.isOptional());
2607 createParamListImplicitLocOpBuilder.emplace_back(
2608 Args: param.getType(), Args: param.getName(), Args: param.getDefaultValue(),
2609 Args: param.isOptional());
2610 nonBuilderStateArgsList.push_back(Elt: param.getName());
2611 }
2612 auto *cWithLoc = opClass.addStaticMethod(retType: opClass.getClassName(), name: "create",
2613 args&: createParamListOpBuilder);
2614 auto *cImplicitLoc = opClass.addStaticMethod(
2615 retType: opClass.getClassName(), name: "create", args&: createParamListImplicitLocOpBuilder);
2616 std::string nonBuilderStateArgs = "";
2617 if (!nonBuilderStateArgsList.empty()) {
2618 llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs);
2619 interleaveComma(c: nonBuilderStateArgsList, os&: nonBuilderStateArgsOS);
2620 nonBuilderStateArgs = ", " + nonBuilderStateArgs;
2621 }
2622 cWithLoc->body() << llvm::formatv(Fmt: inlineCreateBody, Vals&: locParamName,
2623 Vals&: nonBuilderStateArgs,
2624 Vals: opClass.getClassName());
2625 cImplicitLoc->body() << llvm::formatv(Fmt: inlineCreateBodyImplicitLoc,
2626 Vals&: nonBuilderStateArgs);
2627}
2628
2629void OpEmitter::genSeparateArgParamBuilder() {
2630 SmallVector<AttrParamKind, 2> attrBuilderType;
2631 attrBuilderType.push_back(Elt: AttrParamKind::WrappedAttr);
2632 if (canGenerateUnwrappedBuilder(op))
2633 attrBuilderType.push_back(Elt: AttrParamKind::UnwrappedValue);
2634
2635 // Emit with separate builders with or without unwrapped attributes and/or
2636 // inferring result type.
2637 auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
2638 bool inferType) {
2639 SmallVector<MethodParameter> paramList;
2640 SmallVector<std::string, 4> resultNames;
2641 llvm::StringSet<> inferredAttributes;
2642 buildParamList(paramList, inferredAttributes, resultTypeNames&: resultNames, typeParamKind: paramKind,
2643 attrParamKind: attrType);
2644
2645 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args&: paramList);
2646 // If the builder is redundant, skip generating the method.
2647 if (!m)
2648 return;
2649 genInlineCreateBody(paramList);
2650
2651 auto &body = m->body();
2652 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
2653 /*isRawValueAttr=*/attrType ==
2654 AttrParamKind::UnwrappedValue);
2655
2656 // Push all result types to the operation state
2657
2658 if (inferType) {
2659 // Generate builder that infers type too.
2660 // TODO: Subsume this with general checking if type can be
2661 // inferred automatically.
2662 body << formatv(Fmt: R"(
2663 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2664 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2665 {1}.location, {1}.operands,
2666 {1}.attributes.getDictionary({1}.getContext()),
2667 {1}.getRawProperties(),
2668 {1}.regions, inferredReturnTypes)))
2669 {1}.addTypes(inferredReturnTypes);
2670 else
2671 ::mlir::detail::reportFatalInferReturnTypesError({1});
2672 )",
2673 Vals: opClass.getClassName(), Vals: builderOpState);
2674 return;
2675 }
2676
2677 switch (paramKind) {
2678 case TypeParamKind::None:
2679 return;
2680 case TypeParamKind::Separate:
2681 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
2682 if (op.getResult(index: i).isOptional())
2683 body << " if (" << resultNames[i] << ")\n ";
2684 body << " " << builderOpState << ".addTypes(" << resultNames[i]
2685 << ");\n";
2686 }
2687
2688 // Automatically create the 'resultSegmentSizes' attribute using
2689 // the length of the type ranges.
2690 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments")) {
2691 if (op.getDialect().usePropertiesForAttributes()) {
2692 body << " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
2693 } else {
2694 std::string getterName = op.getGetterName(name: resultSegmentAttrName);
2695 body << " " << builderOpState << ".addAttribute(" << getterName
2696 << "AttrName(" << builderOpState << ".name), "
2697 << "odsBuilder.getDenseI32ArrayAttr({";
2698 }
2699 interleaveComma(
2700 c: llvm::seq<int>(Begin: 0, End: op.getNumResults()), os&: body, each_fn: [&](int i) {
2701 const NamedTypeConstraint &result = op.getResult(index: i);
2702 if (!result.isVariableLength()) {
2703 body << "1";
2704 } else if (result.isOptional()) {
2705 body << "(" << resultNames[i] << " ? 1 : 0)";
2706 } else {
2707 // VariadicOfVariadic of results are currently unsupported in
2708 // MLIR, hence it can only be a simple variadic.
2709 // TODO: Add implementation for VariadicOfVariadic results here
2710 // once supported.
2711 assert(result.isVariadic());
2712 body << "static_cast<int32_t>(" << resultNames[i] << ".size())";
2713 }
2714 });
2715 if (op.getDialect().usePropertiesForAttributes()) {
2716 body << "}), " << builderOpState
2717 << ".getOrAddProperties<Properties>()."
2718 "resultSegmentSizes.begin());\n";
2719 } else {
2720 body << "}));\n";
2721 }
2722 }
2723
2724 return;
2725 case TypeParamKind::Collective: {
2726 int numResults = op.getNumResults();
2727 int numVariadicResults = op.getNumVariableLengthResults();
2728 int numNonVariadicResults = numResults - numVariadicResults;
2729 bool hasVariadicResult = numVariadicResults != 0;
2730
2731 // Avoid emitting "resultTypes.size() >= 0u" which is always true.
2732 if (!hasVariadicResult || numNonVariadicResults != 0)
2733 body << " " << "assert(resultTypes.size() "
2734 << (hasVariadicResult ? ">=" : "==") << " "
2735 << numNonVariadicResults
2736 << "u && \"mismatched number of results\");\n";
2737 body << " " << builderOpState << ".addTypes(resultTypes);\n";
2738 }
2739 return;
2740 }
2741 llvm_unreachable("unhandled TypeParamKind");
2742 };
2743
2744 // Some of the build methods generated here may be ambiguous, but TableGen's
2745 // ambiguous function detection will elide those ones.
2746 for (auto attrType : attrBuilderType) {
2747 emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
2748 if (canInferType(op))
2749 emit(attrType, TypeParamKind::None, /*inferType=*/true);
2750 emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
2751 }
2752}
2753
2754void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder(
2755 CollectiveBuilderKind kind) {
2756 int numResults = op.getNumResults();
2757
2758 // Signature
2759 SmallVector<MethodParameter> paramList;
2760 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: "odsBuilder");
2761 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
2762 paramList.emplace_back(Args: "::mlir::ValueRange", Args: "operands");
2763 if (kind == CollectiveBuilderKind::PropStruct)
2764 paramList.emplace_back(Args: "const Properties &", Args: "properties");
2765 // Provide default value for `attributes` when its the last parameter
2766 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2767 StringRef attributesName = kind == CollectiveBuilderKind::PropStruct
2768 ? "discardableAttributes"
2769 : "attributes";
2770 paramList.emplace_back(Args: "::llvm::ArrayRef<::mlir::NamedAttribute>",
2771 Args&: attributesName, Args&: attributesDefaultValue);
2772 if (op.getNumVariadicRegions())
2773 paramList.emplace_back(Args: "unsigned", Args: "numRegions");
2774
2775 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args&: paramList);
2776 // If the builder is redundant, skip generating the method
2777 if (!m)
2778 return;
2779 genInlineCreateBody(paramList);
2780 auto &body = m->body();
2781
2782 // Operands
2783 body << " " << builderOpState << ".addOperands(operands);\n";
2784
2785 if (kind == CollectiveBuilderKind::PropStruct)
2786 body << " " << builderOpState
2787 << ".useProperties(const_cast<Properties&>(properties));\n";
2788 // Attributes
2789 body << " " << builderOpState << ".addAttributes(" << attributesName
2790 << ");\n";
2791
2792 // Create the correct number of regions
2793 if (int numRegions = op.getNumRegions()) {
2794 body << llvm::formatv(
2795 Fmt: " for (unsigned i = 0; i != {0}; ++i)\n",
2796 Vals: (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2797 body << " (void)" << builderOpState << ".addRegion();\n";
2798 }
2799
2800 // Result types
2801 SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
2802 body << " " << builderOpState << ".addTypes({"
2803 << llvm::join(R&: resultTypes, Separator: ", ") << "});\n\n";
2804}
2805
2806void OpEmitter::genPopulateDefaultAttributes() {
2807 // All done if no attributes, except optional ones, have default values.
2808 if (llvm::all_of(Range: op.getAttributes(), P: [](const NamedAttribute &named) {
2809 return !named.attr.hasDefaultValue() || named.attr.isOptional();
2810 }))
2811 return;
2812
2813 if (emitHelper.hasProperties()) {
2814 SmallVector<MethodParameter> paramList;
2815 paramList.emplace_back(Args: "::mlir::OperationName", Args: "opName");
2816 paramList.emplace_back(Args: "Properties &", Args: "properties");
2817 auto *m =
2818 opClass.addStaticMethod(retType: "void", name: "populateDefaultProperties", args&: paramList);
2819 ERROR_IF_PRUNED(m, "populateDefaultProperties", op);
2820 auto &body = m->body();
2821 body.indent();
2822 body << "::mlir::Builder " << odsBuilder << "(opName.getContext());\n";
2823 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2824 auto &attr = namedAttr.attr;
2825 if (!attr.hasDefaultValue() || attr.isOptional())
2826 continue;
2827 StringRef name = namedAttr.name;
2828 FmtContext fctx;
2829 fctx.withBuilder(subst: odsBuilder);
2830 body << "if (!properties." << name << ")\n"
2831 << " properties." << name << " = "
2832 << std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx,
2833 vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx)))
2834 << ";\n";
2835 }
2836 return;
2837 }
2838
2839 SmallVector<MethodParameter> paramList;
2840 paramList.emplace_back(Args: "const ::mlir::OperationName &", Args: "opName");
2841 paramList.emplace_back(Args: "::mlir::NamedAttrList &", Args: "attributes");
2842 auto *m = opClass.addStaticMethod(retType: "void", name: "populateDefaultAttrs", args&: paramList);
2843 ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
2844 auto &body = m->body();
2845 body.indent();
2846
2847 // Set default attributes that are unset.
2848 body << "auto attrNames = opName.getAttributeNames();\n";
2849 body << "::mlir::Builder " << odsBuilder
2850 << "(attrNames.front().getContext());\n";
2851 StringMap<int> attrIndex;
2852 for (const auto &it : llvm::enumerate(First: emitHelper.getAttrMetadata())) {
2853 attrIndex[it.value().first] = it.index();
2854 }
2855 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2856 auto &attr = namedAttr.attr;
2857 if (!attr.hasDefaultValue() || attr.isOptional())
2858 continue;
2859 auto index = attrIndex[namedAttr.name];
2860 body << "if (!attributes.get(attrNames[" << index << "])) {\n";
2861 FmtContext fctx;
2862 fctx.withBuilder(subst: odsBuilder);
2863
2864 std::string defaultValue =
2865 std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx,
2866 vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx)));
2867 body.indent() << formatv(Fmt: "attributes.append(attrNames[{0}], {1});\n", Vals&: index,
2868 Vals&: defaultValue);
2869 body.unindent() << "}\n";
2870 }
2871}
2872
2873void OpEmitter::genInferredTypeCollectiveParamBuilder(
2874 CollectiveBuilderKind kind) {
2875 SmallVector<MethodParameter> paramList;
2876 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: "odsBuilder");
2877 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
2878 paramList.emplace_back(Args: "::mlir::ValueRange", Args: "operands");
2879 if (kind == CollectiveBuilderKind::PropStruct)
2880 paramList.emplace_back(Args: "const Properties &", Args: "properties");
2881 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2882 StringRef attributesName = kind == CollectiveBuilderKind::PropStruct
2883 ? "discardableAttributes"
2884 : "attributes";
2885 paramList.emplace_back(Args: "::llvm::ArrayRef<::mlir::NamedAttribute>",
2886 Args&: attributesName, Args&: attributesDefaultValue);
2887 if (op.getNumVariadicRegions())
2888 paramList.emplace_back(Args: "unsigned", Args: "numRegions");
2889
2890 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args&: paramList);
2891 // If the builder is redundant, skip generating the method
2892 if (!m)
2893 return;
2894 genInlineCreateBody(paramList);
2895 auto &body = m->body();
2896
2897 int numResults = op.getNumResults();
2898 int numVariadicResults = op.getNumVariableLengthResults();
2899 int numNonVariadicResults = numResults - numVariadicResults;
2900
2901 int numOperands = op.getNumOperands();
2902 int numVariadicOperands = op.getNumVariableLengthOperands();
2903 int numNonVariadicOperands = numOperands - numVariadicOperands;
2904
2905 // Operands
2906 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
2907 body << " assert(operands.size()"
2908 << (numVariadicOperands != 0 ? " >= " : " == ")
2909 << numNonVariadicOperands
2910 << "u && \"mismatched number of parameters\");\n";
2911 body << " " << builderOpState << ".addOperands(operands);\n";
2912 if (kind == CollectiveBuilderKind::PropStruct)
2913 body << " " << builderOpState
2914 << ".useProperties(const_cast<Properties &>(properties));\n";
2915 body << " " << builderOpState << ".addAttributes(" << attributesName
2916 << ");\n";
2917
2918 // Create the correct number of regions
2919 if (int numRegions = op.getNumRegions()) {
2920 body << llvm::formatv(
2921 Fmt: " for (unsigned i = 0; i != {0}; ++i)\n",
2922 Vals: (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2923 body << " (void)" << builderOpState << ".addRegion();\n";
2924 }
2925
2926 // Result types
2927 if (emitHelper.hasNonEmptyPropertiesStruct() &&
2928 kind == CollectiveBuilderKind::AttrDict) {
2929 // Initialize the properties from Attributes before invoking the infer
2930 // function.
2931 body << formatv(Fmt: R"(
2932 if (!attributes.empty()) {
2933 ::mlir::OpaqueProperties properties =
2934 &{1}.getOrAddProperties<{0}::Properties>();
2935 std::optional<::mlir::RegisteredOperationName> info =
2936 {1}.name.getRegisteredInfo();
2937 if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
2938 {1}.attributes.getDictionary({1}.getContext()), nullptr)))
2939 ::llvm::report_fatal_error("Property conversion failed.");
2940 })",
2941 Vals: opClass.getClassName(), Vals: builderOpState);
2942 }
2943 body << formatv(Fmt: R"(
2944 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2945 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2946 {1}.location, operands,
2947 {1}.attributes.getDictionary({1}.getContext()),
2948 {1}.getRawProperties(),
2949 {1}.regions, inferredReturnTypes))) {{)",
2950 Vals: opClass.getClassName(), Vals: builderOpState);
2951 if (numVariadicResults == 0 || numNonVariadicResults != 0)
2952 body << "\n assert(inferredReturnTypes.size()"
2953 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
2954 << "u && \"mismatched number of return types\");";
2955 body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);";
2956
2957 body << R"(
2958 } else {
2959 ::llvm::report_fatal_error("Failed to infer result type(s).");
2960 })";
2961}
2962
2963void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
2964 auto emit = [&](AttrParamKind attrType) {
2965 SmallVector<MethodParameter> paramList;
2966 SmallVector<std::string, 4> resultNames;
2967 llvm::StringSet<> inferredAttributes;
2968 buildParamList(paramList, inferredAttributes, resultTypeNames&: resultNames,
2969 typeParamKind: TypeParamKind::None, attrParamKind: attrType);
2970
2971 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args&: paramList);
2972 // If the builder is redundant, skip generating the method
2973 if (!m)
2974 return;
2975 genInlineCreateBody(paramList);
2976 auto &body = m->body();
2977 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
2978 /*isRawValueAttr=*/attrType ==
2979 AttrParamKind::UnwrappedValue);
2980
2981 auto numResults = op.getNumResults();
2982 if (numResults == 0)
2983 return;
2984
2985 // Push all result types to the operation state
2986 const char *index = op.getOperand(index: 0).isVariadic() ? ".front()" : "";
2987 std::string resultType =
2988 formatv(Fmt: "{0}{1}.getType()", Vals: getArgumentName(op, index: 0), Vals&: index).str();
2989 body << " " << builderOpState << ".addTypes({" << resultType;
2990 for (int i = 1; i != numResults; ++i)
2991 body << ", " << resultType;
2992 body << "});\n\n";
2993 };
2994
2995 emit(AttrParamKind::WrappedAttr);
2996 // Generate additional builder(s) if any attributes can be "unwrapped"
2997 if (canGenerateUnwrappedBuilder(op))
2998 emit(AttrParamKind::UnwrappedValue);
2999}
3000
3001void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(
3002 CollectiveBuilderKind kind) {
3003 SmallVector<MethodParameter> paramList;
3004 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: "odsBuilder");
3005 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
3006 paramList.emplace_back(Args: "::mlir::ValueRange", Args: "operands");
3007 if (kind == CollectiveBuilderKind::PropStruct)
3008 paramList.emplace_back(Args: "const Properties &", Args: "properties");
3009 StringRef attributesName = kind == CollectiveBuilderKind::PropStruct
3010 ? "discardableAttributes"
3011 : "attributes";
3012 paramList.emplace_back(Args: "::llvm::ArrayRef<::mlir::NamedAttribute>",
3013 Args&: attributesName, Args: "{}");
3014 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args&: paramList);
3015 // If the builder is redundant, skip generating the method
3016 if (!m)
3017 return;
3018 genInlineCreateBody(paramList);
3019
3020 auto &body = m->body();
3021
3022 // Push all result types to the operation state
3023 std::string resultType;
3024 const auto &namedAttr = op.getAttribute(index: 0);
3025
3026 if (namedAttr.attr.isTypeAttr()) {
3027 resultType = "::llvm::cast<::mlir::TypeAttr>(typeAttr).getValue()";
3028 } else {
3029 resultType = "::llvm::cast<::mlir::TypedAttr>(typeAttr).getType()";
3030 }
3031
3032 if (kind == CollectiveBuilderKind::PropStruct) {
3033 body << " ::mlir::Attribute typeAttr = properties."
3034 << op.getGetterName(name: namedAttr.name) << "();\n";
3035 } else {
3036 body << " ::mlir::Attribute typeAttr;\n"
3037 << " auto attrName = " << op.getGetterName(name: namedAttr.name)
3038 << "AttrName(" << builderOpState
3039 << ".name);\n"
3040 " for (auto attr : attributes) {\n"
3041 " if (attr.getName() == attrName) {\n"
3042 " typeAttr = attr.getValue();\n"
3043 " break;\n"
3044 " }\n"
3045 " }\n";
3046 }
3047
3048 // Operands
3049 body << " " << builderOpState << ".addOperands(operands);\n";
3050
3051 // Properties
3052 if (kind == CollectiveBuilderKind::PropStruct)
3053 body << " " << builderOpState
3054 << ".useProperties(const_cast<Properties&>(properties));\n";
3055
3056 // Attributes
3057 body << " " << builderOpState << ".addAttributes(" << attributesName
3058 << ");\n";
3059
3060 // Result types
3061 SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
3062 body << " " << builderOpState << ".addTypes({"
3063 << llvm::join(R&: resultTypes, Separator: ", ") << "});\n";
3064}
3065
3066/// Returns a signature of the builder. Updates the context `fctx` to enable
3067/// replacement of $_builder and $_state in the body.
3068static SmallVector<MethodParameter>
3069getBuilderSignature(const Builder &builder) {
3070 ArrayRef<Builder::Parameter> params(builder.getParameters());
3071
3072 // Inject builder and state arguments.
3073 SmallVector<MethodParameter> arguments;
3074 arguments.reserve(N: params.size() + 2);
3075 arguments.emplace_back(Args: "::mlir::OpBuilder &", Args: odsBuilder);
3076 arguments.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
3077
3078 FmtContext fctx;
3079 fctx.withBuilder(subst: odsBuilder);
3080
3081 for (unsigned i = 0, e = params.size(); i < e; ++i) {
3082 // If no name is provided, generate one.
3083 std::optional<StringRef> paramName = params[i].getName();
3084 std::string name =
3085 paramName ? paramName->str() : "odsArg" + std::to_string(val: i);
3086
3087 StringRef defaultValue;
3088 if (std::optional<StringRef> defaultParamValue =
3089 params[i].getDefaultValue())
3090 defaultValue = *defaultParamValue;
3091
3092 arguments.emplace_back(Args: params[i].getCppType(), Args: std::move(name),
3093 Args: tgfmt(fmt: defaultValue, ctx: &fctx));
3094 }
3095
3096 return arguments;
3097}
3098
3099void OpEmitter::genBuilder() {
3100 // Handle custom builders if provided.
3101 for (const Builder &builder : op.getBuilders()) {
3102 SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
3103
3104 std::optional<StringRef> body = builder.getBody();
3105 auto properties = body ? Method::Static : Method::StaticDeclaration;
3106 auto *method = opClass.addMethod(retType: "void", name: "build", properties, args&: arguments);
3107 if (body)
3108 ERROR_IF_PRUNED(method, "build", op);
3109
3110 if (method)
3111 method->setDeprecated(builder.getDeprecatedMessage());
3112
3113 FmtContext fctx;
3114 fctx.withBuilder(subst: odsBuilder);
3115 fctx.addSubst(placeholder: "_state", subst: builderOpState);
3116 if (body)
3117 method->body() << tgfmt(fmt: *body, ctx: &fctx);
3118 genInlineCreateBody(paramList: arguments);
3119 }
3120
3121 // Generate default builders that requires all result type, operands, and
3122 // attributes as parameters.
3123 if (op.skipDefaultBuilders())
3124 return;
3125
3126 // We generate three classes of builders here:
3127 // 1. one having a stand-alone parameter for each operand / attribute, and
3128 genSeparateArgParamBuilder();
3129 // 2. one having an aggregated parameter for all result types / operands /
3130 // [properties / discardable] attributes, and
3131 genCollectiveParamBuilder(kind: CollectiveBuilderKind::AttrDict);
3132 if (emitHelper.hasProperties())
3133 genCollectiveParamBuilder(kind: CollectiveBuilderKind::PropStruct);
3134 // 3. one having a stand-alone parameter for each operand and attribute,
3135 // use the first operand or attribute's type as all result types
3136 // to facilitate different call patterns.
3137 if (op.getNumVariableLengthResults() == 0) {
3138 if (op.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType")) {
3139 genUseOperandAsResultTypeSeparateParamBuilder();
3140 genUseOperandAsResultTypeCollectiveParamBuilder(
3141 kind: CollectiveBuilderKind::AttrDict);
3142 if (emitHelper.hasProperties())
3143 genUseOperandAsResultTypeCollectiveParamBuilder(
3144 kind: CollectiveBuilderKind::PropStruct);
3145 }
3146 if (op.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType")) {
3147 genUseAttrAsResultTypeCollectiveParamBuilder(
3148 kind: CollectiveBuilderKind::AttrDict);
3149 genUseAttrAsResultTypeCollectiveParamBuilder(
3150 kind: CollectiveBuilderKind::PropStruct);
3151 }
3152 }
3153}
3154
3155void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) {
3156 int numResults = op.getNumResults();
3157 int numVariadicResults = op.getNumVariableLengthResults();
3158 int numNonVariadicResults = numResults - numVariadicResults;
3159
3160 int numOperands = op.getNumOperands();
3161 int numVariadicOperands = op.getNumVariableLengthOperands();
3162 int numNonVariadicOperands = numOperands - numVariadicOperands;
3163
3164 SmallVector<MethodParameter> paramList;
3165 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: "");
3166 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
3167 paramList.emplace_back(Args: "::mlir::TypeRange", Args: "resultTypes");
3168 paramList.emplace_back(Args: "::mlir::ValueRange", Args: "operands");
3169 if (kind == CollectiveBuilderKind::PropStruct)
3170 paramList.emplace_back(Args: "const Properties &", Args: "properties");
3171 // Provide default value for `attributes` when its the last parameter
3172 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
3173 StringRef attributesName = kind == CollectiveBuilderKind::PropStruct
3174 ? "discardableAttributes"
3175 : "attributes";
3176 paramList.emplace_back(Args: "::llvm::ArrayRef<::mlir::NamedAttribute>",
3177 Args&: attributesName, Args&: attributesDefaultValue);
3178 if (op.getNumVariadicRegions())
3179 paramList.emplace_back(Args: "unsigned", Args: "numRegions");
3180
3181 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args&: paramList);
3182 // If the builder is redundant, skip generating the method
3183 if (!m)
3184 return;
3185 genInlineCreateBody(paramList);
3186 auto &body = m->body();
3187
3188 // Operands
3189 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
3190 body << " assert(operands.size()"
3191 << (numVariadicOperands != 0 ? " >= " : " == ")
3192 << numNonVariadicOperands
3193 << "u && \"mismatched number of parameters\");\n";
3194 body << " " << builderOpState << ".addOperands(operands);\n";
3195
3196 // Properties
3197 if (kind == CollectiveBuilderKind::PropStruct)
3198 body << " " << builderOpState
3199 << ".useProperties(const_cast<Properties&>(properties));\n";
3200
3201 // Attributes
3202 body << " " << builderOpState << ".addAttributes(" << attributesName
3203 << ");\n";
3204
3205 // Create the correct number of regions
3206 if (int numRegions = op.getNumRegions()) {
3207 body << llvm::formatv(
3208 Fmt: " for (unsigned i = 0; i != {0}; ++i)\n",
3209 Vals: (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
3210 body << " (void)" << builderOpState << ".addRegion();\n";
3211 }
3212
3213 // Result types
3214 if (numVariadicResults == 0 || numNonVariadicResults != 0)
3215 body << " assert(resultTypes.size()"
3216 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
3217 << "u && \"mismatched number of return types\");\n";
3218 body << " " << builderOpState << ".addTypes(resultTypes);\n";
3219
3220 if (emitHelper.hasNonEmptyPropertiesStruct() &&
3221 kind == CollectiveBuilderKind::AttrDict) {
3222 // Initialize the properties from Attributes before invoking the infer
3223 // function.
3224 body << formatv(Fmt: R"(
3225 if (!attributes.empty()) {
3226 ::mlir::OpaqueProperties properties =
3227 &{1}.getOrAddProperties<{0}::Properties>();
3228 std::optional<::mlir::RegisteredOperationName> info =
3229 {1}.name.getRegisteredInfo();
3230 if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
3231 {1}.attributes.getDictionary({1}.getContext()), nullptr)))
3232 ::llvm::report_fatal_error("Property conversion failed.");
3233 })",
3234 Vals: opClass.getClassName(), Vals: builderOpState);
3235 }
3236
3237 // Generate builder that infers type too.
3238 // TODO: Expand to handle successors.
3239 if (canInferType(op) && op.getNumSuccessors() == 0)
3240 genInferredTypeCollectiveParamBuilder(kind);
3241}
3242
3243void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
3244 llvm::StringSet<> &inferredAttributes,
3245 SmallVectorImpl<std::string> &resultTypeNames,
3246 TypeParamKind typeParamKind,
3247 AttrParamKind attrParamKind) {
3248 resultTypeNames.clear();
3249 auto numResults = op.getNumResults();
3250 resultTypeNames.reserve(N: numResults);
3251
3252 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: odsBuilder);
3253 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
3254
3255 switch (typeParamKind) {
3256 case TypeParamKind::None:
3257 break;
3258 case TypeParamKind::Separate: {
3259 // Add parameters for all return types
3260 for (int i = 0; i < numResults; ++i) {
3261 const auto &result = op.getResult(index: i);
3262 std::string resultName = std::string(result.name);
3263 if (resultName.empty())
3264 resultName = std::string(formatv(Fmt: "resultType{0}", Vals&: i));
3265
3266 StringRef type =
3267 result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
3268
3269 paramList.emplace_back(Args&: type, Args&: resultName, Args: result.isOptional());
3270 resultTypeNames.emplace_back(Args: std::move(resultName));
3271 }
3272 } break;
3273 case TypeParamKind::Collective: {
3274 paramList.emplace_back(Args: "::mlir::TypeRange", Args: "resultTypes");
3275 resultTypeNames.push_back(Elt: "resultTypes");
3276 } break;
3277 }
3278
3279 // Add parameters for all arguments (operands and attributes).
3280 // Track "attr-like" (property and attribute) optional values separate from
3281 // attributes themselves so that the disambiguation code can look at the first
3282 // attribute specifically when determining where to trim the optional-value
3283 // list to avoid ambiguity while preserving the ability of all-property ops to
3284 // use default parameters.
3285 int defaultValuedAttrLikeStartIndex = op.getNumArgs();
3286 int defaultValuedAttrStartIndex = op.getNumArgs();
3287 // Successors and variadic regions go at the end of the parameter list, so no
3288 // default arguments are possible.
3289 bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions();
3290 if (!hasTrailingParams) {
3291 // Calculate the start index from which we can attach default values in the
3292 // builder declaration.
3293 for (int i = op.getNumArgs() - 1; i >= 0; --i) {
3294 auto *namedAttr =
3295 llvm::dyn_cast_if_present<tblgen::NamedAttribute *>(Val: op.getArg(index: i));
3296 auto *namedProperty =
3297 llvm::dyn_cast_if_present<tblgen::NamedProperty *>(Val: op.getArg(index: i));
3298 if (namedProperty) {
3299 Property prop = namedProperty->prop;
3300 if (!prop.hasDefaultValue())
3301 break;
3302 defaultValuedAttrLikeStartIndex = i;
3303 continue;
3304 }
3305 if (!namedAttr)
3306 break;
3307
3308 Attribute attr = namedAttr->attr;
3309 // TODO: Currently we can't differentiate between optional meaning do not
3310 // verify/not always error if missing or optional meaning need not be
3311 // specified in builder. Expand isOptional once we can differentiate.
3312 if (!attr.hasDefaultValue() && !attr.isDerivedAttr())
3313 break;
3314
3315 // Creating an APInt requires us to provide bitwidth, value, and
3316 // signedness, which is complicated compared to others. Similarly
3317 // for APFloat.
3318 // TODO: Adjust the 'returnType' field of such attributes
3319 // to support them.
3320 StringRef retType = namedAttr->attr.getReturnType();
3321 if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
3322 break;
3323
3324 defaultValuedAttrLikeStartIndex = i;
3325 defaultValuedAttrStartIndex = i;
3326 }
3327 }
3328
3329 // Check if parameters besides default valued one are enough to distinguish
3330 // between builders with wrapped and unwrapped arguments.
3331 bool hasBuilderAmbiguity = true;
3332 for (const auto &arg : op.getArgs()) {
3333 auto *namedAttr = dyn_cast<NamedAttribute *>(Val: arg);
3334 if (!namedAttr)
3335 continue;
3336 Attribute attr = namedAttr->attr;
3337 if (attr.hasDefaultValue() || attr.isDerivedAttr())
3338 continue;
3339
3340 if (attrParamKind != AttrParamKind::WrappedAttr ||
3341 !canUseUnwrappedRawValue(attr))
3342 continue;
3343
3344 hasBuilderAmbiguity = false;
3345 break;
3346 }
3347
3348 // Avoid generating build methods that are ambiguous due to default values by
3349 // requiring at least one attribute.
3350 if (defaultValuedAttrStartIndex < op.getNumArgs()) {
3351 // TODO: This should have been possible as a cast<NamedAttribute> but
3352 // required template instantiations is not yet defined for the tblgen helper
3353 // classes.
3354 auto *namedAttr =
3355 cast<NamedAttribute *>(Val: op.getArg(index: defaultValuedAttrStartIndex));
3356 Attribute attr = namedAttr->attr;
3357 if ((attrParamKind == AttrParamKind::WrappedAttr &&
3358 canUseUnwrappedRawValue(attr) && hasBuilderAmbiguity) ||
3359 (attrParamKind == AttrParamKind::UnwrappedValue &&
3360 !canUseUnwrappedRawValue(attr) && hasBuilderAmbiguity)) {
3361 ++defaultValuedAttrStartIndex;
3362 defaultValuedAttrLikeStartIndex = defaultValuedAttrStartIndex;
3363 }
3364 }
3365
3366 /// Collect any inferred attributes.
3367 for (const NamedTypeConstraint &operand : op.getOperands()) {
3368 if (operand.isVariadicOfVariadic()) {
3369 inferredAttributes.insert(
3370 key: operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
3371 }
3372 }
3373
3374 FmtContext fctx;
3375 fctx.withBuilder(subst: odsBuilder);
3376
3377 for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
3378 Argument arg = op.getArg(index: i);
3379 if (const auto *operand =
3380 llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: arg)) {
3381 StringRef type;
3382 if (operand->isVariadicOfVariadic())
3383 type = "::llvm::ArrayRef<::mlir::ValueRange>";
3384 else if (operand->isVariadic())
3385 type = "::mlir::ValueRange";
3386 else
3387 type = "::mlir::Value";
3388
3389 paramList.emplace_back(Args&: type, Args: getArgumentName(op, index: numOperands++),
3390 Args: operand->isOptional());
3391 continue;
3392 }
3393 if (auto *propArg = llvm::dyn_cast_if_present<NamedProperty *>(Val&: arg)) {
3394 const Property &prop = propArg->prop;
3395 StringRef type = prop.getInterfaceType();
3396 std::string defaultValue;
3397 if (prop.hasDefaultValue() && i >= defaultValuedAttrLikeStartIndex) {
3398 defaultValue = tgfmt(fmt: prop.getDefaultValue(), ctx: &fctx);
3399 }
3400 bool isOptional = prop.hasDefaultValue();
3401 paramList.emplace_back(Args&: type, Args&: propArg->name, Args: StringRef(defaultValue),
3402 Args&: isOptional);
3403 continue;
3404 }
3405 const NamedAttribute &namedAttr = *cast<NamedAttribute *>(Val&: arg);
3406 const Attribute &attr = namedAttr.attr;
3407
3408 // Inferred attributes don't need to be added to the param list.
3409 if (inferredAttributes.contains(key: namedAttr.name))
3410 continue;
3411
3412 StringRef type;
3413 switch (attrParamKind) {
3414 case AttrParamKind::WrappedAttr:
3415 type = attr.getStorageType();
3416 break;
3417 case AttrParamKind::UnwrappedValue:
3418 if (canUseUnwrappedRawValue(attr))
3419 type = attr.getReturnType();
3420 else
3421 type = attr.getStorageType();
3422 break;
3423 }
3424
3425 // Attach default value if requested and possible.
3426 std::string defaultValue;
3427 if (i >= defaultValuedAttrStartIndex) {
3428 if (attrParamKind == AttrParamKind::UnwrappedValue &&
3429 canUseUnwrappedRawValue(attr))
3430 defaultValue += tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx);
3431 else
3432 defaultValue += "nullptr";
3433 }
3434 paramList.emplace_back(Args&: type, Args: namedAttr.name, Args: StringRef(defaultValue),
3435 Args: attr.isOptional());
3436 }
3437
3438 /// Insert parameters for each successor.
3439 for (const NamedSuccessor &succ : op.getSuccessors()) {
3440 StringRef type =
3441 succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
3442 paramList.emplace_back(Args&: type, Args: succ.name);
3443 }
3444
3445 /// Insert parameters for variadic regions.
3446 for (const NamedRegion &region : op.getRegions())
3447 if (region.isVariadic())
3448 paramList.emplace_back(Args: "unsigned",
3449 Args: llvm::formatv(Fmt: "{0}Count", Vals: region.name).str());
3450}
3451
3452void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
3453 MethodBody &body, llvm::StringSet<> &inferredAttributes,
3454 bool isRawValueAttr) {
3455 // Push all operands to the result.
3456 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
3457 std::string argName = getArgumentName(op, index: i);
3458 const NamedTypeConstraint &operand = op.getOperand(index: i);
3459 if (operand.constraint.isVariadicOfVariadic()) {
3460 body << " for (::mlir::ValueRange range : " << argName << ")\n "
3461 << builderOpState << ".addOperands(range);\n";
3462
3463 // Add the segment attribute.
3464 body << " {\n"
3465 << " ::llvm::SmallVector<int32_t> rangeSegments;\n"
3466 << " for (::mlir::ValueRange range : " << argName << ")\n"
3467 << " rangeSegments.push_back(range.size());\n"
3468 << " auto rangeAttr = " << odsBuilder
3469 << ".getDenseI32ArrayAttr(rangeSegments);\n";
3470 if (op.getDialect().usePropertiesForAttributes()) {
3471 body << " " << builderOpState << ".getOrAddProperties<Properties>()."
3472 << operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
3473 << " = rangeAttr;";
3474 } else {
3475 body << " " << builderOpState << ".addAttribute("
3476 << op.getGetterName(
3477 name: operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
3478 << "AttrName(" << builderOpState << ".name), rangeAttr);";
3479 }
3480 body << " }\n";
3481 continue;
3482 }
3483
3484 if (operand.isOptional())
3485 body << " if (" << argName << ")\n ";
3486 body << " " << builderOpState << ".addOperands(" << argName << ");\n";
3487 }
3488
3489 // If the operation has the operand segment size attribute, add it here.
3490 auto emitSegment = [&]() {
3491 interleaveComma(c: llvm::seq<int>(Begin: 0, End: op.getNumOperands()), os&: body, each_fn: [&](int i) {
3492 const NamedTypeConstraint &operand = op.getOperand(index: i);
3493 if (!operand.isVariableLength()) {
3494 body << "1";
3495 return;
3496 }
3497
3498 std::string operandName = getArgumentName(op, index: i);
3499 if (operand.isOptional()) {
3500 body << "(" << operandName << " ? 1 : 0)";
3501 } else if (operand.isVariadicOfVariadic()) {
3502 body << llvm::formatv(
3503 Fmt: "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
3504 "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
3505 "static_cast<int32_t>(range.size()); }))",
3506 Vals&: operandName);
3507 } else {
3508 body << "static_cast<int32_t>(" << getArgumentName(op, index: i) << ".size())";
3509 }
3510 });
3511 };
3512 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
3513 std::string sizes = op.getGetterName(name: operandSegmentAttrName);
3514 if (op.getDialect().usePropertiesForAttributes()) {
3515 body << " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
3516 emitSegment();
3517 body << "}), " << builderOpState
3518 << ".getOrAddProperties<Properties>()."
3519 "operandSegmentSizes.begin());\n";
3520 } else {
3521 body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
3522 << builderOpState << ".name), "
3523 << "odsBuilder.getDenseI32ArrayAttr({";
3524 emitSegment();
3525 body << "}));\n";
3526 }
3527 }
3528
3529 // Push all properties to the result.
3530 for (const auto &namedProp : op.getProperties()) {
3531 // Use the setter from the Properties struct since the conversion from the
3532 // interface type (used in the builder argument) to the storage type (used
3533 // in the state) is not necessarily trivial.
3534 std::string setterName = op.getSetterName(name: namedProp.name);
3535 body << formatv(Fmt: " {0}.getOrAddProperties<Properties>().{1}({2});\n",
3536 Vals: builderOpState, Vals&: setterName, Vals: namedProp.name);
3537 }
3538 // Push all attributes to the result.
3539 for (const auto &namedAttr : op.getAttributes()) {
3540 auto &attr = namedAttr.attr;
3541 if (attr.isDerivedAttr() || inferredAttributes.contains(key: namedAttr.name))
3542 continue;
3543
3544 // TODO: The wrapping of optional is different for default or not, so don't
3545 // unwrap for default ones that would fail below.
3546 bool emitNotNullCheck =
3547 (attr.isOptional() && !attr.hasDefaultValue()) ||
3548 (attr.hasDefaultValue() && !isRawValueAttr) ||
3549 // TODO: UnitAttr is optional, not wrapped, but needs to be guarded as
3550 // the constant materialization is only for true case.
3551 (isRawValueAttr && attr.getAttrDefName() == "UnitAttr");
3552 if (emitNotNullCheck)
3553 body.indent() << formatv(Fmt: "if ({0}) ", Vals: namedAttr.name) << "{\n";
3554
3555 if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
3556 // If this is a raw value, then we need to wrap it in an Attribute
3557 // instance.
3558 FmtContext fctx;
3559 fctx.withBuilder(subst: "odsBuilder");
3560 if (op.getDialect().usePropertiesForAttributes()) {
3561 body << formatv(Fmt: " {0}.getOrAddProperties<Properties>().{1} = {2};\n",
3562 Vals: builderOpState, Vals: namedAttr.name,
3563 Vals: constBuildAttrFromParam(attr, fctx, paramName: namedAttr.name));
3564 } else {
3565 body << formatv(Fmt: " {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3566 Vals: builderOpState, Vals: op.getGetterName(name: namedAttr.name),
3567 Vals: constBuildAttrFromParam(attr, fctx, paramName: namedAttr.name));
3568 }
3569 } else {
3570 if (op.getDialect().usePropertiesForAttributes()) {
3571 body << formatv(Fmt: " {0}.getOrAddProperties<Properties>().{1} = {1};\n",
3572 Vals: builderOpState, Vals: namedAttr.name);
3573 } else {
3574 body << formatv(Fmt: " {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3575 Vals: builderOpState, Vals: op.getGetterName(name: namedAttr.name),
3576 Vals: namedAttr.name);
3577 }
3578 }
3579 if (emitNotNullCheck)
3580 body.unindent() << " }\n";
3581 }
3582
3583 // Create the correct number of regions.
3584 for (const NamedRegion &region : op.getRegions()) {
3585 if (region.isVariadic())
3586 body << formatv(Fmt: " for (unsigned i = 0; i < {0}Count; ++i)\n ",
3587 Vals: region.name);
3588
3589 body << " (void)" << builderOpState << ".addRegion();\n";
3590 }
3591
3592 // Push all successors to the result.
3593 for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
3594 body << formatv(Fmt: " {0}.addSuccessors({1});\n", Vals: builderOpState,
3595 Vals: namedSuccessor.name);
3596 }
3597}
3598
3599void OpEmitter::genCanonicalizerDecls() {
3600 bool hasCanonicalizeMethod = def.getValueAsBit(FieldName: "hasCanonicalizeMethod");
3601 if (hasCanonicalizeMethod) {
3602 // static LogicResult FooOp::
3603 // canonicalize(FooOp op, PatternRewriter &rewriter);
3604 SmallVector<MethodParameter> paramList;
3605 paramList.emplace_back(Args: op.getCppClassName(), Args: "op");
3606 paramList.emplace_back(Args: "::mlir::PatternRewriter &", Args: "rewriter");
3607 auto *m = opClass.declareStaticMethod(retType: "::llvm::LogicalResult",
3608 name: "canonicalize", args: std::move(paramList));
3609 ERROR_IF_PRUNED(m, "canonicalize", op);
3610 }
3611
3612 // We get a prototype for 'getCanonicalizationPatterns' if requested directly
3613 // or if using a 'canonicalize' method.
3614 bool hasCanonicalizer = def.getValueAsBit(FieldName: "hasCanonicalizer");
3615 if (!hasCanonicalizeMethod && !hasCanonicalizer)
3616 return;
3617
3618 // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
3619 // method, but not implementing 'getCanonicalizationPatterns' manually.
3620 bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
3621
3622 // Add a signature for getCanonicalizationPatterns if implemented by the
3623 // dialect or if synthesized to call 'canonicalize'.
3624 SmallVector<MethodParameter> paramList;
3625 paramList.emplace_back(Args: "::mlir::RewritePatternSet &", Args: "results");
3626 paramList.emplace_back(Args: "::mlir::MLIRContext *", Args: "context");
3627 auto kind = hasBody ? Method::Static : Method::StaticDeclaration;
3628 auto *method = opClass.addMethod(retType: "void", name: "getCanonicalizationPatterns", properties: kind,
3629 args: std::move(paramList));
3630
3631 // If synthesizing the method, fill it.
3632 if (hasBody) {
3633 ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op);
3634 method->body() << " results.add(canonicalize);\n";
3635 }
3636}
3637
3638void OpEmitter::genFolderDecls() {
3639 if (!op.hasFolder())
3640 return;
3641
3642 SmallVector<MethodParameter> paramList;
3643 paramList.emplace_back(Args: "FoldAdaptor", Args: "adaptor");
3644
3645 StringRef retType;
3646 bool hasSingleResult =
3647 op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
3648 if (hasSingleResult) {
3649 retType = "::mlir::OpFoldResult";
3650 } else {
3651 paramList.emplace_back(Args: "::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
3652 Args: "results");
3653 retType = "::llvm::LogicalResult";
3654 }
3655
3656 auto *m = opClass.declareMethod(retType, name: "fold", args: std::move(paramList));
3657 ERROR_IF_PRUNED(m, "fold", op);
3658}
3659
3660void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
3661 Interface interface = opTrait->getInterface();
3662
3663 // Get the set of methods that should always be declared.
3664 auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
3665 llvm::StringSet<> alwaysDeclaredMethods;
3666 alwaysDeclaredMethods.insert_range(R&: alwaysDeclaredMethodsVec);
3667
3668 for (const InterfaceMethod &method : interface.getMethods()) {
3669 // Don't declare if the method has a body.
3670 if (method.getBody())
3671 continue;
3672 // Don't declare if the method has a default implementation and the op
3673 // didn't request that it always be declared.
3674 if (method.getDefaultImplementation() &&
3675 !alwaysDeclaredMethods.count(Key: method.getName()))
3676 continue;
3677 // Interface methods are allowed to overlap with existing methods, so don't
3678 // check if pruned.
3679 (void)genOpInterfaceMethod(method);
3680 }
3681}
3682
3683Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
3684 bool declaration) {
3685 SmallVector<MethodParameter> paramList;
3686 for (const InterfaceMethod::Argument &arg : method.getArguments())
3687 paramList.emplace_back(Args: arg.type, Args: arg.name);
3688
3689 auto props = (method.isStatic() ? Method::Static : Method::None) |
3690 (declaration ? Method::Declaration : Method::None);
3691 return opClass.addMethod(retType: method.getReturnType(), name: method.getName(), properties: props,
3692 args: std::move(paramList));
3693}
3694
3695void OpEmitter::genOpInterfaceMethods() {
3696 for (const auto &trait : op.getTraits()) {
3697 if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(Val: &trait))
3698 if (opTrait->shouldDeclareMethods())
3699 genOpInterfaceMethods(opTrait);
3700 }
3701}
3702
3703void OpEmitter::genSideEffectInterfaceMethods() {
3704 enum EffectKind { Operand, Result, Symbol, Static };
3705 struct EffectLocation {
3706 /// The effect applied.
3707 SideEffect effect;
3708
3709 /// The index if the kind is not static.
3710 unsigned index;
3711
3712 /// The kind of the location.
3713 unsigned kind;
3714 };
3715
3716 StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
3717 auto resolveDecorators = [&](Operator::var_decorator_range decorators,
3718 unsigned index, unsigned kind) {
3719 for (auto decorator : decorators)
3720 if (SideEffect *effect = dyn_cast<SideEffect>(Val: &decorator)) {
3721 opClass.addTrait(trait: effect->getInterfaceTrait());
3722 interfaceEffects[effect->getBaseEffectName()].push_back(
3723 Elt: EffectLocation{.effect: *effect, .index: index, .kind: kind});
3724 }
3725 };
3726
3727 // Collect effects that were specified via:
3728 /// Traits.
3729 for (const auto &trait : op.getTraits()) {
3730 const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(Val: &trait);
3731 if (!opTrait)
3732 continue;
3733 auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
3734 for (auto decorator : opTrait->getEffects())
3735 effects.push_back(Elt: EffectLocation{.effect: cast<SideEffect>(Val&: decorator),
3736 /*index=*/0, .kind: EffectKind::Static});
3737 }
3738 /// Attributes and Operands.
3739 for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
3740 Argument arg = op.getArg(index: i);
3741 if (isa<NamedTypeConstraint *>(Val: arg)) {
3742 resolveDecorators(op.getArgDecorators(index: i), operandIt, EffectKind::Operand);
3743 ++operandIt;
3744 continue;
3745 }
3746 if (isa<NamedProperty *>(Val: arg))
3747 continue;
3748 const NamedAttribute *attr = cast<NamedAttribute *>(Val&: arg);
3749 if (attr->attr.getBaseAttr().isSymbolRefAttr())
3750 resolveDecorators(op.getArgDecorators(index: i), i, EffectKind::Symbol);
3751 }
3752 /// Results.
3753 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
3754 resolveDecorators(op.getResultDecorators(index: i), i, EffectKind::Result);
3755
3756 // The code used to add an effect instance.
3757 // {0}: The effect class.
3758 // {1}: Optional value or symbol reference.
3759 // {2}: The side effect stage.
3760 // {3}: Does this side effect act on every single value of resource.
3761 // {4}: The resource class.
3762 const char *addEffectCode =
3763 " effects.emplace_back({0}::get(), {1}{2}, {3}, {4}::get());\n";
3764
3765 for (auto &it : interfaceEffects) {
3766 // Generate the 'getEffects' method.
3767 std::string type = llvm::formatv(Fmt: "::llvm::SmallVectorImpl<::mlir::"
3768 "SideEffects::EffectInstance<{0}>> &",
3769 Vals: it.first())
3770 .str();
3771 auto *getEffects = opClass.addMethod(retType: "void", name: "getEffects",
3772 args: MethodParameter(type, "effects"));
3773 ERROR_IF_PRUNED(getEffects, "getEffects", op);
3774 auto &body = getEffects->body();
3775
3776 // Add effect instances for each of the locations marked on the operation.
3777 for (auto &location : it.second) {
3778 StringRef effect = location.effect.getName();
3779 StringRef resource = location.effect.getResource();
3780 int stage = (int)location.effect.getStage();
3781 bool effectOnFullRegion = (int)location.effect.getEffectOnfullRegion();
3782 if (location.kind == EffectKind::Static) {
3783 // A static instance has no attached value.
3784 body << llvm::formatv(Fmt: addEffectCode, Vals&: effect, Vals: "", Vals&: stage,
3785 Vals&: effectOnFullRegion, Vals&: resource)
3786 .str();
3787 } else if (location.kind == EffectKind::Symbol) {
3788 // A symbol reference requires adding the proper attribute.
3789 const auto *attr = cast<NamedAttribute *>(Val: op.getArg(index: location.index));
3790 std::string argName = op.getGetterName(name: attr->name);
3791 if (attr->attr.isOptional()) {
3792 body << " if (auto symbolRef = " << argName << "Attr())\n "
3793 << llvm::formatv(Fmt: addEffectCode, Vals&: effect, Vals: "symbolRef, ", Vals&: stage,
3794 Vals&: effectOnFullRegion, Vals&: resource)
3795 .str();
3796 } else {
3797 body << llvm::formatv(Fmt: addEffectCode, Vals&: effect, Vals: argName + "Attr(), ",
3798 Vals&: stage, Vals&: effectOnFullRegion, Vals&: resource)
3799 .str();
3800 }
3801 } else {
3802 // Otherwise this is an operand/result, so we need to attach the Value.
3803 body << " {\n auto valueRange = getODS"
3804 << (location.kind == EffectKind::Operand ? "Operand" : "Result")
3805 << "IndexAndLength(" << location.index << ");\n"
3806 << " for (unsigned idx = valueRange.first; idx < "
3807 "valueRange.first"
3808 << " + valueRange.second; idx++) {\n "
3809 << llvm::formatv(Fmt: addEffectCode, Vals&: effect,
3810 Vals: (location.kind == EffectKind::Operand
3811 ? "&getOperation()->getOpOperand(idx), "
3812 : "getOperation()->getOpResult(idx), "),
3813 Vals&: stage, Vals&: effectOnFullRegion, Vals&: resource)
3814 << " }\n }\n";
3815 }
3816 }
3817 }
3818}
3819
3820void OpEmitter::genTypeInterfaceMethods() {
3821 if (!op.allResultTypesKnown())
3822 return;
3823 // Generate 'inferReturnTypes' method declaration using the interface method
3824 // declared in 'InferTypeOpInterface' op interface.
3825 const auto *trait =
3826 cast<InterfaceTrait>(Val: op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait"));
3827 Interface interface = trait->getInterface();
3828 Method *method = [&]() -> Method * {
3829 for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
3830 if (interfaceMethod.getName() == "inferReturnTypes") {
3831 return genOpInterfaceMethod(method: interfaceMethod, /*declaration=*/false);
3832 }
3833 }
3834 assert(0 && "unable to find inferReturnTypes interface method");
3835 return nullptr;
3836 }();
3837 ERROR_IF_PRUNED(method, "inferReturnTypes", op);
3838 auto &body = method->body();
3839 body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
3840
3841 FmtContext fctx;
3842 fctx.withBuilder(subst: "odsBuilder");
3843 fctx.addSubst(placeholder: "_ctxt", subst: "context");
3844 body << " ::mlir::Builder odsBuilder(context);\n";
3845
3846 // Preprocessing stage to verify all accesses to operands are valid.
3847 int maxAccessedIndex = -1;
3848 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
3849 const InferredResultType &infer = op.getInferredResultType(index: i);
3850 if (!infer.isArg())
3851 continue;
3852 Operator::OperandOrAttribute arg =
3853 op.getArgToOperandOrAttribute(index: infer.getIndex());
3854 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3855 maxAccessedIndex =
3856 std::max(a: maxAccessedIndex, b: arg.operandOrAttributeIndex());
3857 }
3858 }
3859 if (maxAccessedIndex != -1) {
3860 body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";
3861 body << " return ::mlir::failure();\n";
3862 }
3863
3864 // Process the type inference graph in topological order, starting from types
3865 // that are always fully-inferred: operands and results with constructible
3866 // types. The type inference graph here will always be a DAG, so this gives
3867 // us the correct order for generating the types. -1 is a placeholder to
3868 // indicate the type for a result has not been generated.
3869 SmallVector<int> constructedIndices(op.getNumResults(), -1);
3870 int inferredTypeIdx = 0;
3871 for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) {
3872 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
3873 if (constructedIndices[i] >= 0)
3874 continue;
3875 const InferredResultType &infer = op.getInferredResultType(index: i);
3876 std::string typeStr;
3877 if (infer.isArg()) {
3878 // If this is an operand, just index into operand list to access the
3879 // type.
3880 Operator::OperandOrAttribute arg =
3881 op.getArgToOperandOrAttribute(index: infer.getIndex());
3882 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3883 typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
3884 "].getType()")
3885 .str();
3886
3887 // If this is an attribute, index into the attribute dictionary.
3888 } else {
3889 auto *attr =
3890 cast<NamedAttribute *>(Val: op.getArg(index: arg.operandOrAttributeIndex()));
3891 body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
3892 << " = ";
3893 if (op.getDialect().usePropertiesForAttributes()) {
3894 body << "(properties ? properties.as<Properties *>()->"
3895 << attr->name
3896 << " : "
3897 "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3898 "get(\"" +
3899 attr->name + "\")));\n";
3900 } else {
3901 body << "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3902 "get(\"" +
3903 attr->name + "\"));\n";
3904 }
3905 body << " if (!odsInferredTypeAttr" << inferredTypeIdx
3906 << ") return ::mlir::failure();\n";
3907 typeStr =
3908 ("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()")
3909 .str();
3910 }
3911 } else if (std::optional<StringRef> builder =
3912 op.getResult(index: infer.getResultIndex())
3913 .constraint.getBuilderCall()) {
3914 typeStr = tgfmt(fmt: *builder, ctx: &fctx).str();
3915 } else if (int index = constructedIndices[infer.getResultIndex()];
3916 index >= 0) {
3917 typeStr = ("odsInferredType" + Twine(index)).str();
3918 } else {
3919 continue;
3920 }
3921 body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
3922 << tgfmt(fmt: infer.getTransformer(), ctx: &fctx.withSelf(subst: typeStr)) << ";\n";
3923 constructedIndices[i] = inferredTypeIdx - 1;
3924 }
3925 }
3926 for (auto [i, index] : llvm::enumerate(First&: constructedIndices))
3927 body << " inferredReturnTypes[" << i << "] = odsInferredType" << index
3928 << ";\n";
3929 body << " return ::mlir::success();";
3930}
3931
3932void OpEmitter::genParser() {
3933 if (hasStringAttribute(record: def, fieldName: "assemblyFormat"))
3934 return;
3935
3936 if (!def.getValueAsBit(FieldName: "hasCustomAssemblyFormat"))
3937 return;
3938
3939 SmallVector<MethodParameter> paramList;
3940 paramList.emplace_back(Args: "::mlir::OpAsmParser &", Args: "parser");
3941 paramList.emplace_back(Args: "::mlir::OperationState &", Args: "result");
3942
3943 auto *method = opClass.declareStaticMethod(retType: "::mlir::ParseResult", name: "parse",
3944 args: std::move(paramList));
3945 ERROR_IF_PRUNED(method, "parse", op);
3946}
3947
3948void OpEmitter::genPrinter() {
3949 if (hasStringAttribute(record: def, fieldName: "assemblyFormat"))
3950 return;
3951
3952 // Check to see if this op uses a c++ format.
3953 if (!def.getValueAsBit(FieldName: "hasCustomAssemblyFormat"))
3954 return;
3955 auto *method = opClass.declareMethod(
3956 retType: "void", name: "print", args: MethodParameter("::mlir::OpAsmPrinter &", "p"));
3957 ERROR_IF_PRUNED(method, "print", op);
3958}
3959
3960void OpEmitter::genVerifier() {
3961 auto *implMethod =
3962 opClass.addMethod(retType: "::llvm::LogicalResult", name: "verifyInvariantsImpl");
3963 ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
3964 auto &implBody = implMethod->body();
3965 bool useProperties = emitHelper.hasProperties();
3966
3967 populateSubstitutions(emitHelper, ctx&: verifyCtx);
3968 genPropertyVerifier(emitHelper, ctx&: verifyCtx, body&: implBody, staticVerifierEmitter);
3969 genAttributeVerifier(emitHelper, ctx&: verifyCtx, body&: implBody, staticVerifierEmitter,
3970 useProperties);
3971 genOperandResultVerifier(body&: implBody, values: op.getOperands(), valueKind: "operand");
3972 genOperandResultVerifier(body&: implBody, values: op.getResults(), valueKind: "result");
3973
3974 for (auto &trait : op.getTraits()) {
3975 if (auto *t = dyn_cast<tblgen::PredTrait>(Val: &trait)) {
3976 implBody << tgfmt(fmt: " if (!($0))\n "
3977 "return emitOpError(\"failed to verify that $1\");\n",
3978 ctx: &verifyCtx, vals: tgfmt(fmt: t->getPredTemplate(), ctx: &verifyCtx),
3979 vals: t->getSummary());
3980 }
3981 }
3982
3983 genRegionVerifier(body&: implBody);
3984 genSuccessorVerifier(body&: implBody);
3985
3986 implBody << " return ::mlir::success();\n";
3987
3988 // TODO: Some places use the `verifyInvariants` to do operation verification.
3989 // This may not act as their expectation because this doesn't call any
3990 // verifiers of native/interface traits. Needs to review those use cases and
3991 // see if we should use the mlir::verify() instead.
3992 auto *method = opClass.addMethod(retType: "::llvm::LogicalResult", name: "verifyInvariants");
3993 ERROR_IF_PRUNED(method, "verifyInvariants", op);
3994 auto &body = method->body();
3995 if (def.getValueAsBit(FieldName: "hasVerifier")) {
3996 body << " if(::mlir::succeeded(verifyInvariantsImpl()) && "
3997 "::mlir::succeeded(verify()))\n";
3998 body << " return ::mlir::success();\n";
3999 body << " return ::mlir::failure();";
4000 } else {
4001 body << " return verifyInvariantsImpl();";
4002 }
4003}
4004
4005void OpEmitter::genCustomVerifier() {
4006 if (def.getValueAsBit(FieldName: "hasVerifier")) {
4007 auto *method = opClass.declareMethod(retType: "::llvm::LogicalResult", name: "verify");
4008 ERROR_IF_PRUNED(method, "verify", op);
4009 }
4010
4011 if (def.getValueAsBit(FieldName: "hasRegionVerifier")) {
4012 auto *method =
4013 opClass.declareMethod(retType: "::llvm::LogicalResult", name: "verifyRegions");
4014 ERROR_IF_PRUNED(method, "verifyRegions", op);
4015 }
4016}
4017
4018void OpEmitter::genOperandResultVerifier(MethodBody &body,
4019 Operator::const_value_range values,
4020 StringRef valueKind) {
4021 // Check that an optional value is at most 1 element.
4022 //
4023 // {0}: Value index.
4024 // {1}: "operand" or "result"
4025 const char *const verifyOptional = R"(
4026 if (valueGroup{0}.size() > 1) {
4027 return emitOpError("{1} group starting at #") << index
4028 << " requires 0 or 1 element, but found " << valueGroup{0}.size();
4029 }
4030)";
4031 // Check the types of a range of values.
4032 //
4033 // {0}: Value index.
4034 // {1}: Type constraint function.
4035 // {2}: "operand" or "result"
4036 const char *const verifyValues = R"(
4037 for (auto v : valueGroup{0}) {
4038 if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
4039 return ::mlir::failure();
4040 }
4041)";
4042
4043 const auto canSkip = [](const NamedTypeConstraint &value) {
4044 return !value.hasPredicate() && !value.isOptional() &&
4045 !value.isVariadicOfVariadic();
4046 };
4047 if (values.empty() || llvm::all_of(Range&: values, P: canSkip))
4048 return;
4049
4050 FmtContext fctx;
4051
4052 body << " {\n unsigned index = 0; (void)index;\n";
4053
4054 for (const auto &staticValue : llvm::enumerate(First&: values)) {
4055 const NamedTypeConstraint &value = staticValue.value();
4056
4057 bool hasPredicate = value.hasPredicate();
4058 bool isOptional = value.isOptional();
4059 bool isVariadicOfVariadic = value.isVariadicOfVariadic();
4060 if (!hasPredicate && !isOptional && !isVariadicOfVariadic)
4061 continue;
4062 body << formatv(Fmt: " auto valueGroup{2} = getODS{0}{1}s({2});\n",
4063 // Capitalize the first letter to match the function name
4064 Vals: valueKind.substr(Start: 0, N: 1).upper(), Vals: valueKind.substr(Start: 1),
4065 Vals: staticValue.index());
4066
4067 // If the constraint is optional check that the value group has at most 1
4068 // value.
4069 if (isOptional) {
4070 body << formatv(Fmt: verifyOptional, Vals: staticValue.index(), Vals&: valueKind);
4071 } else if (isVariadicOfVariadic) {
4072 body << formatv(
4073 Fmt: " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
4074 "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
4075 " return ::mlir::failure();\n",
4076 Vals: value.constraint.getVariadicOfVariadicSegmentSizeAttr(), Vals: value.name,
4077 Vals: staticValue.index());
4078 }
4079
4080 // Otherwise, if there is no predicate there is nothing left to do.
4081 if (!hasPredicate)
4082 continue;
4083 // Emit a loop to check all the dynamic values in the pack.
4084 StringRef constraintFn =
4085 staticVerifierEmitter.getTypeConstraintFn(constraint: value.constraint);
4086 body << formatv(Fmt: verifyValues, Vals: staticValue.index(), Vals&: constraintFn, Vals&: valueKind);
4087 }
4088
4089 body << " }\n";
4090}
4091
4092void OpEmitter::genRegionVerifier(MethodBody &body) {
4093 /// Code to verify a region.
4094 ///
4095 /// {0}: Getter for the regions.
4096 /// {1}: The region constraint.
4097 /// {2}: The region's name.
4098 /// {3}: The region description.
4099 const char *const verifyRegion = R"(
4100 for (auto &region : {0})
4101 if (::mlir::failed({1}(*this, region, "{2}", index++)))
4102 return ::mlir::failure();
4103)";
4104 /// Get a single region.
4105 ///
4106 /// {0}: The region's index.
4107 const char *const getSingleRegion =
4108 "::llvm::MutableArrayRef((*this)->getRegion({0}))";
4109
4110 // If we have no regions, there is nothing more to do.
4111 const auto canSkip = [](const NamedRegion &region) {
4112 return region.constraint.getPredicate().isNull();
4113 };
4114 auto regions = op.getRegions();
4115 if (regions.empty() && llvm::all_of(Range&: regions, P: canSkip))
4116 return;
4117
4118 body << " {\n unsigned index = 0; (void)index;\n";
4119 for (const auto &it : llvm::enumerate(First&: regions)) {
4120 const auto &region = it.value();
4121 if (canSkip(region))
4122 continue;
4123
4124 auto getRegion = region.isVariadic()
4125 ? formatv(Fmt: "{0}()", Vals: op.getGetterName(name: region.name)).str()
4126 : formatv(Fmt: getSingleRegion, Vals: it.index()).str();
4127 auto constraintFn =
4128 staticVerifierEmitter.getRegionConstraintFn(constraint: region.constraint);
4129 body << formatv(Fmt: verifyRegion, Vals&: getRegion, Vals&: constraintFn, Vals: region.name);
4130 }
4131 body << " }\n";
4132}
4133
4134void OpEmitter::genSuccessorVerifier(MethodBody &body) {
4135 const char *const verifySuccessor = R"(
4136 for (auto *successor : {0})
4137 if (::mlir::failed({1}(*this, successor, "{2}", index++)))
4138 return ::mlir::failure();
4139)";
4140 /// Get a single successor.
4141 ///
4142 /// {0}: The successor's name.
4143 const char *const getSingleSuccessor = "::llvm::MutableArrayRef({0}())";
4144
4145 // If we have no successors, there is nothing more to do.
4146 const auto canSkip = [](const NamedSuccessor &successor) {
4147 return successor.constraint.getPredicate().isNull();
4148 };
4149 auto successors = op.getSuccessors();
4150 if (successors.empty() && llvm::all_of(Range&: successors, P: canSkip))
4151 return;
4152
4153 body << " {\n unsigned index = 0; (void)index;\n";
4154
4155 for (auto it : llvm::enumerate(First&: successors)) {
4156 const auto &successor = it.value();
4157 if (canSkip(successor))
4158 continue;
4159
4160 auto getSuccessor =
4161 formatv(Fmt: successor.isVariadic() ? "{0}()" : getSingleSuccessor,
4162 Vals: successor.name)
4163 .str();
4164 auto constraintFn =
4165 staticVerifierEmitter.getSuccessorConstraintFn(constraint: successor.constraint);
4166 body << formatv(Fmt: verifySuccessor, Vals&: getSuccessor, Vals&: constraintFn,
4167 Vals: successor.name);
4168 }
4169 body << " }\n";
4170}
4171
4172/// Add a size count trait to the given operation class.
4173static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
4174 int numTotal, int numVariadic) {
4175 if (numVariadic != 0) {
4176 if (numTotal == numVariadic)
4177 opClass.addTrait(trait: "::mlir::OpTrait::Variadic" + traitKind + "s");
4178 else
4179 opClass.addTrait(trait: "::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
4180 Twine(numTotal - numVariadic) + ">::Impl");
4181 return;
4182 }
4183 switch (numTotal) {
4184 case 0:
4185 opClass.addTrait(trait: "::mlir::OpTrait::Zero" + traitKind + "s");
4186 break;
4187 case 1:
4188 opClass.addTrait(trait: "::mlir::OpTrait::One" + traitKind);
4189 break;
4190 default:
4191 opClass.addTrait(trait: "::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
4192 ">::Impl");
4193 break;
4194 }
4195}
4196
4197void OpEmitter::genTraits() {
4198 // Add region size trait.
4199 unsigned numRegions = op.getNumRegions();
4200 unsigned numVariadicRegions = op.getNumVariadicRegions();
4201 addSizeCountTrait(opClass, traitKind: "Region", numTotal: numRegions, numVariadic: numVariadicRegions);
4202
4203 // Add result size traits.
4204 int numResults = op.getNumResults();
4205 int numVariadicResults = op.getNumVariableLengthResults();
4206 addSizeCountTrait(opClass, traitKind: "Result", numTotal: numResults, numVariadic: numVariadicResults);
4207
4208 // For single result ops with a known specific type, generate a OneTypedResult
4209 // trait.
4210 if (numResults == 1 && numVariadicResults == 0) {
4211 auto cppName = op.getResults().begin()->constraint.getCppType();
4212 opClass.addTrait(trait: "::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
4213 }
4214
4215 // Add successor size trait.
4216 unsigned numSuccessors = op.getNumSuccessors();
4217 unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
4218 addSizeCountTrait(opClass, traitKind: "Successor", numTotal: numSuccessors, numVariadic: numVariadicSuccessors);
4219
4220 // Add variadic size trait and normal op traits.
4221 int numOperands = op.getNumOperands();
4222 int numVariadicOperands = op.getNumVariableLengthOperands();
4223
4224 // Add operand size trait.
4225 addSizeCountTrait(opClass, traitKind: "Operand", numTotal: numOperands, numVariadic: numVariadicOperands);
4226
4227 // The op traits defined internal are ensured that they can be verified
4228 // earlier.
4229 for (const auto &trait : op.getTraits()) {
4230 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
4231 if (opTrait->isStructuralOpTrait())
4232 opClass.addTrait(trait: opTrait->getFullyQualifiedTraitName());
4233 }
4234 }
4235
4236 // OpInvariants wrapps the verifyInvariants which needs to be run before
4237 // native/interface traits and after all the traits with `StructuralOpTrait`.
4238 opClass.addTrait(trait: "::mlir::OpTrait::OpInvariants");
4239
4240 if (emitHelper.hasNonEmptyPropertiesStruct())
4241 opClass.addTrait(trait: "::mlir::BytecodeOpInterface::Trait");
4242
4243 // Add the native and interface traits.
4244 for (const auto &trait : op.getTraits()) {
4245 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
4246 if (!opTrait->isStructuralOpTrait())
4247 opClass.addTrait(trait: opTrait->getFullyQualifiedTraitName());
4248 } else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(Val: &trait)) {
4249 opClass.addTrait(trait: opTrait->getFullyQualifiedTraitName());
4250 }
4251 }
4252}
4253
4254void OpEmitter::genOpNameGetter() {
4255 auto *method = opClass.addStaticMethod<Method::Constexpr>(
4256 retType: "::llvm::StringLiteral", name: "getOperationName");
4257 ERROR_IF_PRUNED(method, "getOperationName", op);
4258 method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName()
4259 << "\");";
4260}
4261
4262void OpEmitter::genOpAsmInterface() {
4263 // If the user only has one results or specifically added the Asm trait,
4264 // then don't generate it for them. We specifically only handle multi result
4265 // operations, because the name of a single result in the common case is not
4266 // interesting(generally 'result'/'output'/etc.).
4267 // TODO: We could also add a flag to allow operations to opt in to this
4268 // generation, even if they only have a single operation.
4269 int numResults = op.getNumResults();
4270 if (numResults <= 1 || op.getTrait(trait: "::mlir::OpAsmOpInterface::Trait"))
4271 return;
4272
4273 SmallVector<StringRef, 4> resultNames(numResults);
4274 for (int i = 0; i != numResults; ++i)
4275 resultNames[i] = op.getResultName(index: i);
4276
4277 // Don't add the trait if none of the results have a valid name.
4278 if (llvm::all_of(Range&: resultNames, P: [](StringRef name) { return name.empty(); }))
4279 return;
4280 opClass.addTrait(trait: "::mlir::OpAsmOpInterface::Trait");
4281
4282 // Generate the right accessor for the number of results.
4283 auto *method = opClass.addMethod(
4284 retType: "void", name: "getAsmResultNames",
4285 args: MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
4286 ERROR_IF_PRUNED(method, "getAsmResultNames", op);
4287 auto &body = method->body();
4288 for (int i = 0; i != numResults; ++i) {
4289 body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n"
4290 << " if (!resultGroup" << i << ".empty())\n"
4291 << " setNameFn(*resultGroup" << i << ".begin(), \""
4292 << resultNames[i] << "\");\n";
4293 }
4294}
4295
4296//===----------------------------------------------------------------------===//
4297// OpOperandAdaptor emitter
4298//===----------------------------------------------------------------------===//
4299
4300namespace {
4301// Helper class to emit Op operand adaptors to an output stream. Operand
4302// adaptors are wrappers around random access ranges that provide named operand
4303// getters identical to those defined in the Op.
4304// This currently generates 3 classes per Op:
4305// * A Base class within the 'detail' namespace, which contains all logic and
4306// members independent of the random access range that is indexed into.
4307// In other words, it contains all the attribute and region getters.
4308// * A templated class named '{OpName}GenericAdaptor' with a template parameter
4309// 'RangeT' that is indexed into by the getters to access the operands.
4310// It contains all getters to access operands and inherits from the previous
4311// class.
4312// * A class named '{OpName}Adaptor', which inherits from the 'GenericAdaptor'
4313// with 'mlir::ValueRange' as template parameter. It adds a constructor from
4314// an instance of the op type and a verify function.
4315class OpOperandAdaptorEmitter {
4316public:
4317 static void
4318 emitDecl(const Operator &op,
4319 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4320 raw_ostream &os);
4321 static void
4322 emitDef(const Operator &op,
4323 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4324 raw_ostream &os);
4325
4326private:
4327 explicit OpOperandAdaptorEmitter(
4328 const Operator &op,
4329 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
4330
4331 // Add verification function. This generates a verify method for the adaptor
4332 // which verifies all the op-independent attribute constraints.
4333 void addVerification();
4334
4335 // The operation for which to emit an adaptor.
4336 const Operator &op;
4337
4338 // The generated adaptor classes.
4339 Class genericAdaptorBase;
4340 Class genericAdaptor;
4341 Class adaptor;
4342
4343 // The emitter containing all of the locally emitted verification functions.
4344 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
4345
4346 // Helper for emitting adaptor code.
4347 OpOrAdaptorHelper emitHelper;
4348};
4349} // namespace
4350
4351OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
4352 const Operator &op,
4353 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
4354 : op(op), genericAdaptorBase(op.getGenericAdaptorName() + "Base"),
4355 genericAdaptor(op.getGenericAdaptorName()), adaptor(op.getAdaptorName()),
4356 staticVerifierEmitter(staticVerifierEmitter),
4357 emitHelper(op, /*emitForOp=*/false) {
4358
4359 FmtContext fctx;
4360 fctx.withBuilder(subst: odsBuilder);
4361
4362 genericAdaptorBase.declare<VisibilityDeclaration>(args: Visibility::Public);
4363 bool useProperties = emitHelper.hasProperties();
4364 if (useProperties) {
4365 // Define the properties struct with multiple members.
4366 using ConstArgument =
4367 llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
4368 SmallVector<ConstArgument> attrOrProperties;
4369 for (const std::pair<StringRef, AttributeMetadata> &it :
4370 emitHelper.getAttrMetadata()) {
4371 if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
4372 attrOrProperties.push_back(Elt: &it.second);
4373 }
4374 for (const NamedProperty &prop : op.getProperties())
4375 attrOrProperties.push_back(Elt: &prop);
4376 if (emitHelper.getOperandSegmentsSize())
4377 attrOrProperties.push_back(Elt: &emitHelper.getOperandSegmentsSize().value());
4378 if (emitHelper.getResultSegmentsSize())
4379 attrOrProperties.push_back(Elt: &emitHelper.getResultSegmentsSize().value());
4380 std::string declarations = " struct Properties {\n";
4381 llvm::raw_string_ostream os(declarations);
4382 std::string comparator =
4383 " bool operator==(const Properties &rhs) const {\n"
4384 " return \n";
4385 llvm::raw_string_ostream comparatorOs(comparator);
4386 for (const auto &attrOrProp : attrOrProperties) {
4387 if (const auto *namedProperty =
4388 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
4389 StringRef name = namedProperty->name;
4390 if (name.empty())
4391 report_fatal_error(reason: "missing name for property");
4392 std::string camelName =
4393 convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
4394 auto &prop = namedProperty->prop;
4395 // Generate the data member using the storage type.
4396 os << " using " << name << "Ty = " << prop.getStorageType() << ";\n"
4397 << " " << name << "Ty " << name;
4398 if (prop.hasStorageTypeValueOverride())
4399 os << " = " << prop.getStorageTypeValueOverride();
4400 else if (prop.hasDefaultValue())
4401 os << " = " << tgfmt(fmt: prop.getDefaultValue(), ctx: &fctx);
4402 comparatorOs << " rhs." << name << " == this->" << name
4403 << " &&\n";
4404 // Emit accessors using the interface type.
4405 const char *accessorFmt = R"decl(;
4406 {0} get{1}() const {
4407 auto &propStorage = this->{2};
4408 return {3};
4409 }
4410 void set{1}({0} propValue) {
4411 auto &propStorage = this->{2};
4412 {4};
4413 }
4414)decl";
4415 FmtContext fctx;
4416 os << formatv(Fmt: accessorFmt, Vals: prop.getInterfaceType(), Vals&: camelName, Vals&: name,
4417 Vals: tgfmt(fmt: prop.getConvertFromStorageCall(),
4418 ctx: &fctx.addSubst(placeholder: "_storage", subst: propertyStorage)),
4419 Vals: tgfmt(fmt: prop.getAssignToStorageCall(),
4420 ctx: &fctx.addSubst(placeholder: "_value", subst: propertyValue)
4421 .addSubst(placeholder: "_storage", subst: propertyStorage)));
4422 continue;
4423 }
4424 const auto *namedAttr =
4425 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
4426 const Attribute *attr = nullptr;
4427 if (namedAttr->constraint)
4428 attr = &*namedAttr->constraint;
4429 StringRef name = namedAttr->attrName;
4430 if (name.empty())
4431 report_fatal_error(reason: "missing name for property attr");
4432 std::string camelName =
4433 convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
4434 // Generate the data member using the storage type.
4435 StringRef storageType;
4436 if (attr) {
4437 storageType = attr->getStorageType();
4438 } else {
4439 if (name != operandSegmentAttrName && name != resultSegmentAttrName) {
4440 report_fatal_error(reason: "unexpected AttributeMetadata");
4441 }
4442 // TODO: update to use native integers.
4443 storageType = "::mlir::DenseI32ArrayAttr";
4444 }
4445 os << " using " << name << "Ty = " << storageType << ";\n"
4446 << " " << name << "Ty " << name << ";\n";
4447 comparatorOs << " rhs." << name << " == this->" << name << " &&\n";
4448
4449 // Emit accessors using the interface type.
4450 if (attr) {
4451 const char *accessorFmt = R"decl(
4452 auto get{0}() const {
4453 auto &propStorage = this->{1};
4454 return ::llvm::{2}<{3}>(propStorage);
4455 }
4456 void set{0}(const {3} &propValue) {
4457 this->{1} = propValue;
4458 }
4459)decl";
4460 os << formatv(Fmt: accessorFmt, Vals&: camelName, Vals&: name,
4461 Vals: attr->isOptional() || attr->hasDefaultValue()
4462 ? "dyn_cast_or_null"
4463 : "cast",
4464 Vals&: storageType);
4465 }
4466 }
4467 comparatorOs << " true;\n }\n"
4468 " bool operator!=(const Properties &rhs) const {\n"
4469 " return !(*this == rhs);\n"
4470 " }\n";
4471 os << comparator;
4472 os << " };\n";
4473
4474 if (attrOrProperties.empty())
4475 genericAdaptorBase.declare<UsingDeclaration>(args: "Properties",
4476 args: "::mlir::EmptyProperties");
4477 else
4478 genericAdaptorBase.declare<ExtraClassDeclaration>(
4479 args: std::move(declarations));
4480 }
4481 genericAdaptorBase.declare<VisibilityDeclaration>(args: Visibility::Protected);
4482 genericAdaptorBase.declare<Field>(args: "::mlir::DictionaryAttr", args: "odsAttrs");
4483 genericAdaptorBase.declare<Field>(args: "::std::optional<::mlir::OperationName>",
4484 args: "odsOpName");
4485 if (useProperties)
4486 genericAdaptorBase.declare<Field>(args: "Properties", args: "properties");
4487 genericAdaptorBase.declare<Field>(args: "::mlir::RegionRange", args: "odsRegions");
4488
4489 genericAdaptor.addTemplateParam(param: "RangeT");
4490 genericAdaptor.addField(type: "RangeT", name: "odsOperands");
4491 genericAdaptor.addParent(
4492 parent: ParentClass("detail::" + genericAdaptorBase.getClassName()));
4493 genericAdaptor.declare<UsingDeclaration>(
4494 args: "ValueT", args: "::llvm::detail::ValueOfRange<RangeT>");
4495 genericAdaptor.declare<UsingDeclaration>(
4496 args: "Base", args: "detail::" + genericAdaptorBase.getClassName());
4497
4498 const auto *attrSizedOperands =
4499 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments");
4500 {
4501 SmallVector<MethodParameter> paramList;
4502 if (useProperties) {
4503 // Properties can't be given a default constructor here due to Properties
4504 // struct being defined in the enclosing class which isn't complete by
4505 // here.
4506 paramList.emplace_back(Args: "::mlir::DictionaryAttr", Args: "attrs");
4507 paramList.emplace_back(Args: "const Properties &", Args: "properties");
4508 } else {
4509 paramList.emplace_back(Args: "::mlir::DictionaryAttr", Args: "attrs", Args: "{}");
4510 paramList.emplace_back(Args: "const ::mlir::EmptyProperties &", Args: "properties",
4511 Args: "{}");
4512 }
4513 paramList.emplace_back(Args: "::mlir::RegionRange", Args: "regions", Args: "{}");
4514 auto *baseConstructor =
4515 genericAdaptorBase.addConstructor<Method::Inline>(args&: paramList);
4516 baseConstructor->addMemberInitializer(name: "odsAttrs", value: "attrs");
4517 if (useProperties)
4518 baseConstructor->addMemberInitializer(name: "properties", value: "properties");
4519 baseConstructor->addMemberInitializer(name: "odsRegions", value: "regions");
4520
4521 MethodBody &body = baseConstructor->body();
4522 body.indent() << "if (odsAttrs)\n";
4523 body.indent() << formatv(
4524 Fmt: "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
4525 Vals: op.getOperationName());
4526
4527 paramList.insert(I: paramList.begin(), Elt: MethodParameter("RangeT", "values"));
4528 auto *constructor = genericAdaptor.addConstructor(args&: paramList);
4529 constructor->addMemberInitializer(name: "Base", value: "attrs, properties, regions");
4530 constructor->addMemberInitializer(name: "odsOperands", value: "values");
4531
4532 // Add a forwarding constructor to the previous one that accepts
4533 // OpaqueProperties instead and check for null and perform the cast to the
4534 // actual properties type.
4535 paramList[1] = MethodParameter("::mlir::DictionaryAttr", "attrs");
4536 paramList[2] = MethodParameter("::mlir::OpaqueProperties", "properties");
4537 auto *opaquePropertiesConstructor =
4538 genericAdaptor.addConstructor(args: std::move(paramList));
4539 if (useProperties) {
4540 opaquePropertiesConstructor->addMemberInitializer(
4541 name: genericAdaptor.getClassName(),
4542 value: "values, "
4543 "attrs, "
4544 "(properties ? *properties.as<Properties *>() : Properties{}), "
4545 "regions");
4546 } else {
4547 opaquePropertiesConstructor->addMemberInitializer(
4548 name: genericAdaptor.getClassName(),
4549 value: "values, "
4550 "attrs, "
4551 "(properties ? *properties.as<::mlir::EmptyProperties *>() : "
4552 "::mlir::EmptyProperties{}), "
4553 "regions");
4554 }
4555
4556 // Add forwarding constructor that constructs Properties.
4557 if (useProperties) {
4558 SmallVector<MethodParameter> paramList;
4559 paramList.emplace_back(Args: "RangeT", Args: "values");
4560 paramList.emplace_back(Args: "::mlir::DictionaryAttr", Args: "attrs",
4561 Args: attrSizedOperands ? "" : "nullptr");
4562 auto *noPropertiesConstructor =
4563 genericAdaptor.addConstructor(args: std::move(paramList));
4564 noPropertiesConstructor->addMemberInitializer(
4565 name: genericAdaptor.getClassName(), value: "values, "
4566 "attrs, "
4567 "Properties{}, "
4568 "{}");
4569 }
4570 }
4571
4572 // Create a constructor that creates a new generic adaptor by copying
4573 // everything from another adaptor, except for the values.
4574 {
4575 SmallVector<MethodParameter> paramList;
4576 paramList.emplace_back(Args: "RangeT", Args: "values");
4577 paramList.emplace_back(Args: "const " + op.getGenericAdaptorName() + "Base &",
4578 Args: "base");
4579 auto *constructor =
4580 genericAdaptor.addConstructor<Method::Inline>(args&: paramList);
4581 constructor->addMemberInitializer(name: "Base", value: "base");
4582 constructor->addMemberInitializer(name: "odsOperands", value: "values");
4583 }
4584
4585 // Create constructors constructing the adaptor from an instance of the op.
4586 // This takes the attributes, properties and regions from the op instance
4587 // and the value range from the parameter.
4588 {
4589 // Base class is in the cpp file and can simply access the members of the op
4590 // class to initialize the template independent fields. If the op doesn't
4591 // have properties, we can emit a generic constructor inline. Otherwise,
4592 // emit it out-of-line because we need the op to be defined.
4593 Constructor *constructor;
4594 if (useProperties) {
4595 constructor = genericAdaptorBase.addConstructor(
4596 args: MethodParameter(op.getCppClassName(), "op"));
4597 } else {
4598 constructor = genericAdaptorBase.addConstructor<Method::Inline>(
4599 args: MethodParameter("::mlir::Operation *", "op"));
4600 }
4601 constructor->addMemberInitializer(name: "odsAttrs",
4602 value: "op->getRawDictionaryAttrs()");
4603 // Retrieve the operation name from the op directly.
4604 constructor->addMemberInitializer(name: "odsOpName", value: "op->getName()");
4605 if (useProperties)
4606 constructor->addMemberInitializer(name: "properties", value: "op.getProperties()");
4607 constructor->addMemberInitializer(name: "odsRegions", value: "op->getRegions()");
4608
4609 // Generic adaptor is templated and therefore defined inline in the header.
4610 // We cannot use the Op class here as it is an incomplete type (we have a
4611 // circular reference between the two).
4612 // Use a template trick to make the constructor be instantiated at call site
4613 // when the op class is complete.
4614 constructor = genericAdaptor.addConstructor(
4615 args: MethodParameter("RangeT", "values"), args: MethodParameter("LateInst", "op"));
4616 constructor->addTemplateParam(param: "LateInst = " + op.getCppClassName());
4617 constructor->addTemplateParam(
4618 param: "= std::enable_if_t<std::is_same_v<LateInst, " + op.getCppClassName() +
4619 ">>");
4620 constructor->addMemberInitializer(name: "Base", value: "op");
4621 constructor->addMemberInitializer(name: "odsOperands", value: "values");
4622 }
4623
4624 std::string sizeAttrInit;
4625 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
4626 if (op.getDialect().usePropertiesForAttributes())
4627 sizeAttrInit =
4628 formatv(Fmt: adapterSegmentSizeAttrInitCodeProperties,
4629 Vals: llvm::formatv(Fmt: "getProperties().operandSegmentSizes"));
4630 else
4631 sizeAttrInit = formatv(Fmt: adapterSegmentSizeAttrInitCode,
4632 Vals: emitHelper.getAttr(attrName: operandSegmentAttrName));
4633 }
4634 generateNamedOperandGetters(op, opClass&: genericAdaptor,
4635 /*genericAdaptorBase=*/&genericAdaptorBase,
4636 /*sizeAttrInit=*/sizeAttrInit,
4637 /*rangeType=*/"RangeT",
4638 /*rangeElementType=*/"ValueT",
4639 /*rangeBeginCall=*/"odsOperands.begin()",
4640 /*rangeSizeCall=*/"odsOperands.size()",
4641 /*getOperandCallPattern=*/"odsOperands[{0}]");
4642
4643 // Any invalid overlap for `getOperands` will have been diagnosed before
4644 // here already.
4645 if (auto *m = genericAdaptor.addMethod(retType: "RangeT", name: "getOperands"))
4646 m->body() << " return odsOperands;";
4647
4648 fctx.withBuilder(subst: "::mlir::Builder(odsAttrs.getContext())");
4649
4650 // Generate named accessor with Attribute return type.
4651 auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName,
4652 Attribute attr) {
4653 // The method body is trivial if the attribute does not have a default
4654 // value, in which case the default value may be arbitrary code.
4655 auto *method = genericAdaptorBase.addMethod(
4656 retType: attr.getStorageType(), name: emitName + "Attr",
4657 properties: attr.hasDefaultValue() || !useProperties ? Method::Properties::None
4658 : Method::Properties::Inline);
4659 ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
4660 auto &body = method->body().indent();
4661 if (!useProperties)
4662 body << "assert(odsAttrs && \"no attributes when constructing "
4663 "adapter\");\n";
4664 body << formatv(
4665 Fmt: "auto attr = ::llvm::{1}<{2}>({0});\n", Vals: emitHelper.getAttr(attrName: name),
4666 Vals: attr.hasDefaultValue() || attr.isOptional() ? "dyn_cast_or_null"
4667 : "cast",
4668 Vals: attr.getStorageType());
4669
4670 if (attr.hasDefaultValue() && attr.isOptional()) {
4671 // Use the default value if attribute is not set.
4672 // TODO: this is inefficient, we are recreating the attribute for every
4673 // call. This should be set instead.
4674 std::string defaultValue =
4675 std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx,
4676 vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx)));
4677 body << "if (!attr)\n attr = " << defaultValue << ";\n";
4678 }
4679 body << "return attr;\n";
4680 };
4681
4682 if (useProperties) {
4683 auto *m = genericAdaptorBase.addInlineMethod(retType: "const Properties &",
4684 name: "getProperties");
4685 ERROR_IF_PRUNED(m, "Adaptor::getProperties", op);
4686 m->body() << " return properties;";
4687 }
4688 {
4689 auto *m = genericAdaptorBase.addInlineMethod(retType: "::mlir::DictionaryAttr",
4690 name: "getAttributes");
4691 ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
4692 m->body() << " return odsAttrs;";
4693 }
4694 for (auto &namedProp : op.getProperties()) {
4695 std::string name = op.getGetterName(name: namedProp.name);
4696 emitPropGetter(opClass&: genericAdaptorBase, op, name, prop: namedProp.prop);
4697 }
4698
4699 for (auto &namedAttr : op.getAttributes()) {
4700 const auto &name = namedAttr.name;
4701 const auto &attr = namedAttr.attr;
4702 if (attr.isDerivedAttr())
4703 continue;
4704 std::string emitName = op.getGetterName(name);
4705 emitAttrWithStorageType(name, emitName, attr);
4706 emitAttrGetterWithReturnType(fctx, opClass&: genericAdaptorBase, op, name: emitName, attr);
4707 }
4708
4709 unsigned numRegions = op.getNumRegions();
4710 for (unsigned i = 0; i < numRegions; ++i) {
4711 const auto &region = op.getRegion(index: i);
4712 if (region.name.empty())
4713 continue;
4714
4715 // Generate the accessors for a variadic region.
4716 std::string name = op.getGetterName(name: region.name);
4717 if (region.isVariadic()) {
4718 auto *m = genericAdaptorBase.addInlineMethod(retType: "::mlir::RegionRange", name);
4719 ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
4720 m->body() << formatv(Fmt: " return odsRegions.drop_front({0});", Vals&: i);
4721 continue;
4722 }
4723
4724 auto *m = genericAdaptorBase.addInlineMethod(retType: "::mlir::Region &", name);
4725 ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
4726 m->body() << formatv(Fmt: " return *odsRegions[{0}];", Vals&: i);
4727 }
4728 if (numRegions > 0) {
4729 // Any invalid overlap for `getRegions` will have been diagnosed before
4730 // here already.
4731 if (auto *m = genericAdaptorBase.addInlineMethod(retType: "::mlir::RegionRange",
4732 name: "getRegions"))
4733 m->body() << " return odsRegions;";
4734 }
4735
4736 StringRef genericAdaptorClassName = genericAdaptor.getClassName();
4737 adaptor.addParent(parent: ParentClass(genericAdaptorClassName))
4738 .addTemplateParam(param: "::mlir::ValueRange");
4739 adaptor.declare<VisibilityDeclaration>(args: Visibility::Public);
4740 adaptor.declare<UsingDeclaration>(args: genericAdaptorClassName +
4741 "::" + genericAdaptorClassName);
4742 {
4743 // Constructor taking the Op as single parameter.
4744 auto *constructor =
4745 adaptor.addConstructor(args: MethodParameter(op.getCppClassName(), "op"));
4746 constructor->addMemberInitializer(name&: genericAdaptorClassName,
4747 value: "op->getOperands(), op");
4748 }
4749
4750 // Add verification function.
4751 addVerification();
4752
4753 genericAdaptorBase.finalize();
4754 genericAdaptor.finalize();
4755 adaptor.finalize();
4756}
4757
4758void OpOperandAdaptorEmitter::addVerification() {
4759 auto *method = adaptor.addMethod(retType: "::llvm::LogicalResult", name: "verify",
4760 args: MethodParameter("::mlir::Location", "loc"));
4761 ERROR_IF_PRUNED(method, "verify", op);
4762 auto &body = method->body();
4763 bool useProperties = emitHelper.hasProperties();
4764
4765 FmtContext verifyCtx;
4766 populateSubstitutions(emitHelper, ctx&: verifyCtx);
4767 genAttributeVerifier(emitHelper, ctx&: verifyCtx, body, staticVerifierEmitter,
4768 useProperties);
4769
4770 body << " return ::mlir::success();";
4771}
4772
4773void OpOperandAdaptorEmitter::emitDecl(
4774 const Operator &op,
4775 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4776 raw_ostream &os) {
4777 OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
4778 {
4779 NamespaceEmitter ns(os, "detail");
4780 emitter.genericAdaptorBase.writeDeclTo(rawOs&: os);
4781 }
4782 emitter.genericAdaptor.writeDeclTo(rawOs&: os);
4783 emitter.adaptor.writeDeclTo(rawOs&: os);
4784}
4785
4786void OpOperandAdaptorEmitter::emitDef(
4787 const Operator &op,
4788 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4789 raw_ostream &os) {
4790 OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
4791 {
4792 NamespaceEmitter ns(os, "detail");
4793 emitter.genericAdaptorBase.writeDefTo(rawOs&: os);
4794 }
4795 emitter.genericAdaptor.writeDefTo(rawOs&: os);
4796 emitter.adaptor.writeDefTo(rawOs&: os);
4797}
4798
4799/// Emit the class declarations or definitions for the given op defs.
4800static void
4801emitOpClasses(const RecordKeeper &records,
4802 const std::vector<const Record *> &defs, raw_ostream &os,
4803 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4804 bool emitDecl) {
4805 if (defs.empty())
4806 return;
4807
4808 for (auto *def : defs) {
4809 Operator op(*def);
4810 if (emitDecl) {
4811 {
4812 NamespaceEmitter emitter(os, op.getCppNamespace());
4813 os << formatv(Fmt: opCommentHeader, Vals: op.getQualCppClassName(),
4814 Vals: "declarations");
4815 OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os);
4816 OpEmitter::emitDecl(op, os, staticVerifierEmitter);
4817 }
4818 // Emit the TypeID explicit specialization to have a single definition.
4819 if (!op.getCppNamespace().empty())
4820 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
4821 << "::" << op.getCppClassName() << ")\n\n";
4822 } else {
4823 {
4824 NamespaceEmitter emitter(os, op.getCppNamespace());
4825 os << formatv(Fmt: opCommentHeader, Vals: op.getQualCppClassName(), Vals: "definitions");
4826 OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os);
4827 OpEmitter::emitDef(op, os, staticVerifierEmitter);
4828 }
4829 // Emit the TypeID explicit specialization to have a single definition.
4830 if (!op.getCppNamespace().empty())
4831 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
4832 << "::" << op.getCppClassName() << ")\n\n";
4833 }
4834 }
4835}
4836
4837/// Emit the declarations for the provided op classes.
4838static void emitOpClassDecls(const RecordKeeper &records,
4839 const std::vector<const Record *> &defs,
4840 raw_ostream &os) {
4841 // First emit forward declaration for each class, this allows them to refer
4842 // to each others in traits for example.
4843 for (auto *def : defs) {
4844 Operator op(*def);
4845 NamespaceEmitter emitter(os, op.getCppNamespace());
4846 std::string comments = tblgen::emitSummaryAndDescComments(
4847 summary: op.getSummary(), description: op.getDescription());
4848 if (!comments.empty()) {
4849 os << comments << "\n";
4850 }
4851 os << "class " << op.getCppClassName() << ";\n";
4852 }
4853
4854 // Emit the op class declarations.
4855 IfDefScope scope("GET_OP_CLASSES", os);
4856 if (defs.empty())
4857 return;
4858 StaticVerifierFunctionEmitter staticVerifierEmitter(os, records);
4859 staticVerifierEmitter.collectOpConstraints(opDefs: defs);
4860 emitOpClasses(records, defs, os, staticVerifierEmitter,
4861 /*emitDecl=*/true);
4862}
4863
4864/// Emit the definitions for the provided op classes.
4865static void emitOpClassDefs(const RecordKeeper &records,
4866 ArrayRef<const Record *> defs, raw_ostream &os,
4867 StringRef constraintPrefix = "") {
4868 if (defs.empty())
4869 return;
4870
4871 // Generate all of the locally instantiated methods first.
4872 StaticVerifierFunctionEmitter staticVerifierEmitter(os, records,
4873 constraintPrefix);
4874 os << formatv(Fmt: opCommentHeader, Vals: "Local Utility Method", Vals: "Definitions");
4875 staticVerifierEmitter.collectOpConstraints(opDefs: defs);
4876 staticVerifierEmitter.emitOpConstraints(opDefs: defs);
4877
4878 // Emit the classes.
4879 emitOpClasses(records, defs, os, staticVerifierEmitter,
4880 /*emitDecl=*/false);
4881}
4882
4883/// Emit op declarations for all op records.
4884static bool emitOpDecls(const RecordKeeper &records, raw_ostream &os) {
4885 emitSourceFileHeader(Desc: "Op Declarations", OS&: os, Record: records);
4886
4887 std::vector<const Record *> defs = getRequestedOpDefinitions(records);
4888 emitOpClassDecls(records, defs, os);
4889
4890 // If we are generating sharded op definitions, emit the sharded op
4891 // registration hooks.
4892 SmallVector<ArrayRef<const Record *>, 4> shardedDefs;
4893 shardOpDefinitions(defs, shardedDefs);
4894 if (defs.empty() || shardedDefs.size() <= 1)
4895 return false;
4896
4897 Dialect dialect = Operator(defs.front()).getDialect();
4898 NamespaceEmitter ns(os, dialect);
4899
4900 const char *const opRegistrationHook =
4901 "void register{0}Operations{1}({2}::{0} *dialect);\n";
4902 os << formatv(Fmt: opRegistrationHook, Vals: dialect.getCppClassName(), Vals: "",
4903 Vals: dialect.getCppNamespace());
4904 for (unsigned i = 0; i < shardedDefs.size(); ++i) {
4905 os << formatv(Fmt: opRegistrationHook, Vals: dialect.getCppClassName(), Vals&: i,
4906 Vals: dialect.getCppNamespace());
4907 }
4908
4909 return false;
4910}
4911
4912/// Generate the dialect op registration hook and the op class definitions for a
4913/// shard of ops.
4914static void emitOpDefShard(const RecordKeeper &records,
4915 ArrayRef<const Record *> defs,
4916 const Dialect &dialect, unsigned shardIndex,
4917 unsigned shardCount, raw_ostream &os) {
4918 std::string shardGuard = "GET_OP_DEFS_";
4919 std::string indexStr = std::to_string(val: shardIndex);
4920 shardGuard += indexStr;
4921 IfDefScope scope(shardGuard, os);
4922
4923 // Emit the op registration hook in the first shard.
4924 const char *const opRegistrationHook =
4925 "void {0}::register{1}Operations{2}({0}::{1} *dialect) {{\n";
4926 if (shardIndex == 0) {
4927 os << formatv(Fmt: opRegistrationHook, Vals: dialect.getCppNamespace(),
4928 Vals: dialect.getCppClassName(), Vals: "");
4929 for (unsigned i = 0; i < shardCount; ++i) {
4930 os << formatv(Fmt: " {0}::register{1}Operations{2}(dialect);\n",
4931 Vals: dialect.getCppNamespace(), Vals: dialect.getCppClassName(), Vals&: i);
4932 }
4933 os << "}\n";
4934 }
4935
4936 // Generate the per-shard op registration hook.
4937 os << formatv(Fmt: opCommentHeader, Vals: dialect.getCppClassName(),
4938 Vals: "Op Registration Hook")
4939 << formatv(Fmt: opRegistrationHook, Vals: dialect.getCppNamespace(),
4940 Vals: dialect.getCppClassName(), Vals&: shardIndex);
4941 for (const Record *def : defs) {
4942 os << formatv(Fmt: " ::mlir::RegisteredOperationName::insert<{0}>(*dialect);\n",
4943 Vals: Operator(def).getQualCppClassName());
4944 }
4945 os << "}\n";
4946
4947 // Generate the per-shard op definitions.
4948 emitOpClassDefs(records, defs, os, constraintPrefix: indexStr);
4949}
4950
4951/// Emit op definitions for all op records.
4952static bool emitOpDefs(const RecordKeeper &records, raw_ostream &os) {
4953 emitSourceFileHeader(Desc: "Op Definitions", OS&: os, Record: records);
4954
4955 std::vector<const Record *> defs = getRequestedOpDefinitions(records);
4956 SmallVector<ArrayRef<const Record *>, 4> shardedDefs;
4957 shardOpDefinitions(defs, shardedDefs);
4958
4959 // If no shard was requested, emit the regular op list and class definitions.
4960 if (shardedDefs.size() == 1) {
4961 {
4962 IfDefScope scope("GET_OP_LIST", os);
4963 interleave(
4964 c: defs, os,
4965 each_fn: [&](const Record *def) { os << Operator(def).getQualCppClassName(); },
4966 separator: ",\n");
4967 }
4968 {
4969 IfDefScope scope("GET_OP_CLASSES", os);
4970 emitOpClassDefs(records, defs, os);
4971 }
4972 return false;
4973 }
4974
4975 if (defs.empty())
4976 return false;
4977 Dialect dialect = Operator(defs.front()).getDialect();
4978 for (auto [idx, value] : llvm::enumerate(First&: shardedDefs)) {
4979 emitOpDefShard(records, defs: value, dialect, shardIndex: idx, shardCount: shardedDefs.size(), os);
4980 }
4981 return false;
4982}
4983
4984static mlir::GenRegistration
4985 genOpDecls("gen-op-decls", "Generate op declarations",
4986 [](const RecordKeeper &records, raw_ostream &os) {
4987 return emitOpDecls(records, os);
4988 });
4989
4990static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
4991 [](const RecordKeeper &records,
4992 raw_ostream &os) {
4993 return emitOpDefs(records, os);
4994 });
4995

source code of mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp