| 1 | //===- MatchInterfaces.cpp - Transform Dialect Interfaces -----------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" |
| 10 | |
| 11 | #include "llvm/Support/InterleavedRange.h" |
| 12 | |
| 13 | using namespace mlir; |
| 14 | |
| 15 | //===----------------------------------------------------------------------===// |
| 16 | // Printing and parsing for match ops. |
| 17 | //===----------------------------------------------------------------------===// |
| 18 | |
| 19 | /// Keyword syntax for positional specification inversion. |
| 20 | constexpr const static llvm::StringLiteral kDimExceptKeyword = "except" ; |
| 21 | |
| 22 | /// Keyword syntax for full inclusion in positional specification. |
| 23 | constexpr const static llvm::StringLiteral kDimAllKeyword = "all" ; |
| 24 | |
| 25 | ParseResult transform::parseTransformMatchDims(OpAsmParser &parser, |
| 26 | DenseI64ArrayAttr &rawDimList, |
| 27 | UnitAttr &isInverted, |
| 28 | UnitAttr &isAll) { |
| 29 | Builder &builder = parser.getBuilder(); |
| 30 | if (parser.parseOptionalKeyword(keyword: kDimAllKeyword).succeeded()) { |
| 31 | rawDimList = builder.getDenseI64ArrayAttr({}); |
| 32 | isInverted = nullptr; |
| 33 | isAll = builder.getUnitAttr(); |
| 34 | return success(); |
| 35 | } |
| 36 | |
| 37 | isAll = nullptr; |
| 38 | isInverted = nullptr; |
| 39 | if (parser.parseOptionalKeyword(keyword: kDimExceptKeyword).succeeded()) { |
| 40 | isInverted = builder.getUnitAttr(); |
| 41 | } |
| 42 | |
| 43 | if (isInverted) { |
| 44 | if (parser.parseLParen().failed()) |
| 45 | return failure(); |
| 46 | } |
| 47 | |
| 48 | SmallVector<int64_t> values; |
| 49 | ParseResult listResult = parser.parseCommaSeparatedList( |
| 50 | parseElementFn: [&]() { return parser.parseInteger(result&: values.emplace_back()); }); |
| 51 | if (listResult.failed()) |
| 52 | return failure(); |
| 53 | |
| 54 | rawDimList = builder.getDenseI64ArrayAttr(values); |
| 55 | |
| 56 | if (isInverted) { |
| 57 | if (parser.parseRParen().failed()) |
| 58 | return failure(); |
| 59 | } |
| 60 | return success(); |
| 61 | } |
| 62 | |
| 63 | void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op, |
| 64 | DenseI64ArrayAttr rawDimList, |
| 65 | UnitAttr isInverted, UnitAttr isAll) { |
| 66 | if (isAll) { |
| 67 | printer << kDimAllKeyword; |
| 68 | return; |
| 69 | } |
| 70 | if (isInverted) { |
| 71 | printer << kDimExceptKeyword << "(" ; |
| 72 | } |
| 73 | printer << llvm::interleaved(rawDimList.asArrayRef()); |
| 74 | if (isInverted) { |
| 75 | printer << ")" ; |
| 76 | } |
| 77 | } |
| 78 | |
| 79 | LogicalResult transform::verifyTransformMatchDimsOp(Operation *op, |
| 80 | ArrayRef<int64_t> raw, |
| 81 | bool inverted, bool all) { |
| 82 | if (all) { |
| 83 | if (inverted) { |
| 84 | return op->emitOpError() |
| 85 | << "cannot request both 'all' and 'inverted' values in the list" ; |
| 86 | } |
| 87 | if (!raw.empty()) { |
| 88 | return op->emitOpError() |
| 89 | << "cannot both request 'all' and specific values in the list" ; |
| 90 | } |
| 91 | } |
| 92 | if (!all && raw.empty()) { |
| 93 | return op->emitOpError() << "must request specific values in the list if " |
| 94 | "'all' is not specified" ; |
| 95 | } |
| 96 | SmallVector<int64_t> rawVector = llvm::to_vector(Range&: raw); |
| 97 | auto *it = llvm::unique(R&: rawVector); |
| 98 | if (it != rawVector.end()) |
| 99 | return op->emitOpError() << "expected the listed values to be unique" ; |
| 100 | |
| 101 | return success(); |
| 102 | } |
| 103 | |
| 104 | DiagnosedSilenceableFailure transform::expandTargetSpecification( |
| 105 | Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList, |
| 106 | int64_t maxNumber, SmallVectorImpl<int64_t> &result) { |
| 107 | assert(maxNumber > 0 && "expected size to be positive" ); |
| 108 | assert(!(isAll && isInverted) && "cannot invert all" ); |
| 109 | if (isAll) { |
| 110 | result = llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: maxNumber)); |
| 111 | return DiagnosedSilenceableFailure::success(); |
| 112 | } |
| 113 | |
| 114 | SmallVector<int64_t> expanded; |
| 115 | llvm::SmallDenseSet<int64_t> visited; |
| 116 | expanded.reserve(N: rawList.size()); |
| 117 | SmallVectorImpl<int64_t> &target = isInverted ? expanded : result; |
| 118 | for (int64_t raw : rawList) { |
| 119 | int64_t updated = raw < 0 ? maxNumber + raw : raw; |
| 120 | if (updated >= maxNumber) { |
| 121 | return emitSilenceableFailure(loc) |
| 122 | << "position overflow " << updated << " (updated from " << raw |
| 123 | << ") for maximum " << maxNumber; |
| 124 | } |
| 125 | if (updated < 0) { |
| 126 | return emitSilenceableFailure(loc) << "position underflow " << updated |
| 127 | << " (updated from " << raw << ")" ; |
| 128 | } |
| 129 | if (!visited.insert(V: updated).second) { |
| 130 | return emitSilenceableFailure(loc) << "repeated position " << updated |
| 131 | << " (updated from " << raw << ")" ; |
| 132 | } |
| 133 | target.push_back(Elt: updated); |
| 134 | } |
| 135 | |
| 136 | if (!isInverted) |
| 137 | return DiagnosedSilenceableFailure::success(); |
| 138 | |
| 139 | result.reserve(N: result.size() + (maxNumber - expanded.size())); |
| 140 | for (int64_t candidate : llvm::seq<int64_t>(Begin: 0, End: maxNumber)) { |
| 141 | if (llvm::is_contained(Range&: expanded, Element: candidate)) |
| 142 | continue; |
| 143 | result.push_back(Elt: candidate); |
| 144 | } |
| 145 | |
| 146 | return DiagnosedSilenceableFailure::success(); |
| 147 | } |
| 148 | |
| 149 | //===----------------------------------------------------------------------===// |
| 150 | // Generated interface implementation. |
| 151 | //===----------------------------------------------------------------------===// |
| 152 | |
| 153 | #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.cpp.inc" |
| 154 | |