1 | //===- MatchInterfaces.cpp - Transform Dialect Interfaces -----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" |
10 | |
11 | #include "llvm/Support/InterleavedRange.h" |
12 | |
13 | using namespace mlir; |
14 | |
15 | //===----------------------------------------------------------------------===// |
16 | // Printing and parsing for match ops. |
17 | //===----------------------------------------------------------------------===// |
18 | |
19 | /// Keyword syntax for positional specification inversion. |
20 | constexpr const static llvm::StringLiteral kDimExceptKeyword = "except" ; |
21 | |
22 | /// Keyword syntax for full inclusion in positional specification. |
23 | constexpr const static llvm::StringLiteral kDimAllKeyword = "all" ; |
24 | |
25 | ParseResult transform::parseTransformMatchDims(OpAsmParser &parser, |
26 | DenseI64ArrayAttr &rawDimList, |
27 | UnitAttr &isInverted, |
28 | UnitAttr &isAll) { |
29 | Builder &builder = parser.getBuilder(); |
30 | if (parser.parseOptionalKeyword(keyword: kDimAllKeyword).succeeded()) { |
31 | rawDimList = builder.getDenseI64ArrayAttr({}); |
32 | isInverted = nullptr; |
33 | isAll = builder.getUnitAttr(); |
34 | return success(); |
35 | } |
36 | |
37 | isAll = nullptr; |
38 | isInverted = nullptr; |
39 | if (parser.parseOptionalKeyword(keyword: kDimExceptKeyword).succeeded()) { |
40 | isInverted = builder.getUnitAttr(); |
41 | } |
42 | |
43 | if (isInverted) { |
44 | if (parser.parseLParen().failed()) |
45 | return failure(); |
46 | } |
47 | |
48 | SmallVector<int64_t> values; |
49 | ParseResult listResult = parser.parseCommaSeparatedList( |
50 | parseElementFn: [&]() { return parser.parseInteger(result&: values.emplace_back()); }); |
51 | if (listResult.failed()) |
52 | return failure(); |
53 | |
54 | rawDimList = builder.getDenseI64ArrayAttr(values); |
55 | |
56 | if (isInverted) { |
57 | if (parser.parseRParen().failed()) |
58 | return failure(); |
59 | } |
60 | return success(); |
61 | } |
62 | |
63 | void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op, |
64 | DenseI64ArrayAttr rawDimList, |
65 | UnitAttr isInverted, UnitAttr isAll) { |
66 | if (isAll) { |
67 | printer << kDimAllKeyword; |
68 | return; |
69 | } |
70 | if (isInverted) { |
71 | printer << kDimExceptKeyword << "(" ; |
72 | } |
73 | printer << llvm::interleaved(rawDimList.asArrayRef()); |
74 | if (isInverted) { |
75 | printer << ")" ; |
76 | } |
77 | } |
78 | |
79 | LogicalResult transform::verifyTransformMatchDimsOp(Operation *op, |
80 | ArrayRef<int64_t> raw, |
81 | bool inverted, bool all) { |
82 | if (all) { |
83 | if (inverted) { |
84 | return op->emitOpError() |
85 | << "cannot request both 'all' and 'inverted' values in the list" ; |
86 | } |
87 | if (!raw.empty()) { |
88 | return op->emitOpError() |
89 | << "cannot both request 'all' and specific values in the list" ; |
90 | } |
91 | } |
92 | if (!all && raw.empty()) { |
93 | return op->emitOpError() << "must request specific values in the list if " |
94 | "'all' is not specified" ; |
95 | } |
96 | SmallVector<int64_t> rawVector = llvm::to_vector(Range&: raw); |
97 | auto *it = llvm::unique(R&: rawVector); |
98 | if (it != rawVector.end()) |
99 | return op->emitOpError() << "expected the listed values to be unique" ; |
100 | |
101 | return success(); |
102 | } |
103 | |
104 | DiagnosedSilenceableFailure transform::expandTargetSpecification( |
105 | Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList, |
106 | int64_t maxNumber, SmallVectorImpl<int64_t> &result) { |
107 | assert(maxNumber > 0 && "expected size to be positive" ); |
108 | assert(!(isAll && isInverted) && "cannot invert all" ); |
109 | if (isAll) { |
110 | result = llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: maxNumber)); |
111 | return DiagnosedSilenceableFailure::success(); |
112 | } |
113 | |
114 | SmallVector<int64_t> expanded; |
115 | llvm::SmallDenseSet<int64_t> visited; |
116 | expanded.reserve(N: rawList.size()); |
117 | SmallVectorImpl<int64_t> &target = isInverted ? expanded : result; |
118 | for (int64_t raw : rawList) { |
119 | int64_t updated = raw < 0 ? maxNumber + raw : raw; |
120 | if (updated >= maxNumber) { |
121 | return emitSilenceableFailure(loc) |
122 | << "position overflow " << updated << " (updated from " << raw |
123 | << ") for maximum " << maxNumber; |
124 | } |
125 | if (updated < 0) { |
126 | return emitSilenceableFailure(loc) << "position underflow " << updated |
127 | << " (updated from " << raw << ")" ; |
128 | } |
129 | if (!visited.insert(V: updated).second) { |
130 | return emitSilenceableFailure(loc) << "repeated position " << updated |
131 | << " (updated from " << raw << ")" ; |
132 | } |
133 | target.push_back(Elt: updated); |
134 | } |
135 | |
136 | if (!isInverted) |
137 | return DiagnosedSilenceableFailure::success(); |
138 | |
139 | result.reserve(N: result.size() + (maxNumber - expanded.size())); |
140 | for (int64_t candidate : llvm::seq<int64_t>(Begin: 0, End: maxNumber)) { |
141 | if (llvm::is_contained(Range&: expanded, Element: candidate)) |
142 | continue; |
143 | result.push_back(Elt: candidate); |
144 | } |
145 | |
146 | return DiagnosedSilenceableFailure::success(); |
147 | } |
148 | |
149 | //===----------------------------------------------------------------------===// |
150 | // Generated interface implementation. |
151 | //===----------------------------------------------------------------------===// |
152 | |
153 | #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.cpp.inc" |
154 | |