1//===- MatchInterfaces.cpp - Transform Dialect Interfaces -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
10
11using namespace mlir;
12
13//===----------------------------------------------------------------------===//
14// Printing and parsing for match ops.
15//===----------------------------------------------------------------------===//
16
17/// Keyword syntax for positional specification inversion.
18constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
19
20/// Keyword syntax for full inclusion in positional specification.
21constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
22
23ParseResult transform::parseTransformMatchDims(OpAsmParser &parser,
24 DenseI64ArrayAttr &rawDimList,
25 UnitAttr &isInverted,
26 UnitAttr &isAll) {
27 Builder &builder = parser.getBuilder();
28 if (parser.parseOptionalKeyword(keyword: kDimAllKeyword).succeeded()) {
29 rawDimList = builder.getDenseI64ArrayAttr({});
30 isInverted = nullptr;
31 isAll = builder.getUnitAttr();
32 return success();
33 }
34
35 isAll = nullptr;
36 isInverted = nullptr;
37 if (parser.parseOptionalKeyword(keyword: kDimExceptKeyword).succeeded()) {
38 isInverted = builder.getUnitAttr();
39 }
40
41 if (isInverted) {
42 if (parser.parseLParen().failed())
43 return failure();
44 }
45
46 SmallVector<int64_t> values;
47 ParseResult listResult = parser.parseCommaSeparatedList(
48 parseElementFn: [&]() { return parser.parseInteger(result&: values.emplace_back()); });
49 if (listResult.failed())
50 return failure();
51
52 rawDimList = builder.getDenseI64ArrayAttr(values);
53
54 if (isInverted) {
55 if (parser.parseRParen().failed())
56 return failure();
57 }
58 return success();
59}
60
61void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
62 DenseI64ArrayAttr rawDimList,
63 UnitAttr isInverted, UnitAttr isAll) {
64 if (isAll) {
65 printer << kDimAllKeyword;
66 return;
67 }
68 if (isInverted) {
69 printer << kDimExceptKeyword << "(";
70 }
71 llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
72 [&](int64_t value) { printer << value; });
73 if (isInverted) {
74 printer << ")";
75 }
76}
77
78LogicalResult transform::verifyTransformMatchDimsOp(Operation *op,
79 ArrayRef<int64_t> raw,
80 bool inverted, bool all) {
81 if (all) {
82 if (inverted) {
83 return op->emitOpError()
84 << "cannot request both 'all' and 'inverted' values in the list";
85 }
86 if (!raw.empty()) {
87 return op->emitOpError()
88 << "cannot both request 'all' and specific values in the list";
89 }
90 }
91 if (!all && raw.empty()) {
92 return op->emitOpError() << "must request specific values in the list if "
93 "'all' is not specified";
94 }
95 SmallVector<int64_t> rawVector = llvm::to_vector(Range&: raw);
96 auto *it = std::unique(first: rawVector.begin(), last: rawVector.end());
97 if (it != rawVector.end())
98 return op->emitOpError() << "expected the listed values to be unique";
99
100 return success();
101}
102
103DiagnosedSilenceableFailure transform::expandTargetSpecification(
104 Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList,
105 int64_t maxNumber, SmallVectorImpl<int64_t> &result) {
106 assert(maxNumber > 0 && "expected size to be positive");
107 assert(!(isAll && isInverted) && "cannot invert all");
108 if (isAll) {
109 result = llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: maxNumber));
110 return DiagnosedSilenceableFailure::success();
111 }
112
113 SmallVector<int64_t> expanded;
114 llvm::SmallDenseSet<int64_t> visited;
115 expanded.reserve(N: rawList.size());
116 SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
117 for (int64_t raw : rawList) {
118 int64_t updated = raw < 0 ? maxNumber + raw : raw;
119 if (updated >= maxNumber) {
120 return emitSilenceableFailure(loc)
121 << "position overflow " << updated << " (updated from " << raw
122 << ") for maximum " << maxNumber;
123 }
124 if (updated < 0) {
125 return emitSilenceableFailure(loc) << "position underflow " << updated
126 << " (updated from " << raw << ")";
127 }
128 if (!visited.insert(V: updated).second) {
129 return emitSilenceableFailure(loc) << "repeated position " << updated
130 << " (updated from " << raw << ")";
131 }
132 target.push_back(Elt: updated);
133 }
134
135 if (!isInverted)
136 return DiagnosedSilenceableFailure::success();
137
138 result.reserve(N: result.size() + (maxNumber - expanded.size()));
139 for (int64_t candidate : llvm::seq<int64_t>(Begin: 0, End: maxNumber)) {
140 if (llvm::is_contained(Range&: expanded, Element: candidate))
141 continue;
142 result.push_back(Elt: candidate);
143 }
144
145 return DiagnosedSilenceableFailure::success();
146}
147
148//===----------------------------------------------------------------------===//
149// Generated interface implementation.
150//===----------------------------------------------------------------------===//
151
152#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.cpp.inc"
153

source code of mlir/lib/Dialect/Transform/Interfaces/MatchInterfaces.cpp