1//===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpDefinitionsGen uses the description of operations to generate C++
10// definitions for ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "OpClass.h"
15#include "OpFormatGen.h"
16#include "OpGenHelpers.h"
17#include "mlir/TableGen/Argument.h"
18#include "mlir/TableGen/Attribute.h"
19#include "mlir/TableGen/Class.h"
20#include "mlir/TableGen/CodeGenHelpers.h"
21#include "mlir/TableGen/Format.h"
22#include "mlir/TableGen/GenInfo.h"
23#include "mlir/TableGen/Interfaces.h"
24#include "mlir/TableGen/Operator.h"
25#include "mlir/TableGen/Property.h"
26#include "mlir/TableGen/SideEffects.h"
27#include "mlir/TableGen/Trait.h"
28#include "llvm/ADT/BitVector.h"
29#include "llvm/ADT/MapVector.h"
30#include "llvm/ADT/Sequence.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/StringExtras.h"
33#include "llvm/ADT/StringSet.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/ErrorHandling.h"
36#include "llvm/Support/Signals.h"
37#include "llvm/Support/raw_ostream.h"
38#include "llvm/TableGen/Error.h"
39#include "llvm/TableGen/Record.h"
40#include "llvm/TableGen/TableGenBackend.h"
41
42#define DEBUG_TYPE "mlir-tblgen-opdefgen"
43
44using namespace llvm;
45using namespace mlir;
46using namespace mlir::tblgen;
47
48static const char *const tblgenNamePrefix = "tblgen_";
49static const char *const generatedArgName = "odsArg";
50static const char *const odsBuilder = "odsBuilder";
51static const char *const builderOpState = "odsState";
52static const char *const propertyStorage = "propStorage";
53static const char *const propertyValue = "propValue";
54static const char *const propertyAttr = "propAttr";
55static const char *const propertyDiag = "emitError";
56
57/// The names of the implicit attributes that contain variadic operand and
58/// result segment sizes.
59static const char *const operandSegmentAttrName = "operandSegmentSizes";
60static const char *const resultSegmentAttrName = "resultSegmentSizes";
61
62/// Code for an Op to lookup an attribute. Uses cached identifiers and subrange
63/// lookup.
64///
65/// {0}: Code snippet to get the attribute's name or identifier.
66/// {1}: The lower bound on the sorted subrange.
67/// {2}: The upper bound on the sorted subrange.
68/// {3}: Code snippet to get the array of named attributes.
69/// {4}: "Named" to get the named attribute.
70static const char *const subrangeGetAttr =
71 "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - "
72 "{2}, {0})";
73
74/// The logic to calculate the actual value range for a declared operand/result
75/// of an op with variadic operands/results. Note that this logic is not for
76/// general use; it assumes all variadic operands/results must have the same
77/// number of values.
78///
79/// {0}: The list of whether each declared operand/result is variadic.
80/// {1}: The total number of non-variadic operands/results.
81/// {2}: The total number of variadic operands/results.
82/// {3}: The total number of actual values.
83/// {4}: "operand" or "result".
84static const char *const sameVariadicSizeValueRangeCalcCode = R"(
85 bool isVariadic[] = {{{0}};
86 int prevVariadicCount = 0;
87 for (unsigned i = 0; i < index; ++i)
88 if (isVariadic[i]) ++prevVariadicCount;
89
90 // Calculate how many dynamic values a static variadic {4} corresponds to.
91 // This assumes all static variadic {4}s have the same dynamic value count.
92 int variadicSize = ({3} - {1}) / {2};
93 // `index` passed in as the parameter is the static index which counts each
94 // {4} (variadic or not) as size 1. So here for each previous static variadic
95 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
96 // value pack for this static {4} starts.
97 int start = index + (variadicSize - 1) * prevVariadicCount;
98 int size = isVariadic[index] ? variadicSize : 1;
99 return {{start, size};
100)";
101
102/// The logic to calculate the actual value range for a declared operand/result
103/// of an op with variadic operands/results. Note that this logic is assumes
104/// the op has an attribute specifying the size of each operand/result segment
105/// (variadic or not).
106static const char *const attrSizedSegmentValueRangeCalcCode = R"(
107 unsigned start = 0;
108 for (unsigned i = 0; i < index; ++i)
109 start += sizeAttr[i];
110 return {start, sizeAttr[index]};
111)";
112/// The code snippet to initialize the sizes for the value range calculation.
113///
114/// {0}: The code to get the attribute.
115static const char *const adapterSegmentSizeAttrInitCode = R"(
116 assert({0} && "missing segment size attribute for op");
117 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
118)";
119static const char *const adapterSegmentSizeAttrInitCodeProperties = R"(
120 ::llvm::ArrayRef<int32_t> sizeAttr = {0};
121)";
122
123/// The code snippet to initialize the sizes for the value range calculation.
124///
125/// {0}: The code to get the attribute.
126static const char *const opSegmentSizeAttrInitCode = R"(
127 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
128)";
129
130/// The logic to calculate the actual value range for a declared operand
131/// of an op with variadic of variadic operands within the OpAdaptor.
132///
133/// {0}: The name of the segment attribute.
134/// {1}: The index of the main operand.
135/// {2}: The range type of adaptor.
136static const char *const variadicOfVariadicAdaptorCalcCode = R"(
137 auto tblgenTmpOperands = getODSOperands({1});
138 auto sizes = {0}();
139
140 ::llvm::SmallVector<{2}> tblgenTmpOperandGroups;
141 for (int i = 0, e = sizes.size(); i < e; ++i) {{
142 tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(sizes[i]));
143 tblgenTmpOperands = tblgenTmpOperands.drop_front(sizes[i]);
144 }
145 return tblgenTmpOperandGroups;
146)";
147
148/// The logic to build a range of either operand or result values.
149///
150/// {0}: The begin iterator of the actual values.
151/// {1}: The call to generate the start and length of the value range.
152static const char *const valueRangeReturnCode = R"(
153 auto valueRange = {1};
154 return {{std::next({0}, valueRange.first),
155 std::next({0}, valueRange.first + valueRange.second)};
156)";
157
158/// Read operand/result segment_size from bytecode.
159static const char *const readBytecodeSegmentSizeNative = R"(
160 if ($_reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
161 return $_reader.readSparseArray(::llvm::MutableArrayRef($_storage));
162)";
163
164static const char *const readBytecodeSegmentSizeLegacy = R"(
165 if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
166 auto &$_storage = prop.$_propName;
167 ::mlir::DenseI32ArrayAttr attr;
168 if (::mlir::failed($_reader.readAttribute(attr))) return ::mlir::failure();
169 if (attr.size() > static_cast<int64_t>(sizeof($_storage) / sizeof(int32_t))) {
170 $_reader.emitError("size mismatch for operand/result_segment_size");
171 return ::mlir::failure();
172 }
173 ::llvm::copy(::llvm::ArrayRef<int32_t>(attr), $_storage.begin());
174 }
175)";
176
177/// Write operand/result segment_size to bytecode.
178static const char *const writeBytecodeSegmentSizeNative = R"(
179 if ($_writer.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
180 $_writer.writeSparseArray(::llvm::ArrayRef($_storage));
181)";
182
183/// Write operand/result segment_size to bytecode.
184static const char *const writeBytecodeSegmentSizeLegacy = R"(
185if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
186 auto &$_storage = prop.$_propName;
187 $_writer.writeAttribute(::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage));
188}
189)";
190
191/// A header for indicating code sections.
192///
193/// {0}: Some text, or a class name.
194/// {1}: Some text.
195static const char *const opCommentHeader = R"(
196//===----------------------------------------------------------------------===//
197// {0} {1}
198//===----------------------------------------------------------------------===//
199
200)";
201
202//===----------------------------------------------------------------------===//
203// Utility structs and functions
204//===----------------------------------------------------------------------===//
205
206// Replaces all occurrences of `match` in `str` with `substitute`.
207static std::string replaceAllSubstrs(std::string str, const std::string &match,
208 const std::string &substitute) {
209 std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
210 while ((matchLoc = str.find(str: match, pos: scanLoc)) != std::string::npos) {
211 str = str.replace(pos: matchLoc, n: match.size(), str: substitute);
212 scanLoc = matchLoc + substitute.size();
213 }
214 return str;
215}
216
217// Returns whether the record has a value of the given name that can be returned
218// via getValueAsString.
219static inline bool hasStringAttribute(const Record &record,
220 StringRef fieldName) {
221 auto *valueInit = record.getValueInit(FieldName: fieldName);
222 return isa<StringInit>(Val: valueInit);
223}
224
225static std::string getArgumentName(const Operator &op, int index) {
226 const auto &operand = op.getOperand(index);
227 if (!operand.name.empty())
228 return std::string(operand.name);
229 return std::string(formatv(Fmt: "{0}_{1}", Vals: generatedArgName, Vals&: index));
230}
231
232// Returns true if we can use unwrapped value for the given `attr` in builders.
233static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
234 return attr.getReturnType() != attr.getStorageType() &&
235 // We need to wrap the raw value into an attribute in the builder impl
236 // so we need to make sure that the attribute specifies how to do that.
237 !attr.getConstBuilderTemplate().empty();
238}
239
240/// Build an attribute from a parameter value using the constant builder.
241static std::string constBuildAttrFromParam(const tblgen::Attribute &attr,
242 FmtContext &fctx,
243 StringRef paramName) {
244 std::string builderTemplate = attr.getConstBuilderTemplate().str();
245
246 // For StringAttr, its constant builder call will wrap the input in
247 // quotes, which is correct for normal string literals, but incorrect
248 // here given we use function arguments. So we need to strip the
249 // wrapping quotes.
250 if (StringRef(builderTemplate).contains(Other: "\"$0\""))
251 builderTemplate = replaceAllSubstrs(str: builderTemplate, match: "\"$0\"", substitute: "$0");
252
253 return tgfmt(fmt: builderTemplate, ctx: &fctx, vals&: paramName).str();
254}
255
256namespace {
257/// Metadata on a registered attribute. Given that attributes are stored in
258/// sorted order on operations, we can use information from ODS to deduce the
259/// number of required attributes less and and greater than each attribute,
260/// allowing us to search only a subrange of the attributes in ODS-generated
261/// getters.
262struct AttributeMetadata {
263 /// The attribute name.
264 StringRef attrName;
265 /// Whether the attribute is required.
266 bool isRequired;
267 /// The ODS attribute constraint. Not present for implicit attributes.
268 std::optional<Attribute> constraint;
269 /// The number of required attributes less than this attribute.
270 unsigned lowerBound = 0;
271 /// The number of required attributes greater than this attribute.
272 unsigned upperBound = 0;
273};
274
275/// Helper class to select between OpAdaptor and Op code templates.
276class OpOrAdaptorHelper {
277public:
278 OpOrAdaptorHelper(const Operator &op, bool emitForOp)
279 : op(op), emitForOp(emitForOp) {
280 computeAttrMetadata();
281 }
282
283 /// Object that wraps a functor in a stream operator for interop with
284 /// llvm::formatv.
285 class Formatter {
286 public:
287 template <typename Functor>
288 Formatter(Functor &&func) : func(std::forward<Functor>(func)) {}
289
290 std::string str() const {
291 std::string result;
292 llvm::raw_string_ostream os(result);
293 os << *this;
294 return os.str();
295 }
296
297 private:
298 std::function<raw_ostream &(raw_ostream &)> func;
299
300 friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) {
301 return fmt.func(os);
302 }
303 };
304
305 // Generate code for getting an attribute.
306 Formatter getAttr(StringRef attrName, bool isNamed = false) const {
307 assert(attrMetadata.count(attrName) && "expected attribute metadata");
308 return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
309 const AttributeMetadata &attr = attrMetadata.find(Key: attrName)->second;
310 if (hasProperties()) {
311 assert(!isNamed);
312 return os << "getProperties()." << attrName;
313 }
314 return os << formatv(Fmt: subrangeGetAttr, Vals: getAttrName(attrName),
315 Vals: attr.lowerBound, Vals: attr.upperBound, Vals: getAttrRange(),
316 Vals: isNamed ? "Named" : "");
317 };
318 }
319
320 // Generate code for getting the name of an attribute.
321 Formatter getAttrName(StringRef attrName) const {
322 return [this, attrName](raw_ostream &os) -> raw_ostream & {
323 if (emitForOp)
324 return os << op.getGetterName(name: attrName) << "AttrName()";
325 return os << formatv(Fmt: "{0}::{1}AttrName(*odsOpName)", Vals: op.getCppClassName(),
326 Vals: op.getGetterName(name: attrName));
327 };
328 }
329
330 // Get the code snippet for getting the named attribute range.
331 StringRef getAttrRange() const {
332 return emitForOp ? "(*this)->getAttrs()" : "odsAttrs";
333 }
334
335 // Get the prefix code for emitting an error.
336 Formatter emitErrorPrefix() const {
337 return [this](raw_ostream &os) -> raw_ostream & {
338 if (emitForOp)
339 return os << "emitOpError(";
340 return os << formatv(Fmt: "emitError(loc, \"'{0}' op \"",
341 Vals: op.getOperationName());
342 };
343 }
344
345 // Get the call to get an operand or segment of operands.
346 Formatter getOperand(unsigned index) const {
347 return [this, index](raw_ostream &os) -> raw_ostream & {
348 return os << formatv(Fmt: op.getOperand(index).isVariadic()
349 ? "this->getODSOperands({0})"
350 : "(*this->getODSOperands({0}).begin())",
351 Vals: index);
352 };
353 }
354
355 // Get the call to get a result of segment of results.
356 Formatter getResult(unsigned index) const {
357 return [this, index](raw_ostream &os) -> raw_ostream & {
358 if (!emitForOp)
359 return os << "<no results should be generated>";
360 return os << formatv(Fmt: op.getResult(index).isVariadic()
361 ? "this->getODSResults({0})"
362 : "(*this->getODSResults({0}).begin())",
363 Vals: index);
364 };
365 }
366
367 // Return whether an op instance is available.
368 bool isEmittingForOp() const { return emitForOp; }
369
370 // Return the ODS operation wrapper.
371 const Operator &getOp() const { return op; }
372
373 // Get the attribute metadata sorted by name.
374 const llvm::MapVector<StringRef, AttributeMetadata> &getAttrMetadata() const {
375 return attrMetadata;
376 }
377
378 /// Returns whether to emit a `Properties` struct for this operation or not.
379 bool hasProperties() const {
380 if (!op.getProperties().empty())
381 return true;
382 if (!op.getDialect().usePropertiesForAttributes())
383 return false;
384 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments") ||
385 op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments"))
386 return true;
387 return llvm::any_of(Range: getAttrMetadata(),
388 P: [](const std::pair<StringRef, AttributeMetadata> &it) {
389 return !it.second.constraint ||
390 !it.second.constraint->isDerivedAttr();
391 });
392 }
393
394 std::optional<NamedProperty> &getOperandSegmentsSize() {
395 return operandSegmentsSize;
396 }
397
398 std::optional<NamedProperty> &getResultSegmentsSize() {
399 return resultSegmentsSize;
400 }
401
402 uint32_t getOperandSegmentSizesLegacyIndex() {
403 return operandSegmentSizesLegacyIndex;
404 }
405
406 uint32_t getResultSegmentSizesLegacyIndex() {
407 return resultSegmentSizesLegacyIndex;
408 }
409
410private:
411 // Compute the attribute metadata.
412 void computeAttrMetadata();
413
414 // The operation ODS wrapper.
415 const Operator &op;
416 // True if code is being generate for an op. False for an adaptor.
417 const bool emitForOp;
418
419 // The attribute metadata, mapped by name.
420 llvm::MapVector<StringRef, AttributeMetadata> attrMetadata;
421
422 // Property
423 std::optional<NamedProperty> operandSegmentsSize;
424 std::string operandSegmentsSizeStorage;
425 std::optional<NamedProperty> resultSegmentsSize;
426 std::string resultSegmentsSizeStorage;
427
428 // Indices to store the position in the emission order of the operand/result
429 // segment sizes attribute if emitted as part of the properties for legacy
430 // bytecode encodings, i.e. versions less than 6.
431 uint32_t operandSegmentSizesLegacyIndex = 0;
432 uint32_t resultSegmentSizesLegacyIndex = 0;
433
434 // The number of required attributes.
435 unsigned numRequired;
436};
437
438} // namespace
439
440void OpOrAdaptorHelper::computeAttrMetadata() {
441 // Enumerate the attribute names of this op, ensuring the attribute names are
442 // unique in case implicit attributes are explicitly registered.
443 for (const NamedAttribute &namedAttr : op.getAttributes()) {
444 Attribute attr = namedAttr.attr;
445 bool isOptional =
446 attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr();
447 attrMetadata.insert(
448 KV: {namedAttr.name, AttributeMetadata{.attrName: namedAttr.name, .isRequired: !isOptional, .constraint: attr}});
449 }
450
451 auto makeProperty = [&](StringRef storageType) {
452 return Property(
453 /*storageType=*/storageType,
454 /*interfaceType=*/"::llvm::ArrayRef<int32_t>",
455 /*convertFromStorageCall=*/"$_storage",
456 /*assignToStorageCall=*/
457 "::llvm::copy($_value, $_storage.begin())",
458 /*convertToAttributeCall=*/
459 "::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage)",
460 /*convertFromAttributeCall=*/
461 "return convertFromAttribute($_storage, $_attr, $_diag);",
462 /*readFromMlirBytecodeCall=*/readBytecodeSegmentSizeNative,
463 /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSizeNative,
464 /*hashPropertyCall=*/
465 "::llvm::hash_combine_range(std::begin($_storage), "
466 "std::end($_storage));",
467 /*StringRef defaultValue=*/"");
468 };
469 // Include key attributes from several traits as implicitly registered.
470 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
471 if (op.getDialect().usePropertiesForAttributes()) {
472 operandSegmentsSizeStorage =
473 llvm::formatv(Fmt: "std::array<int32_t, {0}>", Vals: op.getNumOperands());
474 operandSegmentsSize = {.name: "operandSegmentSizes",
475 .prop: makeProperty(operandSegmentsSizeStorage)};
476 } else {
477 attrMetadata.insert(
478 KV: {operandSegmentAttrName, AttributeMetadata{.attrName: operandSegmentAttrName,
479 /*isRequired=*/true,
480 /*attr=*/.constraint: std::nullopt}});
481 }
482 }
483 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments")) {
484 if (op.getDialect().usePropertiesForAttributes()) {
485 resultSegmentsSizeStorage =
486 llvm::formatv(Fmt: "std::array<int32_t, {0}>", Vals: op.getNumResults());
487 resultSegmentsSize = {.name: "resultSegmentSizes",
488 .prop: makeProperty(resultSegmentsSizeStorage)};
489 } else {
490 attrMetadata.insert(
491 KV: {resultSegmentAttrName,
492 AttributeMetadata{.attrName: resultSegmentAttrName, /*isRequired=*/true,
493 /*attr=*/.constraint: std::nullopt}});
494 }
495 }
496
497 // Store the metadata in sorted order.
498 SmallVector<AttributeMetadata> sortedAttrMetadata =
499 llvm::to_vector(Range: llvm::make_second_range(c: attrMetadata.takeVector()));
500 llvm::sort(C&: sortedAttrMetadata,
501 Comp: [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) {
502 return lhs.attrName < rhs.attrName;
503 });
504
505 // Store the position of the legacy operand_segment_sizes /
506 // result_segment_sizes so we can emit a backward compatible property readers
507 // and writers.
508 StringRef legacyOperandSegmentSizeName =
509 StringLiteral("operand_segment_sizes");
510 StringRef legacyResultSegmentSizeName = StringLiteral("result_segment_sizes");
511 operandSegmentSizesLegacyIndex = 0;
512 resultSegmentSizesLegacyIndex = 0;
513 for (auto item : sortedAttrMetadata) {
514 if (item.attrName < legacyOperandSegmentSizeName)
515 ++operandSegmentSizesLegacyIndex;
516 if (item.attrName < legacyResultSegmentSizeName)
517 ++resultSegmentSizesLegacyIndex;
518 }
519
520 // Compute the subrange bounds for each attribute.
521 numRequired = 0;
522 for (AttributeMetadata &attr : sortedAttrMetadata) {
523 attr.lowerBound = numRequired;
524 numRequired += attr.isRequired;
525 };
526 for (AttributeMetadata &attr : sortedAttrMetadata)
527 attr.upperBound = numRequired - attr.lowerBound - attr.isRequired;
528
529 // Store the results back into the map.
530 for (const AttributeMetadata &attr : sortedAttrMetadata)
531 attrMetadata.insert(KV: {attr.attrName, attr});
532}
533
534//===----------------------------------------------------------------------===//
535// Op emitter
536//===----------------------------------------------------------------------===//
537
538namespace {
539// Helper class to emit a record into the given output stream.
540class OpEmitter {
541 using ConstArgument =
542 llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
543
544public:
545 static void
546 emitDecl(const Operator &op, raw_ostream &os,
547 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
548 static void
549 emitDef(const Operator &op, raw_ostream &os,
550 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
551
552private:
553 OpEmitter(const Operator &op,
554 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
555
556 void emitDecl(raw_ostream &os);
557 void emitDef(raw_ostream &os);
558
559 // Generate methods for accessing the attribute names of this operation.
560 void genAttrNameGetters();
561
562 // Generates the OpAsmOpInterface for this operation if possible.
563 void genOpAsmInterface();
564
565 // Generates the `getOperationName` method for this op.
566 void genOpNameGetter();
567
568 // Generates code to manage the properties, if any!
569 void genPropertiesSupport();
570
571 // Generates code to manage the encoding of properties to bytecode.
572 void
573 genPropertiesSupportForBytecode(ArrayRef<ConstArgument> attrOrProperties);
574
575 // Generates getters for the attributes.
576 void genAttrGetters();
577
578 // Generates setter for the attributes.
579 void genAttrSetters();
580
581 // Generates removers for optional attributes.
582 void genOptionalAttrRemovers();
583
584 // Generates getters for named operands.
585 void genNamedOperandGetters();
586
587 // Generates setters for named operands.
588 void genNamedOperandSetters();
589
590 // Generates getters for named results.
591 void genNamedResultGetters();
592
593 // Generates getters for named regions.
594 void genNamedRegionGetters();
595
596 // Generates getters for named successors.
597 void genNamedSuccessorGetters();
598
599 // Generates the method to populate default attributes.
600 void genPopulateDefaultAttributes();
601
602 // Generates builder methods for the operation.
603 void genBuilder();
604
605 // Generates the build() method that takes each operand/attribute
606 // as a stand-alone parameter.
607 void genSeparateArgParamBuilder();
608
609 // Generates the build() method that takes each operand/attribute as a
610 // stand-alone parameter. The generated build() method uses first operand's
611 // type as all results' types.
612 void genUseOperandAsResultTypeSeparateParamBuilder();
613
614 // Generates the build() method that takes all operands/attributes
615 // collectively as one parameter. The generated build() method uses first
616 // operand's type as all results' types.
617 void genUseOperandAsResultTypeCollectiveParamBuilder();
618
619 // Generates the build() method that takes aggregate operands/attributes
620 // parameters. This build() method uses inferred types as result types.
621 // Requires: The type needs to be inferable via InferTypeOpInterface.
622 void genInferredTypeCollectiveParamBuilder();
623
624 // Generates the build() method that takes each operand/attribute as a
625 // stand-alone parameter. The generated build() method uses first attribute's
626 // type as all result's types.
627 void genUseAttrAsResultTypeBuilder();
628
629 // Generates the build() method that takes all result types collectively as
630 // one parameter. Similarly for operands and attributes.
631 void genCollectiveParamBuilder();
632
633 // The kind of parameter to generate for result types in builders.
634 enum class TypeParamKind {
635 None, // No result type in parameter list.
636 Separate, // A separate parameter for each result type.
637 Collective, // An ArrayRef<Type> for all result types.
638 };
639
640 // The kind of parameter to generate for attributes in builders.
641 enum class AttrParamKind {
642 WrappedAttr, // A wrapped MLIR Attribute instance.
643 UnwrappedValue, // A raw value without MLIR Attribute wrapper.
644 };
645
646 // Builds the parameter list for build() method of this op. This method writes
647 // to `paramList` the comma-separated parameter list and updates
648 // `resultTypeNames` with the names for parameters for specifying result
649 // types. `inferredAttributes` is populated with any attributes that are
650 // elided from the build list. The given `typeParamKind` and `attrParamKind`
651 // controls how result types and attributes are placed in the parameter list.
652 void buildParamList(SmallVectorImpl<MethodParameter> &paramList,
653 llvm::StringSet<> &inferredAttributes,
654 SmallVectorImpl<std::string> &resultTypeNames,
655 TypeParamKind typeParamKind,
656 AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
657
658 // Adds op arguments and regions into operation state for build() methods.
659 void
660 genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
661 llvm::StringSet<> &inferredAttributes,
662 bool isRawValueAttr = false);
663
664 // Generates canonicalizer declaration for the operation.
665 void genCanonicalizerDecls();
666
667 // Generates the folder declaration for the operation.
668 void genFolderDecls();
669
670 // Generates the parser for the operation.
671 void genParser();
672
673 // Generates the printer for the operation.
674 void genPrinter();
675
676 // Generates verify method for the operation.
677 void genVerifier();
678
679 // Generates custom verify methods for the operation.
680 void genCustomVerifier();
681
682 // Generates verify statements for operands and results in the operation.
683 // The generated code will be attached to `body`.
684 void genOperandResultVerifier(MethodBody &body,
685 Operator::const_value_range values,
686 StringRef valueKind);
687
688 // Generates verify statements for regions in the operation.
689 // The generated code will be attached to `body`.
690 void genRegionVerifier(MethodBody &body);
691
692 // Generates verify statements for successors in the operation.
693 // The generated code will be attached to `body`.
694 void genSuccessorVerifier(MethodBody &body);
695
696 // Generates the traits used by the object.
697 void genTraits();
698
699 // Generate the OpInterface methods for all interfaces.
700 void genOpInterfaceMethods();
701
702 // Generate op interface methods for the given interface.
703 void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
704
705 // Generate op interface method for the given interface method. If
706 // 'declaration' is true, generates a declaration, else a definition.
707 Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
708 bool declaration = true);
709
710 // Generate the side effect interface methods.
711 void genSideEffectInterfaceMethods();
712
713 // Generate the type inference interface methods.
714 void genTypeInterfaceMethods();
715
716private:
717 // The TableGen record for this op.
718 // TODO: OpEmitter should not have a Record directly,
719 // it should rather go through the Operator for better abstraction.
720 const Record &def;
721
722 // The wrapper operator class for querying information from this op.
723 const Operator &op;
724
725 // The C++ code builder for this op
726 OpClass opClass;
727
728 // The format context for verification code generation.
729 FmtContext verifyCtx;
730
731 // The emitter containing all of the locally emitted verification functions.
732 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
733
734 // Helper for emitting op code.
735 OpOrAdaptorHelper emitHelper;
736};
737
738} // namespace
739
740// Populate the format context `ctx` with substitutions of attributes, operands
741// and results.
742static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
743 FmtContext &ctx) {
744 // Populate substitutions for attributes.
745 auto &op = emitHelper.getOp();
746 for (const auto &namedAttr : op.getAttributes())
747 ctx.addSubst(placeholder: namedAttr.name,
748 subst: emitHelper.getOp().getGetterName(name: namedAttr.name) + "()");
749
750 // Populate substitutions for named operands.
751 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
752 auto &value = op.getOperand(index: i);
753 if (!value.name.empty())
754 ctx.addSubst(placeholder: value.name, subst: emitHelper.getOperand(index: i).str());
755 }
756
757 // Populate substitutions for results.
758 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
759 auto &value = op.getResult(index: i);
760 if (!value.name.empty())
761 ctx.addSubst(placeholder: value.name, subst: emitHelper.getResult(index: i).str());
762 }
763}
764
765/// Generate verification on native traits requiring attributes.
766static void genNativeTraitAttrVerifier(MethodBody &body,
767 const OpOrAdaptorHelper &emitHelper) {
768 // Check that the variadic segment sizes attribute exists and contains the
769 // expected number of elements.
770 //
771 // {0}: Attribute name.
772 // {1}: Expected number of elements.
773 // {2}: "operand" or "result".
774 // {3}: Emit error prefix.
775 const char *const checkAttrSizedValueSegmentsCode = R"(
776 {
777 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>(tblgen_{0});
778 auto numElements = sizeAttr.asArrayRef().size();
779 if (numElements != {1})
780 return {3}"'{0}' attribute for specifying {2} segments must have {1} "
781 "elements, but got ") << numElements;
782 }
783 )";
784
785 // Verify a few traits first so that we can use getODSOperands() and
786 // getODSResults() in the rest of the verifier.
787 auto &op = emitHelper.getOp();
788 if (!op.getDialect().usePropertiesForAttributes()) {
789 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
790 body << formatv(Fmt: checkAttrSizedValueSegmentsCode, Vals: operandSegmentAttrName,
791 Vals: op.getNumOperands(), Vals: "operand",
792 Vals: emitHelper.emitErrorPrefix());
793 }
794 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments")) {
795 body << formatv(Fmt: checkAttrSizedValueSegmentsCode, Vals: resultSegmentAttrName,
796 Vals: op.getNumResults(), Vals: "result",
797 Vals: emitHelper.emitErrorPrefix());
798 }
799 }
800}
801
802// Return true if a verifier can be emitted for the attribute: it is not a
803// derived attribute, it has a predicate, its condition is not empty, and, for
804// adaptors, the condition does not reference the op.
805static bool canEmitAttrVerifier(Attribute attr, bool isEmittingForOp) {
806 if (attr.isDerivedAttr())
807 return false;
808 Pred pred = attr.getPredicate();
809 if (pred.isNull())
810 return false;
811 std::string condition = pred.getCondition();
812 return !condition.empty() &&
813 (!StringRef(condition).contains(Other: "$_op") || isEmittingForOp);
814}
815
816// Generate attribute verification. If an op instance is not available, then
817// attribute checks that require one will not be emitted.
818//
819// Attribute verification is performed as follows:
820//
821// 1. Verify that all required attributes are present in sorted order. This
822// ensures that we can use subrange lookup even with potentially missing
823// attributes.
824// 2. Verify native trait attributes so that other attributes may call methods
825// that depend on the validity of these attributes, e.g. segment size attributes
826// and operand or result getters.
827// 3. Verify the constraints on all present attributes.
828static void
829genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx,
830 MethodBody &body,
831 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
832 bool useProperties) {
833 if (emitHelper.getAttrMetadata().empty())
834 return;
835
836 // Verify the attribute if it is present. This assumes that default values
837 // are valid. This code snippet pastes the condition inline.
838 //
839 // TODO: verify the default value is valid (perhaps in debug mode only).
840 //
841 // {0}: Attribute variable name.
842 // {1}: Attribute condition code.
843 // {2}: Emit error prefix.
844 // {3}: Attribute name.
845 // {4}: Attribute/constraint description.
846 const char *const verifyAttrInline = R"(
847 if ({0} && !({1}))
848 return {2}"attribute '{3}' failed to satisfy constraint: {4}");
849)";
850 // Verify the attribute using a uniqued constraint. Can only be used within
851 // the context of an op.
852 //
853 // {0}: Unique constraint name.
854 // {1}: Attribute variable name.
855 // {2}: Attribute name.
856 const char *const verifyAttrUnique = R"(
857 if (::mlir::failed({0}(*this, {1}, "{2}")))
858 return ::mlir::failure();
859)";
860
861 // Traverse the array until the required attribute is found. Return an error
862 // if the traversal reached the end.
863 //
864 // {0}: Code to get the name of the attribute.
865 // {1}: The emit error prefix.
866 // {2}: The name of the attribute.
867 const char *const findRequiredAttr = R"(
868while (true) {{
869 if (namedAttrIt == namedAttrRange.end())
870 return {1}"requires attribute '{2}'");
871 if (namedAttrIt->getName() == {0}) {{
872 tblgen_{2} = namedAttrIt->getValue();
873 break;
874 })";
875
876 // Emit a check to see if the iteration has encountered an optional attribute.
877 //
878 // {0}: Code to get the name of the attribute.
879 // {1}: The name of the attribute.
880 const char *const checkOptionalAttr = R"(
881 else if (namedAttrIt->getName() == {0}) {{
882 tblgen_{1} = namedAttrIt->getValue();
883 })";
884
885 // Emit the start of the loop for checking trailing attributes.
886 const char *const checkTrailingAttrs = R"(while (true) {
887 if (namedAttrIt == namedAttrRange.end()) {
888 break;
889 })";
890
891 // Emit the verifier for the attribute.
892 const auto emitVerifier = [&](Attribute attr, StringRef attrName,
893 StringRef varName) {
894 std::string condition = attr.getPredicate().getCondition();
895
896 std::optional<StringRef> constraintFn;
897 if (emitHelper.isEmittingForOp() &&
898 (constraintFn = staticVerifierEmitter.getAttrConstraintFn(constraint: attr))) {
899 body << formatv(Fmt: verifyAttrUnique, Vals&: *constraintFn, Vals&: varName, Vals&: attrName);
900 } else {
901 body << formatv(Fmt: verifyAttrInline, Vals&: varName,
902 Vals: tgfmt(fmt: condition, ctx: &ctx.withSelf(subst: varName)),
903 Vals: emitHelper.emitErrorPrefix(), Vals&: attrName,
904 Vals: escapeString(value: attr.getSummary()));
905 }
906 };
907
908 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
909 const auto getVarName = [&](StringRef attrName) {
910 return (tblgenNamePrefix + attrName).str();
911 };
912
913 body.indent();
914 if (useProperties) {
915 for (const std::pair<StringRef, AttributeMetadata> &it :
916 emitHelper.getAttrMetadata()) {
917 const AttributeMetadata &metadata = it.second;
918 if (metadata.constraint && metadata.constraint->isDerivedAttr())
919 continue;
920 body << formatv(
921 Fmt: "auto tblgen_{0} = getProperties().{0}; (void)tblgen_{0};\n",
922 Vals: it.first);
923 if (metadata.isRequired)
924 body << formatv(
925 Fmt: "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n",
926 Vals: it.first, Vals: emitHelper.emitErrorPrefix());
927 }
928 } else {
929 body << formatv(Fmt: "auto namedAttrRange = {0};\n", Vals: emitHelper.getAttrRange());
930 body << "auto namedAttrIt = namedAttrRange.begin();\n";
931
932 // Iterate over the attributes in sorted order. Keep track of the optional
933 // attributes that may be encountered along the way.
934 SmallVector<const AttributeMetadata *> optionalAttrs;
935
936 for (const std::pair<StringRef, AttributeMetadata> &it :
937 emitHelper.getAttrMetadata()) {
938 const AttributeMetadata &metadata = it.second;
939 if (!metadata.isRequired) {
940 optionalAttrs.push_back(Elt: &metadata);
941 continue;
942 }
943
944 body << formatv(Fmt: "::mlir::Attribute {0};\n", Vals: getVarName(it.first));
945 for (const AttributeMetadata *optional : optionalAttrs) {
946 body << formatv(Fmt: "::mlir::Attribute {0};\n",
947 Vals: getVarName(optional->attrName));
948 }
949 body << formatv(Fmt: findRequiredAttr, Vals: emitHelper.getAttrName(attrName: it.first),
950 Vals: emitHelper.emitErrorPrefix(), Vals: it.first);
951 for (const AttributeMetadata *optional : optionalAttrs) {
952 body << formatv(Fmt: checkOptionalAttr,
953 Vals: emitHelper.getAttrName(attrName: optional->attrName),
954 Vals: optional->attrName);
955 }
956 body << "\n ++namedAttrIt;\n}\n";
957 optionalAttrs.clear();
958 }
959 // Get trailing optional attributes.
960 if (!optionalAttrs.empty()) {
961 for (const AttributeMetadata *optional : optionalAttrs) {
962 body << formatv(Fmt: "::mlir::Attribute {0};\n",
963 Vals: getVarName(optional->attrName));
964 }
965 body << checkTrailingAttrs;
966 for (const AttributeMetadata *optional : optionalAttrs) {
967 body << formatv(Fmt: checkOptionalAttr,
968 Vals: emitHelper.getAttrName(attrName: optional->attrName),
969 Vals: optional->attrName);
970 }
971 body << "\n ++namedAttrIt;\n}\n";
972 }
973 }
974 body.unindent();
975
976 // Emit the checks for segment attributes first so that the other
977 // constraints can call operand and result getters.
978 genNativeTraitAttrVerifier(body, emitHelper);
979
980 bool isEmittingForOp = emitHelper.isEmittingForOp();
981 for (const auto &namedAttr : emitHelper.getOp().getAttributes())
982 if (canEmitAttrVerifier(attr: namedAttr.attr, isEmittingForOp))
983 emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
984}
985
986/// Include declarations specified on NativeTrait
987static std::string formatExtraDeclarations(const Operator &op) {
988 SmallVector<StringRef> extraDeclarations;
989 // Include extra class declarations from NativeTrait
990 for (const auto &trait : op.getTraits()) {
991 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
992 StringRef value = opTrait->getExtraConcreteClassDeclaration();
993 if (value.empty())
994 continue;
995 extraDeclarations.push_back(Elt: value);
996 }
997 }
998 extraDeclarations.push_back(Elt: op.getExtraClassDeclaration());
999 return llvm::join(R&: extraDeclarations, Separator: "\n");
1000}
1001
1002/// Op extra class definitions have a `$cppClass` substitution that is to be
1003/// replaced by the C++ class name.
1004/// Include declarations specified on NativeTrait
1005static std::string formatExtraDefinitions(const Operator &op) {
1006 SmallVector<StringRef> extraDefinitions;
1007 // Include extra class definitions from NativeTrait
1008 for (const auto &trait : op.getTraits()) {
1009 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
1010 StringRef value = opTrait->getExtraConcreteClassDefinition();
1011 if (value.empty())
1012 continue;
1013 extraDefinitions.push_back(Elt: value);
1014 }
1015 }
1016 extraDefinitions.push_back(Elt: op.getExtraClassDefinition());
1017 FmtContext ctx = FmtContext().addSubst(placeholder: "cppClass", subst: op.getCppClassName());
1018 return tgfmt(fmt: llvm::join(R&: extraDefinitions, Separator: "\n"), ctx: &ctx).str();
1019}
1020
1021OpEmitter::OpEmitter(const Operator &op,
1022 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
1023 : def(op.getDef()), op(op),
1024 opClass(op.getCppClassName(), formatExtraDeclarations(op),
1025 formatExtraDefinitions(op)),
1026 staticVerifierEmitter(staticVerifierEmitter),
1027 emitHelper(op, /*emitForOp=*/true) {
1028 verifyCtx.addSubst(placeholder: "_op", subst: "(*this->getOperation())");
1029 verifyCtx.addSubst(placeholder: "_ctxt", subst: "this->getOperation()->getContext()");
1030
1031 genTraits();
1032
1033 // Generate C++ code for various op methods. The order here determines the
1034 // methods in the generated file.
1035 genAttrNameGetters();
1036 genOpAsmInterface();
1037 genOpNameGetter();
1038 genNamedOperandGetters();
1039 genNamedOperandSetters();
1040 genNamedResultGetters();
1041 genNamedRegionGetters();
1042 genNamedSuccessorGetters();
1043 genPropertiesSupport();
1044 genAttrGetters();
1045 genAttrSetters();
1046 genOptionalAttrRemovers();
1047 genBuilder();
1048 genPopulateDefaultAttributes();
1049 genParser();
1050 genPrinter();
1051 genVerifier();
1052 genCustomVerifier();
1053 genCanonicalizerDecls();
1054 genFolderDecls();
1055 genTypeInterfaceMethods();
1056 genOpInterfaceMethods();
1057 generateOpFormat(constOp: op, opClass);
1058 genSideEffectInterfaceMethods();
1059}
1060void OpEmitter::emitDecl(
1061 const Operator &op, raw_ostream &os,
1062 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
1063 OpEmitter(op, staticVerifierEmitter).emitDecl(os);
1064}
1065
1066void OpEmitter::emitDef(
1067 const Operator &op, raw_ostream &os,
1068 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
1069 OpEmitter(op, staticVerifierEmitter).emitDef(os);
1070}
1071
1072void OpEmitter::emitDecl(raw_ostream &os) {
1073 opClass.finalize();
1074 opClass.writeDeclTo(rawOs&: os);
1075}
1076
1077void OpEmitter::emitDef(raw_ostream &os) {
1078 opClass.finalize();
1079 opClass.writeDefTo(rawOs&: os);
1080}
1081
1082static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
1083 const Operator &op) {
1084 if (m)
1085 return;
1086 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "Unexpected overlap when generating `" +
1087 methodName + "` for " +
1088 op.getOperationName() + " (from line " +
1089 Twine(line) + ")");
1090}
1091
1092#define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
1093
1094void OpEmitter::genAttrNameGetters() {
1095 const llvm::MapVector<StringRef, AttributeMetadata> &attributes =
1096 emitHelper.getAttrMetadata();
1097 bool hasOperandSegmentsSize =
1098 op.getDialect().usePropertiesForAttributes() &&
1099 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments");
1100 // Emit the getAttributeNames method.
1101 {
1102 auto *method = opClass.addStaticInlineMethod(
1103 retType: "::llvm::ArrayRef<::llvm::StringRef>", name: "getAttributeNames");
1104 ERROR_IF_PRUNED(method, "getAttributeNames", op);
1105 auto &body = method->body();
1106 if (!hasOperandSegmentsSize && attributes.empty()) {
1107 body << " return {};";
1108 // Nothing else to do if there are no registered attributes. Exit early.
1109 return;
1110 }
1111 body << " static ::llvm::StringRef attrNames[] = {";
1112 llvm::interleaveComma(c: llvm::make_first_range(c: attributes), os&: body,
1113 each_fn: [&](StringRef attrName) {
1114 body << "::llvm::StringRef(\"" << attrName << "\")";
1115 });
1116 if (hasOperandSegmentsSize) {
1117 if (!attributes.empty())
1118 body << ", ";
1119 body << "::llvm::StringRef(\"" << operandSegmentAttrName << "\")";
1120 }
1121 body << "};\n return ::llvm::ArrayRef(attrNames);";
1122 }
1123
1124 // Emit the getAttributeNameForIndex methods.
1125 {
1126 auto *method = opClass.addInlineMethod<Method::Private>(
1127 retType: "::mlir::StringAttr", name: "getAttributeNameForIndex",
1128 args: MethodParameter("unsigned", "index"));
1129 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
1130 method->body()
1131 << " return getAttributeNameForIndex((*this)->getName(), index);";
1132 }
1133 {
1134 auto *method = opClass.addStaticInlineMethod<Method::Private>(
1135 retType: "::mlir::StringAttr", name: "getAttributeNameForIndex",
1136 args: MethodParameter("::mlir::OperationName", "name"),
1137 args: MethodParameter("unsigned", "index"));
1138 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
1139
1140 if (attributes.empty()) {
1141 method->body() << " return {};";
1142 } else {
1143 const char *const getAttrName = R"(
1144 assert(index < {0} && "invalid attribute index");
1145 assert(name.getStringRef() == getOperationName() && "invalid operation name");
1146 assert(name.isRegistered() && "Operation isn't registered, missing a "
1147 "dependent dialect loading?");
1148 return name.getAttributeNames()[index];
1149)";
1150 method->body() << formatv(Fmt: getAttrName, Vals: attributes.size());
1151 }
1152 }
1153
1154 // Generate the <attr>AttrName methods, that expose the attribute names to
1155 // users.
1156 const char *attrNameMethodBody = " return getAttributeNameForIndex({0});";
1157 for (auto [index, attr] :
1158 llvm::enumerate(First: llvm::make_first_range(c: attributes))) {
1159 std::string name = op.getGetterName(name: attr);
1160 std::string methodName = name + "AttrName";
1161
1162 // Generate the non-static variant.
1163 {
1164 auto *method = opClass.addInlineMethod(retType: "::mlir::StringAttr", name&: methodName);
1165 ERROR_IF_PRUNED(method, methodName, op);
1166 method->body() << llvm::formatv(Fmt: attrNameMethodBody, Vals&: index);
1167 }
1168
1169 // Generate the static variant.
1170 {
1171 auto *method = opClass.addStaticInlineMethod(
1172 retType: "::mlir::StringAttr", name&: methodName,
1173 args: MethodParameter("::mlir::OperationName", "name"));
1174 ERROR_IF_PRUNED(method, methodName, op);
1175 method->body() << llvm::formatv(Fmt: attrNameMethodBody,
1176 Vals: "name, " + Twine(index));
1177 }
1178 }
1179 if (hasOperandSegmentsSize) {
1180 std::string name = op.getGetterName(name: operandSegmentAttrName);
1181 std::string methodName = name + "AttrName";
1182 // Generate the non-static variant.
1183 {
1184 auto *method = opClass.addInlineMethod(retType: "::mlir::StringAttr", name&: methodName);
1185 ERROR_IF_PRUNED(method, methodName, op);
1186 method->body()
1187 << " return (*this)->getName().getAttributeNames().back();";
1188 }
1189
1190 // Generate the static variant.
1191 {
1192 auto *method = opClass.addStaticInlineMethod(
1193 retType: "::mlir::StringAttr", name&: methodName,
1194 args: MethodParameter("::mlir::OperationName", "name"));
1195 ERROR_IF_PRUNED(method, methodName, op);
1196 method->body() << " return name.getAttributeNames().back();";
1197 }
1198 }
1199}
1200
1201// Emit the getter for an attribute with the return type specified.
1202// It is templated to be shared between the Op and the adaptor class.
1203template <typename OpClassOrAdaptor>
1204static void emitAttrGetterWithReturnType(FmtContext &fctx,
1205 OpClassOrAdaptor &opClass,
1206 const Operator &op, StringRef name,
1207 Attribute attr) {
1208 auto *method = opClass.addMethod(attr.getReturnType(), name);
1209 ERROR_IF_PRUNED(method, name, op);
1210 auto &body = method->body();
1211 body << " auto attr = " << name << "Attr();\n";
1212 if (attr.hasDefaultValue() && attr.isOptional()) {
1213 // Returns the default value if not set.
1214 // TODO: this is inefficient, we are recreating the attribute for every
1215 // call. This should be set instead.
1216 if (!attr.isConstBuildable()) {
1217 PrintFatalError(Msg: "DefaultValuedAttr of type " + attr.getAttrDefName() +
1218 " must have a constBuilder");
1219 }
1220 std::string defaultValue = std::string(
1221 tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, vals: attr.getDefaultValue()));
1222 body << " if (!attr)\n return "
1223 << tgfmt(fmt: attr.getConvertFromStorageCall(),
1224 ctx: &fctx.withSelf(subst: defaultValue))
1225 << ";\n";
1226 }
1227 body << " return "
1228 << tgfmt(fmt: attr.getConvertFromStorageCall(), ctx: &fctx.withSelf(subst: "attr"))
1229 << ";\n";
1230}
1231
1232void OpEmitter::genPropertiesSupport() {
1233 if (!emitHelper.hasProperties())
1234 return;
1235
1236 SmallVector<ConstArgument> attrOrProperties;
1237 for (const std::pair<StringRef, AttributeMetadata> &it :
1238 emitHelper.getAttrMetadata()) {
1239 if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
1240 attrOrProperties.push_back(Elt: &it.second);
1241 }
1242 for (const NamedProperty &prop : op.getProperties())
1243 attrOrProperties.push_back(Elt: &prop);
1244 if (emitHelper.getOperandSegmentsSize())
1245 attrOrProperties.push_back(Elt: &emitHelper.getOperandSegmentsSize().value());
1246 if (emitHelper.getResultSegmentsSize())
1247 attrOrProperties.push_back(Elt: &emitHelper.getResultSegmentsSize().value());
1248 if (attrOrProperties.empty())
1249 return;
1250 auto &setPropMethod =
1251 opClass
1252 .addStaticMethod(
1253 retType: "::mlir::LogicalResult", name: "setPropertiesFromAttr",
1254 args: MethodParameter("Properties &", "prop"),
1255 args: MethodParameter("::mlir::Attribute", "attr"),
1256 args: MethodParameter(
1257 "::llvm::function_ref<::mlir::InFlightDiagnostic()>",
1258 "emitError"))
1259 ->body();
1260 auto &getPropMethod =
1261 opClass
1262 .addStaticMethod(retType: "::mlir::Attribute", name: "getPropertiesAsAttr",
1263 args: MethodParameter("::mlir::MLIRContext *", "ctx"),
1264 args: MethodParameter("const Properties &", "prop"))
1265 ->body();
1266 auto &hashMethod =
1267 opClass
1268 .addStaticMethod(retType: "llvm::hash_code", name: "computePropertiesHash",
1269 args: MethodParameter("const Properties &", "prop"))
1270 ->body();
1271 auto &getInherentAttrMethod =
1272 opClass
1273 .addStaticMethod(retType: "std::optional<mlir::Attribute>", name: "getInherentAttr",
1274 args: MethodParameter("::mlir::MLIRContext *", "ctx"),
1275 args: MethodParameter("const Properties &", "prop"),
1276 args: MethodParameter("llvm::StringRef", "name"))
1277 ->body();
1278 auto &setInherentAttrMethod =
1279 opClass
1280 .addStaticMethod(retType: "void", name: "setInherentAttr",
1281 args: MethodParameter("Properties &", "prop"),
1282 args: MethodParameter("llvm::StringRef", "name"),
1283 args: MethodParameter("mlir::Attribute", "value"))
1284 ->body();
1285 auto &populateInherentAttrsMethod =
1286 opClass
1287 .addStaticMethod(retType: "void", name: "populateInherentAttrs",
1288 args: MethodParameter("::mlir::MLIRContext *", "ctx"),
1289 args: MethodParameter("const Properties &", "prop"),
1290 args: MethodParameter("::mlir::NamedAttrList &", "attrs"))
1291 ->body();
1292 auto &verifyInherentAttrsMethod =
1293 opClass
1294 .addStaticMethod(
1295 retType: "::mlir::LogicalResult", name: "verifyInherentAttrs",
1296 args: MethodParameter("::mlir::OperationName", "opName"),
1297 args: MethodParameter("::mlir::NamedAttrList &", "attrs"),
1298 args: MethodParameter(
1299 "llvm::function_ref<::mlir::InFlightDiagnostic()>",
1300 "emitError"))
1301 ->body();
1302
1303 opClass.declare<UsingDeclaration>(args: "Properties", args: "FoldAdaptor::Properties");
1304
1305 // Convert the property to the attribute form.
1306
1307 setPropMethod << R"decl(
1308 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
1309 if (!dict) {
1310 emitError() << "expected DictionaryAttr to set properties";
1311 return ::mlir::failure();
1312 }
1313 )decl";
1314 // TODO: properties might be optional as well.
1315 const char *propFromAttrFmt = R"decl(;
1316 {{
1317 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
1318 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{
1319 {0};
1320 };
1321 {2};
1322 if (!attr) {{
1323 emitError() << "expected key entry for {1} in DictionaryAttr to set "
1324 "Properties.";
1325 return ::mlir::failure();
1326 }
1327 if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError)))
1328 return ::mlir::failure();
1329 }
1330)decl";
1331
1332 for (const auto &attrOrProp : attrOrProperties) {
1333 if (const auto *namedProperty =
1334 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
1335 StringRef name = namedProperty->name;
1336 auto &prop = namedProperty->prop;
1337 FmtContext fctx;
1338
1339 std::string getAttr;
1340 llvm::raw_string_ostream os(getAttr);
1341 os << " auto attr = dict.get(\"" << name << "\");";
1342 if (name == operandSegmentAttrName) {
1343 // Backward compat for now, TODO: Remove at some point.
1344 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1345 }
1346 if (name == resultSegmentAttrName) {
1347 // Backward compat for now, TODO: Remove at some point.
1348 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1349 }
1350 os.flush();
1351
1352 setPropMethod << formatv(Fmt: propFromAttrFmt,
1353 Vals: tgfmt(fmt: prop.getConvertFromAttributeCall(),
1354 ctx: &fctx.addSubst(placeholder: "_attr", subst: propertyAttr)
1355 .addSubst(placeholder: "_storage", subst: propertyStorage)
1356 .addSubst(placeholder: "_diag", subst: propertyDiag)),
1357 Vals&: name, Vals&: getAttr);
1358
1359 } else {
1360 const auto *namedAttr =
1361 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
1362 StringRef name = namedAttr->attrName;
1363 std::string getAttr;
1364 llvm::raw_string_ostream os(getAttr);
1365 os << " auto attr = dict.get(\"" << name << "\");";
1366 if (name == operandSegmentAttrName) {
1367 // Backward compat for now
1368 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1369 }
1370 if (name == resultSegmentAttrName) {
1371 // Backward compat for now
1372 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1373 }
1374 os.flush();
1375
1376 setPropMethod << formatv(Fmt: R"decl(
1377 {{
1378 auto &propStorage = prop.{0};
1379 {2}
1380 if (attr || /*isRequired=*/{1}) {{
1381 if (!attr) {{
1382 emitError() << "expected key entry for {0} in DictionaryAttr to set "
1383 "Properties.";
1384 return ::mlir::failure();
1385 }
1386 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
1387 if (convertedAttr) {{
1388 propStorage = convertedAttr;
1389 } else {{
1390 emitError() << "Invalid attribute `{0}` in property conversion: " << attr;
1391 return ::mlir::failure();
1392 }
1393 }
1394 }
1395)decl",
1396 Vals&: name, Vals: namedAttr->isRequired, Vals&: getAttr);
1397 }
1398 }
1399 setPropMethod << " return ::mlir::success();\n";
1400
1401 // Convert the attribute form to the property.
1402
1403 getPropMethod << " ::mlir::SmallVector<::mlir::NamedAttribute> attrs;\n"
1404 << " ::mlir::Builder odsBuilder{ctx};\n";
1405 const char *propToAttrFmt = R"decl(
1406 {
1407 const auto &propStorage = prop.{0};
1408 attrs.push_back(odsBuilder.getNamedAttr("{0}",
1409 {1}));
1410 }
1411)decl";
1412 for (const auto &attrOrProp : attrOrProperties) {
1413 if (const auto *namedProperty =
1414 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
1415 StringRef name = namedProperty->name;
1416 auto &prop = namedProperty->prop;
1417 FmtContext fctx;
1418 getPropMethod << formatv(
1419 Fmt: propToAttrFmt, Vals&: name,
1420 Vals: tgfmt(fmt: prop.getConvertToAttributeCall(),
1421 ctx: &fctx.addSubst(placeholder: "_ctxt", subst: "ctx")
1422 .addSubst(placeholder: "_storage", subst: propertyStorage)));
1423 continue;
1424 }
1425 const auto *namedAttr =
1426 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
1427 StringRef name = namedAttr->attrName;
1428 getPropMethod << formatv(Fmt: R"decl(
1429 {{
1430 const auto &propStorage = prop.{0};
1431 if (propStorage)
1432 attrs.push_back(odsBuilder.getNamedAttr("{0}",
1433 propStorage));
1434 }
1435)decl",
1436 Vals&: name);
1437 }
1438 getPropMethod << R"decl(
1439 if (!attrs.empty())
1440 return odsBuilder.getDictionaryAttr(attrs);
1441 return {};
1442)decl";
1443
1444 // Hashing for the property
1445
1446 const char *propHashFmt = R"decl(
1447 auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code {
1448 return {1};
1449 };
1450)decl";
1451 for (const auto &attrOrProp : attrOrProperties) {
1452 if (const auto *namedProperty =
1453 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
1454 StringRef name = namedProperty->name;
1455 auto &prop = namedProperty->prop;
1456 FmtContext fctx;
1457 hashMethod << formatv(Fmt: propHashFmt, Vals&: name,
1458 Vals: tgfmt(fmt: prop.getHashPropertyCall(),
1459 ctx: &fctx.addSubst(placeholder: "_storage", subst: propertyStorage)));
1460 }
1461 }
1462 hashMethod << " return llvm::hash_combine(";
1463 llvm::interleaveComma(
1464 c: attrOrProperties, os&: hashMethod, each_fn: [&](const ConstArgument &attrOrProp) {
1465 if (const auto *namedProperty =
1466 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
1467 hashMethod << "\n hash_" << namedProperty->name << "(prop."
1468 << namedProperty->name << ")";
1469 return;
1470 }
1471 const auto *namedAttr =
1472 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
1473 StringRef name = namedAttr->attrName;
1474 hashMethod << "\n llvm::hash_value(prop." << name
1475 << ".getAsOpaquePointer())";
1476 });
1477 hashMethod << ");\n";
1478
1479 const char *getInherentAttrMethodFmt = R"decl(
1480 if (name == "{0}")
1481 return prop.{0};
1482)decl";
1483 const char *setInherentAttrMethodFmt = R"decl(
1484 if (name == "{0}") {{
1485 prop.{0} = ::llvm::dyn_cast_or_null<std::remove_reference_t<decltype(prop.{0})>>(value);
1486 return;
1487 }
1488)decl";
1489 const char *populateInherentAttrsMethodFmt = R"decl(
1490 if (prop.{0}) attrs.append("{0}", prop.{0});
1491)decl";
1492 for (const auto &attrOrProp : attrOrProperties) {
1493 if (const auto *namedAttr =
1494 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp)) {
1495 StringRef name = namedAttr->attrName;
1496 getInherentAttrMethod << formatv(Fmt: getInherentAttrMethodFmt, Vals&: name);
1497 setInherentAttrMethod << formatv(Fmt: setInherentAttrMethodFmt, Vals&: name);
1498 populateInherentAttrsMethod
1499 << formatv(Fmt: populateInherentAttrsMethodFmt, Vals&: name);
1500 continue;
1501 }
1502 // The ODS segment size property is "special": we expose it as an attribute
1503 // even though it is a native property.
1504 const auto *namedProperty = cast<const NamedProperty *>(Val: attrOrProp);
1505 StringRef name = namedProperty->name;
1506 if (name != operandSegmentAttrName && name != resultSegmentAttrName)
1507 continue;
1508 auto &prop = namedProperty->prop;
1509 FmtContext fctx;
1510 fctx.addSubst(placeholder: "_ctxt", subst: "ctx");
1511 fctx.addSubst(placeholder: "_storage", subst: Twine("prop.") + name);
1512 if (name == operandSegmentAttrName) {
1513 getInherentAttrMethod
1514 << formatv(Fmt: " if (name == \"operand_segment_sizes\" || name == "
1515 "\"{0}\") return ",
1516 Vals: operandSegmentAttrName);
1517 } else {
1518 getInherentAttrMethod
1519 << formatv(Fmt: " if (name == \"result_segment_sizes\" || name == "
1520 "\"{0}\") return ",
1521 Vals: resultSegmentAttrName);
1522 }
1523 getInherentAttrMethod << tgfmt(fmt: prop.getConvertToAttributeCall(), ctx: &fctx)
1524 << ";\n";
1525
1526 if (name == operandSegmentAttrName) {
1527 setInherentAttrMethod
1528 << formatv(Fmt: " if (name == \"operand_segment_sizes\" || name == "
1529 "\"{0}\") {{",
1530 Vals: operandSegmentAttrName);
1531 } else {
1532 setInherentAttrMethod
1533 << formatv(Fmt: " if (name == \"result_segment_sizes\" || name == "
1534 "\"{0}\") {{",
1535 Vals: resultSegmentAttrName);
1536 }
1537 setInherentAttrMethod << formatv(Fmt: R"decl(
1538 auto arrAttr = ::llvm::dyn_cast_or_null<::mlir::DenseI32ArrayAttr>(value);
1539 if (!arrAttr) return;
1540 if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t))
1541 return;
1542 llvm::copy(arrAttr.asArrayRef(), prop.{0}.begin());
1543 return;
1544 }
1545)decl",
1546 Vals&: name);
1547 if (name == operandSegmentAttrName) {
1548 populateInherentAttrsMethod
1549 << formatv(Fmt: " attrs.append(\"{0}\", {1});\n", Vals: operandSegmentAttrName,
1550 Vals: tgfmt(fmt: prop.getConvertToAttributeCall(), ctx: &fctx));
1551 } else {
1552 populateInherentAttrsMethod
1553 << formatv(Fmt: " attrs.append(\"{0}\", {1});\n", Vals: resultSegmentAttrName,
1554 Vals: tgfmt(fmt: prop.getConvertToAttributeCall(), ctx: &fctx));
1555 }
1556 }
1557 getInherentAttrMethod << " return std::nullopt;\n";
1558
1559 // Emit the verifiers method for backward compatibility with the generic
1560 // syntax. This method verifies the constraint on the properties attributes
1561 // before they are set, since dyn_cast<> will silently omit failures.
1562 for (const auto &attrOrProp : attrOrProperties) {
1563 const auto *namedAttr =
1564 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
1565 if (!namedAttr || !namedAttr->constraint)
1566 continue;
1567 Attribute attr = *namedAttr->constraint;
1568 std::optional<StringRef> constraintFn =
1569 staticVerifierEmitter.getAttrConstraintFn(constraint: attr);
1570 if (!constraintFn)
1571 continue;
1572 if (canEmitAttrVerifier(attr,
1573 /*isEmittingForOp=*/false)) {
1574 std::string name = op.getGetterName(name: namedAttr->attrName);
1575 verifyInherentAttrsMethod
1576 << formatv(Fmt: R"(
1577 {{
1578 ::mlir::Attribute attr = attrs.get({0}AttrName(opName));
1579 if (attr && ::mlir::failed({1}(attr, "{2}", emitError)))
1580 return ::mlir::failure();
1581 }
1582)",
1583 Vals&: name, Vals&: constraintFn, Vals: namedAttr->attrName);
1584 }
1585 }
1586 verifyInherentAttrsMethod << " return ::mlir::success();";
1587
1588 // Generate methods to interact with bytecode.
1589 genPropertiesSupportForBytecode(attrOrProperties);
1590}
1591
1592void OpEmitter::genPropertiesSupportForBytecode(
1593 ArrayRef<ConstArgument> attrOrProperties) {
1594 if (op.useCustomPropertiesEncoding()) {
1595 opClass.declareStaticMethod(
1596 retType: "::mlir::LogicalResult", name: "readProperties",
1597 args: MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1598 args: MethodParameter("::mlir::OperationState &", "state"));
1599 opClass.declareMethod(
1600 retType: "void", name: "writeProperties",
1601 args: MethodParameter("::mlir::DialectBytecodeWriter &", "writer"));
1602 return;
1603 }
1604
1605 auto &readPropertiesMethod =
1606 opClass
1607 .addStaticMethod(
1608 retType: "::mlir::LogicalResult", name: "readProperties",
1609 args: MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1610 args: MethodParameter("::mlir::OperationState &", "state"))
1611 ->body();
1612
1613 auto &writePropertiesMethod =
1614 opClass
1615 .addMethod(
1616 retType: "void", name: "writeProperties",
1617 args: MethodParameter("::mlir::DialectBytecodeWriter &", "writer"))
1618 ->body();
1619
1620 // Populate bytecode serialization logic.
1621 readPropertiesMethod
1622 << " auto &prop = state.getOrAddProperties<Properties>(); (void)prop;";
1623 writePropertiesMethod << " auto &prop = getProperties(); (void)prop;\n";
1624 for (const auto &item : llvm::enumerate(First&: attrOrProperties)) {
1625 auto &attrOrProp = item.value();
1626 FmtContext fctx;
1627 fctx.addSubst(placeholder: "_reader", subst: "reader")
1628 .addSubst(placeholder: "_writer", subst: "writer")
1629 .addSubst(placeholder: "_storage", subst: propertyStorage)
1630 .addSubst(placeholder: "_ctxt", subst: "this->getContext()");
1631 // If the op emits operand/result segment sizes as a property, emit the
1632 // legacy reader/writer in the appropriate order to allow backward
1633 // compatibility and back deployment.
1634 if (emitHelper.getOperandSegmentsSize().has_value() &&
1635 item.index() == emitHelper.getOperandSegmentSizesLegacyIndex()) {
1636 FmtContext fmtCtxt(fctx);
1637 fmtCtxt.addSubst(placeholder: "_propName", subst: operandSegmentAttrName);
1638 readPropertiesMethod << tgfmt(fmt: readBytecodeSegmentSizeLegacy, ctx: &fmtCtxt);
1639 writePropertiesMethod << tgfmt(fmt: writeBytecodeSegmentSizeLegacy, ctx: &fmtCtxt);
1640 }
1641 if (emitHelper.getResultSegmentsSize().has_value() &&
1642 item.index() == emitHelper.getResultSegmentSizesLegacyIndex()) {
1643 FmtContext fmtCtxt(fctx);
1644 fmtCtxt.addSubst(placeholder: "_propName", subst: resultSegmentAttrName);
1645 readPropertiesMethod << tgfmt(fmt: readBytecodeSegmentSizeLegacy, ctx: &fmtCtxt);
1646 writePropertiesMethod << tgfmt(fmt: writeBytecodeSegmentSizeLegacy, ctx: &fmtCtxt);
1647 }
1648 if (const auto *namedProperty =
1649 attrOrProp.dyn_cast<const NamedProperty *>()) {
1650 StringRef name = namedProperty->name;
1651 readPropertiesMethod << formatv(
1652 Fmt: R"(
1653 {{
1654 auto &propStorage = prop.{0};
1655 auto readProp = [&]() {
1656 {1};
1657 return ::mlir::success();
1658 };
1659 if (::mlir::failed(readProp()))
1660 return ::mlir::failure();
1661 }
1662)",
1663 Vals&: name,
1664 Vals: tgfmt(fmt: namedProperty->prop.getReadFromMlirBytecodeCall(), ctx: &fctx));
1665 writePropertiesMethod << formatv(
1666 Fmt: R"(
1667 {{
1668 auto &propStorage = prop.{0};
1669 {1};
1670 }
1671)",
1672 Vals&: name, Vals: tgfmt(fmt: namedProperty->prop.getWriteToMlirBytecodeCall(), ctx: &fctx));
1673 continue;
1674 }
1675 const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
1676 StringRef name = namedAttr->attrName;
1677 if (namedAttr->isRequired) {
1678 readPropertiesMethod << formatv(Fmt: R"(
1679 if (::mlir::failed(reader.readAttribute(prop.{0})))
1680 return ::mlir::failure();
1681)",
1682 Vals&: name);
1683 writePropertiesMethod
1684 << formatv(Fmt: " writer.writeAttribute(prop.{0});\n", Vals&: name);
1685 } else {
1686 readPropertiesMethod << formatv(Fmt: R"(
1687 if (::mlir::failed(reader.readOptionalAttribute(prop.{0})))
1688 return ::mlir::failure();
1689)",
1690 Vals&: name);
1691 writePropertiesMethod << formatv(Fmt: R"(
1692 writer.writeOptionalAttribute(prop.{0});
1693)",
1694 Vals&: name);
1695 }
1696 }
1697 readPropertiesMethod << " return ::mlir::success();";
1698}
1699
1700void OpEmitter::genAttrGetters() {
1701 FmtContext fctx;
1702 fctx.withBuilder(subst: "::mlir::Builder((*this)->getContext())");
1703
1704 // Emit the derived attribute body.
1705 auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
1706 if (auto *method = opClass.addMethod(retType: attr.getReturnType(), name))
1707 method->body() << " " << attr.getDerivedCodeBody() << "\n";
1708 };
1709
1710 // Generate named accessor with Attribute return type. This is a wrapper
1711 // class that allows referring to the attributes via accessors instead of
1712 // having to use the string interface for better compile time verification.
1713 auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName,
1714 Attribute attr) {
1715 // The method body for this getter is trivial. Emit it inline.
1716 auto *method =
1717 opClass.addInlineMethod(retType: attr.getStorageType(), name: name + "Attr");
1718 if (!method)
1719 return;
1720 method->body() << formatv(
1721 Fmt: " return ::llvm::{1}<{2}>({0});", Vals: emitHelper.getAttr(attrName),
1722 Vals: attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null"
1723 : "cast",
1724 Vals: attr.getStorageType());
1725 };
1726
1727 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1728 std::string name = op.getGetterName(name: namedAttr.name);
1729 if (namedAttr.attr.isDerivedAttr()) {
1730 emitDerivedAttr(name, namedAttr.attr);
1731 } else {
1732 emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr);
1733 emitAttrGetterWithReturnType(fctx, opClass, op, name, attr: namedAttr.attr);
1734 }
1735 }
1736
1737 auto derivedAttrs = make_filter_range(Range: op.getAttributes(),
1738 Pred: [](const NamedAttribute &namedAttr) {
1739 return namedAttr.attr.isDerivedAttr();
1740 });
1741 if (derivedAttrs.empty())
1742 return;
1743
1744 opClass.addTrait(trait: "::mlir::DerivedAttributeOpInterface::Trait");
1745 // Generate helper method to query whether a named attribute is a derived
1746 // attribute. This enables, for example, avoiding adding an attribute that
1747 // overlaps with a derived attribute.
1748 {
1749 auto *method =
1750 opClass.addStaticMethod(retType: "bool", name: "isDerivedAttribute",
1751 args: MethodParameter("::llvm::StringRef", "name"));
1752 ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
1753 auto &body = method->body();
1754 for (auto namedAttr : derivedAttrs)
1755 body << " if (name == \"" << namedAttr.name << "\") return true;\n";
1756 body << " return false;";
1757 }
1758 // Generate method to materialize derived attributes as a DictionaryAttr.
1759 {
1760 auto *method = opClass.addMethod(retType: "::mlir::DictionaryAttr",
1761 name: "materializeDerivedAttributes");
1762 ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
1763 auto &body = method->body();
1764
1765 auto nonMaterializable =
1766 make_filter_range(Range&: derivedAttrs, Pred: [](const NamedAttribute &namedAttr) {
1767 return namedAttr.attr.getConvertFromStorageCall().empty();
1768 });
1769 if (!nonMaterializable.empty()) {
1770 std::string attrs;
1771 llvm::raw_string_ostream os(attrs);
1772 interleaveComma(c: nonMaterializable, os, each_fn: [&](const NamedAttribute &attr) {
1773 os << op.getGetterName(name: attr.name);
1774 });
1775 PrintWarning(
1776 WarningLoc: op.getLoc(),
1777 Msg: formatv(
1778 Fmt: "op has non-materializable derived attributes '{0}', skipping",
1779 Vals&: os.str()));
1780 body << formatv(Fmt: " emitOpError(\"op has non-materializable derived "
1781 "attributes '{0}'\");\n",
1782 Vals&: attrs);
1783 body << " return nullptr;";
1784 return;
1785 }
1786
1787 body << " ::mlir::MLIRContext* ctx = getContext();\n";
1788 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
1789 body << " return ::mlir::DictionaryAttr::get(";
1790 body << " ctx, {\n";
1791 interleave(
1792 c: derivedAttrs, os&: body,
1793 each_fn: [&](const NamedAttribute &namedAttr) {
1794 auto tmpl = namedAttr.attr.getConvertFromStorageCall();
1795 std::string name = op.getGetterName(name: namedAttr.name);
1796 body << " {" << name << "AttrName(),\n"
1797 << tgfmt(fmt: tmpl, ctx: &fctx.withSelf(subst: name + "()")
1798 .withBuilder(subst: "odsBuilder")
1799 .addSubst(placeholder: "_ctxt", subst: "ctx")
1800 .addSubst(placeholder: "_storage", subst: "ctx"))
1801 << "}";
1802 },
1803 separator: ",\n");
1804 body << "});";
1805 }
1806}
1807
1808void OpEmitter::genAttrSetters() {
1809 bool useProperties = op.getDialect().usePropertiesForAttributes();
1810
1811 // Generate the code to set an attribute.
1812 auto emitSetAttr = [&](Method *method, StringRef getterName,
1813 StringRef attrName, StringRef attrVar) {
1814 if (useProperties) {
1815 method->body() << formatv(Fmt: " getProperties().{0} = {1};", Vals&: attrName,
1816 Vals&: attrVar);
1817 } else {
1818 method->body() << formatv(Fmt: " (*this)->setAttr({0}AttrName(), {1});",
1819 Vals&: getterName, Vals&: attrVar);
1820 }
1821 };
1822
1823 // Generate raw named setter type. This is a wrapper class that allows setting
1824 // to the attributes via setters instead of having to use the string interface
1825 // for better compile time verification.
1826 auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
1827 StringRef attrName, Attribute attr) {
1828 // This method body is trivial, so emit it inline.
1829 auto *method =
1830 opClass.addInlineMethod(retType: "void", name: setterName + "Attr",
1831 args: MethodParameter(attr.getStorageType(), "attr"));
1832 if (method)
1833 emitSetAttr(method, getterName, attrName, "attr");
1834 };
1835
1836 // Generate a setter that accepts the underlying C++ type as opposed to the
1837 // attribute type.
1838 auto emitAttrWithReturnType = [&](StringRef setterName, StringRef getterName,
1839 StringRef attrName, Attribute attr) {
1840 Attribute baseAttr = attr.getBaseAttr();
1841 if (!canUseUnwrappedRawValue(attr: baseAttr))
1842 return;
1843 FmtContext fctx;
1844 fctx.withBuilder(subst: "::mlir::Builder((*this)->getContext())");
1845 bool isUnitAttr = attr.getAttrDefName() == "UnitAttr";
1846 bool isOptional = attr.isOptional();
1847
1848 auto createMethod = [&](const Twine &paramType) {
1849 return opClass.addMethod(retType: "void", name&: setterName,
1850 args: MethodParameter(paramType.str(), "attrValue"));
1851 };
1852
1853 // Build the method using the correct parameter type depending on
1854 // optionality.
1855 Method *method = nullptr;
1856 if (isUnitAttr)
1857 method = createMethod("bool");
1858 else if (isOptional)
1859 method =
1860 createMethod("::std::optional<" + baseAttr.getReturnType() + ">");
1861 else
1862 method = createMethod(attr.getReturnType());
1863 if (!method)
1864 return;
1865
1866 // If the value isn't optional, just set it directly.
1867 if (!isOptional) {
1868 emitSetAttr(method, getterName, attrName,
1869 constBuildAttrFromParam(attr, fctx, paramName: "attrValue"));
1870 return;
1871 }
1872
1873 // Otherwise, we only set if the provided value is valid. If it isn't, we
1874 // remove the attribute.
1875
1876 // TODO: Handle unit attr parameters specially, given that it is treated as
1877 // optional but not in the same way as the others (i.e. it uses bool over
1878 // std::optional<>).
1879 StringRef paramStr = isUnitAttr ? "attrValue" : "*attrValue";
1880 if (!useProperties) {
1881 const char *optionalCodeBody = R"(
1882 if (attrValue)
1883 return (*this)->setAttr({0}AttrName(), {1});
1884 (*this)->removeAttr({0}AttrName());)";
1885 method->body() << formatv(
1886 Fmt: optionalCodeBody, Vals&: getterName,
1887 Vals: constBuildAttrFromParam(attr: baseAttr, fctx, paramName: paramStr));
1888 } else {
1889 const char *optionalCodeBody = R"(
1890 auto &odsProp = getProperties().{0};
1891 if (attrValue)
1892 odsProp = {1};
1893 else
1894 odsProp = nullptr;)";
1895 method->body() << formatv(
1896 Fmt: optionalCodeBody, Vals&: attrName,
1897 Vals: constBuildAttrFromParam(attr: baseAttr, fctx, paramName: paramStr));
1898 }
1899 };
1900
1901 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1902 if (namedAttr.attr.isDerivedAttr())
1903 continue;
1904 std::string setterName = op.getSetterName(name: namedAttr.name);
1905 std::string getterName = op.getGetterName(name: namedAttr.name);
1906 emitAttrWithStorageType(setterName, getterName, namedAttr.name,
1907 namedAttr.attr);
1908 emitAttrWithReturnType(setterName, getterName, namedAttr.name,
1909 namedAttr.attr);
1910 }
1911}
1912
1913void OpEmitter::genOptionalAttrRemovers() {
1914 // Generate methods for removing optional attributes, instead of having to
1915 // use the string interface. Enables better compile time verification.
1916 auto emitRemoveAttr = [&](StringRef name, bool useProperties) {
1917 auto upperInitial = name.take_front().upper();
1918 auto *method = opClass.addInlineMethod(retType: "::mlir::Attribute",
1919 name: op.getRemoverName(name) + "Attr");
1920 if (!method)
1921 return;
1922 if (useProperties) {
1923 method->body() << formatv(Fmt: R"(
1924 auto &attr = getProperties().{0};
1925 attr = {{};
1926 return attr;
1927)",
1928 Vals&: name);
1929 return;
1930 }
1931 method->body() << formatv(Fmt: "return (*this)->removeAttr({0}AttrName());",
1932 Vals: op.getGetterName(name));
1933 };
1934
1935 for (const NamedAttribute &namedAttr : op.getAttributes())
1936 if (namedAttr.attr.isOptional())
1937 emitRemoveAttr(namedAttr.name,
1938 op.getDialect().usePropertiesForAttributes());
1939}
1940
1941// Generates the code to compute the start and end index of an operand or result
1942// range.
1943template <typename RangeT>
1944static void generateValueRangeStartAndEnd(
1945 Class &opClass, bool isGenericAdaptorBase, StringRef methodName,
1946 int numVariadic, int numNonVariadic, StringRef rangeSizeCall,
1947 bool hasAttrSegmentSize, StringRef sizeAttrInit, RangeT &&odsValues) {
1948
1949 SmallVector<MethodParameter> parameters{MethodParameter("unsigned", "index")};
1950 if (isGenericAdaptorBase) {
1951 parameters.emplace_back(Args: "unsigned", Args: "odsOperandsSize");
1952 // The range size is passed per parameter for generic adaptor bases as
1953 // using the rangeSizeCall would require the operands, which are not
1954 // accessible in the base class.
1955 rangeSizeCall = "odsOperandsSize";
1956 }
1957
1958 // The method is trivial if the operation does not have any variadic operands.
1959 // In that case, make sure to generate it in-line.
1960 auto *method = opClass.addMethod(retType: "std::pair<unsigned, unsigned>", name&: methodName,
1961 properties: numVariadic == 0 ? Method::Properties::Inline
1962 : Method::Properties::None,
1963 args&: parameters);
1964 if (!method)
1965 return;
1966 auto &body = method->body();
1967 if (numVariadic == 0) {
1968 body << " return {index, 1};\n";
1969 } else if (hasAttrSegmentSize) {
1970 body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
1971 } else {
1972 // Because the op can have arbitrarily interleaved variadic and non-variadic
1973 // operands, we need to embed a list in the "sink" getter method for
1974 // calculation at run-time.
1975 SmallVector<StringRef, 4> isVariadic;
1976 isVariadic.reserve(N: llvm::size(odsValues));
1977 for (auto &it : odsValues)
1978 isVariadic.push_back(Elt: it.isVariableLength() ? "true" : "false");
1979 std::string isVariadicList = llvm::join(R&: isVariadic, Separator: ", ");
1980 body << formatv(Fmt: sameVariadicSizeValueRangeCalcCode, Vals&: isVariadicList,
1981 Vals&: numNonVariadic, Vals&: numVariadic, Vals&: rangeSizeCall, Vals: "operand");
1982 }
1983}
1984
1985static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
1986 std::string str = "::mlir::Value";
1987 /// If the CPPClassName is not a fully qualified type. Uses of types
1988 /// across Dialect fail because they are not in the correct namespace. So we
1989 /// dont generate TypedValue unless the type is fully qualified.
1990 /// getCPPClassName doesn't return the fully qualified path for
1991 /// `mlir::pdl::OperationType` see
1992 /// https://github.com/llvm/llvm-project/issues/57279.
1993 /// Adaptor will have values that are not from the type of their operation and
1994 /// this is expected, so we dont generate TypedValue for Adaptor
1995 if (value.constraint.getCPPClassName() != "::mlir::Type" &&
1996 StringRef(value.constraint.getCPPClassName()).starts_with(Prefix: "::"))
1997 str = llvm::formatv(Fmt: "::mlir::TypedValue<{0}>",
1998 Vals: value.constraint.getCPPClassName())
1999 .str();
2000 return str;
2001}
2002
2003// Generates the named operand getter methods for the given Operator `op` and
2004// puts them in `opClass`. Uses `rangeType` as the return type of getters that
2005// return a range of operands (individual operands are `Value ` and each
2006// element in the range must also be `Value `); use `rangeBeginCall` to get
2007// an iterator to the beginning of the operand range; use `rangeSizeCall` to
2008// obtain the number of operands. `getOperandCallPattern` contains the code
2009// necessary to obtain a single operand whose position will be substituted
2010// instead of
2011// "{0}" marker in the pattern. Note that the pattern should work for any kind
2012// of ops, in particular for one-operand ops that may not have the
2013// `getOperand(unsigned)` method.
2014static void
2015generateNamedOperandGetters(const Operator &op, Class &opClass,
2016 Class *genericAdaptorBase, StringRef sizeAttrInit,
2017 StringRef rangeType, StringRef rangeElementType,
2018 StringRef rangeBeginCall, StringRef rangeSizeCall,
2019 StringRef getOperandCallPattern) {
2020 const int numOperands = op.getNumOperands();
2021 const int numVariadicOperands = op.getNumVariableLengthOperands();
2022 const int numNormalOperands = numOperands - numVariadicOperands;
2023
2024 const auto *sameVariadicSize =
2025 op.getTrait(trait: "::mlir::OpTrait::SameVariadicOperandSize");
2026 const auto *attrSizedOperands =
2027 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments");
2028
2029 if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
2030 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "op has multiple variadic operands but no "
2031 "specification over their sizes");
2032 }
2033
2034 if (numVariadicOperands < 2 && attrSizedOperands) {
2035 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "op must have at least two variadic operands "
2036 "to use 'AttrSizedOperandSegments' trait");
2037 }
2038
2039 if (attrSizedOperands && sameVariadicSize) {
2040 PrintFatalError(ErrorLoc: op.getLoc(),
2041 Msg: "op cannot have both 'AttrSizedOperandSegments' and "
2042 "'SameVariadicOperandSize' traits");
2043 }
2044
2045 // First emit a few "sink" getter methods upon which we layer all nicer named
2046 // getter methods.
2047 // If generating for an adaptor, the method is put into the non-templated
2048 // generic base class, to not require being defined in the header.
2049 // Since the operand size can't be determined from the base class however,
2050 // it has to be passed as an additional argument. The trampoline below
2051 // generates the function with the same signature as the Op in the generic
2052 // adaptor.
2053 bool isGenericAdaptorBase = genericAdaptorBase != nullptr;
2054 generateValueRangeStartAndEnd(
2055 /*opClass=*/isGenericAdaptorBase ? *genericAdaptorBase : opClass,
2056 isGenericAdaptorBase,
2057 /*methodName=*/"getODSOperandIndexAndLength", numVariadic: numVariadicOperands,
2058 numNonVariadic: numNormalOperands, rangeSizeCall, hasAttrSegmentSize: attrSizedOperands, sizeAttrInit,
2059 odsValues: const_cast<Operator &>(op).getOperands());
2060 if (isGenericAdaptorBase) {
2061 // Generate trampoline for calling 'getODSOperandIndexAndLength' with just
2062 // the index. This just calls the implementation in the base class but
2063 // passes the operand size as parameter.
2064 Method *method = opClass.addInlineMethod(
2065 retType: "std::pair<unsigned, unsigned>", name: "getODSOperandIndexAndLength",
2066 args: MethodParameter("unsigned", "index"));
2067 ERROR_IF_PRUNED(method, "getODSOperandIndexAndLength", op);
2068 MethodBody &body = method->body();
2069 body.indent() << formatv(
2070 Fmt: "return Base::getODSOperandIndexAndLength(index, {0});", Vals&: rangeSizeCall);
2071 }
2072
2073 // The implementation of this method is trivial and it is very load-bearing.
2074 // Generate it inline.
2075 auto *m = opClass.addInlineMethod(retType&: rangeType, name: "getODSOperands",
2076 args: MethodParameter("unsigned", "index"));
2077 ERROR_IF_PRUNED(m, "getODSOperands", op);
2078 auto &body = m->body();
2079 body << formatv(Fmt: valueRangeReturnCode, Vals&: rangeBeginCall,
2080 Vals: "getODSOperandIndexAndLength(index)");
2081
2082 // Then we emit nicer named getter methods by redirecting to the "sink" getter
2083 // method.
2084 for (int i = 0; i != numOperands; ++i) {
2085 const auto &operand = op.getOperand(index: i);
2086 if (operand.name.empty())
2087 continue;
2088 std::string name = op.getGetterName(name: operand.name);
2089 if (operand.isOptional()) {
2090 m = opClass.addInlineMethod(retType: isGenericAdaptorBase
2091 ? rangeElementType
2092 : generateTypeForGetter(value: operand),
2093 name);
2094 ERROR_IF_PRUNED(m, name, op);
2095 m->body().indent() << formatv(Fmt: "auto operands = getODSOperands({0});\n"
2096 "return operands.empty() ? {1}{{} : ",
2097 Vals&: i, Vals: m->getReturnType());
2098 if (!isGenericAdaptorBase)
2099 m->body() << llvm::formatv(Fmt: "::llvm::cast<{0}>", Vals: m->getReturnType());
2100 m->body() << "(*operands.begin());";
2101 } else if (operand.isVariadicOfVariadic()) {
2102 std::string segmentAttr = op.getGetterName(
2103 name: operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
2104 if (genericAdaptorBase) {
2105 m = opClass.addMethod(retType: "::llvm::SmallVector<" + rangeType + ">", name);
2106 ERROR_IF_PRUNED(m, name, op);
2107 m->body() << llvm::formatv(Fmt: variadicOfVariadicAdaptorCalcCode,
2108 Vals&: segmentAttr, Vals&: i, Vals&: rangeType);
2109 continue;
2110 }
2111
2112 m = opClass.addInlineMethod(retType: "::mlir::OperandRangeRange", name);
2113 ERROR_IF_PRUNED(m, name, op);
2114 m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr
2115 << "Attr());";
2116 } else if (operand.isVariadic()) {
2117 m = opClass.addInlineMethod(retType&: rangeType, name);
2118 ERROR_IF_PRUNED(m, name, op);
2119 m->body() << " return getODSOperands(" << i << ");";
2120 } else {
2121 m = opClass.addInlineMethod(retType: isGenericAdaptorBase
2122 ? rangeElementType
2123 : generateTypeForGetter(value: operand),
2124 name);
2125 ERROR_IF_PRUNED(m, name, op);
2126 m->body().indent() << "return ";
2127 if (!isGenericAdaptorBase)
2128 m->body() << llvm::formatv(Fmt: "::llvm::cast<{0}>", Vals: m->getReturnType());
2129 m->body() << llvm::formatv(Fmt: "(*getODSOperands({0}).begin());", Vals&: i);
2130 }
2131 }
2132}
2133
2134void OpEmitter::genNamedOperandGetters() {
2135 // Build the code snippet used for initializing the operand_segment_size)s
2136 // array.
2137 std::string attrSizeInitCode;
2138 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
2139 if (op.getDialect().usePropertiesForAttributes())
2140 attrSizeInitCode = formatv(Fmt: adapterSegmentSizeAttrInitCodeProperties,
2141 Vals: "getProperties().operandSegmentSizes");
2142
2143 else
2144 attrSizeInitCode = formatv(Fmt: opSegmentSizeAttrInitCode,
2145 Vals: emitHelper.getAttr(attrName: operandSegmentAttrName));
2146 }
2147
2148 generateNamedOperandGetters(
2149 op, opClass,
2150 /*genericAdaptorBase=*/nullptr,
2151 /*sizeAttrInit=*/attrSizeInitCode,
2152 /*rangeType=*/"::mlir::Operation::operand_range",
2153 /*rangeElementType=*/"::mlir::Value",
2154 /*rangeBeginCall=*/"getOperation()->operand_begin()",
2155 /*rangeSizeCall=*/"getOperation()->getNumOperands()",
2156 /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
2157}
2158
2159void OpEmitter::genNamedOperandSetters() {
2160 auto *attrSizedOperands =
2161 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments");
2162 for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
2163 const auto &operand = op.getOperand(index: i);
2164 if (operand.name.empty())
2165 continue;
2166 std::string name = op.getGetterName(name: operand.name);
2167
2168 StringRef returnType;
2169 if (operand.isVariadicOfVariadic()) {
2170 returnType = "::mlir::MutableOperandRangeRange";
2171 } else if (operand.isVariableLength()) {
2172 returnType = "::mlir::MutableOperandRange";
2173 } else {
2174 returnType = "::mlir::OpOperand &";
2175 }
2176 bool isVariadicOperand =
2177 operand.isVariadicOfVariadic() || operand.isVariableLength();
2178 auto *m = opClass.addMethod(retType&: returnType, name: name + "Mutable",
2179 properties: isVariadicOperand ? Method::Properties::None
2180 : Method::Properties::Inline);
2181 ERROR_IF_PRUNED(m, name, op);
2182 auto &body = m->body();
2183 body << " auto range = getODSOperandIndexAndLength(" << i << ");\n";
2184
2185 if (!isVariadicOperand) {
2186 // In case of a single operand, return a single OpOperand.
2187 body << " return getOperation()->getOpOperand(range.first);\n";
2188 continue;
2189 }
2190
2191 body << " auto mutableRange = "
2192 "::mlir::MutableOperandRange(getOperation(), "
2193 "range.first, range.second";
2194 if (attrSizedOperands) {
2195 if (emitHelper.hasProperties())
2196 body << formatv(Fmt: ", ::mlir::MutableOperandRange::OperandSegment({0}u, "
2197 "{{getOperandSegmentSizesAttrName(), "
2198 "::mlir::DenseI32ArrayAttr::get(getContext(), "
2199 "getProperties().operandSegmentSizes)})",
2200 Vals&: i);
2201 else
2202 body << formatv(
2203 Fmt: ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", Vals&: i,
2204 Vals: emitHelper.getAttr(attrName: operandSegmentAttrName, /*isNamed=*/true));
2205 }
2206 body << ");\n";
2207
2208 // If this operand is a nested variadic, we split the range into a
2209 // MutableOperandRangeRange that provides a range over all of the
2210 // sub-ranges.
2211 if (operand.isVariadicOfVariadic()) {
2212 body << " return "
2213 "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
2214 << op.getGetterName(
2215 name: operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
2216 << "AttrName()));\n";
2217 } else {
2218 // Otherwise, we use the full range directly.
2219 body << " return mutableRange;\n";
2220 }
2221 }
2222}
2223
2224void OpEmitter::genNamedResultGetters() {
2225 const int numResults = op.getNumResults();
2226 const int numVariadicResults = op.getNumVariableLengthResults();
2227 const int numNormalResults = numResults - numVariadicResults;
2228
2229 // If we have more than one variadic results, we need more complicated logic
2230 // to calculate the value range for each result.
2231
2232 const auto *sameVariadicSize =
2233 op.getTrait(trait: "::mlir::OpTrait::SameVariadicResultSize");
2234 const auto *attrSizedResults =
2235 op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments");
2236
2237 if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
2238 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "op has multiple variadic results but no "
2239 "specification over their sizes");
2240 }
2241
2242 if (numVariadicResults < 2 && attrSizedResults) {
2243 PrintFatalError(ErrorLoc: op.getLoc(), Msg: "op must have at least two variadic results "
2244 "to use 'AttrSizedResultSegments' trait");
2245 }
2246
2247 if (attrSizedResults && sameVariadicSize) {
2248 PrintFatalError(ErrorLoc: op.getLoc(),
2249 Msg: "op cannot have both 'AttrSizedResultSegments' and "
2250 "'SameVariadicResultSize' traits");
2251 }
2252
2253 // Build the initializer string for the result segment size attribute.
2254 std::string attrSizeInitCode;
2255 if (attrSizedResults) {
2256 if (op.getDialect().usePropertiesForAttributes())
2257 attrSizeInitCode = formatv(Fmt: adapterSegmentSizeAttrInitCodeProperties,
2258 Vals: "getProperties().resultSegmentSizes");
2259
2260 else
2261 attrSizeInitCode = formatv(Fmt: opSegmentSizeAttrInitCode,
2262 Vals: emitHelper.getAttr(attrName: resultSegmentAttrName));
2263 }
2264
2265 generateValueRangeStartAndEnd(
2266 opClass, /*isGenericAdaptorBase=*/false, methodName: "getODSResultIndexAndLength",
2267 numVariadic: numVariadicResults, numNonVariadic: numNormalResults, rangeSizeCall: "getOperation()->getNumResults()",
2268 hasAttrSegmentSize: attrSizedResults, sizeAttrInit: attrSizeInitCode, odsValues: op.getResults());
2269
2270 // The implementation of this method is trivial and it is very load-bearing.
2271 // Generate it inline.
2272 auto *m = opClass.addInlineMethod(retType: "::mlir::Operation::result_range",
2273 name: "getODSResults",
2274 args: MethodParameter("unsigned", "index"));
2275 ERROR_IF_PRUNED(m, "getODSResults", op);
2276 m->body() << formatv(Fmt: valueRangeReturnCode, Vals: "getOperation()->result_begin()",
2277 Vals: "getODSResultIndexAndLength(index)");
2278
2279 for (int i = 0; i != numResults; ++i) {
2280 const auto &result = op.getResult(index: i);
2281 if (result.name.empty())
2282 continue;
2283 std::string name = op.getGetterName(name: result.name);
2284 if (result.isOptional()) {
2285 m = opClass.addInlineMethod(retType: generateTypeForGetter(value: result), name);
2286 ERROR_IF_PRUNED(m, name, op);
2287 m->body() << " auto results = getODSResults(" << i << ");\n"
2288 << llvm::formatv(Fmt: " return results.empty()"
2289 " ? {0}()"
2290 " : ::llvm::cast<{0}>(*results.begin());",
2291 Vals: m->getReturnType());
2292 } else if (result.isVariadic()) {
2293 m = opClass.addInlineMethod(retType: "::mlir::Operation::result_range", name);
2294 ERROR_IF_PRUNED(m, name, op);
2295 m->body() << " return getODSResults(" << i << ");";
2296 } else {
2297 m = opClass.addInlineMethod(retType: generateTypeForGetter(value: result), name);
2298 ERROR_IF_PRUNED(m, name, op);
2299 m->body() << llvm::formatv(
2300 Fmt: " return ::llvm::cast<{0}>(*getODSResults({1}).begin());",
2301 Vals: m->getReturnType(), Vals&: i);
2302 }
2303 }
2304}
2305
2306void OpEmitter::genNamedRegionGetters() {
2307 unsigned numRegions = op.getNumRegions();
2308 for (unsigned i = 0; i < numRegions; ++i) {
2309 const auto &region = op.getRegion(index: i);
2310 if (region.name.empty())
2311 continue;
2312 std::string name = op.getGetterName(name: region.name);
2313
2314 // Generate the accessors for a variadic region.
2315 if (region.isVariadic()) {
2316 auto *m = opClass.addInlineMethod(
2317 retType: "::mlir::MutableArrayRef<::mlir::Region>", name);
2318 ERROR_IF_PRUNED(m, name, op);
2319 m->body() << formatv(Fmt: " return (*this)->getRegions().drop_front({0});",
2320 Vals&: i);
2321 continue;
2322 }
2323
2324 auto *m = opClass.addInlineMethod(retType: "::mlir::Region &", name);
2325 ERROR_IF_PRUNED(m, name, op);
2326 m->body() << formatv(Fmt: " return (*this)->getRegion({0});", Vals&: i);
2327 }
2328}
2329
2330void OpEmitter::genNamedSuccessorGetters() {
2331 unsigned numSuccessors = op.getNumSuccessors();
2332 for (unsigned i = 0; i < numSuccessors; ++i) {
2333 const NamedSuccessor &successor = op.getSuccessor(index: i);
2334 if (successor.name.empty())
2335 continue;
2336 std::string name = op.getGetterName(name: successor.name);
2337 // Generate the accessors for a variadic successor list.
2338 if (successor.isVariadic()) {
2339 auto *m = opClass.addInlineMethod(retType: "::mlir::SuccessorRange", name);
2340 ERROR_IF_PRUNED(m, name, op);
2341 m->body() << formatv(
2342 Fmt: " return {std::next((*this)->successor_begin(), {0}), "
2343 "(*this)->successor_end()};",
2344 Vals&: i);
2345 continue;
2346 }
2347
2348 auto *m = opClass.addInlineMethod(retType: "::mlir::Block *", name);
2349 ERROR_IF_PRUNED(m, name, op);
2350 m->body() << formatv(Fmt: " return (*this)->getSuccessor({0});", Vals&: i);
2351 }
2352}
2353
2354static bool canGenerateUnwrappedBuilder(const Operator &op) {
2355 // If this op does not have native attributes at all, return directly to avoid
2356 // redefining builders.
2357 if (op.getNumNativeAttributes() == 0)
2358 return false;
2359
2360 bool canGenerate = false;
2361 // We are generating builders that take raw values for attributes. We need to
2362 // make sure the native attributes have a meaningful "unwrapped" value type
2363 // different from the wrapped mlir::Attribute type to avoid redefining
2364 // builders. This checks for the op has at least one such native attribute.
2365 for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
2366 const NamedAttribute &namedAttr = op.getAttribute(index: i);
2367 if (canUseUnwrappedRawValue(attr: namedAttr.attr)) {
2368 canGenerate = true;
2369 break;
2370 }
2371 }
2372 return canGenerate;
2373}
2374
2375static bool canInferType(const Operator &op) {
2376 return op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait");
2377}
2378
2379void OpEmitter::genSeparateArgParamBuilder() {
2380 SmallVector<AttrParamKind, 2> attrBuilderType;
2381 attrBuilderType.push_back(Elt: AttrParamKind::WrappedAttr);
2382 if (canGenerateUnwrappedBuilder(op))
2383 attrBuilderType.push_back(Elt: AttrParamKind::UnwrappedValue);
2384
2385 // Emit with separate builders with or without unwrapped attributes and/or
2386 // inferring result type.
2387 auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
2388 bool inferType) {
2389 SmallVector<MethodParameter> paramList;
2390 SmallVector<std::string, 4> resultNames;
2391 llvm::StringSet<> inferredAttributes;
2392 buildParamList(paramList, inferredAttributes, resultTypeNames&: resultNames, typeParamKind: paramKind,
2393 attrParamKind: attrType);
2394
2395 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args: std::move(paramList));
2396 // If the builder is redundant, skip generating the method.
2397 if (!m)
2398 return;
2399 auto &body = m->body();
2400 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
2401 /*isRawValueAttr=*/attrType ==
2402 AttrParamKind::UnwrappedValue);
2403
2404 // Push all result types to the operation state
2405
2406 if (inferType) {
2407 // Generate builder that infers type too.
2408 // TODO: Subsume this with general checking if type can be
2409 // inferred automatically.
2410 body << formatv(Fmt: R"(
2411 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2412 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2413 {1}.location, {1}.operands,
2414 {1}.attributes.getDictionary({1}.getContext()),
2415 {1}.getRawProperties(),
2416 {1}.regions, inferredReturnTypes)))
2417 {1}.addTypes(inferredReturnTypes);
2418 else
2419 ::llvm::report_fatal_error("Failed to infer result type(s).");)",
2420 Vals: opClass.getClassName(), Vals: builderOpState);
2421 return;
2422 }
2423
2424 switch (paramKind) {
2425 case TypeParamKind::None:
2426 return;
2427 case TypeParamKind::Separate:
2428 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
2429 if (op.getResult(index: i).isOptional())
2430 body << " if (" << resultNames[i] << ")\n ";
2431 body << " " << builderOpState << ".addTypes(" << resultNames[i]
2432 << ");\n";
2433 }
2434
2435 // Automatically create the 'resultSegmentSizes' attribute using
2436 // the length of the type ranges.
2437 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedResultSegments")) {
2438 if (op.getDialect().usePropertiesForAttributes()) {
2439 body << " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
2440 } else {
2441 std::string getterName = op.getGetterName(name: resultSegmentAttrName);
2442 body << " " << builderOpState << ".addAttribute(" << getterName
2443 << "AttrName(" << builderOpState << ".name), "
2444 << "odsBuilder.getDenseI32ArrayAttr({";
2445 }
2446 interleaveComma(
2447 c: llvm::seq<int>(Begin: 0, End: op.getNumResults()), os&: body, each_fn: [&](int i) {
2448 const NamedTypeConstraint &result = op.getResult(index: i);
2449 if (!result.isVariableLength()) {
2450 body << "1";
2451 } else if (result.isOptional()) {
2452 body << "(" << resultNames[i] << " ? 1 : 0)";
2453 } else {
2454 // VariadicOfVariadic of results are currently unsupported in
2455 // MLIR, hence it can only be a simple variadic.
2456 // TODO: Add implementation for VariadicOfVariadic results here
2457 // once supported.
2458 assert(result.isVariadic());
2459 body << "static_cast<int32_t>(" << resultNames[i] << ".size())";
2460 }
2461 });
2462 if (op.getDialect().usePropertiesForAttributes()) {
2463 body << "}), " << builderOpState
2464 << ".getOrAddProperties<Properties>()."
2465 "resultSegmentSizes.begin());\n";
2466 } else {
2467 body << "}));\n";
2468 }
2469 }
2470
2471 return;
2472 case TypeParamKind::Collective: {
2473 int numResults = op.getNumResults();
2474 int numVariadicResults = op.getNumVariableLengthResults();
2475 int numNonVariadicResults = numResults - numVariadicResults;
2476 bool hasVariadicResult = numVariadicResults != 0;
2477
2478 // Avoid emitting "resultTypes.size() >= 0u" which is always true.
2479 if (!hasVariadicResult || numNonVariadicResults != 0)
2480 body << " "
2481 << "assert(resultTypes.size() "
2482 << (hasVariadicResult ? ">=" : "==") << " "
2483 << numNonVariadicResults
2484 << "u && \"mismatched number of results\");\n";
2485 body << " " << builderOpState << ".addTypes(resultTypes);\n";
2486 }
2487 return;
2488 }
2489 llvm_unreachable("unhandled TypeParamKind");
2490 };
2491
2492 // Some of the build methods generated here may be ambiguous, but TableGen's
2493 // ambiguous function detection will elide those ones.
2494 for (auto attrType : attrBuilderType) {
2495 emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
2496 if (canInferType(op))
2497 emit(attrType, TypeParamKind::None, /*inferType=*/true);
2498 emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
2499 }
2500}
2501
2502void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
2503 int numResults = op.getNumResults();
2504
2505 // Signature
2506 SmallVector<MethodParameter> paramList;
2507 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: "odsBuilder");
2508 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
2509 paramList.emplace_back(Args: "::mlir::ValueRange", Args: "operands");
2510 // Provide default value for `attributes` when its the last parameter
2511 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2512 paramList.emplace_back(Args: "::llvm::ArrayRef<::mlir::NamedAttribute>",
2513 Args: "attributes", Args&: attributesDefaultValue);
2514 if (op.getNumVariadicRegions())
2515 paramList.emplace_back(Args: "unsigned", Args: "numRegions");
2516
2517 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args: std::move(paramList));
2518 // If the builder is redundant, skip generating the method
2519 if (!m)
2520 return;
2521 auto &body = m->body();
2522
2523 // Operands
2524 body << " " << builderOpState << ".addOperands(operands);\n";
2525
2526 // Attributes
2527 body << " " << builderOpState << ".addAttributes(attributes);\n";
2528
2529 // Create the correct number of regions
2530 if (int numRegions = op.getNumRegions()) {
2531 body << llvm::formatv(
2532 Fmt: " for (unsigned i = 0; i != {0}; ++i)\n",
2533 Vals: (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2534 body << " (void)" << builderOpState << ".addRegion();\n";
2535 }
2536
2537 // Result types
2538 SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
2539 body << " " << builderOpState << ".addTypes({"
2540 << llvm::join(R&: resultTypes, Separator: ", ") << "});\n\n";
2541}
2542
2543void OpEmitter::genPopulateDefaultAttributes() {
2544 // All done if no attributes, except optional ones, have default values.
2545 if (llvm::all_of(Range: op.getAttributes(), P: [](const NamedAttribute &named) {
2546 return !named.attr.hasDefaultValue() || named.attr.isOptional();
2547 }))
2548 return;
2549
2550 if (emitHelper.hasProperties()) {
2551 SmallVector<MethodParameter> paramList;
2552 paramList.emplace_back(Args: "::mlir::OperationName", Args: "opName");
2553 paramList.emplace_back(Args: "Properties &", Args: "properties");
2554 auto *m =
2555 opClass.addStaticMethod(retType: "void", name: "populateDefaultProperties", args&: paramList);
2556 ERROR_IF_PRUNED(m, "populateDefaultProperties", op);
2557 auto &body = m->body();
2558 body.indent();
2559 body << "::mlir::Builder " << odsBuilder << "(opName.getContext());\n";
2560 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2561 auto &attr = namedAttr.attr;
2562 if (!attr.hasDefaultValue() || attr.isOptional())
2563 continue;
2564 StringRef name = namedAttr.name;
2565 FmtContext fctx;
2566 fctx.withBuilder(subst: odsBuilder);
2567 body << "if (!properties." << name << ")\n"
2568 << " properties." << name << " = "
2569 << std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx,
2570 vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx)))
2571 << ";\n";
2572 }
2573 return;
2574 }
2575
2576 SmallVector<MethodParameter> paramList;
2577 paramList.emplace_back(Args: "const ::mlir::OperationName &", Args: "opName");
2578 paramList.emplace_back(Args: "::mlir::NamedAttrList &", Args: "attributes");
2579 auto *m = opClass.addStaticMethod(retType: "void", name: "populateDefaultAttrs", args&: paramList);
2580 ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
2581 auto &body = m->body();
2582 body.indent();
2583
2584 // Set default attributes that are unset.
2585 body << "auto attrNames = opName.getAttributeNames();\n";
2586 body << "::mlir::Builder " << odsBuilder
2587 << "(attrNames.front().getContext());\n";
2588 StringMap<int> attrIndex;
2589 for (const auto &it : llvm::enumerate(First: emitHelper.getAttrMetadata())) {
2590 attrIndex[it.value().first] = it.index();
2591 }
2592 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2593 auto &attr = namedAttr.attr;
2594 if (!attr.hasDefaultValue() || attr.isOptional())
2595 continue;
2596 auto index = attrIndex[namedAttr.name];
2597 body << "if (!attributes.get(attrNames[" << index << "])) {\n";
2598 FmtContext fctx;
2599 fctx.withBuilder(subst: odsBuilder);
2600
2601 std::string defaultValue =
2602 std::string(tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx,
2603 vals: tgfmt(fmt: attr.getDefaultValue(), ctx: &fctx)));
2604 body.indent() << formatv(Fmt: "attributes.append(attrNames[{0}], {1});\n", Vals&: index,
2605 Vals&: defaultValue);
2606 body.unindent() << "}\n";
2607 }
2608}
2609
2610void OpEmitter::genInferredTypeCollectiveParamBuilder() {
2611 SmallVector<MethodParameter> paramList;
2612 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: "odsBuilder");
2613 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
2614 paramList.emplace_back(Args: "::mlir::ValueRange", Args: "operands");
2615 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2616 paramList.emplace_back(Args: "::llvm::ArrayRef<::mlir::NamedAttribute>",
2617 Args: "attributes", Args&: attributesDefaultValue);
2618 if (op.getNumVariadicRegions())
2619 paramList.emplace_back(Args: "unsigned", Args: "numRegions");
2620
2621 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args: std::move(paramList));
2622 // If the builder is redundant, skip generating the method
2623 if (!m)
2624 return;
2625 auto &body = m->body();
2626
2627 int numResults = op.getNumResults();
2628 int numVariadicResults = op.getNumVariableLengthResults();
2629 int numNonVariadicResults = numResults - numVariadicResults;
2630
2631 int numOperands = op.getNumOperands();
2632 int numVariadicOperands = op.getNumVariableLengthOperands();
2633 int numNonVariadicOperands = numOperands - numVariadicOperands;
2634
2635 // Operands
2636 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
2637 body << " assert(operands.size()"
2638 << (numVariadicOperands != 0 ? " >= " : " == ")
2639 << numNonVariadicOperands
2640 << "u && \"mismatched number of parameters\");\n";
2641 body << " " << builderOpState << ".addOperands(operands);\n";
2642 body << " " << builderOpState << ".addAttributes(attributes);\n";
2643
2644 // Create the correct number of regions
2645 if (int numRegions = op.getNumRegions()) {
2646 body << llvm::formatv(
2647 Fmt: " for (unsigned i = 0; i != {0}; ++i)\n",
2648 Vals: (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2649 body << " (void)" << builderOpState << ".addRegion();\n";
2650 }
2651
2652 // Result types
2653 body << formatv(Fmt: R"(
2654 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2655 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2656 {1}.location, operands,
2657 {1}.attributes.getDictionary({1}.getContext()),
2658 {1}.getRawProperties(),
2659 {1}.regions, inferredReturnTypes))) {{)",
2660 Vals: opClass.getClassName(), Vals: builderOpState);
2661 if (numVariadicResults == 0 || numNonVariadicResults != 0)
2662 body << "\n assert(inferredReturnTypes.size()"
2663 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
2664 << "u && \"mismatched number of return types\");";
2665 body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);";
2666
2667 body << formatv(Fmt: R"(
2668 } else {{
2669 ::llvm::report_fatal_error("Failed to infer result type(s).");
2670 })",
2671 Vals: opClass.getClassName(), Vals: builderOpState);
2672}
2673
2674void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
2675 auto emit = [&](AttrParamKind attrType) {
2676 SmallVector<MethodParameter> paramList;
2677 SmallVector<std::string, 4> resultNames;
2678 llvm::StringSet<> inferredAttributes;
2679 buildParamList(paramList, inferredAttributes, resultTypeNames&: resultNames,
2680 typeParamKind: TypeParamKind::None, attrParamKind: attrType);
2681
2682 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args: std::move(paramList));
2683 // If the builder is redundant, skip generating the method
2684 if (!m)
2685 return;
2686 auto &body = m->body();
2687 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
2688 /*isRawValueAttr=*/attrType ==
2689 AttrParamKind::UnwrappedValue);
2690
2691 auto numResults = op.getNumResults();
2692 if (numResults == 0)
2693 return;
2694
2695 // Push all result types to the operation state
2696 const char *index = op.getOperand(index: 0).isVariadic() ? ".front()" : "";
2697 std::string resultType =
2698 formatv(Fmt: "{0}{1}.getType()", Vals: getArgumentName(op, index: 0), Vals&: index).str();
2699 body << " " << builderOpState << ".addTypes({" << resultType;
2700 for (int i = 1; i != numResults; ++i)
2701 body << ", " << resultType;
2702 body << "});\n\n";
2703 };
2704
2705 emit(AttrParamKind::WrappedAttr);
2706 // Generate additional builder(s) if any attributes can be "unwrapped"
2707 if (canGenerateUnwrappedBuilder(op))
2708 emit(AttrParamKind::UnwrappedValue);
2709}
2710
2711void OpEmitter::genUseAttrAsResultTypeBuilder() {
2712 SmallVector<MethodParameter> paramList;
2713 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: "odsBuilder");
2714 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
2715 paramList.emplace_back(Args: "::mlir::ValueRange", Args: "operands");
2716 paramList.emplace_back(Args: "::llvm::ArrayRef<::mlir::NamedAttribute>",
2717 Args: "attributes", Args: "{}");
2718 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args: std::move(paramList));
2719 // If the builder is redundant, skip generating the method
2720 if (!m)
2721 return;
2722
2723 auto &body = m->body();
2724
2725 // Push all result types to the operation state
2726 std::string resultType;
2727 const auto &namedAttr = op.getAttribute(index: 0);
2728
2729 body << " auto attrName = " << op.getGetterName(name: namedAttr.name)
2730 << "AttrName(" << builderOpState
2731 << ".name);\n"
2732 " for (auto attr : attributes) {\n"
2733 " if (attr.getName() != attrName) continue;\n";
2734 if (namedAttr.attr.isTypeAttr()) {
2735 resultType = "::llvm::cast<::mlir::TypeAttr>(attr.getValue()).getValue()";
2736 } else {
2737 resultType = "::llvm::cast<::mlir::TypedAttr>(attr.getValue()).getType()";
2738 }
2739
2740 // Operands
2741 body << " " << builderOpState << ".addOperands(operands);\n";
2742
2743 // Attributes
2744 body << " " << builderOpState << ".addAttributes(attributes);\n";
2745
2746 // Result types
2747 SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
2748 body << " " << builderOpState << ".addTypes({"
2749 << llvm::join(R&: resultTypes, Separator: ", ") << "});\n";
2750 body << " }\n";
2751}
2752
2753/// Returns a signature of the builder. Updates the context `fctx` to enable
2754/// replacement of $_builder and $_state in the body.
2755static SmallVector<MethodParameter>
2756getBuilderSignature(const Builder &builder) {
2757 ArrayRef<Builder::Parameter> params(builder.getParameters());
2758
2759 // Inject builder and state arguments.
2760 SmallVector<MethodParameter> arguments;
2761 arguments.reserve(N: params.size() + 2);
2762 arguments.emplace_back(Args: "::mlir::OpBuilder &", Args: odsBuilder);
2763 arguments.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
2764
2765 for (unsigned i = 0, e = params.size(); i < e; ++i) {
2766 // If no name is provided, generate one.
2767 std::optional<StringRef> paramName = params[i].getName();
2768 std::string name =
2769 paramName ? paramName->str() : "odsArg" + std::to_string(val: i);
2770
2771 StringRef defaultValue;
2772 if (std::optional<StringRef> defaultParamValue =
2773 params[i].getDefaultValue())
2774 defaultValue = *defaultParamValue;
2775
2776 arguments.emplace_back(Args: params[i].getCppType(), Args: std::move(name),
2777 Args&: defaultValue);
2778 }
2779
2780 return arguments;
2781}
2782
2783void OpEmitter::genBuilder() {
2784 // Handle custom builders if provided.
2785 for (const Builder &builder : op.getBuilders()) {
2786 SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
2787
2788 std::optional<StringRef> body = builder.getBody();
2789 auto properties = body ? Method::Static : Method::StaticDeclaration;
2790 auto *method =
2791 opClass.addMethod(retType: "void", name: "build", properties, args: std::move(arguments));
2792 if (body)
2793 ERROR_IF_PRUNED(method, "build", op);
2794
2795 if (method)
2796 method->setDeprecated(builder.getDeprecatedMessage());
2797
2798 FmtContext fctx;
2799 fctx.withBuilder(subst: odsBuilder);
2800 fctx.addSubst(placeholder: "_state", subst: builderOpState);
2801 if (body)
2802 method->body() << tgfmt(fmt: *body, ctx: &fctx);
2803 }
2804
2805 // Generate default builders that requires all result type, operands, and
2806 // attributes as parameters.
2807 if (op.skipDefaultBuilders())
2808 return;
2809
2810 // We generate three classes of builders here:
2811 // 1. one having a stand-alone parameter for each operand / attribute, and
2812 genSeparateArgParamBuilder();
2813 // 2. one having an aggregated parameter for all result types / operands /
2814 // attributes, and
2815 genCollectiveParamBuilder();
2816 // 3. one having a stand-alone parameter for each operand and attribute,
2817 // use the first operand or attribute's type as all result types
2818 // to facilitate different call patterns.
2819 if (op.getNumVariableLengthResults() == 0) {
2820 if (op.getTrait(trait: "::mlir::OpTrait::SameOperandsAndResultType")) {
2821 genUseOperandAsResultTypeSeparateParamBuilder();
2822 genUseOperandAsResultTypeCollectiveParamBuilder();
2823 }
2824 if (op.getTrait(trait: "::mlir::OpTrait::FirstAttrDerivedResultType"))
2825 genUseAttrAsResultTypeBuilder();
2826 }
2827}
2828
2829void OpEmitter::genCollectiveParamBuilder() {
2830 int numResults = op.getNumResults();
2831 int numVariadicResults = op.getNumVariableLengthResults();
2832 int numNonVariadicResults = numResults - numVariadicResults;
2833
2834 int numOperands = op.getNumOperands();
2835 int numVariadicOperands = op.getNumVariableLengthOperands();
2836 int numNonVariadicOperands = numOperands - numVariadicOperands;
2837
2838 SmallVector<MethodParameter> paramList;
2839 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: "");
2840 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
2841 paramList.emplace_back(Args: "::mlir::TypeRange", Args: "resultTypes");
2842 paramList.emplace_back(Args: "::mlir::ValueRange", Args: "operands");
2843 // Provide default value for `attributes` when its the last parameter
2844 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2845 paramList.emplace_back(Args: "::llvm::ArrayRef<::mlir::NamedAttribute>",
2846 Args: "attributes", Args&: attributesDefaultValue);
2847 if (op.getNumVariadicRegions())
2848 paramList.emplace_back(Args: "unsigned", Args: "numRegions");
2849
2850 auto *m = opClass.addStaticMethod(retType: "void", name: "build", args: std::move(paramList));
2851 // If the builder is redundant, skip generating the method
2852 if (!m)
2853 return;
2854 auto &body = m->body();
2855
2856 // Operands
2857 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
2858 body << " assert(operands.size()"
2859 << (numVariadicOperands != 0 ? " >= " : " == ")
2860 << numNonVariadicOperands
2861 << "u && \"mismatched number of parameters\");\n";
2862 body << " " << builderOpState << ".addOperands(operands);\n";
2863
2864 // Attributes
2865 body << " " << builderOpState << ".addAttributes(attributes);\n";
2866
2867 // Create the correct number of regions
2868 if (int numRegions = op.getNumRegions()) {
2869 body << llvm::formatv(
2870 Fmt: " for (unsigned i = 0; i != {0}; ++i)\n",
2871 Vals: (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2872 body << " (void)" << builderOpState << ".addRegion();\n";
2873 }
2874
2875 // Result types
2876 if (numVariadicResults == 0 || numNonVariadicResults != 0)
2877 body << " assert(resultTypes.size()"
2878 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
2879 << "u && \"mismatched number of return types\");\n";
2880 body << " " << builderOpState << ".addTypes(resultTypes);\n";
2881
2882 // Generate builder that infers type too.
2883 // TODO: Expand to handle successors.
2884 if (canInferType(op) && op.getNumSuccessors() == 0)
2885 genInferredTypeCollectiveParamBuilder();
2886}
2887
2888void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
2889 llvm::StringSet<> &inferredAttributes,
2890 SmallVectorImpl<std::string> &resultTypeNames,
2891 TypeParamKind typeParamKind,
2892 AttrParamKind attrParamKind) {
2893 resultTypeNames.clear();
2894 auto numResults = op.getNumResults();
2895 resultTypeNames.reserve(N: numResults);
2896
2897 paramList.emplace_back(Args: "::mlir::OpBuilder &", Args: odsBuilder);
2898 paramList.emplace_back(Args: "::mlir::OperationState &", Args: builderOpState);
2899
2900 switch (typeParamKind) {
2901 case TypeParamKind::None:
2902 break;
2903 case TypeParamKind::Separate: {
2904 // Add parameters for all return types
2905 for (int i = 0; i < numResults; ++i) {
2906 const auto &result = op.getResult(index: i);
2907 std::string resultName = std::string(result.name);
2908 if (resultName.empty())
2909 resultName = std::string(formatv(Fmt: "resultType{0}", Vals&: i));
2910
2911 StringRef type =
2912 result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
2913
2914 paramList.emplace_back(Args&: type, Args&: resultName, Args: result.isOptional());
2915 resultTypeNames.emplace_back(Args: std::move(resultName));
2916 }
2917 } break;
2918 case TypeParamKind::Collective: {
2919 paramList.emplace_back(Args: "::mlir::TypeRange", Args: "resultTypes");
2920 resultTypeNames.push_back(Elt: "resultTypes");
2921 } break;
2922 }
2923
2924 // Add parameters for all arguments (operands and attributes).
2925 int defaultValuedAttrStartIndex = op.getNumArgs();
2926 // Successors and variadic regions go at the end of the parameter list, so no
2927 // default arguments are possible.
2928 bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions();
2929 if (!hasTrailingParams) {
2930 // Calculate the start index from which we can attach default values in the
2931 // builder declaration.
2932 for (int i = op.getNumArgs() - 1; i >= 0; --i) {
2933 auto *namedAttr =
2934 llvm::dyn_cast_if_present<tblgen::NamedAttribute *>(Val: op.getArg(index: i));
2935 if (!namedAttr)
2936 break;
2937
2938 Attribute attr = namedAttr->attr;
2939 // TODO: Currently we can't differentiate between optional meaning do not
2940 // verify/not always error if missing or optional meaning need not be
2941 // specified in builder. Expand isOptional once we can differentiate.
2942 if (!attr.hasDefaultValue() && !attr.isDerivedAttr())
2943 break;
2944
2945 // Creating an APInt requires us to provide bitwidth, value, and
2946 // signedness, which is complicated compared to others. Similarly
2947 // for APFloat.
2948 // TODO: Adjust the 'returnType' field of such attributes
2949 // to support them.
2950 StringRef retType = namedAttr->attr.getReturnType();
2951 if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
2952 break;
2953
2954 defaultValuedAttrStartIndex = i;
2955 }
2956 }
2957 // Avoid generating build methods that are ambiguous due to default values by
2958 // requiring at least one attribute.
2959 if (defaultValuedAttrStartIndex < op.getNumArgs()) {
2960 // TODO: This should have been possible as a cast<NamedAttribute> but
2961 // required template instantiations is not yet defined for the tblgen helper
2962 // classes.
2963 auto *namedAttr =
2964 cast<NamedAttribute *>(Val: op.getArg(index: defaultValuedAttrStartIndex));
2965 Attribute attr = namedAttr->attr;
2966 if ((attrParamKind == AttrParamKind::WrappedAttr &&
2967 canUseUnwrappedRawValue(attr)) ||
2968 (attrParamKind == AttrParamKind::UnwrappedValue &&
2969 !canUseUnwrappedRawValue(attr)))
2970 ++defaultValuedAttrStartIndex;
2971 }
2972
2973 /// Collect any inferred attributes.
2974 for (const NamedTypeConstraint &operand : op.getOperands()) {
2975 if (operand.isVariadicOfVariadic()) {
2976 inferredAttributes.insert(
2977 key: operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
2978 }
2979 }
2980
2981 for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
2982 Argument arg = op.getArg(index: i);
2983 if (const auto *operand =
2984 llvm::dyn_cast_if_present<NamedTypeConstraint *>(Val&: arg)) {
2985 StringRef type;
2986 if (operand->isVariadicOfVariadic())
2987 type = "::llvm::ArrayRef<::mlir::ValueRange>";
2988 else if (operand->isVariadic())
2989 type = "::mlir::ValueRange";
2990 else
2991 type = "::mlir::Value";
2992
2993 paramList.emplace_back(Args&: type, Args: getArgumentName(op, index: numOperands++),
2994 Args: operand->isOptional());
2995 continue;
2996 }
2997 if (llvm::isa_and_present<NamedProperty *>(Val: arg)) {
2998 // TODO
2999 continue;
3000 }
3001 const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
3002 const Attribute &attr = namedAttr.attr;
3003
3004 // Inferred attributes don't need to be added to the param list.
3005 if (inferredAttributes.contains(key: namedAttr.name))
3006 continue;
3007
3008 StringRef type;
3009 switch (attrParamKind) {
3010 case AttrParamKind::WrappedAttr:
3011 type = attr.getStorageType();
3012 break;
3013 case AttrParamKind::UnwrappedValue:
3014 if (canUseUnwrappedRawValue(attr))
3015 type = attr.getReturnType();
3016 else
3017 type = attr.getStorageType();
3018 break;
3019 }
3020
3021 // Attach default value if requested and possible.
3022 std::string defaultValue;
3023 if (i >= defaultValuedAttrStartIndex) {
3024 if (attrParamKind == AttrParamKind::UnwrappedValue &&
3025 canUseUnwrappedRawValue(attr))
3026 defaultValue += attr.getDefaultValue();
3027 else
3028 defaultValue += "nullptr";
3029 }
3030 paramList.emplace_back(Args&: type, Args: namedAttr.name, Args: StringRef(defaultValue),
3031 Args: attr.isOptional());
3032 }
3033
3034 /// Insert parameters for each successor.
3035 for (const NamedSuccessor &succ : op.getSuccessors()) {
3036 StringRef type =
3037 succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
3038 paramList.emplace_back(Args&: type, Args: succ.name);
3039 }
3040
3041 /// Insert parameters for variadic regions.
3042 for (const NamedRegion &region : op.getRegions())
3043 if (region.isVariadic())
3044 paramList.emplace_back(Args: "unsigned",
3045 Args: llvm::formatv(Fmt: "{0}Count", Vals: region.name).str());
3046}
3047
3048void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
3049 MethodBody &body, llvm::StringSet<> &inferredAttributes,
3050 bool isRawValueAttr) {
3051 // Push all operands to the result.
3052 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
3053 std::string argName = getArgumentName(op, index: i);
3054 const NamedTypeConstraint &operand = op.getOperand(index: i);
3055 if (operand.constraint.isVariadicOfVariadic()) {
3056 body << " for (::mlir::ValueRange range : " << argName << ")\n "
3057 << builderOpState << ".addOperands(range);\n";
3058
3059 // Add the segment attribute.
3060 body << " {\n"
3061 << " ::llvm::SmallVector<int32_t> rangeSegments;\n"
3062 << " for (::mlir::ValueRange range : " << argName << ")\n"
3063 << " rangeSegments.push_back(range.size());\n"
3064 << " auto rangeAttr = " << odsBuilder
3065 << ".getDenseI32ArrayAttr(rangeSegments);\n";
3066 if (op.getDialect().usePropertiesForAttributes()) {
3067 body << " " << builderOpState << ".getOrAddProperties<Properties>()."
3068 << operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
3069 << " = rangeAttr;";
3070 } else {
3071 body << " " << builderOpState << ".addAttribute("
3072 << op.getGetterName(
3073 name: operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
3074 << "AttrName(" << builderOpState << ".name), rangeAttr);";
3075 }
3076 body << " }\n";
3077 continue;
3078 }
3079
3080 if (operand.isOptional())
3081 body << " if (" << argName << ")\n ";
3082 body << " " << builderOpState << ".addOperands(" << argName << ");\n";
3083 }
3084
3085 // If the operation has the operand segment size attribute, add it here.
3086 auto emitSegment = [&]() {
3087 interleaveComma(c: llvm::seq<int>(Begin: 0, End: op.getNumOperands()), os&: body, each_fn: [&](int i) {
3088 const NamedTypeConstraint &operand = op.getOperand(index: i);
3089 if (!operand.isVariableLength()) {
3090 body << "1";
3091 return;
3092 }
3093
3094 std::string operandName = getArgumentName(op, index: i);
3095 if (operand.isOptional()) {
3096 body << "(" << operandName << " ? 1 : 0)";
3097 } else if (operand.isVariadicOfVariadic()) {
3098 body << llvm::formatv(
3099 Fmt: "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
3100 "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
3101 "static_cast<int32_t>(range.size()); }))",
3102 Vals&: operandName);
3103 } else {
3104 body << "static_cast<int32_t>(" << getArgumentName(op, index: i) << ".size())";
3105 }
3106 });
3107 };
3108 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
3109 std::string sizes = op.getGetterName(name: operandSegmentAttrName);
3110 if (op.getDialect().usePropertiesForAttributes()) {
3111 body << " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
3112 emitSegment();
3113 body << "}), " << builderOpState
3114 << ".getOrAddProperties<Properties>()."
3115 "operandSegmentSizes.begin());\n";
3116 } else {
3117 body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
3118 << builderOpState << ".name), "
3119 << "odsBuilder.getDenseI32ArrayAttr({";
3120 emitSegment();
3121 body << "}));\n";
3122 }
3123 }
3124
3125 // Push all attributes to the result.
3126 for (const auto &namedAttr : op.getAttributes()) {
3127 auto &attr = namedAttr.attr;
3128 if (attr.isDerivedAttr() || inferredAttributes.contains(key: namedAttr.name))
3129 continue;
3130
3131 // TODO: The wrapping of optional is different for default or not, so don't
3132 // unwrap for default ones that would fail below.
3133 bool emitNotNullCheck =
3134 (attr.isOptional() && !attr.hasDefaultValue()) ||
3135 (attr.hasDefaultValue() && !isRawValueAttr) ||
3136 // TODO: UnitAttr is optional, not wrapped, but needs to be guarded as
3137 // the constant materialization is only for true case.
3138 (isRawValueAttr && attr.getAttrDefName() == "UnitAttr");
3139 if (emitNotNullCheck)
3140 body.indent() << formatv(Fmt: "if ({0}) ", Vals: namedAttr.name) << "{\n";
3141
3142 if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
3143 // If this is a raw value, then we need to wrap it in an Attribute
3144 // instance.
3145 FmtContext fctx;
3146 fctx.withBuilder(subst: "odsBuilder");
3147 if (op.getDialect().usePropertiesForAttributes()) {
3148 body << formatv(Fmt: " {0}.getOrAddProperties<Properties>().{1} = {2};\n",
3149 Vals: builderOpState, Vals: namedAttr.name,
3150 Vals: constBuildAttrFromParam(attr, fctx, paramName: namedAttr.name));
3151 } else {
3152 body << formatv(Fmt: " {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3153 Vals: builderOpState, Vals: op.getGetterName(name: namedAttr.name),
3154 Vals: constBuildAttrFromParam(attr, fctx, paramName: namedAttr.name));
3155 }
3156 } else {
3157 if (op.getDialect().usePropertiesForAttributes()) {
3158 body << formatv(Fmt: " {0}.getOrAddProperties<Properties>().{1} = {1};\n",
3159 Vals: builderOpState, Vals: namedAttr.name);
3160 } else {
3161 body << formatv(Fmt: " {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3162 Vals: builderOpState, Vals: op.getGetterName(name: namedAttr.name),
3163 Vals: namedAttr.name);
3164 }
3165 }
3166 if (emitNotNullCheck)
3167 body.unindent() << " }\n";
3168 }
3169
3170 // Create the correct number of regions.
3171 for (const NamedRegion &region : op.getRegions()) {
3172 if (region.isVariadic())
3173 body << formatv(Fmt: " for (unsigned i = 0; i < {0}Count; ++i)\n ",
3174 Vals: region.name);
3175
3176 body << " (void)" << builderOpState << ".addRegion();\n";
3177 }
3178
3179 // Push all successors to the result.
3180 for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
3181 body << formatv(Fmt: " {0}.addSuccessors({1});\n", Vals: builderOpState,
3182 Vals: namedSuccessor.name);
3183 }
3184}
3185
3186void OpEmitter::genCanonicalizerDecls() {
3187 bool hasCanonicalizeMethod = def.getValueAsBit(FieldName: "hasCanonicalizeMethod");
3188 if (hasCanonicalizeMethod) {
3189 // static LogicResult FooOp::
3190 // canonicalize(FooOp op, PatternRewriter &rewriter);
3191 SmallVector<MethodParameter> paramList;
3192 paramList.emplace_back(Args: op.getCppClassName(), Args: "op");
3193 paramList.emplace_back(Args: "::mlir::PatternRewriter &", Args: "rewriter");
3194 auto *m = opClass.declareStaticMethod(retType: "::mlir::LogicalResult",
3195 name: "canonicalize", args: std::move(paramList));
3196 ERROR_IF_PRUNED(m, "canonicalize", op);
3197 }
3198
3199 // We get a prototype for 'getCanonicalizationPatterns' if requested directly
3200 // or if using a 'canonicalize' method.
3201 bool hasCanonicalizer = def.getValueAsBit(FieldName: "hasCanonicalizer");
3202 if (!hasCanonicalizeMethod && !hasCanonicalizer)
3203 return;
3204
3205 // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
3206 // method, but not implementing 'getCanonicalizationPatterns' manually.
3207 bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
3208
3209 // Add a signature for getCanonicalizationPatterns if implemented by the
3210 // dialect or if synthesized to call 'canonicalize'.
3211 SmallVector<MethodParameter> paramList;
3212 paramList.emplace_back(Args: "::mlir::RewritePatternSet &", Args: "results");
3213 paramList.emplace_back(Args: "::mlir::MLIRContext *", Args: "context");
3214 auto kind = hasBody ? Method::Static : Method::StaticDeclaration;
3215 auto *method = opClass.addMethod(retType: "void", name: "getCanonicalizationPatterns", properties: kind,
3216 args: std::move(paramList));
3217
3218 // If synthesizing the method, fill it.
3219 if (hasBody) {
3220 ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op);
3221 method->body() << " results.add(canonicalize);\n";
3222 }
3223}
3224
3225void OpEmitter::genFolderDecls() {
3226 if (!op.hasFolder())
3227 return;
3228
3229 SmallVector<MethodParameter> paramList;
3230 paramList.emplace_back(Args: "FoldAdaptor", Args: "adaptor");
3231
3232 StringRef retType;
3233 bool hasSingleResult =
3234 op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
3235 if (hasSingleResult) {
3236 retType = "::mlir::OpFoldResult";
3237 } else {
3238 paramList.emplace_back(Args: "::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
3239 Args: "results");
3240 retType = "::mlir::LogicalResult";
3241 }
3242
3243 auto *m = opClass.declareMethod(retType, name: "fold", args: std::move(paramList));
3244 ERROR_IF_PRUNED(m, "fold", op);
3245}
3246
3247void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
3248 Interface interface = opTrait->getInterface();
3249
3250 // Get the set of methods that should always be declared.
3251 auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
3252 llvm::StringSet<> alwaysDeclaredMethods;
3253 alwaysDeclaredMethods.insert(begin: alwaysDeclaredMethodsVec.begin(),
3254 end: alwaysDeclaredMethodsVec.end());
3255
3256 for (const InterfaceMethod &method : interface.getMethods()) {
3257 // Don't declare if the method has a body.
3258 if (method.getBody())
3259 continue;
3260 // Don't declare if the method has a default implementation and the op
3261 // didn't request that it always be declared.
3262 if (method.getDefaultImplementation() &&
3263 !alwaysDeclaredMethods.count(Key: method.getName()))
3264 continue;
3265 // Interface methods are allowed to overlap with existing methods, so don't
3266 // check if pruned.
3267 (void)genOpInterfaceMethod(method);
3268 }
3269}
3270
3271Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
3272 bool declaration) {
3273 SmallVector<MethodParameter> paramList;
3274 for (const InterfaceMethod::Argument &arg : method.getArguments())
3275 paramList.emplace_back(Args: arg.type, Args: arg.name);
3276
3277 auto props = (method.isStatic() ? Method::Static : Method::None) |
3278 (declaration ? Method::Declaration : Method::None);
3279 return opClass.addMethod(retType: method.getReturnType(), name: method.getName(), properties: props,
3280 args: std::move(paramList));
3281}
3282
3283void OpEmitter::genOpInterfaceMethods() {
3284 for (const auto &trait : op.getTraits()) {
3285 if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(Val: &trait))
3286 if (opTrait->shouldDeclareMethods())
3287 genOpInterfaceMethods(opTrait);
3288 }
3289}
3290
3291void OpEmitter::genSideEffectInterfaceMethods() {
3292 enum EffectKind { Operand, Result, Symbol, Static };
3293 struct EffectLocation {
3294 /// The effect applied.
3295 SideEffect effect;
3296
3297 /// The index if the kind is not static.
3298 unsigned index;
3299
3300 /// The kind of the location.
3301 unsigned kind;
3302 };
3303
3304 StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
3305 auto resolveDecorators = [&](Operator::var_decorator_range decorators,
3306 unsigned index, unsigned kind) {
3307 for (auto decorator : decorators)
3308 if (SideEffect *effect = dyn_cast<SideEffect>(Val: &decorator)) {
3309 opClass.addTrait(trait: effect->getInterfaceTrait());
3310 interfaceEffects[effect->getBaseEffectName()].push_back(
3311 Elt: EffectLocation{.effect: *effect, .index: index, .kind: kind});
3312 }
3313 };
3314
3315 // Collect effects that were specified via:
3316 /// Traits.
3317 for (const auto &trait : op.getTraits()) {
3318 const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(Val: &trait);
3319 if (!opTrait)
3320 continue;
3321 auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
3322 for (auto decorator : opTrait->getEffects())
3323 effects.push_back(Elt: EffectLocation{.effect: cast<SideEffect>(Val&: decorator),
3324 /*index=*/0, .kind: EffectKind::Static});
3325 }
3326 /// Attributes and Operands.
3327 for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
3328 Argument arg = op.getArg(index: i);
3329 if (arg.is<NamedTypeConstraint *>()) {
3330 resolveDecorators(op.getArgDecorators(index: i), operandIt, EffectKind::Operand);
3331 ++operandIt;
3332 continue;
3333 }
3334 if (arg.is<NamedProperty *>())
3335 continue;
3336 const NamedAttribute *attr = arg.get<NamedAttribute *>();
3337 if (attr->attr.getBaseAttr().isSymbolRefAttr())
3338 resolveDecorators(op.getArgDecorators(index: i), i, EffectKind::Symbol);
3339 }
3340 /// Results.
3341 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
3342 resolveDecorators(op.getResultDecorators(index: i), i, EffectKind::Result);
3343
3344 // The code used to add an effect instance.
3345 // {0}: The effect class.
3346 // {1}: Optional value or symbol reference.
3347 // {2}: The side effect stage.
3348 // {3}: Does this side effect act on every single value of resource.
3349 // {4}: The resource class.
3350 const char *addEffectCode =
3351 " effects.emplace_back({0}::get(), {1}{2}, {3}, {4}::get());\n";
3352
3353 for (auto &it : interfaceEffects) {
3354 // Generate the 'getEffects' method.
3355 std::string type = llvm::formatv(Fmt: "::llvm::SmallVectorImpl<::mlir::"
3356 "SideEffects::EffectInstance<{0}>> &",
3357 Vals: it.first())
3358 .str();
3359 auto *getEffects = opClass.addMethod(retType: "void", name: "getEffects",
3360 args: MethodParameter(type, "effects"));
3361 ERROR_IF_PRUNED(getEffects, "getEffects", op);
3362 auto &body = getEffects->body();
3363
3364 // Add effect instances for each of the locations marked on the operation.
3365 for (auto &location : it.second) {
3366 StringRef effect = location.effect.getName();
3367 StringRef resource = location.effect.getResource();
3368 int stage = (int)location.effect.getStage();
3369 bool effectOnFullRegion = (int)location.effect.getEffectOnfullRegion();
3370 if (location.kind == EffectKind::Static) {
3371 // A static instance has no attached value.
3372 body << llvm::formatv(Fmt: addEffectCode, Vals&: effect, Vals: "", Vals&: stage,
3373 Vals&: effectOnFullRegion, Vals&: resource)
3374 .str();
3375 } else if (location.kind == EffectKind::Symbol) {
3376 // A symbol reference requires adding the proper attribute.
3377 const auto *attr = op.getArg(index: location.index).get<NamedAttribute *>();
3378 std::string argName = op.getGetterName(name: attr->name);
3379 if (attr->attr.isOptional()) {
3380 body << " if (auto symbolRef = " << argName << "Attr())\n "
3381 << llvm::formatv(Fmt: addEffectCode, Vals&: effect, Vals: "symbolRef, ", Vals&: stage,
3382 Vals&: effectOnFullRegion, Vals&: resource)
3383 .str();
3384 } else {
3385 body << llvm::formatv(Fmt: addEffectCode, Vals&: effect, Vals: argName + "Attr(), ",
3386 Vals&: stage, Vals&: effectOnFullRegion, Vals&: resource)
3387 .str();
3388 }
3389 } else {
3390 // Otherwise this is an operand/result, so we need to attach the Value.
3391 body << " for (::mlir::Value value : getODS"
3392 << (location.kind == EffectKind::Operand ? "Operands" : "Results")
3393 << "(" << location.index << "))\n "
3394 << llvm::formatv(Fmt: addEffectCode, Vals&: effect, Vals: "value, ", Vals&: stage,
3395 Vals&: effectOnFullRegion, Vals&: resource)
3396 .str();
3397 }
3398 }
3399 }
3400}
3401
3402void OpEmitter::genTypeInterfaceMethods() {
3403 if (!op.allResultTypesKnown())
3404 return;
3405 // Generate 'inferReturnTypes' method declaration using the interface method
3406 // declared in 'InferTypeOpInterface' op interface.
3407 const auto *trait =
3408 cast<InterfaceTrait>(Val: op.getTrait(trait: "::mlir::InferTypeOpInterface::Trait"));
3409 Interface interface = trait->getInterface();
3410 Method *method = [&]() -> Method * {
3411 for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
3412 if (interfaceMethod.getName() == "inferReturnTypes") {
3413 return genOpInterfaceMethod(method: interfaceMethod, /*declaration=*/false);
3414 }
3415 }
3416 assert(0 && "unable to find inferReturnTypes interface method");
3417 return nullptr;
3418 }();
3419 ERROR_IF_PRUNED(method, "inferReturnTypes", op);
3420 auto &body = method->body();
3421 body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
3422
3423 FmtContext fctx;
3424 fctx.withBuilder(subst: "odsBuilder");
3425 fctx.addSubst(placeholder: "_ctxt", subst: "context");
3426 body << " ::mlir::Builder odsBuilder(context);\n";
3427
3428 // Process the type inference graph in topological order, starting from types
3429 // that are always fully-inferred: operands and results with constructible
3430 // types. The type inference graph here will always be a DAG, so this gives
3431 // us the correct order for generating the types. -1 is a placeholder to
3432 // indicate the type for a result has not been generated.
3433 SmallVector<int> constructedIndices(op.getNumResults(), -1);
3434 int inferredTypeIdx = 0;
3435 for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) {
3436 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
3437 if (constructedIndices[i] >= 0)
3438 continue;
3439 const InferredResultType &infer = op.getInferredResultType(index: i);
3440 std::string typeStr;
3441 if (infer.isArg()) {
3442 // If this is an operand, just index into operand list to access the
3443 // type.
3444 auto arg = op.getArgToOperandOrAttribute(index: infer.getIndex());
3445 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3446 typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
3447 "].getType()")
3448 .str();
3449
3450 // If this is an attribute, index into the attribute dictionary.
3451 } else {
3452 auto *attr =
3453 op.getArg(index: arg.operandOrAttributeIndex()).get<NamedAttribute *>();
3454 body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
3455 << " = ";
3456 if (op.getDialect().usePropertiesForAttributes()) {
3457 body << "(properties ? properties.as<Properties *>()->"
3458 << attr->name
3459 << " : "
3460 "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3461 "get(\"" +
3462 attr->name + "\")));\n";
3463 } else {
3464 body << "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3465 "get(\"" +
3466 attr->name + "\"));\n";
3467 }
3468 body << " if (!odsInferredTypeAttr" << inferredTypeIdx
3469 << ") return ::mlir::failure();\n";
3470 typeStr =
3471 ("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()")
3472 .str();
3473 }
3474 } else if (std::optional<StringRef> builder =
3475 op.getResult(index: infer.getResultIndex())
3476 .constraint.getBuilderCall()) {
3477 typeStr = tgfmt(fmt: *builder, ctx: &fctx).str();
3478 } else if (int index = constructedIndices[infer.getResultIndex()];
3479 index >= 0) {
3480 typeStr = ("odsInferredType" + Twine(index)).str();
3481 } else {
3482 continue;
3483 }
3484 body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
3485 << tgfmt(fmt: infer.getTransformer(), ctx: &fctx.withSelf(subst: typeStr)) << ";\n";
3486 constructedIndices[i] = inferredTypeIdx - 1;
3487 }
3488 }
3489 for (auto [i, index] : llvm::enumerate(First&: constructedIndices))
3490 body << " inferredReturnTypes[" << i << "] = odsInferredType" << index
3491 << ";\n";
3492 body << " return ::mlir::success();";
3493}
3494
3495void OpEmitter::genParser() {
3496 if (hasStringAttribute(record: def, fieldName: "assemblyFormat"))
3497 return;
3498
3499 if (!def.getValueAsBit(FieldName: "hasCustomAssemblyFormat"))
3500 return;
3501
3502 SmallVector<MethodParameter> paramList;
3503 paramList.emplace_back(Args: "::mlir::OpAsmParser &", Args: "parser");
3504 paramList.emplace_back(Args: "::mlir::OperationState &", Args: "result");
3505
3506 auto *method = opClass.declareStaticMethod(retType: "::mlir::ParseResult", name: "parse",
3507 args: std::move(paramList));
3508 ERROR_IF_PRUNED(method, "parse", op);
3509}
3510
3511void OpEmitter::genPrinter() {
3512 if (hasStringAttribute(record: def, fieldName: "assemblyFormat"))
3513 return;
3514
3515 // Check to see if this op uses a c++ format.
3516 if (!def.getValueAsBit(FieldName: "hasCustomAssemblyFormat"))
3517 return;
3518 auto *method = opClass.declareMethod(
3519 retType: "void", name: "print", args: MethodParameter("::mlir::OpAsmPrinter &", "p"));
3520 ERROR_IF_PRUNED(method, "print", op);
3521}
3522
3523void OpEmitter::genVerifier() {
3524 auto *implMethod =
3525 opClass.addMethod(retType: "::mlir::LogicalResult", name: "verifyInvariantsImpl");
3526 ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
3527 auto &implBody = implMethod->body();
3528 bool useProperties = emitHelper.hasProperties();
3529
3530 populateSubstitutions(emitHelper, ctx&: verifyCtx);
3531 genAttributeVerifier(emitHelper, ctx&: verifyCtx, body&: implBody, staticVerifierEmitter,
3532 useProperties);
3533 genOperandResultVerifier(body&: implBody, values: op.getOperands(), valueKind: "operand");
3534 genOperandResultVerifier(body&: implBody, values: op.getResults(), valueKind: "result");
3535
3536 for (auto &trait : op.getTraits()) {
3537 if (auto *t = dyn_cast<tblgen::PredTrait>(Val: &trait)) {
3538 implBody << tgfmt(fmt: " if (!($0))\n "
3539 "return emitOpError(\"failed to verify that $1\");\n",
3540 ctx: &verifyCtx, vals: tgfmt(fmt: t->getPredTemplate(), ctx: &verifyCtx),
3541 vals: t->getSummary());
3542 }
3543 }
3544
3545 genRegionVerifier(body&: implBody);
3546 genSuccessorVerifier(body&: implBody);
3547
3548 implBody << " return ::mlir::success();\n";
3549
3550 // TODO: Some places use the `verifyInvariants` to do operation verification.
3551 // This may not act as their expectation because this doesn't call any
3552 // verifiers of native/interface traits. Needs to review those use cases and
3553 // see if we should use the mlir::verify() instead.
3554 auto *method = opClass.addMethod(retType: "::mlir::LogicalResult", name: "verifyInvariants");
3555 ERROR_IF_PRUNED(method, "verifyInvariants", op);
3556 auto &body = method->body();
3557 if (def.getValueAsBit(FieldName: "hasVerifier")) {
3558 body << " if(::mlir::succeeded(verifyInvariantsImpl()) && "
3559 "::mlir::succeeded(verify()))\n";
3560 body << " return ::mlir::success();\n";
3561 body << " return ::mlir::failure();";
3562 } else {
3563 body << " return verifyInvariantsImpl();";
3564 }
3565}
3566
3567void OpEmitter::genCustomVerifier() {
3568 if (def.getValueAsBit(FieldName: "hasVerifier")) {
3569 auto *method = opClass.declareMethod(retType: "::mlir::LogicalResult", name: "verify");
3570 ERROR_IF_PRUNED(method, "verify", op);
3571 }
3572
3573 if (def.getValueAsBit(FieldName: "hasRegionVerifier")) {
3574 auto *method =
3575 opClass.declareMethod(retType: "::mlir::LogicalResult", name: "verifyRegions");
3576 ERROR_IF_PRUNED(method, "verifyRegions", op);
3577 }
3578}
3579
3580void OpEmitter::genOperandResultVerifier(MethodBody &body,
3581 Operator::const_value_range values,
3582 StringRef valueKind) {
3583 // Check that an optional value is at most 1 element.
3584 //
3585 // {0}: Value index.
3586 // {1}: "operand" or "result"
3587 const char *const verifyOptional = R"(
3588 if (valueGroup{0}.size() > 1) {
3589 return emitOpError("{1} group starting at #") << index
3590 << " requires 0 or 1 element, but found " << valueGroup{0}.size();
3591 }
3592)";
3593 // Check the types of a range of values.
3594 //
3595 // {0}: Value index.
3596 // {1}: Type constraint function.
3597 // {2}: "operand" or "result"
3598 const char *const verifyValues = R"(
3599 for (auto v : valueGroup{0}) {
3600 if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
3601 return ::mlir::failure();
3602 }
3603)";
3604
3605 const auto canSkip = [](const NamedTypeConstraint &value) {
3606 return !value.hasPredicate() && !value.isOptional() &&
3607 !value.isVariadicOfVariadic();
3608 };
3609 if (values.empty() || llvm::all_of(Range&: values, P: canSkip))
3610 return;
3611
3612 FmtContext fctx;
3613
3614 body << " {\n unsigned index = 0; (void)index;\n";
3615
3616 for (const auto &staticValue : llvm::enumerate(First&: values)) {
3617 const NamedTypeConstraint &value = staticValue.value();
3618
3619 bool hasPredicate = value.hasPredicate();
3620 bool isOptional = value.isOptional();
3621 bool isVariadicOfVariadic = value.isVariadicOfVariadic();
3622 if (!hasPredicate && !isOptional && !isVariadicOfVariadic)
3623 continue;
3624 body << formatv(Fmt: " auto valueGroup{2} = getODS{0}{1}s({2});\n",
3625 // Capitalize the first letter to match the function name
3626 Vals: valueKind.substr(Start: 0, N: 1).upper(), Vals: valueKind.substr(Start: 1),
3627 Vals: staticValue.index());
3628
3629 // If the constraint is optional check that the value group has at most 1
3630 // value.
3631 if (isOptional) {
3632 body << formatv(Fmt: verifyOptional, Vals: staticValue.index(), Vals&: valueKind);
3633 } else if (isVariadicOfVariadic) {
3634 body << formatv(
3635 Fmt: " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
3636 "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
3637 " return ::mlir::failure();\n",
3638 Vals: value.constraint.getVariadicOfVariadicSegmentSizeAttr(), Vals: value.name,
3639 Vals: staticValue.index());
3640 }
3641
3642 // Otherwise, if there is no predicate there is nothing left to do.
3643 if (!hasPredicate)
3644 continue;
3645 // Emit a loop to check all the dynamic values in the pack.
3646 StringRef constraintFn =
3647 staticVerifierEmitter.getTypeConstraintFn(constraint: value.constraint);
3648 body << formatv(Fmt: verifyValues, Vals: staticValue.index(), Vals&: constraintFn, Vals&: valueKind);
3649 }
3650
3651 body << " }\n";
3652}
3653
3654void OpEmitter::genRegionVerifier(MethodBody &body) {
3655 /// Code to verify a region.
3656 ///
3657 /// {0}: Getter for the regions.
3658 /// {1}: The region constraint.
3659 /// {2}: The region's name.
3660 /// {3}: The region description.
3661 const char *const verifyRegion = R"(
3662 for (auto &region : {0})
3663 if (::mlir::failed({1}(*this, region, "{2}", index++)))
3664 return ::mlir::failure();
3665)";
3666 /// Get a single region.
3667 ///
3668 /// {0}: The region's index.
3669 const char *const getSingleRegion =
3670 "::llvm::MutableArrayRef((*this)->getRegion({0}))";
3671
3672 // If we have no regions, there is nothing more to do.
3673 const auto canSkip = [](const NamedRegion &region) {
3674 return region.constraint.getPredicate().isNull();
3675 };
3676 auto regions = op.getRegions();
3677 if (regions.empty() && llvm::all_of(Range&: regions, P: canSkip))
3678 return;
3679
3680 body << " {\n unsigned index = 0; (void)index;\n";
3681 for (const auto &it : llvm::enumerate(First&: regions)) {
3682 const auto &region = it.value();
3683 if (canSkip(region))
3684 continue;
3685
3686 auto getRegion = region.isVariadic()
3687 ? formatv(Fmt: "{0}()", Vals: op.getGetterName(name: region.name)).str()
3688 : formatv(Fmt: getSingleRegion, Vals: it.index()).str();
3689 auto constraintFn =
3690 staticVerifierEmitter.getRegionConstraintFn(constraint: region.constraint);
3691 body << formatv(Fmt: verifyRegion, Vals&: getRegion, Vals&: constraintFn, Vals: region.name);
3692 }
3693 body << " }\n";
3694}
3695
3696void OpEmitter::genSuccessorVerifier(MethodBody &body) {
3697 const char *const verifySuccessor = R"(
3698 for (auto *successor : {0})
3699 if (::mlir::failed({1}(*this, successor, "{2}", index++)))
3700 return ::mlir::failure();
3701)";
3702 /// Get a single successor.
3703 ///
3704 /// {0}: The successor's name.
3705 const char *const getSingleSuccessor = "::llvm::MutableArrayRef({0}())";
3706
3707 // If we have no successors, there is nothing more to do.
3708 const auto canSkip = [](const NamedSuccessor &successor) {
3709 return successor.constraint.getPredicate().isNull();
3710 };
3711 auto successors = op.getSuccessors();
3712 if (successors.empty() && llvm::all_of(Range&: successors, P: canSkip))
3713 return;
3714
3715 body << " {\n unsigned index = 0; (void)index;\n";
3716
3717 for (auto it : llvm::enumerate(First&: successors)) {
3718 const auto &successor = it.value();
3719 if (canSkip(successor))
3720 continue;
3721
3722 auto getSuccessor =
3723 formatv(Fmt: successor.isVariadic() ? "{0}()" : getSingleSuccessor,
3724 Vals: successor.name, Vals: it.index())
3725 .str();
3726 auto constraintFn =
3727 staticVerifierEmitter.getSuccessorConstraintFn(constraint: successor.constraint);
3728 body << formatv(Fmt: verifySuccessor, Vals&: getSuccessor, Vals&: constraintFn,
3729 Vals: successor.name);
3730 }
3731 body << " }\n";
3732}
3733
3734/// Add a size count trait to the given operation class.
3735static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
3736 int numTotal, int numVariadic) {
3737 if (numVariadic != 0) {
3738 if (numTotal == numVariadic)
3739 opClass.addTrait(trait: "::mlir::OpTrait::Variadic" + traitKind + "s");
3740 else
3741 opClass.addTrait(trait: "::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
3742 Twine(numTotal - numVariadic) + ">::Impl");
3743 return;
3744 }
3745 switch (numTotal) {
3746 case 0:
3747 opClass.addTrait(trait: "::mlir::OpTrait::Zero" + traitKind + "s");
3748 break;
3749 case 1:
3750 opClass.addTrait(trait: "::mlir::OpTrait::One" + traitKind);
3751 break;
3752 default:
3753 opClass.addTrait(trait: "::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
3754 ">::Impl");
3755 break;
3756 }
3757}
3758
3759void OpEmitter::genTraits() {
3760 // Add region size trait.
3761 unsigned numRegions = op.getNumRegions();
3762 unsigned numVariadicRegions = op.getNumVariadicRegions();
3763 addSizeCountTrait(opClass, traitKind: "Region", numTotal: numRegions, numVariadic: numVariadicRegions);
3764
3765 // Add result size traits.
3766 int numResults = op.getNumResults();
3767 int numVariadicResults = op.getNumVariableLengthResults();
3768 addSizeCountTrait(opClass, traitKind: "Result", numTotal: numResults, numVariadic: numVariadicResults);
3769
3770 // For single result ops with a known specific type, generate a OneTypedResult
3771 // trait.
3772 if (numResults == 1 && numVariadicResults == 0) {
3773 auto cppName = op.getResults().begin()->constraint.getCPPClassName();
3774 opClass.addTrait(trait: "::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
3775 }
3776
3777 // Add successor size trait.
3778 unsigned numSuccessors = op.getNumSuccessors();
3779 unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
3780 addSizeCountTrait(opClass, traitKind: "Successor", numTotal: numSuccessors, numVariadic: numVariadicSuccessors);
3781
3782 // Add variadic size trait and normal op traits.
3783 int numOperands = op.getNumOperands();
3784 int numVariadicOperands = op.getNumVariableLengthOperands();
3785
3786 // Add operand size trait.
3787 addSizeCountTrait(opClass, traitKind: "Operand", numTotal: numOperands, numVariadic: numVariadicOperands);
3788
3789 // The op traits defined internal are ensured that they can be verified
3790 // earlier.
3791 for (const auto &trait : op.getTraits()) {
3792 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
3793 if (opTrait->isStructuralOpTrait())
3794 opClass.addTrait(trait: opTrait->getFullyQualifiedTraitName());
3795 }
3796 }
3797
3798 // OpInvariants wrapps the verifyInvariants which needs to be run before
3799 // native/interface traits and after all the traits with `StructuralOpTrait`.
3800 opClass.addTrait(trait: "::mlir::OpTrait::OpInvariants");
3801
3802 if (emitHelper.hasProperties())
3803 opClass.addTrait(trait: "::mlir::BytecodeOpInterface::Trait");
3804
3805 // Add the native and interface traits.
3806 for (const auto &trait : op.getTraits()) {
3807 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(Val: &trait)) {
3808 if (!opTrait->isStructuralOpTrait())
3809 opClass.addTrait(trait: opTrait->getFullyQualifiedTraitName());
3810 } else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(Val: &trait)) {
3811 opClass.addTrait(trait: opTrait->getFullyQualifiedTraitName());
3812 }
3813 }
3814}
3815
3816void OpEmitter::genOpNameGetter() {
3817 auto *method = opClass.addStaticMethod<Method::Constexpr>(
3818 retType: "::llvm::StringLiteral", name: "getOperationName");
3819 ERROR_IF_PRUNED(method, "getOperationName", op);
3820 method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName()
3821 << "\");";
3822}
3823
3824void OpEmitter::genOpAsmInterface() {
3825 // If the user only has one results or specifically added the Asm trait,
3826 // then don't generate it for them. We specifically only handle multi result
3827 // operations, because the name of a single result in the common case is not
3828 // interesting(generally 'result'/'output'/etc.).
3829 // TODO: We could also add a flag to allow operations to opt in to this
3830 // generation, even if they only have a single operation.
3831 int numResults = op.getNumResults();
3832 if (numResults <= 1 || op.getTrait(trait: "::mlir::OpAsmOpInterface::Trait"))
3833 return;
3834
3835 SmallVector<StringRef, 4> resultNames(numResults);
3836 for (int i = 0; i != numResults; ++i)
3837 resultNames[i] = op.getResultName(index: i);
3838
3839 // Don't add the trait if none of the results have a valid name.
3840 if (llvm::all_of(Range&: resultNames, P: [](StringRef name) { return name.empty(); }))
3841 return;
3842 opClass.addTrait(trait: "::mlir::OpAsmOpInterface::Trait");
3843
3844 // Generate the right accessor for the number of results.
3845 auto *method = opClass.addMethod(
3846 retType: "void", name: "getAsmResultNames",
3847 args: MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
3848 ERROR_IF_PRUNED(method, "getAsmResultNames", op);
3849 auto &body = method->body();
3850 for (int i = 0; i != numResults; ++i) {
3851 body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n"
3852 << " if (!resultGroup" << i << ".empty())\n"
3853 << " setNameFn(*resultGroup" << i << ".begin(), \""
3854 << resultNames[i] << "\");\n";
3855 }
3856}
3857
3858//===----------------------------------------------------------------------===//
3859// OpOperandAdaptor emitter
3860//===----------------------------------------------------------------------===//
3861
3862namespace {
3863// Helper class to emit Op operand adaptors to an output stream. Operand
3864// adaptors are wrappers around random access ranges that provide named operand
3865// getters identical to those defined in the Op.
3866// This currently generates 3 classes per Op:
3867// * A Base class within the 'detail' namespace, which contains all logic and
3868// members independent of the random access range that is indexed into.
3869// In other words, it contains all the attribute and region getters.
3870// * A templated class named '{OpName}GenericAdaptor' with a template parameter
3871// 'RangeT' that is indexed into by the getters to access the operands.
3872// It contains all getters to access operands and inherits from the previous
3873// class.
3874// * A class named '{OpName}Adaptor', which inherits from the 'GenericAdaptor'
3875// with 'mlir::ValueRange' as template parameter. It adds a constructor from
3876// an instance of the op type and a verify function.
3877class OpOperandAdaptorEmitter {
3878public:
3879 static void
3880 emitDecl(const Operator &op,
3881 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
3882 raw_ostream &os);
3883 static void
3884 emitDef(const Operator &op,
3885 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
3886 raw_ostream &os);
3887
3888private:
3889 explicit OpOperandAdaptorEmitter(
3890 const Operator &op,
3891 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
3892
3893 // Add verification function. This generates a verify method for the adaptor
3894 // which verifies all the op-independent attribute constraints.
3895 void addVerification();
3896
3897 // The operation for which to emit an adaptor.
3898 const Operator &op;
3899
3900 // The generated adaptor classes.
3901 Class genericAdaptorBase;
3902 Class genericAdaptor;
3903 Class adaptor;
3904
3905 // The emitter containing all of the locally emitted verification functions.
3906 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
3907
3908 // Helper for emitting adaptor code.
3909 OpOrAdaptorHelper emitHelper;
3910};
3911} // namespace
3912
3913OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
3914 const Operator &op,
3915 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
3916 : op(op), genericAdaptorBase(op.getGenericAdaptorName() + "Base"),
3917 genericAdaptor(op.getGenericAdaptorName()), adaptor(op.getAdaptorName()),
3918 staticVerifierEmitter(staticVerifierEmitter),
3919 emitHelper(op, /*emitForOp=*/false) {
3920
3921 genericAdaptorBase.declare<VisibilityDeclaration>(args: Visibility::Public);
3922 bool useProperties = emitHelper.hasProperties();
3923 if (useProperties) {
3924 // Define the properties struct with multiple members.
3925 using ConstArgument =
3926 llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
3927 SmallVector<ConstArgument> attrOrProperties;
3928 for (const std::pair<StringRef, AttributeMetadata> &it :
3929 emitHelper.getAttrMetadata()) {
3930 if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
3931 attrOrProperties.push_back(Elt: &it.second);
3932 }
3933 for (const NamedProperty &prop : op.getProperties())
3934 attrOrProperties.push_back(Elt: &prop);
3935 if (emitHelper.getOperandSegmentsSize())
3936 attrOrProperties.push_back(Elt: &emitHelper.getOperandSegmentsSize().value());
3937 if (emitHelper.getResultSegmentsSize())
3938 attrOrProperties.push_back(Elt: &emitHelper.getResultSegmentsSize().value());
3939 assert(!attrOrProperties.empty());
3940 std::string declarations = " struct Properties {\n";
3941 llvm::raw_string_ostream os(declarations);
3942 std::string comparator =
3943 " bool operator==(const Properties &rhs) const {\n"
3944 " return \n";
3945 llvm::raw_string_ostream comparatorOs(comparator);
3946 for (const auto &attrOrProp : attrOrProperties) {
3947 if (const auto *namedProperty =
3948 llvm::dyn_cast_if_present<const NamedProperty *>(Val: attrOrProp)) {
3949 StringRef name = namedProperty->name;
3950 if (name.empty())
3951 report_fatal_error(reason: "missing name for property");
3952 std::string camelName =
3953 convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
3954 auto &prop = namedProperty->prop;
3955 // Generate the data member using the storage type.
3956 os << " using " << name << "Ty = " << prop.getStorageType() << ";\n"
3957 << " " << name << "Ty " << name;
3958 if (prop.hasDefaultValue())
3959 os << " = " << prop.getDefaultValue();
3960 comparatorOs << " rhs." << name << " == this->" << name
3961 << " &&\n";
3962 // Emit accessors using the interface type.
3963 const char *accessorFmt = R"decl(;
3964 {0} get{1}() {
3965 auto &propStorage = this->{2};
3966 return {3};
3967 }
3968 void set{1}(const {0} &propValue) {
3969 auto &propStorage = this->{2};
3970 {4};
3971 }
3972)decl";
3973 FmtContext fctx;
3974 os << formatv(Fmt: accessorFmt, Vals: prop.getInterfaceType(), Vals&: camelName, Vals&: name,
3975 Vals: tgfmt(fmt: prop.getConvertFromStorageCall(),
3976 ctx: &fctx.addSubst(placeholder: "_storage", subst: propertyStorage)),
3977 Vals: tgfmt(fmt: prop.getAssignToStorageCall(),
3978 ctx: &fctx.addSubst(placeholder: "_value", subst: propertyValue)
3979 .addSubst(placeholder: "_storage", subst: propertyStorage)));
3980 continue;
3981 }
3982 const auto *namedAttr =
3983 llvm::dyn_cast_if_present<const AttributeMetadata *>(Val: attrOrProp);
3984 const Attribute *attr = nullptr;
3985 if (namedAttr->constraint)
3986 attr = &*namedAttr->constraint;
3987 StringRef name = namedAttr->attrName;
3988 if (name.empty())
3989 report_fatal_error(reason: "missing name for property attr");
3990 std::string camelName =
3991 convertToCamelFromSnakeCase(input: name, /*capitalizeFirst=*/true);
3992 // Generate the data member using the storage type.
3993 StringRef storageType;
3994 if (attr) {
3995 storageType = attr->getStorageType();
3996 } else {
3997 if (name != operandSegmentAttrName && name != resultSegmentAttrName) {
3998 report_fatal_error(reason: "unexpected AttributeMetadata");
3999 }
4000 // TODO: update to use native integers.
4001 storageType = "::mlir::DenseI32ArrayAttr";
4002 }
4003 os << " using " << name << "Ty = " << storageType << ";\n"
4004 << " " << name << "Ty " << name << ";\n";
4005 comparatorOs << " rhs." << name << " == this->" << name << " &&\n";
4006
4007 // Emit accessors using the interface type.
4008 if (attr) {
4009 const char *accessorFmt = R"decl(
4010 auto get{0}() {
4011 auto &propStorage = this->{1};
4012 return ::llvm::{2}<{3}>(propStorage);
4013 }
4014 void set{0}(const {3} &propValue) {
4015 this->{1} = propValue;
4016 }
4017)decl";
4018 os << formatv(Fmt: accessorFmt, Vals&: camelName, Vals&: name,
4019 Vals: attr->isOptional() || attr->hasDefaultValue()
4020 ? "dyn_cast_or_null"
4021 : "cast",
4022 Vals&: storageType);
4023 }
4024 }
4025 comparatorOs << " true;\n }\n"
4026 " bool operator!=(const Properties &rhs) const {\n"
4027 " return !(*this == rhs);\n"
4028 " }\n";
4029 comparatorOs.flush();
4030 os << comparator;
4031 os << " };\n";
4032 os.flush();
4033
4034 genericAdaptorBase.declare<ExtraClassDeclaration>(args: std::move(declarations));
4035 }
4036 genericAdaptorBase.declare<VisibilityDeclaration>(args: Visibility::Protected);
4037 genericAdaptorBase.declare<Field>(args: "::mlir::DictionaryAttr", args: "odsAttrs");
4038 genericAdaptorBase.declare<Field>(args: "::std::optional<::mlir::OperationName>",
4039 args: "odsOpName");
4040 if (useProperties)
4041 genericAdaptorBase.declare<Field>(args: "Properties", args: "properties");
4042 genericAdaptorBase.declare<Field>(args: "::mlir::RegionRange", args: "odsRegions");
4043
4044 genericAdaptor.addTemplateParam(param: "RangeT");
4045 genericAdaptor.addField(type: "RangeT", name: "odsOperands");
4046 genericAdaptor.addParent(
4047 parent: ParentClass("detail::" + genericAdaptorBase.getClassName()));
4048 genericAdaptor.declare<UsingDeclaration>(
4049 args: "ValueT", args: "::llvm::detail::ValueOfRange<RangeT>");
4050 genericAdaptor.declare<UsingDeclaration>(
4051 args: "Base", args: "detail::" + genericAdaptorBase.getClassName());
4052
4053 const auto *attrSizedOperands =
4054 op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments");
4055 {
4056 SmallVector<MethodParameter> paramList;
4057 paramList.emplace_back(Args: "::mlir::DictionaryAttr", Args: "attrs",
4058 Args: attrSizedOperands ? "" : "nullptr");
4059 if (useProperties)
4060 paramList.emplace_back(Args: "const Properties &", Args: "properties", Args: "{}");
4061 else
4062 paramList.emplace_back(Args: "const ::mlir::EmptyProperties &", Args: "properties",
4063 Args: "{}");
4064 paramList.emplace_back(Args: "::mlir::RegionRange", Args: "regions", Args: "{}");
4065 auto *baseConstructor = genericAdaptorBase.addConstructor(args&: paramList);
4066 baseConstructor->addMemberInitializer(name: "odsAttrs", value: "attrs");
4067 if (useProperties)
4068 baseConstructor->addMemberInitializer(name: "properties", value: "properties");
4069 baseConstructor->addMemberInitializer(name: "odsRegions", value: "regions");
4070
4071 MethodBody &body = baseConstructor->body();
4072 body.indent() << "if (odsAttrs)\n";
4073 body.indent() << formatv(
4074 Fmt: "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
4075 Vals: op.getOperationName());
4076
4077 paramList.insert(I: paramList.begin(), Elt: MethodParameter("RangeT", "values"));
4078 auto *constructor = genericAdaptor.addConstructor(args&: paramList);
4079 constructor->addMemberInitializer(name: "Base", value: "attrs, properties, regions");
4080 constructor->addMemberInitializer(name: "odsOperands", value: "values");
4081
4082 // Add a forwarding constructor to the previous one that accepts
4083 // OpaqueProperties instead and check for null and perform the cast to the
4084 // actual properties type.
4085 paramList[1] = MethodParameter("::mlir::DictionaryAttr", "attrs");
4086 paramList[2] = MethodParameter("::mlir::OpaqueProperties", "properties");
4087 auto *opaquePropertiesConstructor =
4088 genericAdaptor.addConstructor(args: std::move(paramList));
4089 if (useProperties) {
4090 opaquePropertiesConstructor->addMemberInitializer(
4091 name: genericAdaptor.getClassName(),
4092 value: "values, "
4093 "attrs, "
4094 "(properties ? *properties.as<Properties *>() : Properties{}), "
4095 "regions");
4096 } else {
4097 opaquePropertiesConstructor->addMemberInitializer(
4098 name: genericAdaptor.getClassName(),
4099 value: "values, "
4100 "attrs, "
4101 "(properties ? *properties.as<::mlir::EmptyProperties *>() : "
4102 "::mlir::EmptyProperties{}), "
4103 "regions");
4104 }
4105 }
4106
4107 // Create constructors constructing the adaptor from an instance of the op.
4108 // This takes the attributes, properties and regions from the op instance
4109 // and the value range from the parameter.
4110 {
4111 // Base class is in the cpp file and can simply access the members of the op
4112 // class to initialize the template independent fields.
4113 auto *constructor = genericAdaptorBase.addConstructor(
4114 args: MethodParameter(op.getCppClassName(), "op"));
4115 constructor->addMemberInitializer(
4116 name: genericAdaptorBase.getClassName(),
4117 value: llvm::Twine(!useProperties ? "op->getAttrDictionary()"
4118 : "op->getDiscardableAttrDictionary()") +
4119 ", op.getProperties(), op->getRegions()");
4120
4121 // Generic adaptor is templated and therefore defined inline in the header.
4122 // We cannot use the Op class here as it is an incomplete type (we have a
4123 // circular reference between the two).
4124 // Use a template trick to make the constructor be instantiated at call site
4125 // when the op class is complete.
4126 constructor = genericAdaptor.addConstructor(
4127 args: MethodParameter("RangeT", "values"), args: MethodParameter("LateInst", "op"));
4128 constructor->addTemplateParam(param: "LateInst = " + op.getCppClassName());
4129 constructor->addTemplateParam(
4130 param: "= std::enable_if_t<std::is_same_v<LateInst, " + op.getCppClassName() +
4131 ">>");
4132 constructor->addMemberInitializer(name: "Base", value: "op");
4133 constructor->addMemberInitializer(name: "odsOperands", value: "values");
4134 }
4135
4136 std::string sizeAttrInit;
4137 if (op.getTrait(trait: "::mlir::OpTrait::AttrSizedOperandSegments")) {
4138 if (op.getDialect().usePropertiesForAttributes())
4139 sizeAttrInit =
4140 formatv(Fmt: adapterSegmentSizeAttrInitCodeProperties,
4141 Vals: llvm::formatv(Fmt: "getProperties().operandSegmentSizes"));
4142 else
4143 sizeAttrInit = formatv(Fmt: adapterSegmentSizeAttrInitCode,
4144 Vals: emitHelper.getAttr(attrName: operandSegmentAttrName));
4145 }
4146 generateNamedOperandGetters(op, opClass&: genericAdaptor,
4147 /*genericAdaptorBase=*/&genericAdaptorBase,
4148 /*sizeAttrInit=*/sizeAttrInit,
4149 /*rangeType=*/"RangeT",
4150 /*rangeElementType=*/"ValueT",
4151 /*rangeBeginCall=*/"odsOperands.begin()",
4152 /*rangeSizeCall=*/"odsOperands.size()",
4153 /*getOperandCallPattern=*/"odsOperands[{0}]");
4154
4155 // Any invalid overlap for `getOperands` will have been diagnosed before
4156 // here already.
4157 if (auto *m = genericAdaptor.addMethod(retType: "RangeT", name: "getOperands"))
4158 m->body() << " return odsOperands;";
4159
4160 FmtContext fctx;
4161 fctx.withBuilder(subst: "::mlir::Builder(odsAttrs.getContext())");
4162
4163 // Generate named accessor with Attribute return type.
4164 auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName,
4165 Attribute attr) {
4166 // The method body is trivial if the attribute does not have a default
4167 // value, in which case the default value may be arbitrary code.
4168 auto *method = genericAdaptorBase.addMethod(
4169 retType: attr.getStorageType(), name: emitName + "Attr",
4170 properties: attr.hasDefaultValue() || !useProperties ? Method::Properties::None
4171 : Method::Properties::Inline);
4172 ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
4173 auto &body = method->body().indent();
4174 if (!useProperties)
4175 body << "assert(odsAttrs && \"no attributes when constructing "
4176 "adapter\");\n";
4177 body << formatv(
4178 Fmt: "auto attr = ::llvm::{1}<{2}>({0});\n", Vals: emitHelper.getAttr(attrName: name),
4179 Vals: attr.hasDefaultValue() || attr.isOptional() ? "dyn_cast_or_null"
4180 : "cast",
4181 Vals: attr.getStorageType());
4182
4183 if (attr.hasDefaultValue() && attr.isOptional()) {
4184 // Use the default value if attribute is not set.
4185 // TODO: this is inefficient, we are recreating the attribute for every
4186 // call. This should be set instead.
4187 std::string defaultValue = std::string(
4188 tgfmt(fmt: attr.getConstBuilderTemplate(), ctx: &fctx, vals: attr.getDefaultValue()));
4189 body << "if (!attr)\n attr = " << defaultValue << ";\n";
4190 }
4191 body << "return attr;\n";
4192 };
4193
4194 if (useProperties) {
4195 auto *m = genericAdaptorBase.addInlineMethod(retType: "const Properties &",
4196 name: "getProperties");
4197 ERROR_IF_PRUNED(m, "Adaptor::getProperties", op);
4198 m->body() << " return properties;";
4199 }
4200 {
4201 auto *m = genericAdaptorBase.addInlineMethod(retType: "::mlir::DictionaryAttr",
4202 name: "getAttributes");
4203 ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
4204 m->body() << " return odsAttrs;";
4205 }
4206 for (auto &namedAttr : op.getAttributes()) {
4207 const auto &name = namedAttr.name;
4208 const auto &attr = namedAttr.attr;
4209 if (attr.isDerivedAttr())
4210 continue;
4211 std::string emitName = op.getGetterName(name);
4212 emitAttrWithStorageType(name, emitName, attr);
4213 emitAttrGetterWithReturnType(fctx, opClass&: genericAdaptorBase, op, name: emitName, attr);
4214 }
4215
4216 unsigned numRegions = op.getNumRegions();
4217 for (unsigned i = 0; i < numRegions; ++i) {
4218 const auto &region = op.getRegion(index: i);
4219 if (region.name.empty())
4220 continue;
4221
4222 // Generate the accessors for a variadic region.
4223 std::string name = op.getGetterName(name: region.name);
4224 if (region.isVariadic()) {
4225 auto *m = genericAdaptorBase.addInlineMethod(retType: "::mlir::RegionRange", name);
4226 ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
4227 m->body() << formatv(Fmt: " return odsRegions.drop_front({0});", Vals&: i);
4228 continue;
4229 }
4230
4231 auto *m = genericAdaptorBase.addInlineMethod(retType: "::mlir::Region &", name);
4232 ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
4233 m->body() << formatv(Fmt: " return *odsRegions[{0}];", Vals&: i);
4234 }
4235 if (numRegions > 0) {
4236 // Any invalid overlap for `getRegions` will have been diagnosed before
4237 // here already.
4238 if (auto *m = genericAdaptorBase.addInlineMethod(retType: "::mlir::RegionRange",
4239 name: "getRegions"))
4240 m->body() << " return odsRegions;";
4241 }
4242
4243 StringRef genericAdaptorClassName = genericAdaptor.getClassName();
4244 adaptor.addParent(parent: ParentClass(genericAdaptorClassName))
4245 .addTemplateParam(param: "::mlir::ValueRange");
4246 adaptor.declare<VisibilityDeclaration>(args: Visibility::Public);
4247 adaptor.declare<UsingDeclaration>(args: genericAdaptorClassName +
4248 "::" + genericAdaptorClassName);
4249 {
4250 // Constructor taking the Op as single parameter.
4251 auto *constructor =
4252 adaptor.addConstructor(args: MethodParameter(op.getCppClassName(), "op"));
4253 constructor->addMemberInitializer(name&: genericAdaptorClassName,
4254 value: "op->getOperands(), op");
4255 }
4256
4257 // Add verification function.
4258 addVerification();
4259
4260 genericAdaptorBase.finalize();
4261 genericAdaptor.finalize();
4262 adaptor.finalize();
4263}
4264
4265void OpOperandAdaptorEmitter::addVerification() {
4266 auto *method = adaptor.addMethod(retType: "::mlir::LogicalResult", name: "verify",
4267 args: MethodParameter("::mlir::Location", "loc"));
4268 ERROR_IF_PRUNED(method, "verify", op);
4269 auto &body = method->body();
4270 bool useProperties = emitHelper.hasProperties();
4271
4272 FmtContext verifyCtx;
4273 populateSubstitutions(emitHelper, ctx&: verifyCtx);
4274 genAttributeVerifier(emitHelper, ctx&: verifyCtx, body, staticVerifierEmitter,
4275 useProperties);
4276
4277 body << " return ::mlir::success();";
4278}
4279
4280void OpOperandAdaptorEmitter::emitDecl(
4281 const Operator &op,
4282 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4283 raw_ostream &os) {
4284 OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
4285 {
4286 NamespaceEmitter ns(os, "detail");
4287 emitter.genericAdaptorBase.writeDeclTo(rawOs&: os);
4288 }
4289 emitter.genericAdaptor.writeDeclTo(rawOs&: os);
4290 emitter.adaptor.writeDeclTo(rawOs&: os);
4291}
4292
4293void OpOperandAdaptorEmitter::emitDef(
4294 const Operator &op,
4295 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4296 raw_ostream &os) {
4297 OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
4298 {
4299 NamespaceEmitter ns(os, "detail");
4300 emitter.genericAdaptorBase.writeDefTo(rawOs&: os);
4301 }
4302 emitter.genericAdaptor.writeDefTo(rawOs&: os);
4303 emitter.adaptor.writeDefTo(rawOs&: os);
4304}
4305
4306// Emits the opcode enum and op classes.
4307static void emitOpClasses(const RecordKeeper &recordKeeper,
4308 const std::vector<Record *> &defs, raw_ostream &os,
4309 bool emitDecl) {
4310 // First emit forward declaration for each class, this allows them to refer
4311 // to each others in traits for example.
4312 if (emitDecl) {
4313 os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
4314 os << "#undef GET_OP_FWD_DEFINES\n";
4315 for (auto *def : defs) {
4316 Operator op(*def);
4317 NamespaceEmitter emitter(os, op.getCppNamespace());
4318 os << "class " << op.getCppClassName() << ";\n";
4319 }
4320 os << "#endif\n\n";
4321 }
4322
4323 IfDefScope scope("GET_OP_CLASSES", os);
4324 if (defs.empty())
4325 return;
4326
4327 // Generate all of the locally instantiated methods first.
4328 StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
4329 os << formatv(Fmt: opCommentHeader, Vals: "Local Utility Method", Vals: "Definitions");
4330 staticVerifierEmitter.emitOpConstraints(opDefs: defs, emitDecl);
4331
4332 for (auto *def : defs) {
4333 Operator op(*def);
4334 if (emitDecl) {
4335 {
4336 NamespaceEmitter emitter(os, op.getCppNamespace());
4337 os << formatv(Fmt: opCommentHeader, Vals: op.getQualCppClassName(),
4338 Vals: "declarations");
4339 OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os);
4340 OpEmitter::emitDecl(op, os, staticVerifierEmitter);
4341 }
4342 // Emit the TypeID explicit specialization to have a single definition.
4343 if (!op.getCppNamespace().empty())
4344 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
4345 << "::" << op.getCppClassName() << ")\n\n";
4346 } else {
4347 {
4348 NamespaceEmitter emitter(os, op.getCppNamespace());
4349 os << formatv(Fmt: opCommentHeader, Vals: op.getQualCppClassName(), Vals: "definitions");
4350 OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os);
4351 OpEmitter::emitDef(op, os, staticVerifierEmitter);
4352 }
4353 // Emit the TypeID explicit specialization to have a single definition.
4354 if (!op.getCppNamespace().empty())
4355 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
4356 << "::" << op.getCppClassName() << ")\n\n";
4357 }
4358 }
4359}
4360
4361// Emits a comma-separated list of the ops.
4362static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
4363 IfDefScope scope("GET_OP_LIST", os);
4364
4365 interleave(
4366 // TODO: We are constructing the Operator wrapper instance just for
4367 // getting it's qualified class name here. Reduce the overhead by having a
4368 // lightweight version of Operator class just for that purpose.
4369 c: defs, each_fn: [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
4370 between_fn: [&os]() { os << ",\n"; });
4371}
4372
4373static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
4374 emitSourceFileHeader(Desc: "Op Declarations", OS&: os, Record: recordKeeper);
4375
4376 std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
4377 emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
4378
4379 return false;
4380}
4381
4382static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
4383 emitSourceFileHeader(Desc: "Op Definitions", OS&: os, Record: recordKeeper);
4384
4385 std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
4386 emitOpList(defs, os);
4387 emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
4388
4389 return false;
4390}
4391
4392static mlir::GenRegistration
4393 genOpDecls("gen-op-decls", "Generate op declarations",
4394 [](const RecordKeeper &records, raw_ostream &os) {
4395 return emitOpDecls(recordKeeper: records, os);
4396 });
4397
4398static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
4399 [](const RecordKeeper &records,
4400 raw_ostream &os) {
4401 return emitOpDefs(recordKeeper: records, os);
4402 });
4403

source code of mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp