| 1 | //===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // OpDefinitionsGen uses the description of operations to generate C++ |
| 10 | // definitions for ops. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/TableGen/CodeGenHelpers.h" |
| 15 | #include "mlir/TableGen/Operator.h" |
| 16 | #include "mlir/TableGen/Pattern.h" |
| 17 | #include "llvm/Support/FormatVariadic.h" |
| 18 | #include "llvm/Support/Path.h" |
| 19 | #include "llvm/TableGen/Record.h" |
| 20 | |
| 21 | using namespace llvm; |
| 22 | using namespace mlir; |
| 23 | using namespace mlir::tblgen; |
| 24 | |
| 25 | /// Generate a unique label based on the current file name to prevent name |
| 26 | /// collisions if multiple generated files are included at once. |
| 27 | static std::string getUniqueOutputLabel(const RecordKeeper &records, |
| 28 | StringRef tag) { |
| 29 | // Use the input file name when generating a unique name. |
| 30 | StringRef inputFilename = records.getInputFilename(); |
| 31 | |
| 32 | // Drop all but the base filename. |
| 33 | StringRef nameRef = sys::path::filename(path: inputFilename); |
| 34 | nameRef.consume_back(Suffix: ".td" ); |
| 35 | |
| 36 | // Sanitize any invalid characters. |
| 37 | std::string uniqueName(tag); |
| 38 | for (char c : nameRef) { |
| 39 | if (isAlnum(C: c) || c == '_') |
| 40 | uniqueName.push_back(c: c); |
| 41 | else |
| 42 | uniqueName.append(str: utohexstr(X: (unsigned char)c)); |
| 43 | } |
| 44 | return uniqueName; |
| 45 | } |
| 46 | |
| 47 | StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( |
| 48 | raw_ostream &os, const RecordKeeper &records, StringRef tag) |
| 49 | : os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {} |
| 50 | |
| 51 | void StaticVerifierFunctionEmitter::emitOpConstraints( |
| 52 | ArrayRef<const Record *> opDefs) { |
| 53 | NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace()); |
| 54 | emitTypeConstraints(); |
| 55 | emitAttrConstraints(); |
| 56 | emitPropConstraints(); |
| 57 | emitSuccessorConstraints(); |
| 58 | emitRegionConstraints(); |
| 59 | } |
| 60 | |
| 61 | void StaticVerifierFunctionEmitter::emitPatternConstraints( |
| 62 | const ArrayRef<DagLeaf> constraints) { |
| 63 | collectPatternConstraints(constraints); |
| 64 | emitPatternConstraints(); |
| 65 | } |
| 66 | |
| 67 | //===----------------------------------------------------------------------===// |
| 68 | // Constraint Getters |
| 69 | //===----------------------------------------------------------------------===// |
| 70 | |
| 71 | StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn( |
| 72 | const Constraint &constraint) const { |
| 73 | const auto *it = typeConstraints.find(Key: constraint); |
| 74 | assert(it != typeConstraints.end() && "expected to find a type constraint" ); |
| 75 | return it->second; |
| 76 | } |
| 77 | |
| 78 | // Find a uniqued attribute constraint. Since not all attribute constraints can |
| 79 | // be uniqued, return std::nullopt if one was not found. |
| 80 | std::optional<StringRef> StaticVerifierFunctionEmitter::getAttrConstraintFn( |
| 81 | const Constraint &constraint) const { |
| 82 | const auto *it = attrConstraints.find(Key: constraint); |
| 83 | return it == attrConstraints.end() ? std::optional<StringRef>() |
| 84 | : StringRef(it->second); |
| 85 | } |
| 86 | |
| 87 | // Find a uniqued property constraint. Since not all property constraints can |
| 88 | // be uniqued, return std::nullopt if one was not found. |
| 89 | std::optional<StringRef> StaticVerifierFunctionEmitter::getPropConstraintFn( |
| 90 | const Constraint &constraint) const { |
| 91 | const auto *it = propConstraints.find(Key: constraint); |
| 92 | return it == propConstraints.end() ? std::optional<StringRef>() |
| 93 | : StringRef(it->second); |
| 94 | } |
| 95 | |
| 96 | StringRef StaticVerifierFunctionEmitter::getSuccessorConstraintFn( |
| 97 | const Constraint &constraint) const { |
| 98 | const auto *it = successorConstraints.find(Key: constraint); |
| 99 | assert(it != successorConstraints.end() && |
| 100 | "expected to find a sucessor constraint" ); |
| 101 | return it->second; |
| 102 | } |
| 103 | |
| 104 | StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn( |
| 105 | const Constraint &constraint) const { |
| 106 | const auto *it = regionConstraints.find(Key: constraint); |
| 107 | assert(it != regionConstraints.end() && |
| 108 | "expected to find a region constraint" ); |
| 109 | return it->second; |
| 110 | } |
| 111 | |
| 112 | //===----------------------------------------------------------------------===// |
| 113 | // Constraint Emission |
| 114 | //===----------------------------------------------------------------------===// |
| 115 | |
| 116 | /// Code templates for emitting type, attribute, successor, and region |
| 117 | /// constraints. Each of these templates require the following arguments: |
| 118 | /// |
| 119 | /// {0}: The unique constraint name. |
| 120 | /// {1}: The constraint code. |
| 121 | /// {2}: The constraint description. |
| 122 | |
| 123 | /// Code for a type constraint. These may be called on the type of either |
| 124 | /// operands or results. |
| 125 | static const char *const typeConstraintCode = R"( |
| 126 | static ::llvm::LogicalResult {0}( |
| 127 | ::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef valueKind, |
| 128 | unsigned valueIndex) { |
| 129 | if (!({1})) { |
| 130 | return op->emitOpError(valueKind) << " #" << valueIndex |
| 131 | << " must be {2}, but got " << type; |
| 132 | } |
| 133 | return ::mlir::success(); |
| 134 | } |
| 135 | )" ; |
| 136 | |
| 137 | /// Code for an attribute constraint. These may be called from ops only. |
| 138 | /// Attribute constraints cannot reference anything other than `$_self` and |
| 139 | /// `$_op`. |
| 140 | /// |
| 141 | /// TODO: Unique constraints for adaptors. However, most Adaptor::verify |
| 142 | /// functions are stripped anyways. |
| 143 | static const char *const attrConstraintCode = R"( |
| 144 | static ::llvm::LogicalResult {0}( |
| 145 | ::mlir::Attribute attr, ::llvm::StringRef attrName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{ |
| 146 | if (attr && !({1})) |
| 147 | return emitError() << "attribute '" << attrName |
| 148 | << "' failed to satisfy constraint: {2}"; |
| 149 | return ::mlir::success(); |
| 150 | } |
| 151 | static ::llvm::LogicalResult {0}( |
| 152 | ::mlir::Operation *op, ::mlir::Attribute attr, ::llvm::StringRef attrName) {{ |
| 153 | return {0}(attr, attrName, [op]() {{ |
| 154 | return op->emitOpError(); |
| 155 | }); |
| 156 | } |
| 157 | )" ; |
| 158 | |
| 159 | /// Code for a property constraint. These may be called from ops only. |
| 160 | /// Property constraints cannot reference anything other than `$_self` and |
| 161 | /// `$_op`. {3} is the interface type of the property. |
| 162 | static const char *const propConstraintCode = R"( |
| 163 | static ::llvm::LogicalResult {0}( |
| 164 | {3} prop, ::llvm::StringRef propName, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{ |
| 165 | if (!({1})) |
| 166 | return emitError() << "property '" << propName |
| 167 | << "' failed to satisfy constraint: {2}"; |
| 168 | return ::mlir::success(); |
| 169 | } |
| 170 | static ::llvm::LogicalResult {0}( |
| 171 | ::mlir::Operation *op, {3} prop, ::llvm::StringRef propName) {{ |
| 172 | return {0}(prop, propName, [op]() {{ |
| 173 | return op->emitOpError(); |
| 174 | }); |
| 175 | } |
| 176 | )" ; |
| 177 | |
| 178 | /// Code for a successor constraint. |
| 179 | static const char *const successorConstraintCode = R"( |
| 180 | static ::llvm::LogicalResult {0}( |
| 181 | ::mlir::Operation *op, ::mlir::Block *successor, |
| 182 | ::llvm::StringRef successorName, unsigned successorIndex) { |
| 183 | if (!({1})) { |
| 184 | return op->emitOpError("successor #") << successorIndex << " ('" |
| 185 | << successorName << ")' failed to verify constraint: {2}"; |
| 186 | } |
| 187 | return ::mlir::success(); |
| 188 | } |
| 189 | )" ; |
| 190 | |
| 191 | /// Code for a region constraint. Callers will need to pass in the region's name |
| 192 | /// for emitting an error message. |
| 193 | static const char *const regionConstraintCode = R"( |
| 194 | static ::llvm::LogicalResult {0}( |
| 195 | ::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, |
| 196 | unsigned regionIndex) { |
| 197 | if (!({1})) { |
| 198 | return op->emitOpError("region #") << regionIndex |
| 199 | << (regionName.empty() ? " " : " ('" + regionName + "') ") |
| 200 | << "failed to verify constraint: {2}"; |
| 201 | } |
| 202 | return ::mlir::success(); |
| 203 | } |
| 204 | )" ; |
| 205 | |
| 206 | /// Code for a pattern type or attribute constraint. |
| 207 | /// |
| 208 | /// {3}: "Type type" or "Attribute attr". |
| 209 | static const char *const patternAttrOrTypeConstraintCode = R"( |
| 210 | static ::llvm::LogicalResult {0}( |
| 211 | ::mlir::PatternRewriter &rewriter, ::mlir::Operation *op, ::mlir::{3}, |
| 212 | ::llvm::StringRef failureStr) { |
| 213 | if (!({1})) { |
| 214 | return rewriter.notifyMatchFailure(op, [&](::mlir::Diagnostic &diag) { |
| 215 | diag << failureStr << ": {2}"; |
| 216 | }); |
| 217 | } |
| 218 | return ::mlir::success(); |
| 219 | } |
| 220 | )" ; |
| 221 | |
| 222 | void StaticVerifierFunctionEmitter::emitConstraints( |
| 223 | const ConstraintMap &constraints, StringRef selfName, |
| 224 | const char *const codeTemplate) { |
| 225 | FmtContext ctx; |
| 226 | ctx.addSubst(placeholder: "_op" , subst: "*op" ).withSelf(subst: selfName); |
| 227 | for (auto &it : constraints) { |
| 228 | os << formatv(Fmt: codeTemplate, Vals: it.second, |
| 229 | Vals: tgfmt(fmt: it.first.getConditionTemplate(), ctx: &ctx), |
| 230 | Vals: escapeString(value: it.first.getSummary())); |
| 231 | } |
| 232 | } |
| 233 | |
| 234 | void StaticVerifierFunctionEmitter::emitTypeConstraints() { |
| 235 | emitConstraints(constraints: typeConstraints, selfName: "type" , codeTemplate: typeConstraintCode); |
| 236 | } |
| 237 | |
| 238 | void StaticVerifierFunctionEmitter::emitAttrConstraints() { |
| 239 | emitConstraints(constraints: attrConstraints, selfName: "attr" , codeTemplate: attrConstraintCode); |
| 240 | } |
| 241 | |
| 242 | /// Unlike with the other helpers, this one has to substitute in the interface |
| 243 | /// type of the property, so we can't just use the generic function. |
| 244 | void StaticVerifierFunctionEmitter::emitPropConstraints() { |
| 245 | FmtContext ctx; |
| 246 | ctx.addSubst(placeholder: "_op" , subst: "*op" ).withSelf(subst: "prop" ); |
| 247 | for (auto &it : propConstraints) { |
| 248 | auto propConstraint = cast<PropConstraint>(Val&: it.first); |
| 249 | os << formatv(Fmt: propConstraintCode, Vals&: it.second, |
| 250 | Vals: tgfmt(fmt: propConstraint.getConditionTemplate(), ctx: &ctx), |
| 251 | Vals: escapeString(value: it.first.getSummary()), |
| 252 | Vals: propConstraint.getInterfaceType()); |
| 253 | } |
| 254 | } |
| 255 | |
| 256 | void StaticVerifierFunctionEmitter::emitSuccessorConstraints() { |
| 257 | emitConstraints(constraints: successorConstraints, selfName: "successor" , codeTemplate: successorConstraintCode); |
| 258 | } |
| 259 | |
| 260 | void StaticVerifierFunctionEmitter::emitRegionConstraints() { |
| 261 | emitConstraints(constraints: regionConstraints, selfName: "region" , codeTemplate: regionConstraintCode); |
| 262 | } |
| 263 | |
| 264 | void StaticVerifierFunctionEmitter::emitPatternConstraints() { |
| 265 | FmtContext ctx; |
| 266 | ctx.addSubst(placeholder: "_op" , subst: "*op" ).withBuilder(subst: "rewriter" ).withSelf(subst: "type" ); |
| 267 | for (auto &it : typeConstraints) { |
| 268 | os << formatv(Fmt: patternAttrOrTypeConstraintCode, Vals&: it.second, |
| 269 | Vals: tgfmt(fmt: it.first.getConditionTemplate(), ctx: &ctx), |
| 270 | Vals: escapeString(value: it.first.getSummary()), Vals: "Type type" ); |
| 271 | } |
| 272 | ctx.withSelf(subst: "attr" ); |
| 273 | for (auto &it : attrConstraints) { |
| 274 | os << formatv(Fmt: patternAttrOrTypeConstraintCode, Vals&: it.second, |
| 275 | Vals: tgfmt(fmt: it.first.getConditionTemplate(), ctx: &ctx), |
| 276 | Vals: escapeString(value: it.first.getSummary()), Vals: "Attribute attr" ); |
| 277 | } |
| 278 | } |
| 279 | |
| 280 | //===----------------------------------------------------------------------===// |
| 281 | // Constraint Uniquing |
| 282 | //===----------------------------------------------------------------------===// |
| 283 | |
| 284 | /// An attribute constraint that references anything other than itself and the |
| 285 | /// current op cannot be generically extracted into a function. Most |
| 286 | /// prohibitive are operands and results, which require calls to |
| 287 | /// `getODSOperands` or `getODSResults`. Attribute references are tricky too |
| 288 | /// because ops use cached identifiers. |
| 289 | static bool canUniqueAttrConstraint(Attribute attr) { |
| 290 | FmtContext ctx; |
| 291 | auto test = tgfmt(fmt: attr.getConditionTemplate(), |
| 292 | ctx: &ctx.withSelf(subst: "attr" ).addSubst(placeholder: "_op" , subst: "*op" )) |
| 293 | .str(); |
| 294 | return !StringRef(test).contains(Other: "<no-subst-found>" ); |
| 295 | } |
| 296 | |
| 297 | /// A property constraint that references anything other than itself and the |
| 298 | /// current op cannot be generically extracted into a function, just as with |
| 299 | /// canUnequePropConstraint(). Additionally, property constraints without |
| 300 | /// an interface type specified can't be uniqued, and ones that are a literal |
| 301 | /// "true" shouldn't be constrained. |
| 302 | static bool canUniquePropConstraint(Property prop) { |
| 303 | FmtContext ctx; |
| 304 | auto test = tgfmt(fmt: prop.getConditionTemplate(), |
| 305 | ctx: &ctx.withSelf(subst: "prop" ).addSubst(placeholder: "_op" , subst: "*op" )) |
| 306 | .str(); |
| 307 | return !StringRef(test).contains(Other: "<no-subst-found>" ) && test != "true" && |
| 308 | !prop.getInterfaceType().empty(); |
| 309 | } |
| 310 | |
| 311 | std::string StaticVerifierFunctionEmitter::getUniqueName(StringRef kind, |
| 312 | unsigned index) { |
| 313 | return ("__mlir_ods_local_" + kind + "_constraint_" + uniqueOutputLabel + |
| 314 | Twine(index)) |
| 315 | .str(); |
| 316 | } |
| 317 | |
| 318 | void StaticVerifierFunctionEmitter::collectConstraint(ConstraintMap &map, |
| 319 | StringRef kind, |
| 320 | Constraint constraint) { |
| 321 | auto [it, inserted] = map.try_emplace(Key: constraint); |
| 322 | if (inserted) |
| 323 | it->second = getUniqueName(kind, index: map.size()); |
| 324 | } |
| 325 | |
| 326 | void StaticVerifierFunctionEmitter::collectOpConstraints( |
| 327 | ArrayRef<const Record *> opDefs) { |
| 328 | const auto collectTypeConstraints = [&](Operator::const_value_range values) { |
| 329 | for (const NamedTypeConstraint &value : values) |
| 330 | if (value.hasPredicate()) |
| 331 | collectConstraint(map&: typeConstraints, kind: "type" , constraint: value.constraint); |
| 332 | }; |
| 333 | |
| 334 | for (const Record *def : opDefs) { |
| 335 | Operator op(*def); |
| 336 | /// Collect type constraints. |
| 337 | collectTypeConstraints(op.getOperands()); |
| 338 | collectTypeConstraints(op.getResults()); |
| 339 | /// Collect attribute constraints. |
| 340 | for (const NamedAttribute &namedAttr : op.getAttributes()) { |
| 341 | if (!namedAttr.attr.getPredicate().isNull() && |
| 342 | !namedAttr.attr.isDerivedAttr() && |
| 343 | canUniqueAttrConstraint(attr: namedAttr.attr)) |
| 344 | collectConstraint(map&: attrConstraints, kind: "attr" , constraint: namedAttr.attr); |
| 345 | } |
| 346 | /// Collect non-trivial property constraints. |
| 347 | for (const NamedProperty &namedProp : op.getProperties()) { |
| 348 | if (!namedProp.prop.getPredicate().isNull() && |
| 349 | canUniquePropConstraint(prop: namedProp.prop)) { |
| 350 | collectConstraint(map&: propConstraints, kind: "prop" , constraint: namedProp.prop); |
| 351 | } |
| 352 | } |
| 353 | /// Collect successor constraints. |
| 354 | for (const NamedSuccessor &successor : op.getSuccessors()) { |
| 355 | if (!successor.constraint.getPredicate().isNull()) { |
| 356 | collectConstraint(map&: successorConstraints, kind: "successor" , |
| 357 | constraint: successor.constraint); |
| 358 | } |
| 359 | } |
| 360 | /// Collect region constraints. |
| 361 | for (const NamedRegion ®ion : op.getRegions()) |
| 362 | if (!region.constraint.getPredicate().isNull()) |
| 363 | collectConstraint(map&: regionConstraints, kind: "region" , constraint: region.constraint); |
| 364 | } |
| 365 | } |
| 366 | |
| 367 | void StaticVerifierFunctionEmitter::collectPatternConstraints( |
| 368 | const ArrayRef<DagLeaf> constraints) { |
| 369 | for (auto &leaf : constraints) { |
| 370 | assert(leaf.isOperandMatcher() || leaf.isAttrMatcher()); |
| 371 | collectConstraint( |
| 372 | map&: leaf.isOperandMatcher() ? typeConstraints : attrConstraints, |
| 373 | kind: leaf.isOperandMatcher() ? "type" : "attr" , constraint: leaf.getAsConstraint()); |
| 374 | } |
| 375 | } |
| 376 | |
| 377 | //===----------------------------------------------------------------------===// |
| 378 | // Public Utility Functions |
| 379 | //===----------------------------------------------------------------------===// |
| 380 | |
| 381 | std::string mlir::tblgen::escapeString(StringRef value) { |
| 382 | std::string ret; |
| 383 | raw_string_ostream os(ret); |
| 384 | os.write_escaped(Str: value); |
| 385 | return ret; |
| 386 | } |
| 387 | |