1//===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/PatternMatch.h"
10#include "mlir/Config/mlir-config.h"
11#include "mlir/IR/IRMapping.h"
12#include "mlir/IR/Iterators.h"
13#include "mlir/IR/RegionKindInterface.h"
14#include "llvm/ADT/SmallPtrSet.h"
15
16using namespace mlir;
17
18//===----------------------------------------------------------------------===//
19// PatternBenefit
20//===----------------------------------------------------------------------===//
21
22PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
23 assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
24 "This pattern match benefit is too large to represent");
25}
26
27unsigned short PatternBenefit::getBenefit() const {
28 assert(!isImpossibleToMatch() && "Pattern doesn't match");
29 return representation;
30}
31
32//===----------------------------------------------------------------------===//
33// Pattern
34//===----------------------------------------------------------------------===//
35
36//===----------------------------------------------------------------------===//
37// OperationName Root Constructors
38
39Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
40 MLIRContext *context, ArrayRef<StringRef> generatedNames)
41 : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
42 RootKind::OperationName, generatedNames, benefit, context) {}
43
44//===----------------------------------------------------------------------===//
45// MatchAnyOpTypeTag Root Constructors
46
47Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit,
48 MLIRContext *context, ArrayRef<StringRef> generatedNames)
49 : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
50
51//===----------------------------------------------------------------------===//
52// MatchInterfaceOpTypeTag Root Constructors
53
54Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
55 PatternBenefit benefit, MLIRContext *context,
56 ArrayRef<StringRef> generatedNames)
57 : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
58 generatedNames, benefit, context) {}
59
60//===----------------------------------------------------------------------===//
61// MatchTraitOpTypeTag Root Constructors
62
63Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID,
64 PatternBenefit benefit, MLIRContext *context,
65 ArrayRef<StringRef> generatedNames)
66 : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
67 benefit, context) {}
68
69//===----------------------------------------------------------------------===//
70// General Constructors
71
72Pattern::Pattern(const void *rootValue, RootKind rootKind,
73 ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
74 MLIRContext *context)
75 : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
76 contextAndHasBoundedRecursion(context, false) {
77 if (generatedNames.empty())
78 return;
79 generatedOps.reserve(N: generatedNames.size());
80 std::transform(first: generatedNames.begin(), last: generatedNames.end(),
81 result: std::back_inserter(x&: generatedOps), unary_op: [context](StringRef name) {
82 return OperationName(name, context);
83 });
84}
85
86//===----------------------------------------------------------------------===//
87// RewritePattern
88//===----------------------------------------------------------------------===//
89
90void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
91 llvm_unreachable("need to implement either matchAndRewrite or one of the "
92 "rewrite functions!");
93}
94
95LogicalResult RewritePattern::match(Operation *op) const {
96 llvm_unreachable("need to implement either match or matchAndRewrite!");
97}
98
99/// Out-of-line vtable anchor.
100void RewritePattern::anchor() {}
101
102//===----------------------------------------------------------------------===//
103// RewriterBase
104//===----------------------------------------------------------------------===//
105
106bool RewriterBase::Listener::classof(const OpBuilder::Listener *base) {
107 return base->getKind() == OpBuilder::ListenerBase::Kind::RewriterBaseListener;
108}
109
110RewriterBase::~RewriterBase() {
111 // Out of line to provide a vtable anchor for the class.
112}
113
114void RewriterBase::replaceAllOpUsesWith(Operation *from, ValueRange to) {
115 // Notify the listener that we're about to replace this op.
116 if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener))
117 rewriteListener->notifyOperationReplaced(op: from, replacement: to);
118
119 replaceAllUsesWith(from: from->getResults(), to);
120}
121
122void RewriterBase::replaceAllOpUsesWith(Operation *from, Operation *to) {
123 // Notify the listener that we're about to replace this op.
124 if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener))
125 rewriteListener->notifyOperationReplaced(op: from, replacement: to);
126
127 replaceAllUsesWith(from: from->getResults(), to: to->getResults());
128}
129
130/// This method replaces the results of the operation with the specified list of
131/// values. The number of provided values must match the number of results of
132/// the operation. The replaced op is erased.
133void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
134 assert(op->getNumResults() == newValues.size() &&
135 "incorrect # of replacement values");
136
137 // Replace all result uses. Also notifies the listener of modifications.
138 replaceAllOpUsesWith(from: op, to: newValues);
139
140 // Erase op and notify listener.
141 eraseOp(op);
142}
143
144/// This method replaces the results of the operation with the specified new op
145/// (replacement). The number of results of the two operations must match. The
146/// replaced op is erased.
147void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
148 assert(op && newOp && "expected non-null op");
149 assert(op->getNumResults() == newOp->getNumResults() &&
150 "ops have different number of results");
151
152 // Replace all result uses. Also notifies the listener of modifications.
153 replaceAllOpUsesWith(from: op, to: newOp->getResults());
154
155 // Erase op and notify listener.
156 eraseOp(op);
157}
158
159/// This method erases an operation that is known to have no uses. The uses of
160/// the given operation *must* be known to be dead.
161void RewriterBase::eraseOp(Operation *op) {
162 assert(op->use_empty() && "expected 'op' to have no uses");
163 auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener);
164
165 // Fast path: If no listener is attached, the op can be dropped in one go.
166 if (!rewriteListener) {
167 op->erase();
168 return;
169 }
170
171 // Helper function that erases a single op.
172 auto eraseSingleOp = [&](Operation *op) {
173#ifndef NDEBUG
174 // All nested ops should have been erased already.
175 assert(
176 llvm::all_of(op->getRegions(), [&](Region &r) { return r.empty(); }) &&
177 "expected empty regions");
178 // All users should have been erased already if the op is in a region with
179 // SSA dominance.
180 if (!op->use_empty() && op->getParentOp())
181 assert(mayBeGraphRegion(*op->getParentRegion()) &&
182 "expected that op has no uses");
183#endif // NDEBUG
184 rewriteListener->notifyOperationErased(op);
185
186 // Explicitly drop all uses in case the op is in a graph region.
187 op->dropAllUses();
188 op->erase();
189 };
190
191 // Nested ops must be erased one-by-one, so that listeners have a consistent
192 // view of the IR every time a notification is triggered. Users must be
193 // erased before definitions. I.e., post-order, reverse dominance.
194 std::function<void(Operation *)> eraseTree = [&](Operation *op) {
195 // Erase nested ops.
196 for (Region &r : llvm::reverse(C: op->getRegions())) {
197 // Erase all blocks in the right order. Successors should be erased
198 // before predecessors because successor blocks may use values defined
199 // in predecessor blocks. A post-order traversal of blocks within a
200 // region visits successors before predecessors. Repeat the traversal
201 // until the region is empty. (The block graph could be disconnected.)
202 while (!r.empty()) {
203 SmallVector<Block *> erasedBlocks;
204 // Some blocks may have invalid successor, use a set including nullptr
205 // to avoid null pointer.
206 llvm::SmallPtrSet<Block *, 4> visited{nullptr};
207 for (Block *b : llvm::post_order_ext(G: &r.front(), S&: visited)) {
208 // Visit ops in reverse order.
209 for (Operation &op :
210 llvm::make_early_inc_range(Range: ReverseIterator::makeIterable(range&: *b)))
211 eraseTree(&op);
212 // Do not erase the block immediately. This is not supprted by the
213 // post_order iterator.
214 erasedBlocks.push_back(Elt: b);
215 }
216 for (Block *b : erasedBlocks) {
217 // Explicitly drop all uses in case there is a cycle in the block
218 // graph.
219 for (BlockArgument bbArg : b->getArguments())
220 bbArg.dropAllUses();
221 b->dropAllUses();
222 eraseBlock(block: b);
223 }
224 }
225 }
226 // Then erase the enclosing op.
227 eraseSingleOp(op);
228 };
229
230 eraseTree(op);
231}
232
233void RewriterBase::eraseBlock(Block *block) {
234 assert(block->use_empty() && "expected 'block' to have no uses");
235
236 for (auto &op : llvm::make_early_inc_range(Range: llvm::reverse(C&: *block))) {
237 assert(op.use_empty() && "expected 'op' to have no uses");
238 eraseOp(op: &op);
239 }
240
241 // Notify the listener that the block is about to be removed.
242 if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener))
243 rewriteListener->notifyBlockErased(block);
244
245 block->erase();
246}
247
248void RewriterBase::finalizeOpModification(Operation *op) {
249 // Notify the listener that the operation was modified.
250 if (auto *rewriteListener = dyn_cast_if_present<Listener>(Val: listener))
251 rewriteListener->notifyOperationModified(op);
252}
253
254void RewriterBase::replaceAllUsesExcept(
255 Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
256 return replaceUsesWithIf(from, to, functor: [&](OpOperand &use) {
257 Operation *user = use.getOwner();
258 return !preservedUsers.contains(Ptr: user);
259 });
260}
261
262void RewriterBase::replaceUsesWithIf(Value from, Value to,
263 function_ref<bool(OpOperand &)> functor,
264 bool *allUsesReplaced) {
265 bool allReplaced = true;
266 for (OpOperand &operand : llvm::make_early_inc_range(Range: from.getUses())) {
267 bool replace = functor(operand);
268 if (replace)
269 modifyOpInPlace(root: operand.getOwner(), callable: [&]() { operand.set(to); });
270 allReplaced &= replace;
271 }
272 if (allUsesReplaced)
273 *allUsesReplaced = allReplaced;
274}
275
276void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
277 function_ref<bool(OpOperand &)> functor,
278 bool *allUsesReplaced) {
279 assert(from.size() == to.size() && "incorrect number of replacements");
280 bool allReplaced = true;
281 for (auto it : llvm::zip_equal(t&: from, u&: to)) {
282 bool r;
283 replaceUsesWithIf(from: std::get<0>(t&: it), to: std::get<1>(t&: it), functor,
284 /*allUsesReplaced=*/&r);
285 allReplaced &= r;
286 }
287 if (allUsesReplaced)
288 *allUsesReplaced = allReplaced;
289}
290
291void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
292 Block::iterator before,
293 ValueRange argValues) {
294 assert(argValues.size() == source->getNumArguments() &&
295 "incorrect # of argument replacement values");
296
297 // The source block will be deleted, so it should not have any users (i.e.,
298 // there should be no predecessors).
299 assert(source->hasNoPredecessors() &&
300 "expected 'source' to have no predecessors");
301
302 if (dest->end() != before) {
303 // The source block will be inserted in the middle of the dest block, so
304 // the source block should have no successors. Otherwise, the remainder of
305 // the dest block would be unreachable.
306 assert(source->hasNoSuccessors() &&
307 "expected 'source' to have no successors");
308 } else {
309 // The source block will be inserted at the end of the dest block, so the
310 // dest block should have no successors. Otherwise, the inserted operations
311 // will be unreachable.
312 assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors");
313 }
314
315 // Replace all of the successor arguments with the provided values.
316 for (auto it : llvm::zip(t: source->getArguments(), u&: argValues))
317 replaceAllUsesWith(from: std::get<0>(t&: it), to: std::get<1>(t&: it));
318
319 // Move operations from the source block to the dest block and erase the
320 // source block.
321 if (!listener) {
322 // Fast path: If no listener is attached, move all operations at once.
323 dest->getOperations().splice(where: before, L2&: source->getOperations());
324 } else {
325 while (!source->empty())
326 moveOpBefore(op: &source->front(), block: dest, iterator: before);
327 }
328
329 // Erase the source block.
330 assert(source->empty() && "expected 'source' to be empty");
331 eraseBlock(block: source);
332}
333
334void RewriterBase::inlineBlockBefore(Block *source, Operation *op,
335 ValueRange argValues) {
336 inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues);
337}
338
339void RewriterBase::mergeBlocks(Block *source, Block *dest,
340 ValueRange argValues) {
341 inlineBlockBefore(source, dest, before: dest->end(), argValues);
342}
343
344/// Split the operations starting at "before" (inclusive) out of the given
345/// block into a new block, and return it.
346Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
347 // Fast path: If no listener is attached, split the block directly.
348 if (!listener)
349 return block->splitBlock(splitBefore: before);
350
351 // `createBlock` sets the insertion point at the beginning of the new block.
352 InsertionGuard g(*this);
353 Block *newBlock =
354 createBlock(parent: block->getParent(), insertPt: std::next(x: block->getIterator()));
355
356 // If `before` points to end of the block, no ops should be moved.
357 if (before == block->end())
358 return newBlock;
359
360 // Move ops one-by-one from the end of `block` to the beginning of `newBlock`.
361 // Stop when the operation pointed to by `before` has been moved.
362 while (before->getBlock() != newBlock)
363 moveOpBefore(op: &block->back(), block: newBlock, iterator: newBlock->begin());
364
365 return newBlock;
366}
367
368/// Move the blocks that belong to "region" before the given position in
369/// another region. The two regions must be different. The caller is in
370/// charge to update create the operation transferring the control flow to the
371/// region and pass it the correct block arguments.
372void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
373 Region::iterator before) {
374 // Fast path: If no listener is attached, move all blocks at once.
375 if (!listener) {
376 parent.getBlocks().splice(where: before, L2&: region.getBlocks());
377 return;
378 }
379
380 // Move blocks from the beginning of the region one-by-one.
381 while (!region.empty())
382 moveBlockBefore(block: &region.front(), region: &parent, iterator: before);
383}
384void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
385 inlineRegionBefore(region, parent&: *before->getParent(), before: before->getIterator());
386}
387
388void RewriterBase::moveBlockBefore(Block *block, Block *anotherBlock) {
389 moveBlockBefore(block, region: anotherBlock->getParent(),
390 iterator: anotherBlock->getIterator());
391}
392
393void RewriterBase::moveBlockBefore(Block *block, Region *region,
394 Region::iterator iterator) {
395 Region *currentRegion = block->getParent();
396 Region::iterator nextIterator = std::next(x: block->getIterator());
397 block->moveBefore(region, iterator);
398 if (listener)
399 listener->notifyBlockInserted(block, /*previous=*/currentRegion,
400 /*previousIt=*/nextIterator);
401}
402
403void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
404 moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
405}
406
407void RewriterBase::moveOpBefore(Operation *op, Block *block,
408 Block::iterator iterator) {
409 Block *currentBlock = op->getBlock();
410 Block::iterator nextIterator = std::next(op->getIterator());
411 op->moveBefore(block, iterator);
412 if (listener)
413 listener->notifyOperationInserted(
414 op, /*previous=*/InsertPoint(currentBlock, nextIterator));
415}
416
417void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
418 moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
419}
420
421void RewriterBase::moveOpAfter(Operation *op, Block *block,
422 Block::iterator iterator) {
423 assert(iterator != block->end() && "cannot move after end of block");
424 moveOpBefore(op, block, iterator: std::next(x: iterator));
425}
426

source code of mlir/lib/IR/PatternMatch.cpp