1//===- MatchInterfaces.cpp - Transform Dialect Interfaces -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
10
11#include "llvm/Support/InterleavedRange.h"
12
13using namespace mlir;
14
15//===----------------------------------------------------------------------===//
16// Printing and parsing for match ops.
17//===----------------------------------------------------------------------===//
18
19/// Keyword syntax for positional specification inversion.
20constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
21
22/// Keyword syntax for full inclusion in positional specification.
23constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
24
25ParseResult transform::parseTransformMatchDims(OpAsmParser &parser,
26 DenseI64ArrayAttr &rawDimList,
27 UnitAttr &isInverted,
28 UnitAttr &isAll) {
29 Builder &builder = parser.getBuilder();
30 if (parser.parseOptionalKeyword(keyword: kDimAllKeyword).succeeded()) {
31 rawDimList = builder.getDenseI64ArrayAttr({});
32 isInverted = nullptr;
33 isAll = builder.getUnitAttr();
34 return success();
35 }
36
37 isAll = nullptr;
38 isInverted = nullptr;
39 if (parser.parseOptionalKeyword(keyword: kDimExceptKeyword).succeeded()) {
40 isInverted = builder.getUnitAttr();
41 }
42
43 if (isInverted) {
44 if (parser.parseLParen().failed())
45 return failure();
46 }
47
48 SmallVector<int64_t> values;
49 ParseResult listResult = parser.parseCommaSeparatedList(
50 parseElementFn: [&]() { return parser.parseInteger(result&: values.emplace_back()); });
51 if (listResult.failed())
52 return failure();
53
54 rawDimList = builder.getDenseI64ArrayAttr(values);
55
56 if (isInverted) {
57 if (parser.parseRParen().failed())
58 return failure();
59 }
60 return success();
61}
62
63void transform::printTransformMatchDims(OpAsmPrinter &printer, Operation *op,
64 DenseI64ArrayAttr rawDimList,
65 UnitAttr isInverted, UnitAttr isAll) {
66 if (isAll) {
67 printer << kDimAllKeyword;
68 return;
69 }
70 if (isInverted) {
71 printer << kDimExceptKeyword << "(";
72 }
73 printer << llvm::interleaved(rawDimList.asArrayRef());
74 if (isInverted) {
75 printer << ")";
76 }
77}
78
79LogicalResult transform::verifyTransformMatchDimsOp(Operation *op,
80 ArrayRef<int64_t> raw,
81 bool inverted, bool all) {
82 if (all) {
83 if (inverted) {
84 return op->emitOpError()
85 << "cannot request both 'all' and 'inverted' values in the list";
86 }
87 if (!raw.empty()) {
88 return op->emitOpError()
89 << "cannot both request 'all' and specific values in the list";
90 }
91 }
92 if (!all && raw.empty()) {
93 return op->emitOpError() << "must request specific values in the list if "
94 "'all' is not specified";
95 }
96 SmallVector<int64_t> rawVector = llvm::to_vector(Range&: raw);
97 auto *it = llvm::unique(R&: rawVector);
98 if (it != rawVector.end())
99 return op->emitOpError() << "expected the listed values to be unique";
100
101 return success();
102}
103
104DiagnosedSilenceableFailure transform::expandTargetSpecification(
105 Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList,
106 int64_t maxNumber, SmallVectorImpl<int64_t> &result) {
107 assert(maxNumber > 0 && "expected size to be positive");
108 assert(!(isAll && isInverted) && "cannot invert all");
109 if (isAll) {
110 result = llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: maxNumber));
111 return DiagnosedSilenceableFailure::success();
112 }
113
114 SmallVector<int64_t> expanded;
115 llvm::SmallDenseSet<int64_t> visited;
116 expanded.reserve(N: rawList.size());
117 SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
118 for (int64_t raw : rawList) {
119 int64_t updated = raw < 0 ? maxNumber + raw : raw;
120 if (updated >= maxNumber) {
121 return emitSilenceableFailure(loc)
122 << "position overflow " << updated << " (updated from " << raw
123 << ") for maximum " << maxNumber;
124 }
125 if (updated < 0) {
126 return emitSilenceableFailure(loc) << "position underflow " << updated
127 << " (updated from " << raw << ")";
128 }
129 if (!visited.insert(V: updated).second) {
130 return emitSilenceableFailure(loc) << "repeated position " << updated
131 << " (updated from " << raw << ")";
132 }
133 target.push_back(Elt: updated);
134 }
135
136 if (!isInverted)
137 return DiagnosedSilenceableFailure::success();
138
139 result.reserve(N: result.size() + (maxNumber - expanded.size()));
140 for (int64_t candidate : llvm::seq<int64_t>(Begin: 0, End: maxNumber)) {
141 if (llvm::is_contained(Range&: expanded, Element: candidate))
142 continue;
143 result.push_back(Elt: candidate);
144 }
145
146 return DiagnosedSilenceableFailure::success();
147}
148
149//===----------------------------------------------------------------------===//
150// Generated interface implementation.
151//===----------------------------------------------------------------------===//
152
153#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.cpp.inc"
154

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/Transform/Interfaces/MatchInterfaces.cpp