1 | //===- MatchInterfaces.cpp - Transform Dialect Interfaces -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" |
10 | |
11 | using namespace mlir; |
12 | |
13 | //===----------------------------------------------------------------------===// |
14 | // Printing and parsing for match ops. |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | /// Keyword syntax for positional specification inversion. |
18 | constexpr const static llvm::StringLiteral kDimExceptKeyword = "except" ; |
19 | |
20 | /// Keyword syntax for full inclusion in positional specification. |
21 | constexpr const static llvm::StringLiteral kDimAllKeyword = "all" ; |
22 | |
23 | ParseResult transform::parseTransformMatchDims(OpAsmParser &parser, |
24 | DenseI64ArrayAttr &rawDimList, |
25 | UnitAttr &isInverted, |
26 | UnitAttr &isAll) { |
27 | Builder &builder = parser.getBuilder(); |
28 | if (parser.parseOptionalKeyword(keyword: kDimAllKeyword).succeeded()) { |
29 | rawDimList = builder.getDenseI64ArrayAttr({}); |
30 | isInverted = nullptr; |
31 | isAll = builder.getUnitAttr(); |
32 | return success(); |
33 | } |
34 | |
35 | isAll = nullptr; |
36 | isInverted = nullptr; |
37 | if (parser.parseOptionalKeyword(keyword: kDimExceptKeyword).succeeded()) { |
38 | isInverted = builder.getUnitAttr(); |
39 | } |
40 | |
41 | if (isInverted) { |
42 | if (parser.parseLParen().failed()) |
43 | return failure(); |
44 | } |
45 | |
46 | SmallVector<int64_t> values; |
47 | ParseResult listResult = parser.parseCommaSeparatedList( |
48 | parseElementFn: [&]() { return parser.parseInteger(result&: values.emplace_back()); }); |
49 | if (listResult.failed()) |
50 | return failure(); |
51 | |
52 | rawDimList = builder.getDenseI64ArrayAttr(values); |
53 | |
54 | if (isInverted) { |
55 | if (parser.parseRParen().failed()) |
56 | return failure(); |
57 | } |
58 | return success(); |
59 | } |
60 | |
61 | void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op, |
62 | DenseI64ArrayAttr rawDimList, |
63 | UnitAttr isInverted, UnitAttr isAll) { |
64 | if (isAll) { |
65 | printer << kDimAllKeyword; |
66 | return; |
67 | } |
68 | if (isInverted) { |
69 | printer << kDimExceptKeyword << "(" ; |
70 | } |
71 | llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(), |
72 | [&](int64_t value) { printer << value; }); |
73 | if (isInverted) { |
74 | printer << ")" ; |
75 | } |
76 | } |
77 | |
78 | LogicalResult transform::verifyTransformMatchDimsOp(Operation *op, |
79 | ArrayRef<int64_t> raw, |
80 | bool inverted, bool all) { |
81 | if (all) { |
82 | if (inverted) { |
83 | return op->emitOpError() |
84 | << "cannot request both 'all' and 'inverted' values in the list" ; |
85 | } |
86 | if (!raw.empty()) { |
87 | return op->emitOpError() |
88 | << "cannot both request 'all' and specific values in the list" ; |
89 | } |
90 | } |
91 | if (!all && raw.empty()) { |
92 | return op->emitOpError() << "must request specific values in the list if " |
93 | "'all' is not specified" ; |
94 | } |
95 | SmallVector<int64_t> rawVector = llvm::to_vector(Range&: raw); |
96 | auto *it = std::unique(first: rawVector.begin(), last: rawVector.end()); |
97 | if (it != rawVector.end()) |
98 | return op->emitOpError() << "expected the listed values to be unique" ; |
99 | |
100 | return success(); |
101 | } |
102 | |
103 | DiagnosedSilenceableFailure transform::expandTargetSpecification( |
104 | Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList, |
105 | int64_t maxNumber, SmallVectorImpl<int64_t> &result) { |
106 | assert(maxNumber > 0 && "expected size to be positive" ); |
107 | assert(!(isAll && isInverted) && "cannot invert all" ); |
108 | if (isAll) { |
109 | result = llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: maxNumber)); |
110 | return DiagnosedSilenceableFailure::success(); |
111 | } |
112 | |
113 | SmallVector<int64_t> expanded; |
114 | llvm::SmallDenseSet<int64_t> visited; |
115 | expanded.reserve(N: rawList.size()); |
116 | SmallVectorImpl<int64_t> &target = isInverted ? expanded : result; |
117 | for (int64_t raw : rawList) { |
118 | int64_t updated = raw < 0 ? maxNumber + raw : raw; |
119 | if (updated >= maxNumber) { |
120 | return emitSilenceableFailure(loc) |
121 | << "position overflow " << updated << " (updated from " << raw |
122 | << ") for maximum " << maxNumber; |
123 | } |
124 | if (updated < 0) { |
125 | return emitSilenceableFailure(loc) << "position underflow " << updated |
126 | << " (updated from " << raw << ")" ; |
127 | } |
128 | if (!visited.insert(V: updated).second) { |
129 | return emitSilenceableFailure(loc) << "repeated position " << updated |
130 | << " (updated from " << raw << ")" ; |
131 | } |
132 | target.push_back(Elt: updated); |
133 | } |
134 | |
135 | if (!isInverted) |
136 | return DiagnosedSilenceableFailure::success(); |
137 | |
138 | result.reserve(N: result.size() + (maxNumber - expanded.size())); |
139 | for (int64_t candidate : llvm::seq<int64_t>(Begin: 0, End: maxNumber)) { |
140 | if (llvm::is_contained(Range&: expanded, Element: candidate)) |
141 | continue; |
142 | result.push_back(Elt: candidate); |
143 | } |
144 | |
145 | return DiagnosedSilenceableFailure::success(); |
146 | } |
147 | |
148 | //===----------------------------------------------------------------------===// |
149 | // Generated interface implementation. |
150 | //===----------------------------------------------------------------------===// |
151 | |
152 | #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.cpp.inc" |
153 | |