1//===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/PatternMatch.h"
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/IRMapping.h"
12#include "mlir/IR/Iterators.h"
13#include "mlir/IR/RegionKindInterface.h"
14#include "llvm/ADT/SmallPtrSet.h"
15
16using namespace mlir;
17
18//===----------------------------------------------------------------------===//
19// PatternBenefit
20//===----------------------------------------------------------------------===//
21
22PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
23 assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
24 "This pattern match benefit is too large to represent");
25}
26
27unsigned short PatternBenefit::getBenefit() const {
28 assert(!isImpossibleToMatch() && "Pattern doesn't match");
29 return representation;
30}
31
32//===----------------------------------------------------------------------===//
33// Pattern
34//===----------------------------------------------------------------------===//
35
36//===----------------------------------------------------------------------===//
37// OperationName Root Constructors
38//===----------------------------------------------------------------------===//
39
40Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
41 MLIRContext *context, ArrayRef<StringRef> generatedNames)
42 : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
43 RootKind::OperationName, generatedNames, benefit, context) {}
44
45//===----------------------------------------------------------------------===//
46// MatchAnyOpTypeTag Root Constructors
47//===----------------------------------------------------------------------===//
48
49Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit,
50 MLIRContext *context, ArrayRef<StringRef> generatedNames)
51 : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
52
53//===----------------------------------------------------------------------===//
54// MatchInterfaceOpTypeTag Root Constructors
55//===----------------------------------------------------------------------===//
56
57Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
58 PatternBenefit benefit, MLIRContext *context,
59 ArrayRef<StringRef> generatedNames)
60 : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
61 generatedNames, benefit, context) {}
62
63//===----------------------------------------------------------------------===//
64// MatchTraitOpTypeTag Root Constructors
65//===----------------------------------------------------------------------===//
66
67Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID,
68 PatternBenefit benefit, MLIRContext *context,
69 ArrayRef<StringRef> generatedNames)
70 : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
71 benefit, context) {}
72
73//===----------------------------------------------------------------------===//
74// General Constructors
75//===----------------------------------------------------------------------===//
76
77Pattern::Pattern(const void *rootValue, RootKind rootKind,
78 ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
79 MLIRContext *context)
80 : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
81 contextAndHasBoundedRecursion(context, false) {
82 if (generatedNames.empty())
83 return;
84 generatedOps.reserve(N: generatedNames.size());
85 std::transform(first: generatedNames.begin(), last: generatedNames.end(),
86 result: std::back_inserter(x&: generatedOps), unary_op: [context](StringRef name) {
87 return OperationName(name, context);
88 });
89}
90
91//===----------------------------------------------------------------------===//
92// RewritePattern
93//===----------------------------------------------------------------------===//
94
95/// Out-of-line vtable anchor.
96void RewritePattern::anchor() {}
97
98//===----------------------------------------------------------------------===//
99// RewriterBase
100//===----------------------------------------------------------------------===//
101
102bool RewriterBase::Listener::classof(const OpBuilder::Listener *base) {
103 return base->getKind() == OpBuilder::ListenerBase::Kind::RewriterBaseListener;
104}
105
106RewriterBase::~RewriterBase() {
107 // Out of line to provide a vtable anchor for the class.
108}
109
110void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
111 // Notify the listener that we're about to replace this op.
112 if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener))
113 rewriteListener->notifyOperationReplaced(op: from, replacement: to);
114
115 replaceAllUsesWith(from: from->getResults(), to);
116}
117
118void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
119 // Notify the listener that we're about to replace this op.
120 if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener))
121 rewriteListener->notifyOperationReplaced(op: from, replacement: to);
122
123 replaceAllUsesWith(from: from->getResults(), to: to->getResults());
124}
125
126/// This method replaces the results of the operation with the specified list of
127/// values. The number of provided values must match the number of results of
128/// the operation. The replaced op is erased.
129void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
130 assert(op->getNumResults() == newValues.size() &&
131 "incorrect # of replacement values");
132
133 // Replace all result uses. Also notifies the listener of modifications.
134 replaceAllOpUsesWith(from: op, to: newValues);
135
136 // Erase op and notify listener.
137 eraseOp(op);
138}
139
140/// This method replaces the results of the operation with the specified new op
141/// (replacement). The number of results of the two operations must match. The
142/// replaced op is erased.
143void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
144 assert(op && newOp && "expected non-null op");
145 assert(op->getNumResults() == newOp->getNumResults() &&
146 "ops have different number of results");
147
148 // Replace all result uses. Also notifies the listener of modifications.
149 replaceAllOpUsesWith(from: op, to: newOp->getResults());
150
151 // Erase op and notify listener.
152 eraseOp(op);
153}
154
155/// This method erases an operation that is known to have no uses. The uses of
156/// the given operation *must* be known to be dead.
157void RewriterBase::eraseOp(Operation *op) {
158 assert(op->use_empty() && "expected 'op' to have no uses");
159 auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener);
160
161 // Fast path: If no listener is attached, the op can be dropped in one go.
162 if (!rewriteListener) {
163 op->erase();
164 return;
165 }
166
167 // Helper function that erases a single op.
168 auto eraseSingleOp = [&](Operation *op) {
169#ifndef NDEBUG
170 // All nested ops should have been erased already.
171 assert(
172 llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
173 "expected empty regions");
174 // All users should have been erased already if the op is in a region with
175 // SSA dominance.
176 if (!op->use_empty() && op->getParentOp())
177 assert(mayBeGraphRegion(*op->getParentRegion()) &&
178 "expected that op has no uses");
179#endif // NDEBUG
180 rewriteListener->notifyOperationErased(op);
181
182 // Explicitly drop all uses in case the op is in a graph region.
183 op->dropAllUses();
184 op->erase();
185 };
186
187 // Nested ops must be erased one-by-one, so that listeners have a consistent
188 // view of the IR every time a notification is triggered. Users must be
189 // erased before definitions. I.e., post-order, reverse dominance.
190 std::function<void(Operation *)> eraseTree = [&](Operation *op) {
191 // Erase nested ops.
192 for (Region &r : llvm::reverse(C: op->getRegions())) {
193 // Erase all blocks in the right order. Successors should be erased
194 // before predecessors because successor blocks may use values defined
195 // in predecessor blocks. A post-order traversal of blocks within a
196 // region visits successors before predecessors. Repeat the traversal
197 // until the region is empty. (The block graph could be disconnected.)
198 while (!r.empty()) {
199 SmallVector<Block *> erasedBlocks;
200 // Some blocks may have invalid successor, use a set including nullptr
201 // to avoid null pointer.
202 llvm::SmallPtrSet<Block *, 4> visited{nullptr};
203 for (Block *b : llvm::post_order_ext(G: &r.front(), S&: visited)) {
204 // Visit ops in reverse order.
205 for (Operation &op :
206 llvm::make_early_inc_range(Range: ReverseIterator::makeIterable(range&: *b)))
207 eraseTree(&op);
208 // Do not erase the block immediately. This is not supprted by the
209 // post_order iterator.
210 erasedBlocks.push_back(Elt: b);
211 }
212 for (Block *b : erasedBlocks) {
213 // Explicitly drop all uses in case there is a cycle in the block
214 // graph.
215 for (BlockArgument bbArg : b->getArguments())
216 bbArg.dropAllUses();
217 b->dropAllUses();
218 eraseBlock(block: b);
219 }
220 }
221 }
222 // Then erase the enclosing op.
223 eraseSingleOp(op);
224 };
225
226 eraseTree(op);
227}
228
229void RewriterBase::eraseBlock(Block *block) {
230 assert(block->use_empty() && "expected 'block' to have no uses");
231
232 for (auto &op : llvm::make_early_inc_range(Range: llvm::reverse(C&: *block))) {
233 assert(op.use_empty() && "expected 'op' to have no uses");
234 eraseOp(op: &op);
235 }
236
237 // Notify the listener that the block is about to be removed.
238 if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener))
239 rewriteListener->notifyBlockErased(block);
240
241 block->erase();
242}
243
244void RewriterBase::finalizeOpModification(Operation *op) {
245 // Notify the listener that the operation was modified.
246 if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener))
247 rewriteListener->notifyOperationModified(op);
248}
249
250void RewriterBase::replaceAllUsesExcept(
251 Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
252 return replaceUsesWithIf(from, to, functor: [&](OpOperand &use) {
253 Operation *user = use.getOwner();
254 return !preservedUsers.contains(Ptr: user);
255 });
256}
257
258void RewriterBase::replaceUsesWithIf(Value from, Value to,
259 function_ref<bool(OpOperand &)> functor,
260 bool *allUsesReplaced) {
261 bool allReplaced = true;
262 for (OpOperand &operand : llvm::make_early_inc_range(Range: from.getUses())) {
263 bool replace = functor(operand);
264 if (replace)
265 modifyOpInPlace(root: operand.getOwner(), callable: [&]() { operand.set(to); });
266 allReplaced &= replace;
267 }
268 if (allUsesReplaced)
269 *allUsesReplaced = allReplaced;
270}
271
272void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
273 function_ref<bool(OpOperand &)> functor,
274 bool *allUsesReplaced) {
275 assert(from.size() == to.size() && "incorrect number of replacements");
276 bool allReplaced = true;
277 for (auto it : llvm::zip_equal(t&: from, u&: to)) {
278 bool r;
279 replaceUsesWithIf(from: std::get<0>(t&: it), to: std::get<1>(t&: it), functor,
280 /*allUsesReplaced=*/&r);
281 allReplaced &= r;
282 }
283 if (allUsesReplaced)
284 *allUsesReplaced = allReplaced;
285}
286
287void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
288 Block::iterator before,
289 ValueRange argValues) {
290 assert(argValues.size() == source->getNumArguments() &&
291 "incorrect # of argument replacement values");
292
293 // The source block will be deleted, so it should not have any users (i.e.,
294 // there should be no predecessors).
295 assert(source->hasNoPredecessors() &&
296 "expected 'source' to have no predecessors");
297
298 if (dest->end() != before) {
299 // The source block will be inserted in the middle of the dest block, so
300 // the source block should have no successors. Otherwise, the remainder of
301 // the dest block would be unreachable.
302 assert(source->hasNoSuccessors() &&
303 "expected 'source' to have no successors");
304 } else {
305 // The source block will be inserted at the end of the dest block, so the
306 // dest block should have no successors. Otherwise, the inserted operations
307 // will be unreachable.
308 assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
309 }
310
311 // Replace all of the successor arguments with the provided values.
312 for (auto it : llvm::zip(t: source->getArguments(), u&: argValues))
313 replaceAllUsesWith(from: std::get<0>(t&: it), to: std::get<1>(t&: it));
314
315 // Move operations from the source block to the dest block and erase the
316 // source block.
317 if (!listener) {
318 // Fast path: If no listener is attached, move all operations at once.
319 dest->getOperations().splice(where: before, L2&: source->getOperations());
320 } else {
321 while (!source->empty())
322 moveOpBefore(op: &source->front(), block: dest, iterator: before);
323 }
324
325 // Erase the source block.
326 assert(source->empty() && "expected 'source' to be empty");
327 eraseBlock(block: source);
328}
329
330void RewriterBase::inlineBlockBefore(Block *source, Operation *op,
331 ValueRange argValues) {
332 inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
333}
334
335void RewriterBase::mergeBlocks(Block *source, Block *dest,
336 ValueRange argValues) {
337 inlineBlockBefore(source, dest, before: dest->end(), argValues);
338}
339
340/// Split the operations starting at "before" (inclusive) out of the given
341/// block into a new block, and return it.
342Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
343 // Fast path: If no listener is attached, split the block directly.
344 if (!listener)
345 return block->splitBlock(splitBefore: before);
346
347 // `createBlock` sets the insertion point at the beginning of the new block.
348 InsertionGuard g(*this);
349 Block *newBlock =
350 createBlock(parent: block->getParent(), insertPt: std::next(x: block->getIterator()));
351
352 // If `before` points to end of the block, no ops should be moved.
353 if (before == block->end())
354 return newBlock;
355
356 // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
357 // Stop when the operation pointed to by `before` has been moved.
358 while (before->getBlock() != newBlock)
359 moveOpBefore(op: &block->back(), block: newBlock, iterator: newBlock->begin());
360
361 return newBlock;
362}
363
364/// Move the blocks that belong to "region" before the given position in
365/// another region. The two regions must be different. The caller is in
366/// charge to update create the operation transferring the control flow to the
367/// region and pass it the correct block arguments.
368void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
369 Region::iterator before) {
370 // Fast path: If no listener is attached, move all blocks at once.
371 if (!listener) {
372 parent.getBlocks().splice(where: before, L2&: region.getBlocks());
373 return;
374 }
375
376 // Move blocks from the beginning of the region one-by-one.
377 while (!region.empty())
378 moveBlockBefore(block: &region.front(), region: &parent, iterator: before);
379}
380void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
381 inlineRegionBefore(region, parent&: *before->getParent(), before: before->getIterator());
382}
383
384void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
385 moveBlockBefore(block, region: anotherBlock->getParent(),
386 iterator: anotherBlock->getIterator());
387}
388
389void RewriterBase::moveBlockBefore(Block *block, Region *region,
390 Region::iterator iterator) {
391 Region *currentRegion = block->getParent();
392 Region::iterator nextIterator = std::next(x: block->getIterator());
393 block->moveBefore(region, iterator);
394 if (listener)
395 listener->notifyBlockInserted(block, /*previous=*/currentRegion,
396 /*previousIt=*/nextIterator);
397}
398
399void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
400 moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
401}
402
403void RewriterBase::moveOpBefore(Operation *op, Block *block,
404 Block::iterator iterator) {
405 Block *currentBlock = op->getBlock();
406 Block::iterator nextIterator = std::next(op->getIterator());
407 op->moveBefore(block, iterator);
408 if (listener)
409 listener->notifyOperationInserted(
410 op, /*previous=*/InsertPoint(currentBlock, nextIterator));
411}
412
413void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
414 moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
415}
416
417void RewriterBase::moveOpAfter(Operation *op, Block *block,
418 Block::iterator iterator) {
419 assert(iterator != block->end() && "cannot move after end of block");
420 moveOpBefore(op, block, iterator: std::next(x: iterator));
421}
422

source code of mlir/lib/IR/PatternMatch.cpp